MRC001 / run_qa.py
combi2k2's picture
Running this file starts training the model
a00ac7f
import collections
import numpy as np
import string
import logging
import json
import os
import sys
import evaluate
from dataclasses import dataclass, field
from typing import Optional
from transformers import (
AutoModelForQuestionAnswering,
AutoTokenizer,
EvalPrediction,
TrainingArguments,
DefaultDataCollator,
)
from utils_qa import load_dataset
from utils_qa import postprocess_qa_predictions
from trainer_qa import QuestionAnsweringTrainer
dataset_path = 'data/train.json'
model_checkpoint = 'xlm-roberta-base'
if __name__ == '__main__':
# Load the raw dataset which contains context, question and answers
raw_dataset = load_dataset(dataset_path)
# Load the pretrained tokenizer and model from huggingface.co
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
# Validation preprocessing
def preprocess_function(examples):
'''
help to create a tokenized dataset which should finally be used to train the my question answering model
'''
examples['question'] = [q.lstrip() for q in examples['question']]
tokenized_examples = tokenizer(
examples['question'],
examples['context'],
truncation = "only_second",
max_length = tokenizer.model_max_length,
return_offsets_mapping = True,
padding = "max_length",
)
# The offset mappings will give us a map from token to character position in the original context. This will
# help us compute the start_positions and end_positions.
offset_mapping = tokenized_examples.pop("offset_mapping")
assert(len(offset_mapping) == len(tokenized_examples['input_ids']))
# Let's label those examples!
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
for i, offset in enumerate(offset_mapping):
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
answers = examples['answers'][i]
# If no answers are given, set the cls_index as answer.
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
if len(answers) == 0:
continue
# Find the start and end of the context
context_start = sequence_ids.index(1)
context_end = sequence_ids[context_start:].index(None) + context_start - 1
start_char = answers[0]["answer_start"]
end_char = start_char + len(answers[0]["text"])
# If the answer is not fully inside the context, label it (0, 0)
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
continue
# Otherwise it's the start and end token positions
token_start_index = context_start
token_end_index = context_end
while token_start_index < len(offset) and offset[token_start_index][0] <= start_char: token_start_index += 1
while token_end_index >= 0 and offset[token_end_index][1] >= end_char: token_end_index -= 1
tokenized_examples["start_positions"][-1] = token_start_index - 1
tokenized_examples["end_positions"][-1] = token_end_index + 1
return tokenized_examples
# Create train features from raw dataset
tokenized_dataset = raw_dataset.map(preprocess_function, batched = True, remove_columns = ['title', 'context', 'question'])
# Post-processing:
def post_processing_function(features, tokenizer, predictions, stage = "eval"):
# Post-processing: we match the start logits and end logits to answers in the original context.
predictions = postprocess_qa_predictions(
features = features,
tokenizer = tokenizer,
predictions = predictions
)
formatted_predictions = [
{"id": k,
"prediction_text": v,
"no_answer_probability": 0.0
} for k, v in predictions.items()
]
references = [{"id": ft["id"], "answers": ft["answers"]} for ft in features]
return EvalPrediction(predictions = formatted_predictions, label_ids = references)
metric = evaluate.load("squad_v2")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions = p.predictions,
references = p.label_ids)
data_collator = DefaultDataCollator()
training_args = TrainingArguments(
output_dir = "./results",
evaluation_strategy = 'steps',
learning_rate = 2e-5,
per_device_train_batch_size = 16,
per_device_eval_batch_size = 16,
save_total_limit = 1,
save_steps = 1000,
eval_steps = 1000,
num_train_epochs = 10,
weight_decay = 0.01,
)
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset = tokenized_dataset["train"],
eval_dataset = tokenized_dataset["valid"],
tokenizer = tokenizer,
data_collator = data_collator,
post_process_function=post_processing_function,
compute_metrics = compute_metrics,
)
trainer.train()