| | import torch |
| | import random |
| | from transformers import AutoTokenizer, BertForSequenceClassification |
| | from datasets import load_dataset |
| | from transformers import LukePreTrainedModel, LukeModel, AutoTokenizer, TrainingArguments, default_data_collator, Trainer, AutoModelForQuestionAnswering |
| | from transformers.modeling_outputs import ModelOutput |
| | from typing import Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | from tqdm import tqdm |
| | import evaluate |
| | import torch |
| | from dataclasses import dataclass |
| | from datasets import load_dataset, concatenate_datasets, load_metric |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | import collections |
| | import re |
| |
|
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| | model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device) |
| |
|
| | def preprocess_training_examples(examples): |
| | questions = [q.strip() for q in examples["question"]] |
| | answers = [] |
| | labels = [] |
| | final_questions = [] |
| | |
| | for i in range(len(questions)): |
| | context = examples["context"][i] |
| | words = context.split() |
| | final_questions.append(questions[i]) |
| | original_answer = examples["answers"][i]["text"][0] |
| | original_words = original_answer.split() |
| | answers.append(original_answer) |
| | labels.append(1) |
| | answer_start = examples["answers"][i]["answer_start"][0] |
| | answer_end = answer_start + len(original_answer) |
| |
|
| | begin_context = context[:answer_start] |
| | end_context = context[answer_end:] |
| |
|
| | end = 1 if len(original_words) == 1 else 3 |
| | case_ind = random.randint(0, end) |
| | pre_words = begin_context.rsplit(maxsplit=1) |
| | post_words = end_context.split(maxsplit=1) |
| | if case_ind == 0 and pre_words: |
| | words = pre_words |
| | wrong_context = " ".join([words[-1], original_answer]) |
| | final_questions.append(questions[i]) |
| | answers.append(wrong_context) |
| | labels.append(0) |
| | elif case_ind == 1 and post_words: |
| | words = post_words |
| | wrong_context = " ".join([original_answer, words[0]]) |
| | final_questions.append(questions[i]) |
| | answers.append(wrong_context) |
| | labels.append(0) |
| | elif case_ind == 3: |
| | wrong_context = " ".join(original_words[1:]) |
| | final_questions.append(questions[i]) |
| | answers.append(wrong_context) |
| | labels.append(0) |
| | elif case_ind == 4: |
| | wrong_context = " ".join(original_words[:len(original_words) - 1]) |
| | final_questions.append(questions[i]) |
| | answers.append(wrong_context) |
| | labels.append(0) |
| |
|
| | inputs = tokenizer( |
| | final_questions, |
| | answers, |
| | padding="max_length", |
| | ) |
| | inputs["labels"] = labels |
| | return inputs |
| |
|
| | raw_datasets = load_dataset("squad") |
| | raw_train = raw_datasets["train"] |
| | raw_eval = raw_datasets["validation"] |
| |
|
| | train_dataset = raw_train.map( |
| | preprocess_training_examples, |
| | batched=True, |
| | remove_columns=raw_train.column_names, |
| | ) |
| |
|
| | eval_dataset = raw_eval.map( |
| | preprocess_training_examples, |
| | batched=True, |
| | remove_columns=raw_train.column_names, |
| | ) |
| |
|
| | batch_size = 8 |
| |
|
| | |
| |
|
| | args = TrainingArguments( |
| | "right_span_bert", |
| | evaluation_strategy = "no", |
| | save_strategy="epoch", |
| | learning_rate=2e-5, |
| | per_device_train_batch_size=batch_size, |
| | per_device_eval_batch_size=batch_size, |
| | num_train_epochs=2, |
| | weight_decay=0.01, |
| | push_to_hub=True, |
| | fp16=True |
| | ) |
| |
|
| | def compute_metrics(eval_pred): |
| | load_accuracy = evaluate.load("accuracy") |
| | load_f1 = evaluate.load("f1") |
| | logits, labels = eval_pred |
| | predictions = np.argmax(logits, axis=-1) |
| | accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"] |
| | f1 = load_f1.compute(predictions=predictions, references=labels)["f1"] |
| | return {"accuracy": accuracy, "f1": f1} |
| |
|
| | trainer = Trainer( |
| | model, |
| | args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | data_collator=default_data_collator, |
| | tokenizer=tokenizer, |
| | compute_metrics=compute_metrics |
| | ) |
| |
|
| | trainer.train() |
| |
|
| | res = trainer.evaluate() |
| | print(res) |
| |
|