combi2k2 commited on
Commit
a00ac7f
·
1 Parent(s): 846f8be

Running this file starts training the model

Browse files
Files changed (1) hide show
  1. run_qa.py +152 -0
run_qa.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import numpy as np
3
+ import string
4
+
5
+ import logging
6
+ import json
7
+ import os
8
+ import sys
9
+ import evaluate
10
+
11
+ from dataclasses import dataclass, field
12
+ from typing import Optional
13
+
14
+ from transformers import (
15
+ AutoModelForQuestionAnswering,
16
+ AutoTokenizer,
17
+ EvalPrediction,
18
+ TrainingArguments,
19
+ DefaultDataCollator,
20
+ )
21
+
22
+ from utils_qa import load_dataset
23
+ from utils_qa import postprocess_qa_predictions
24
+
25
+ from trainer_qa import QuestionAnsweringTrainer
26
+
27
+ dataset_path = 'data/train.json'
28
+ model_checkpoint = 'xlm-roberta-base'
29
+
30
+ if __name__ == '__main__':
31
+ # Load the raw dataset which contains context, question and answers
32
+ raw_dataset = load_dataset(dataset_path)
33
+
34
+ # Load the pretrained tokenizer and model from huggingface.co
35
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
36
+ model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
37
+
38
+ # Validation preprocessing
39
+ def preprocess_function(examples):
40
+ '''
41
+ help to create a tokenized dataset which should finally be used to train the my question answering model
42
+ '''
43
+ examples['question'] = [q.lstrip() for q in examples['question']]
44
+ tokenized_examples = tokenizer(
45
+ examples['question'],
46
+ examples['context'],
47
+ truncation = "only_second",
48
+ max_length = tokenizer.model_max_length,
49
+ return_offsets_mapping = True,
50
+ padding = "max_length",
51
+ )
52
+ # The offset mappings will give us a map from token to character position in the original context. This will
53
+ # help us compute the start_positions and end_positions.
54
+ offset_mapping = tokenized_examples.pop("offset_mapping")
55
+
56
+ assert(len(offset_mapping) == len(tokenized_examples['input_ids']))
57
+
58
+ # Let's label those examples!
59
+ tokenized_examples["start_positions"] = []
60
+ tokenized_examples["end_positions"] = []
61
+
62
+ for i, offset in enumerate(offset_mapping):
63
+ input_ids = tokenized_examples["input_ids"][i]
64
+ cls_index = input_ids.index(tokenizer.cls_token_id)
65
+
66
+ # Grab the sequence corresponding to that example (to know what is the context and what is the question).
67
+ sequence_ids = tokenized_examples.sequence_ids(i)
68
+ answers = examples['answers'][i]
69
+
70
+ # If no answers are given, set the cls_index as answer.
71
+
72
+ tokenized_examples["start_positions"].append(cls_index)
73
+ tokenized_examples["end_positions"].append(cls_index)
74
+
75
+ if len(answers) == 0:
76
+ continue
77
+
78
+ # Find the start and end of the context
79
+ context_start = sequence_ids.index(1)
80
+ context_end = sequence_ids[context_start:].index(None) + context_start - 1
81
+
82
+ start_char = answers[0]["answer_start"]
83
+ end_char = start_char + len(answers[0]["text"])
84
+
85
+ # If the answer is not fully inside the context, label it (0, 0)
86
+ if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
87
+ continue
88
+
89
+ # Otherwise it's the start and end token positions
90
+ token_start_index = context_start
91
+ token_end_index = context_end
92
+
93
+ while token_start_index < len(offset) and offset[token_start_index][0] <= start_char: token_start_index += 1
94
+ while token_end_index >= 0 and offset[token_end_index][1] >= end_char: token_end_index -= 1
95
+
96
+ tokenized_examples["start_positions"][-1] = token_start_index - 1
97
+ tokenized_examples["end_positions"][-1] = token_end_index + 1
98
+
99
+ return tokenized_examples
100
+
101
+ # Create train features from raw dataset
102
+ tokenized_dataset = raw_dataset.map(preprocess_function, batched = True, remove_columns = ['title', 'context', 'question'])
103
+
104
+ # Post-processing:
105
+ def post_processing_function(features, tokenizer, predictions, stage = "eval"):
106
+ # Post-processing: we match the start logits and end logits to answers in the original context.
107
+ predictions = postprocess_qa_predictions(
108
+ features = features,
109
+ tokenizer = tokenizer,
110
+ predictions = predictions
111
+ )
112
+ formatted_predictions = [
113
+ {"id": k,
114
+ "prediction_text": v,
115
+ "no_answer_probability": 0.0
116
+ } for k, v in predictions.items()
117
+ ]
118
+ references = [{"id": ft["id"], "answers": ft["answers"]} for ft in features]
119
+
120
+ return EvalPrediction(predictions = formatted_predictions, label_ids = references)
121
+
122
+ metric = evaluate.load("squad_v2")
123
+
124
+ def compute_metrics(p: EvalPrediction):
125
+ return metric.compute(predictions = p.predictions,
126
+ references = p.label_ids)
127
+
128
+ data_collator = DefaultDataCollator()
129
+
130
+ training_args = TrainingArguments(
131
+ output_dir = "./results",
132
+ evaluation_strategy = 'steps',
133
+ learning_rate = 2e-5,
134
+ per_device_train_batch_size = 16,
135
+ per_device_eval_batch_size = 16,
136
+ save_total_limit = 1,
137
+ save_steps = 1000,
138
+ eval_steps = 1000,
139
+ num_train_epochs = 10,
140
+ weight_decay = 0.01,
141
+ )
142
+ trainer = QuestionAnsweringTrainer(
143
+ model=model,
144
+ args=training_args,
145
+ train_dataset = tokenized_dataset["train"],
146
+ eval_dataset = tokenized_dataset["valid"],
147
+ tokenizer = tokenizer,
148
+ data_collator = data_collator,
149
+ post_process_function=post_processing_function,
150
+ compute_metrics = compute_metrics,
151
+ )
152
+ trainer.train()