Spaces:
Runtime error
Runtime error
| # coding=utf8 | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import streamlit as st | |
| import gdown | |
| import numpy as np | |
| import pandas as pd | |
| import collections | |
| from string import punctuation | |
| class CONFIG: | |
| #model params | |
| model = 'deepset/xlm-roberta-large-squad2' | |
| max_input_length = 384 #Hyperparameter to be tuned, following the guide from huggingface | |
| doc_stride = 128 #Hyperparameter to be tuned, following the guide from huggingface | |
| model_checkpoint = "pytorch_model.pth" | |
| trained_model_url = 'https://drive.google.com/uc?id=16Vp918RglyLEFEyDlFuRD1HeNZ8SI7P5' | |
| trained_model_output_fp = 'trained_pytorch.pth' | |
| sample_df_fp = "sample_qa.json" | |
| # model class | |
| class ChaiModel(nn.Module): | |
| def __init__(self, model_config): | |
| super(ChaiModel, self).__init__() | |
| self.backbone = AutoModel.from_pretrained(CONFIG.model) | |
| self.linear = nn.Linear(model_config.hidden_size, 2) | |
| def forward(self, input_ids, attention_mask): | |
| model_output = self.backbone(input_ids, attention_mask=attention_mask) | |
| sequence_output = model_output[0] # (batchsize, sequencelength, hidden_dim) | |
| qa_logits = self.linear(sequence_output) # (batchsize, sequencelength, 2) | |
| start_logit, end_logit = qa_logits.split(1, dim=-1) # (batchsize, sequencelength), 1), (batchsize, sequencelength, 1) | |
| start_logits = start_logit.squeeze(-1) # remove last dim (batchsize, sequencelength) | |
| end_logits = end_logit.squeeze(-1) #remove last dim (batchsize, sequencelength) | |
| return start_logits, end_logits # (2,batchsize, sequencelength) | |
| # dataset class | |
| class ChaiDataset(Dataset): | |
| def __init__(self, dataset, is_train=True): | |
| super(ChaiDataset, self).__init__() | |
| self.dataset = dataset #list of features | |
| self.is_train= is_train | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, index): | |
| features = self.dataset[index] | |
| if self.is_train: | |
| return { | |
| 'input_ids': torch.tensor(features['input_ids'], dtype=torch.long), | |
| 'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long), | |
| 'offset_mapping':torch.tensor(features['offset_mapping'], dtype=torch.long), | |
| 'start_position':torch.tensor(features['start_position'], dtype=torch.long), | |
| 'end_position':torch.tensor(features['end_position'], dtype=torch.long) | |
| } | |
| else: | |
| return { | |
| 'input_ids': torch.tensor(features['input_ids'], dtype=torch.long), | |
| 'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long), | |
| 'offset_mapping':torch.tensor(features['offset_mapping'], dtype=torch.long), | |
| 'sequence_ids':features['sequence_ids'], | |
| 'id':features['example_id'], | |
| 'context':features['context'], | |
| 'question':features['question'] | |
| } | |
| def break_long_context(df, tokenizer, train=True): | |
| if train: | |
| n_examples = len(df) | |
| full_set = [] | |
| for i in range(n_examples): | |
| row = df.iloc[i] | |
| # tokenizer parameters can be found here | |
| # https://huggingface.co/transformers/internal/tokenization_utils.html#transformers.tokenization_utils_base.PreTrainedTokenizerBase | |
| tokenized_examples = tokenizer(row['question'], | |
| row['context'], | |
| padding='max_length', | |
| max_length=CONFIG.max_input_length, | |
| truncation='only_second', | |
| stride=CONFIG.doc_stride, | |
| return_overflowing_tokens=True, #returns the number of over flow | |
| return_offsets_mapping=True #returns the BPE mapping to the original word | |
| ) | |
| # tokenized_example keys | |
| #'input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping' | |
| sample_mappings = tokenized_examples.pop("overflow_to_sample_mapping") | |
| offset_mappings = tokenized_examples.pop("offset_mapping") | |
| final_examples = [] | |
| n_sub_examples = len(sample_mappings) | |
| for j in range(n_sub_examples): | |
| input_ids = tokenized_examples["input_ids"][j] | |
| attention_mask = tokenized_examples["attention_mask"][j] | |
| sliced_text = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids)) | |
| final_example = dict(input_ids = input_ids, | |
| attention_mask = attention_mask, | |
| sliced_text = sliced_text, | |
| offset_mapping=offset_mappings[j], | |
| fold=row['fold']) | |
| # Most of the time cls_index is 0 | |
| cls_index = input_ids.index(tokenizer.cls_token_id) | |
| # None, 0, 0, .... None, None, 1, 1,..... | |
| sequence_ids = tokenized_examples.sequence_ids(j) | |
| sample_index = sample_mappings[j] | |
| offset_map = offset_mappings[j] | |
| if np.isnan(row["answer_start"]) : # if no answer, start and end position is cls_index | |
| final_example['start_position'] = cls_index | |
| final_example['end_position'] = cls_index | |
| final_example['tokenized_answer'] = "" | |
| final_example['answer_text'] = "" | |
| else: | |
| start_char = row["answer_start"] | |
| end_char = start_char + len(row["answer_text"]) | |
| token_start_index = sequence_ids.index(1) | |
| token_end_index = len(sequence_ids)- 1 - (sequence_ids[::-1].index(1)) | |
| if not (offset_map[token_start_index][0]<=start_char and offset_map[token_end_index][1] >= end_char): | |
| final_example['start_position'] = cls_index | |
| final_example['end_position'] = cls_index | |
| final_example['tokenized_answer'] = "" | |
| final_example['answer_text'] = "" | |
| else: | |
| #Move token_start_index to the correct context index | |
| while token_start_index < len(offset_map) and offset_map[token_start_index][0] <= start_char: | |
| token_start_index +=1 | |
| final_example['start_position'] = token_start_index -1 | |
| while offset_map[token_end_index][1] >= end_char: #Take note that we will want the end_index inclusively, we will need to slice properly later | |
| token_end_index -=1 | |
| final_example['end_position'] = token_end_index + 1 | |
| tokenized_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[final_example['start_position']:final_example['end_position']+1])) | |
| final_example['tokenized_answer'] = tokenized_answer | |
| final_example['answer_text'] = row['answer_text'] | |
| final_examples.append(final_example) | |
| full_set += final_examples | |
| else: | |
| n_examples = len(df) | |
| full_set = [] | |
| for i in range(n_examples): | |
| row = df.iloc[i] | |
| tokenized_examples = tokenizer(row['question'], | |
| row['context'], | |
| padding='max_length', | |
| max_length=CONFIG.max_input_length, | |
| truncation='only_second', | |
| stride=CONFIG.doc_stride, | |
| return_overflowing_tokens=True, #returns the number of over flow | |
| return_offsets_mapping=True #returns the BPE mapping to the original word | |
| ) | |
| sample_mappings = tokenized_examples.pop("overflow_to_sample_mapping") | |
| offset_mappings = tokenized_examples.pop("offset_mapping") | |
| n_sub_examples = len(sample_mappings) | |
| final_examples = [] | |
| for j in range(n_sub_examples): | |
| input_ids = tokenized_examples["input_ids"][j] | |
| attention_mask = tokenized_examples["attention_mask"][j] | |
| final_example = dict( | |
| input_ids = input_ids, | |
| attention_mask = attention_mask, | |
| offset_mapping=offset_mappings[j], | |
| example_id = row['id'], | |
| context = row['context'], | |
| question = row['question'], | |
| sequence_ids = [0 if value is None else value for value in tokenized_examples.sequence_ids(j)] | |
| ) | |
| final_examples.append(final_example) | |
| full_set += final_examples | |
| return full_set | |
| def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30): | |
| all_start_logits, all_end_logits = raw_predictions | |
| example_id_to_index = {k: i for i, k in enumerate(examples["id"])} | |
| features_per_example = collections.defaultdict(list) | |
| for i, feature in enumerate(features): | |
| features_per_example[example_id_to_index[feature["example_id"]]].append(i) | |
| predictions = collections.OrderedDict() | |
| print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") | |
| for example_index, example in examples.iterrows(): | |
| feature_indices = features_per_example[example_index] | |
| min_null_score = None | |
| valid_answers = [] | |
| context = example["context"] | |
| for feature_index in feature_indices: | |
| start_logits = all_start_logits[feature_index] | |
| end_logits = all_end_logits[feature_index] | |
| sequence_ids = features[feature_index]["sequence_ids"] | |
| context_index = 1 | |
| features[feature_index]["offset_mapping"] = [ | |
| (o if sequence_ids[k] == context_index else None) | |
| for k, o in enumerate(features[feature_index]["offset_mapping"]) | |
| ] | |
| offset_mapping = features[feature_index]["offset_mapping"] | |
| cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id) | |
| feature_null_score = start_logits[cls_index] + end_logits[cls_index] | |
| if min_null_score is None or min_null_score < feature_null_score: | |
| min_null_score = feature_null_score | |
| start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() | |
| end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() | |
| for start_index in start_indexes: | |
| for end_index in end_indexes: | |
| if ( | |
| start_index >= len(offset_mapping) | |
| or end_index >= len(offset_mapping) | |
| or offset_mapping[start_index] is None | |
| or offset_mapping[end_index] is None | |
| ): | |
| continue | |
| # Don't consider answers with a length that is either < 0 or > max_answer_length. | |
| if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
| continue | |
| start_char = offset_mapping[start_index][0] | |
| end_char = offset_mapping[end_index][1] | |
| valid_answers.append( | |
| { | |
| "score": start_logits[start_index] + end_logits[end_index], | |
| "text": context[start_char: end_char] | |
| } | |
| ) | |
| if len(valid_answers) > 0: | |
| best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0] | |
| else: | |
| best_answer = {"text": "", "score": 0.0} | |
| predictions[example["id"]] = best_answer["text"] | |
| return predictions | |
| def download_finetuned_model(): | |
| gdown.download(url=CONFIG.trained_model_url, output=CONFIG.trained_model_output_fp, quiet=False) | |
| def get_prediction(context:str, question:str, model, tokenizer) -> str: | |
| # convert to dataframe format to make it consistent with training way | |
| test_df = pd.DataFrame({"id":[1], "context":[context.strip()], "question":[question.strip()]}) | |
| test_set = break_long_context(test_df, tokenizer, train=False) | |
| #create dataset and dataloader of batch 1 to prevent OOM | |
| test_dataset = ChaiDataset(test_set, is_train=False) | |
| test_dataloader = DataLoader(test_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| drop_last=False | |
| ) | |
| #main prediction function | |
| start_logits =[] | |
| end_logits=[] | |
| for features in test_dataloader: | |
| input_ids = features['input_ids'] | |
| attention_mask = features['attention_mask'] | |
| with torch.no_grad(): | |
| start_logit, end_logit = model(input_ids, attention_mask) #(batch, 384,1) , (batch, 384,1) | |
| start_logits.append(start_logit.to("cpu").numpy()) | |
| end_logits.append(end_logit.to("cpu").numpy()) | |
| start_logits, end_logits = np.vstack(start_logits), np.vstack(end_logits) | |
| predictions = postprocess_qa_predictions(test_df, test_set, (start_logits, end_logits)) | |
| predictions = list(predictions.items())[0][1] | |
| predictions = predictions.strip(punctuation) | |
| return predictions | |
| def load_model(): | |
| gdown.download(url=CONFIG.trained_model_url, output=CONFIG.trained_model_output_fp, quiet=False) | |
| print("Downloaded pretrained model") | |
| config = AutoConfig.from_pretrained(CONFIG.model) | |
| model = ChaiModel(config) | |
| model.load_state_dict(torch.load(CONFIG.trained_model_output_fp, map_location=torch.device('cpu'))) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(CONFIG.model) | |
| sample_df = pd.read_json(CONFIG.sample_df_fp) | |
| return model, tokenizer, sample_df | |
| model, tokenizer, sample_df = load_model() | |
| ## initialize session_state | |
| if "context" not in st.session_state: | |
| st.session_state["context"] = "" | |
| if "question" not in st.session_state: | |
| st.session_state['question'] = "" | |
| if "answer" not in st.session_state: | |
| st.session_state['answer'] = "" | |
| ## Layout | |
| st.sidebar.title("Hindi/Tamil Extractive Question Answering") | |
| st.sidebar.markdown("---") | |
| random_button = st.sidebar.button("Random") | |
| st.sidebar.write("Randomly Generates a Hindi/Tamil Context and Question") | |
| st.sidebar.markdown("---") | |
| answer_button = st.sidebar.button("Answer!") | |
| if random_button: | |
| sample = sample_df.sample(1) | |
| st.session_state['context'] = sample['context'].item() | |
| st.session_state['question'] = sample['question'].item() | |
| st.session_state['answer'] = "" | |
| if answer_button: | |
| # if question or context is empty text | |
| if len(st.session_state['context']) == 0 or len(st.session_state['question']) ==0: | |
| st.session_state['answer'] = " " | |
| else: | |
| st.session_state['answer'] = get_prediction(st.session_state['context'], st.session_state['question'], model, tokenizer) | |
| st.session_state["context"] = st.text_area("Context", value=st.session_state['context'], height=300) | |
| with st.container(): | |
| col_1, col_2 = st.columns(2) | |
| with col_1: | |
| st.session_state['question'] = st.text_area("Question", value=st.session_state['question'], height=200) | |
| with col_2: | |
| st.text_area("Answer", value=st.session_state['answer'], height=200) | |