combi2k2 commited on
Commit
c47de87
·
1 Parent(s): f0f242a

Add functions which is used to divide the dataset into train and valid set, and post process function for the result of the model

Browse files
Files changed (1) hide show
  1. utils_qa.py +157 -0
utils_qa.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import numpy as np
3
+ import datasets
4
+ import json
5
+
6
+ import os
7
+ from typing import Optional, Tuple
8
+ from tqdm.auto import tqdm
9
+
10
+ # the train data file is expected to have the format of dataset SQUAD v2.0
11
+
12
+ def load_dataset(dataset_path, split = 0.1, shuffle = True):
13
+ with open(dataset_path, 'r') as f:
14
+ data = json.load(f)["data"]
15
+
16
+ dataset = {'id': [],
17
+ 'title': [],
18
+ 'context': [],
19
+ 'question': [],
20
+ 'answers': []}
21
+
22
+ for topic in data:
23
+ title = topic["title"]
24
+ for p in topic["paragraphs"]:
25
+ for qas in p['qas']:
26
+ dataset['id'].append(qas['id'])
27
+ dataset['title'].append(title)
28
+ dataset['context'].append(p["context"])
29
+ dataset['question'].append(qas["question"])
30
+ dataset['answers'].append(qas["answers"])
31
+
32
+ # Since there is no train data and validation data before hand, we have to manually split it
33
+ N_SAMPLE = len(dataset['id'])
34
+
35
+ # If you want to shuffle the dataset, the shuffle parameter should be kept True
36
+ if (shuffle): perms = np.random.permutation(N_SAMPLE)
37
+ else: perms = list(range(N_SAMPLE))
38
+
39
+ train_ds = dict()
40
+ valid_ds = dict()
41
+
42
+ for name, assets in dataset.items():
43
+ mock = N_SAMPLE - int(split * N_SAMPLE)
44
+
45
+ train_ds[name] = [assets[i] for i in perms[:mock]]
46
+ valid_ds[name] = [assets[i] for i in perms[mock:]]
47
+
48
+ raw_dataset = datasets.DatasetDict()
49
+ raw_dataset['train'] = datasets.Dataset.from_dict(train_ds)
50
+ raw_dataset['valid'] = datasets.Dataset.from_dict(valid_ds)
51
+
52
+ return raw_dataset
53
+
54
+ def postprocess_qa_predictions(
55
+ features,
56
+ tokenizer,
57
+ predictions: Tuple[np.ndarray, np.ndarray],
58
+ n_best_size: int = 20,
59
+ max_answer_length: int = 30
60
+ ):
61
+ '''
62
+ Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
63
+ original contexts. This is the base postprocessing functions for models that only return start and end logits.
64
+ Args:
65
+ features: The processed dataset (see the main script for more information).
66
+ tokenizer: The tokenizer to decode ids of the answer back to text
67
+ predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
68
+ The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
69
+ first dimension must match the number of elements of :obj:`features`.
70
+ n_best_size (:obj:`int`, `optional`, defaults to 20):
71
+ The total number of n-best predictions to generate when looking for an answer.
72
+ max_answer_length (:obj:`int`, `optional`, defaults to 30):
73
+ The maximum length of an answer that can be generated. This is needed because the start and end predictions
74
+ are not conditioned on one another.
75
+ """
76
+ '''
77
+ if len(predictions) != 2: raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
78
+ if len(predictions[0]) != len(features): raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
79
+
80
+ all_start_logits, all_end_logits = predictions
81
+ # The dictionaries we have to fill.
82
+ all_predictions = collections.OrderedDict()
83
+
84
+ # Let's loop over all the examples!
85
+ for index, feature in enumerate(tqdm(features)):
86
+ min_null_prediction = None
87
+ prelim_predictions = []
88
+
89
+ # We grab the predictions of the model for this feature.
90
+ start_logits = all_start_logits[index]
91
+ end_logits = all_end_logits[index]
92
+
93
+ # Update minimum null prediction.
94
+ feature_null_score = start_logits[1] + end_logits[0]
95
+ if (min_null_prediction is None or min_null_prediction["score"] > feature_null_score):
96
+ min_null_prediction = {
97
+ "ids": (1, 0),
98
+ "score": feature_null_score
99
+ }
100
+
101
+ # Go through all possibilities for the `n_best_size` greater start and end logits.
102
+ start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
103
+ end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
104
+
105
+ for start_index in start_indexes:
106
+ for end_index in end_indexes:
107
+ # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
108
+ # to part of the input_ids that are not in the context.
109
+ if (start_index >= len(feature['input_ids'])
110
+ or end_index >= len(feature['input_ids'])
111
+ ):
112
+ continue
113
+ # Don't consider answers with a length that is either < 0 or > max_answer_length.
114
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
115
+ continue
116
+
117
+ prelim_predictions.append(
118
+ {
119
+ "ids": (start_index, end_index),
120
+ "score": start_logits[start_index] + end_logits[end_index]
121
+ }
122
+ )
123
+ if min_null_prediction is not None:
124
+ # Add the minimum null prediction
125
+ prelim_predictions.append(min_null_prediction)
126
+ null_score = min_null_prediction["score"]
127
+
128
+ # Only keep the best `n_best_size` predictions.
129
+ predictions = sorted(prelim_predictions,
130
+ key = lambda x: x["score"],
131
+ reverse = True)[:n_best_size]
132
+
133
+ # Add back the minimum null prediction if it was removed because of its low score.
134
+ if (min_null_prediction is not None and not any(p["ids"] == (1, 0) for p in predictions)):
135
+ predictions.append(min_null_prediction)
136
+
137
+ best_non_null_pred = None
138
+
139
+ for pred in predictions:
140
+ l, r = pred.pop("ids")
141
+ if (l <= r):
142
+ pred_input_ids = feature['input_ids'][l: r + 1]
143
+ pred_tokens = tokenizer.convert_ids_to_tokens(pred_input_ids)
144
+ pred_text = tokenizer.convert_tokens_to_string(pred_tokens)
145
+
146
+ pred["text"] = pred_text
147
+ best_non_null_pred = pred
148
+
149
+ break
150
+
151
+ if (best_non_null_pred is None or best_non_null_pred["score"] < null_score):
152
+ all_predictions[feature["id"]] = ""
153
+ else:
154
+ all_predictions[feature["id"]] = best_non_null_pred["text"]
155
+
156
+ return all_predictions
157
+