liuyanyi commited on
Commit
265bc01
·
verified ·
1 Parent(s): 1a1dd41

Upload inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +410 -0
inferencer.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warning
2
+ from typing import List
3
+
4
+ import spacy
5
+ import torch
6
+ import torch.nn as nn
7
+ from nltk.tokenize import sent_tokenize
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
+
11
+
12
+ class Inferencer:
13
+ # Copied from Alignscore Github
14
+ def __init__(
15
+ self,
16
+ ckpt_path,
17
+ batch_size=32,
18
+ device="cuda:0",
19
+ verbose=True,
20
+ ) -> None:
21
+ self.device = device
22
+ self.model = AutoModelForSequenceClassification.from_pretrained(
23
+ ckpt_path, trust_remote_code=True
24
+ ).to(self.device)
25
+ assert self.model.config.model_type == "alignscore", (
26
+ "The model type must be alignscore, please check the ckpt_path."
27
+ )
28
+ self.model.eval()
29
+ self.batch_size = batch_size
30
+
31
+ self.tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
32
+ self.spacy = spacy.load("en_core_web_sm")
33
+
34
+ self.softmax = nn.Softmax(dim=-1)
35
+
36
+ self.smart_type = "smart-n"
37
+ self.smart_n_metric = "f1"
38
+
39
+ self.disable_progress_bar_in_inference = False
40
+
41
+ self.nlg_eval_mode = None # bin, bin_sp, nli, nli_sp
42
+ self.verbose = verbose
43
+
44
+ def inference_example_batch(self, premise: list, hypo: list):
45
+ """
46
+ inference a example,
47
+ premise: list
48
+ hypo: list
49
+ using self.inference to batch the process
50
+
51
+ SummaC Style aggregation
52
+ """
53
+ self.disable_progress_bar_in_inference = True
54
+ assert len(premise) == len(hypo), (
55
+ "Premise must has the same length with Hypothesis!"
56
+ )
57
+
58
+ out_score = []
59
+ for one_pre, one_hypo in tqdm(
60
+ zip(premise, hypo),
61
+ desc="Evaluating",
62
+ total=len(premise),
63
+ disable=(not self.verbose),
64
+ ):
65
+ out_score.append(self.inference_per_example(one_pre, one_hypo))
66
+
67
+ return None, torch.tensor(out_score), None
68
+
69
+ def inference_per_example(self, premise: str, hypo: str):
70
+ """
71
+ inference a example,
72
+ premise: string
73
+ hypo: string
74
+ using self.inference to batch the process
75
+ """
76
+
77
+ def chunks(lst, n):
78
+ """Yield successive n-sized chunks from lst."""
79
+ for i in range(0, len(lst), n):
80
+ yield " ".join(lst[i : i + n])
81
+
82
+ premise_sents = sent_tokenize(premise)
83
+ premise_sents = premise_sents or [""]
84
+
85
+ n_chunk = len(premise.strip().split()) // 350 + 1
86
+ n_chunk = max(len(premise_sents) // n_chunk, 1)
87
+ premise_sents = [each for each in chunks(premise_sents, n_chunk)]
88
+
89
+ hypo_sents = sent_tokenize(hypo)
90
+
91
+ premise_sent_mat = []
92
+ hypo_sents_mat = []
93
+ for i in range(len(premise_sents)):
94
+ for j in range(len(hypo_sents)):
95
+ premise_sent_mat.append(premise_sents[i])
96
+ hypo_sents_mat.append(hypo_sents[j])
97
+
98
+ if self.nlg_eval_mode is not None:
99
+ if self.nlg_eval_mode == "nli_sp":
100
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][
101
+ :, 0
102
+ ] ### use NLI head OR ALIGN head
103
+ elif self.nlg_eval_mode == "bin_sp":
104
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[
105
+ 1
106
+ ] ### use NLI head OR ALIGN head
107
+ elif self.nlg_eval_mode == "reg_sp":
108
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[
109
+ 0
110
+ ] ### use NLI head OR ALIGN head
111
+
112
+ output_score = (
113
+ output_score.view(len(premise_sents), len(hypo_sents))
114
+ .max(dim=0)
115
+ .values.mean()
116
+ .item()
117
+ ) ### sum or mean depends on the task/aspect
118
+ return output_score
119
+
120
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][
121
+ :, 0
122
+ ] ### use NLI head OR ALIGN head
123
+ output_score = (
124
+ output_score.view(len(premise_sents), len(hypo_sents))
125
+ .max(dim=0)
126
+ .values.mean()
127
+ .item()
128
+ ) ### sum or mean depends on the task/aspect
129
+
130
+ return output_score
131
+
132
+ def inference(self, premise, hypo):
133
+ """
134
+ inference a list of premise and hypo
135
+
136
+ Standard aggregation
137
+ """
138
+ if isinstance(premise, str) and isinstance(hypo, str):
139
+ premise = [premise]
140
+ hypo = [hypo]
141
+
142
+ batch = self.batch_tokenize(premise, hypo)
143
+ output_score_reg = []
144
+ output_score_bin = []
145
+ output_score_tri = []
146
+
147
+ for mini_batch in tqdm(
148
+ batch,
149
+ desc="Evaluating",
150
+ disable=not self.verbose or self.disable_progress_bar_in_inference,
151
+ ):
152
+ mini_batch = mini_batch.to(self.model.device)
153
+ with torch.no_grad():
154
+ model_output = self.model(**mini_batch)
155
+ model_output_reg = model_output.reg_label_logits.cpu()
156
+ model_output_bin = (
157
+ model_output.seq_relationship_logits
158
+ ) # Temperature Scaling / 2.5
159
+ model_output_tri = model_output.tri_label_logits
160
+
161
+ model_output_bin = self.softmax(model_output_bin).cpu()
162
+ model_output_tri = self.softmax(model_output_tri).cpu()
163
+ output_score_reg.append(model_output_reg[:, 0])
164
+ output_score_bin.append(model_output_bin[:, 1])
165
+ output_score_tri.append(model_output_tri[:, :])
166
+
167
+ output_score_reg = torch.cat(output_score_reg)
168
+ output_score_bin = torch.cat(output_score_bin)
169
+ output_score_tri = torch.cat(output_score_tri)
170
+
171
+ if self.nlg_eval_mode is not None:
172
+ if self.nlg_eval_mode == "nli":
173
+ output_score_nli = output_score_tri[:, 0]
174
+ return None, output_score_nli, None
175
+ elif self.nlg_eval_mode == "bin":
176
+ return None, output_score_bin, None
177
+ elif self.nlg_eval_mode == "reg":
178
+ return None, output_score_reg, None
179
+ else:
180
+ ValueError("unrecognized nlg eval mode")
181
+
182
+ return output_score_reg, output_score_bin, output_score_tri
183
+
184
+ def inference_reg(self, premise, hypo):
185
+ """
186
+ inference a list of premise and hypo
187
+
188
+ Standard aggregation
189
+ """
190
+ self.model.is_reg_finetune = True
191
+ if isinstance(premise, str) and isinstance(hypo, str):
192
+ premise = [premise]
193
+ hypo = [hypo]
194
+
195
+ batch = self.batch_tokenize(premise, hypo)
196
+ output_score = []
197
+
198
+ for mini_batch in tqdm(
199
+ batch, desc="Evaluating", disable=self.disable_progress_bar_in_inference
200
+ ):
201
+ mini_batch = mini_batch.to(self.model.device)
202
+ with torch.no_grad():
203
+ model_output = (
204
+ self.model(**mini_batch).seq_relationship_logits.cpu().view(-1)
205
+ )
206
+ output_score.append(model_output)
207
+ output_score = torch.cat(output_score)
208
+ return output_score
209
+
210
+ def batch_tokenize(self, premise, hypo):
211
+ """
212
+ input premise and hypos are lists
213
+ """
214
+ assert isinstance(premise, list) and isinstance(hypo, list)
215
+ assert len(premise) == len(hypo), (
216
+ "premise and hypo should be in the same length."
217
+ )
218
+
219
+ batch = []
220
+ for mini_batch_pre, mini_batch_hypo in zip(
221
+ self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)
222
+ ):
223
+ try:
224
+ mini_batch = self.tokenizer(
225
+ mini_batch_pre,
226
+ mini_batch_hypo,
227
+ truncation="only_first",
228
+ padding="max_length",
229
+ max_length=self.tokenizer.model_max_length,
230
+ return_tensors="pt",
231
+ )
232
+ except:
233
+ warning("text_b too long...")
234
+ mini_batch = self.tokenizer(
235
+ mini_batch_pre,
236
+ mini_batch_hypo,
237
+ truncation=True,
238
+ padding="max_length",
239
+ max_length=self.tokenizer.model_max_length,
240
+ return_tensors="pt",
241
+ )
242
+ batch.append(mini_batch)
243
+
244
+ return batch
245
+
246
+ def smart_doc(self, premise: list, hypo: list):
247
+ """
248
+ inference a example,
249
+ premise: list
250
+ hypo: list
251
+ using self.inference to batch the process
252
+
253
+ SMART Style aggregation
254
+ """
255
+ self.disable_progress_bar_in_inference = True
256
+ assert len(premise) == len(hypo), (
257
+ "Premise must has the same length with Hypothesis!"
258
+ )
259
+ assert self.smart_type in ["smart-n", "smart-l"]
260
+
261
+ out_score = []
262
+ for one_pre, one_hypo in tqdm(
263
+ zip(premise, hypo), desc="Evaluating SMART", total=len(premise)
264
+ ):
265
+ out_score.append(
266
+ self.smart_l(one_pre, one_hypo)[1]
267
+ if self.smart_type == "smart-l"
268
+ else self.smart_n(one_pre, one_hypo)[1]
269
+ )
270
+
271
+ return None, torch.tensor(out_score), None
272
+
273
+ def smart_l(self, premise, hypo):
274
+ premise_sents = [each.text for each in self.spacy(premise).sents]
275
+ hypo_sents = [each.text for each in self.spacy(hypo).sents]
276
+
277
+ premise_sent_mat = []
278
+ hypo_sents_mat = []
279
+ for i in range(len(premise_sents)):
280
+ for j in range(len(hypo_sents)):
281
+ premise_sent_mat.append(premise_sents[i])
282
+ hypo_sents_mat.append(hypo_sents[j])
283
+
284
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:, 0]
285
+ output_score = output_score.view(len(premise_sents), len(hypo_sents))
286
+
287
+ ### smart-l
288
+ lcs = [[0] * (len(hypo_sents) + 1)] * (len(premise_sents) + 1)
289
+ for i in range(len(premise_sents) + 1):
290
+ for j in range(len(hypo_sents) + 1):
291
+ if i != 0 and j != 0:
292
+ m = output_score[i - 1, j - 1]
293
+ lcs[i][j] = max(
294
+ [lcs[i - 1][j - 1] + m, lcs[i - 1][j] + m, lcs[i][j - 1]]
295
+ )
296
+
297
+ return None, lcs[-1][-1] / len(premise_sents), None
298
+
299
+ def smart_n(self, premise, hypo):
300
+ ### smart-n
301
+ n_gram = 1
302
+
303
+ premise_sents = [each.text for each in self.spacy(premise).sents]
304
+ hypo_sents = [each.text for each in self.spacy(hypo).sents]
305
+
306
+ premise_sent_mat = []
307
+ hypo_sents_mat = []
308
+ for i in range(len(premise_sents)):
309
+ for j in range(len(hypo_sents)):
310
+ premise_sent_mat.append(premise_sents[i])
311
+ hypo_sents_mat.append(hypo_sents[j])
312
+
313
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:, 0]
314
+ output_score = output_score.view(len(premise_sents), len(hypo_sents))
315
+
316
+ prec = sum(
317
+ [
318
+ max(
319
+ [
320
+ sum(
321
+ [
322
+ output_score[i + n, j + n] / n_gram
323
+ for n in range(0, n_gram)
324
+ ]
325
+ )
326
+ for i in range(len(premise_sents) - n_gram + 1)
327
+ ]
328
+ )
329
+ for j in range(len(hypo_sents) - n_gram + 1)
330
+ ]
331
+ )
332
+ prec = (
333
+ prec / (len(hypo_sents) - n_gram + 1)
334
+ if (len(hypo_sents) - n_gram + 1) > 0
335
+ else 0.0
336
+ )
337
+
338
+ premise_sents = [each.text for each in self.spacy(hypo).sents] # simple change
339
+ hypo_sents = [each.text for each in self.spacy(premise).sents] #
340
+
341
+ premise_sent_mat = []
342
+ hypo_sents_mat = []
343
+ for i in range(len(premise_sents)):
344
+ for j in range(len(hypo_sents)):
345
+ premise_sent_mat.append(premise_sents[i])
346
+ hypo_sents_mat.append(hypo_sents[j])
347
+
348
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[2][:, 0]
349
+ output_score = output_score.view(len(premise_sents), len(hypo_sents))
350
+
351
+ recall = sum(
352
+ [
353
+ max(
354
+ [
355
+ sum(
356
+ [
357
+ output_score[i + n, j + n] / n_gram
358
+ for n in range(0, n_gram)
359
+ ]
360
+ )
361
+ for i in range(len(premise_sents) - n_gram + 1)
362
+ ]
363
+ )
364
+ for j in range(len(hypo_sents) - n_gram + 1)
365
+ ]
366
+ )
367
+ recall = (
368
+ prec / (len(hypo_sents) - n_gram + 1)
369
+ if (len(hypo_sents) - n_gram + 1) > 0
370
+ else 0.0
371
+ )
372
+
373
+ f1 = 2 * prec * recall / (prec + recall)
374
+
375
+ if self.smart_n_metric == "f1":
376
+ return None, f1, None
377
+ elif self.smart_n_metric == "precision":
378
+ return None, prec, None
379
+ elif self.smart_n_metric == "recall":
380
+ return None, recall, None
381
+ else:
382
+ ValueError("SMART return type error")
383
+
384
+ def chunks(self, lst, n):
385
+ """Yield successive n-sized chunks from lst."""
386
+ for i in range(0, len(lst), n):
387
+ yield lst[i : i + n]
388
+
389
+ def nlg_eval(self, premise, hypo):
390
+ assert self.nlg_eval_mode is not None, "Select NLG Eval mode!"
391
+ if (
392
+ (self.nlg_eval_mode == "bin")
393
+ or (self.nlg_eval_mode == "nli")
394
+ or (self.nlg_eval_mode == "reg")
395
+ ):
396
+ return self.inference(premise, hypo)
397
+
398
+ elif (
399
+ (self.nlg_eval_mode == "bin_sp")
400
+ or (self.nlg_eval_mode == "nli_sp")
401
+ or (self.nlg_eval_mode == "reg_sp")
402
+ ):
403
+ return self.inference_example_batch(premise, hypo)
404
+
405
+ else:
406
+ ValueError("Unrecognized NLG Eval mode!")
407
+
408
+ # COPIED from Alignscore class
409
+ def score(self, contexts: List[str], claims: List[str]) -> List[float]:
410
+ return self.nlg_eval(contexts, claims)[1].tolist()