| | from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
| | import numpy as np |
| | from tqdm import tqdm |
| | import torch |
| | import collections |
| |
|
| | luke_beam_size = 5 |
| | n_best = 30 |
| | max_length = 512 |
| | stride = 128 |
| | batch_size = 8 |
| | n_best = 20 |
| | max_answer_length = 30 |
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| | luke_model = AutoModelForQuestionAnswering.from_pretrained("botcon/LUKE_squadshift_finetuned_large").to(device) |
| | luke_tokenizer = AutoTokenizer.from_pretrained("roberta-base") |
| |
|
| | def compute_beam(start_logits, end_logits, features, examples): |
| | example_to_features = collections.defaultdict(list) |
| | for idx, feature in enumerate(features): |
| | example_to_features[feature["example_id"]].append(idx) |
| |
|
| | predicted_answers = [] |
| | for example in tqdm(examples): |
| | example_id = example["id"] |
| | context = example["context"] |
| | answers = [] |
| |
|
| | |
| | for feature_index in example_to_features[example_id]: |
| | start_logit = start_logits[feature_index] |
| | end_logit = end_logits[feature_index] |
| | offsets = features[feature_index]["offset_mapping"] |
| |
|
| | start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() |
| | end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() |
| | for start_index in start_indexes: |
| | for end_index in end_indexes: |
| | |
| | if offsets[start_index] is None or offsets[end_index] is None: |
| | continue |
| | |
| | if ( |
| | end_index < start_index |
| | or end_index - start_index + 1 > max_answer_length |
| | ): |
| | continue |
| |
|
| | answer = { |
| | "text": context[offsets[start_index][0] : offsets[end_index][1]], |
| | "logit_score": start_logit[start_index] + end_logit[end_index], |
| | } |
| | answers.append(answer) |
| |
|
| | |
| | if len(answers) > 0: |
| | best_answers = sorted(answers, key=lambda x: x["logit_score"], reverse=True) |
| | best_ans = [] |
| | best_logits = [] |
| | i = 0 |
| | while i < len(best_answers[:luke_beam_size]): |
| | best_ans.append(best_answers[i]["text"]) |
| | best_logits.append(best_answers[i]["logit_score"]) |
| | i += 1 |
| | while i < luke_beam_size: |
| | best_ans.append("") |
| | best_logits.append(1e-5) |
| | i += 1 |
| |
|
| | predicted_answers.append({"id":example_id, "prediction_text": best_ans, "logits": best_logits}) |
| | else: |
| | predicted_answers.append({"id": example_id, "prediction_text": ""}) |
| |
|
| | return predicted_answers |
| |
|
| | def preprocess_validation_examples(examples): |
| | questions = [q.strip() for q in examples["question"]] |
| | inputs = luke_tokenizer( |
| | questions, |
| | examples["context"], |
| | max_length=max_length, |
| | truncation="only_second", |
| | stride=stride, |
| | return_overflowing_tokens=True, |
| | return_offsets_mapping=True, |
| | padding="max_length", |
| | ) |
| |
|
| |
|
| | sample_map = inputs.pop("overflow_to_sample_mapping") |
| | example_ids = [] |
| |
|
| | for i in range(len(inputs["input_ids"])): |
| | sample_idx = sample_map[i] |
| | example_ids.append(examples["id"][sample_idx]) |
| |
|
| | sequence_ids = inputs.sequence_ids(i) |
| | offset = inputs["offset_mapping"][i] |
| | inputs["offset_mapping"][i] = [ |
| | o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) |
| | ] |
| |
|
| | inputs["example_id"] = example_ids |
| | return inputs |
| |
|
| | def generate(dataset): |
| | luke_model.eval() |
| | with torch.no_grad(): |
| | preprocessed = dataset.map( |
| | preprocess_validation_examples, |
| | batched=True, |
| | remove_columns=dataset.column_names |
| | ) |
| | eval_set_for_model = preprocessed.remove_columns(["example_id", "offset_mapping"]) |
| | eval_set_for_model.set_format("torch") |
| | batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names} |
| | outputs = luke_model(**batch) |
| | start_logits = outputs.start_logits.cpu().numpy() |
| | end_logits = outputs.end_logits.cpu().numpy() |
| | res = compute_beam(start_logits, end_logits, preprocessed, dataset) |
| | return res |