krotima1 commited on
Commit
6885f5c
·
1 Parent(s): 3dd0723

Add AlignScore.py class of transformer model - easy to use

Browse files
Files changed (1) hide show
  1. AlignScoreCS.py +634 -0
AlignScoreCS.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import PretrainedConfig
3
+ import os
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
7
+ import torch.nn as nn
8
+ import torch
9
+ # This include should be add when using different AlignScoreFunction methods instead of score()
10
+ # from nltk.tokenize import sent_tokenize
11
+ from tqdm import tqdm
12
+
13
+ class AlignScoreCS(transformers.XLMRobertaModel):
14
+ """
15
+ ALIGNSCORE class
16
+
17
+ Description:
18
+ Model ALIGNSCORECS has been trained according the paper for 3 days on 4GPUs AMD NVIDIA.
19
+ (3 epochs, 1e-5 learning rate, 1e-6 AdamWeps, batchsize 32, WarmupRatio 0.06, 0.1 WeighDecay)
20
+ - XLMROBERTA-base model with 3 classification HEAD {regression,binary,3way} using shared encoder
21
+
22
+ USAGE: AlignScore.py
23
+ - from_pretrained - loads the model, usage as transformers.model
24
+ - .score(context, claim) - function
25
+ - returns probs of the ALIGNED class using 3way class head as in the paper.
26
+
27
+ alignScoreCS = AlignScoreCS.from_pretrained("/mnt/data/factcheck/AlignScore-data/AAmodel/MTLModel/mo
28
+ alignScoreCS.score(context,claim)
29
+
30
+ If you want to try different classification head use parameter:
31
+ - task_name = "re" : regression head
32
+ - task_name = "bin" : binary classification head
33
+ - task_name = "3way" : 3way classification head
34
+
35
+ """
36
+ _regression_model = "re_model"
37
+ _binary_class_model = "bin_model"
38
+ _3way_class_model = "3way_model"
39
+
40
+ def __init__(self, encoder, taskmodels_dict, model_name= "xlm-roberta-large", **kwargs):
41
+ super().__init__(transformers.XLMRobertaConfig(), **kwargs)
42
+ self.encoder = encoder
43
+ self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)
44
+ self.tokenizer = None
45
+ self.model_name = model_name
46
+ self.inferencer = None
47
+
48
+ def init_inferencer(self, device = "cuda"):
49
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
50
+ self.inferencer = self.InferenceHandler(self, self.tokenizer, device)
51
+
52
+
53
+
54
+ """
55
+ Score: scores the context and claim with Aligned probabitlity of 3way classification head
56
+ - using paper code inferencer from ALignScore
57
+
58
+ """
59
+ def score(self, context, claim, **kwargs):
60
+ if self.inferencer is None:
61
+ self.init_inferencer()
62
+ scores = self.inferencer.nlg_eval(context, claim)
63
+ return scores
64
+
65
+ """
66
+ Score: scores the context and claim with ALIGNED probability (wrt task_name ["re" | "bin" | "3way"])
67
+
68
+ Returns the probability of the ALIGNED CLASS between context text and claim text
69
+ - chunks text by 350 tokens and splits claim into sentences
70
+ - using 3way classification head
71
+ """
72
+ def score_sentences(self, context :str, claim :str, task_name = "3way", batch_size = 2, return_all_outputs = False, **kwargs):
73
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
74
+ chunked_inputs = self.chunk_sent_input(context,claim, chunk_size=350,chunk_claim_size=150)
75
+ nclaims, ncontexts = (chunked_inputs["n_claims"],chunked_inputs["n_contexts"])
76
+ with torch.no_grad():
77
+ chunked_inputs = {key : torch.tensor(item).to(self.device) for key, item in chunked_inputs.items() if not key.startswith("n_")}
78
+ chunked_outputs = {}
79
+ for i in range(0,len(chunked_inputs["input_ids"]),batch_size):
80
+ tmp = self.forward(task_name = task_name,**{"input_ids":chunked_inputs["input_ids"][i:i+batch_size],"attention_mask" :chunked_inputs["attention_mask"][i:i+batch_size]}, **kwargs)
81
+ for k, item in tmp.items():
82
+ chunked_outputs[k] = chunked_outputs.get(k, []) + [item]
83
+ logits = torch.vstack(chunked_outputs["logits"]).cpu()
84
+ outputs = {"score" : self.alignscore_input(logits,nclaims=nclaims,ncontexts=ncontexts, task_name=task_name)}
85
+ outputs["outputs"] = chunked_outputs
86
+ return torch.tensor([outputs["score"]]) if not return_all_outputs else outputs
87
+
88
+
89
+ """
90
+ Score: scores the context and claim with ALIGNED probability (wrt task_name ["re" | "bin" | "3way"])
91
+
92
+ Returns the probability of the ALIGNED CLASS between context text and claim text
93
+ - chunks text into 350 tolens and chunks claim into 150 tokens
94
+ - using 3way classification head
95
+ """
96
+ def score_chunks(self, context :str, claim :str, task_name = "3way", batch_size = 2, return_all_outputs = False, **kwargs):
97
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
98
+ chunked_inputs = self.chunk_inputs(context,claim, chunk_size=350)
99
+ chunked_inputs = {key : torch.tensor(item).to(self.device) for key, item in chunked_inputs.items()}
100
+ chunked_outputs = self.forward(task_name = task_name, **chunked_inputs, **kwargs)
101
+ outputs = {"score" : self.alignscore_input_deprecated(chunked_outputs.logits.cpu(), task_name=task_name)}
102
+ outputs["outputs"] = chunked_outputs
103
+ return outputs["score"] if not return_all_outputs else outputs
104
+
105
+ """
106
+ Classify: classify the context and claim to the class label given the task_name ["re" | "bin" | "3way"]
107
+
108
+ Returns the class of {Neutral, contradict, aligned} between context text and claim text
109
+ - using 3way classification head
110
+ """
111
+ def classify(self, context :str, claim :str, task_name = "3way", return_all_outputs = False, **kwargs):
112
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
113
+ chunked_inputs = self.chunk_inputs(context,claim, chunk_size=350)
114
+ chunked_inputs = {key : torch.tensor(item).to(self.device) for key, item in chunked_inputs.items()}
115
+ chunked_outputs = self.forward(task_name = task_name, **chunked_inputs, **kwargs)
116
+ outputs = {"class" : self.get_system_label(chunked_outputs.logits.cpu(), task_name=task_name)}
117
+ outputs["outputs"] = chunked_outputs
118
+ return outputs["class"] if not return_all_outputs else outputs
119
+
120
+
121
+ def score_truncated(self, context :str, claim :str, task_name = "3way", return_all_outputs = False, **kwargs):
122
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
123
+ tokenized_inputs = self.tokenizer(list(zip([context], [claim])), padding = "max_length", truncation = True, max_length = 512, return_tensors="pt")
124
+ tokenized_inputs = {key : torch.tensor(item).to(self.device) for key, item in tokenized_inputs.items()}
125
+ with torch.no_grad():
126
+ model_outputs = self.forward(task_name=task_name, **tokenized_inputs, **kwargs)
127
+ outputs = {"score" : self.alignscore_input(model_outputs["logits"].cpu(),nclaims=1, ncontexts=1, task_name=task_name)}
128
+ outputs["outputs"] = model_outputs
129
+ return torch.tensor([outputs["score"]]) if not return_all_outputs else outputs
130
+
131
+ def forward(self, task_name = "3way", **kwargs):
132
+ return self.taskmodels_dict[task_name](**kwargs)
133
+
134
+ def __call__(self, task_name, **kwargs):
135
+ return self.taskmodels_dict[task_name](**kwargs)
136
+
137
+ """
138
+ Get the probability of the ALIGNED label from input
139
+ """
140
+ def alignscore_input(self, chunked_logits, nclaims, ncontexts, task_name = "3way"):
141
+ if task_name == "re":
142
+ ouptuts = chunked_logits.detach()
143
+ # Reshape the tensor to separate each block of n rows
144
+ reshaped_tensor = ouptuts.view(nclaims, ncontexts)
145
+
146
+ # Extract the maximum values from the first column (index 0) within each block of n rows
147
+ max_values, _ = reshaped_tensor.max(dim=1)
148
+
149
+ # Calculate the mean of the max values for each block of n rows
150
+ mean_of_maxes = torch.mean(max_values, dim=0)
151
+ return mean_of_maxes.tolist()
152
+ else:
153
+ nlabels = {"3way" : 3, "re" : 1, "2way" : 2}[task_name]
154
+ ouptuts = chunked_logits.softmax(1).detach()
155
+ # Reshape the tensor to separate each block of n rows
156
+ reshaped_tensor = ouptuts.view(nclaims, ncontexts, nlabels)
157
+
158
+ # Extract the maximum values from the first column (index 0) within each block of n rows
159
+ max_values, _ = torch.max(reshaped_tensor[:, :, 1], dim=1)
160
+
161
+ # Calculate the mean of the max values for each block of n rows
162
+ mean_of_maxes = torch.mean(max_values, dim=0)
163
+ return mean_of_maxes.tolist()
164
+
165
+
166
+ def alignscore_input_deprecated(self, chunked_logits, task_name = "3way"):
167
+ if task_name == "re":
168
+ return chunked_logits.detach().amax(0).tolist()
169
+ else:
170
+ return chunked_logits.softmax(1).detach()[:, 1].amax(0).tolist() # return max probability over the ALIGNED class
171
+
172
+
173
+ """
174
+ get the label from the input
175
+ """
176
+ def get_system_label(self, chunked_logits, task_name):
177
+ if task_name == "re":
178
+ return (chunked_logits.sum(0) / chunked_logits.size()[0]).detach().tolist()
179
+ else:
180
+ avg_probs = chunked_logits.softmax(1).sum(0) / chunked_logits.size()[0]
181
+ numpy_array = chunked_logits.softmax(1).argmax(1).detach().numpy()
182
+ # Calculate the frequencies of each value
183
+ unique_values, counts = np.unique(numpy_array, return_counts=True)
184
+ # Find the maximum count
185
+ max_count = np.max(counts)
186
+ # Find all values with the maximum count
187
+ most_frequent_values = unique_values[counts == max_count]
188
+ return most_frequent_values[0] if most_frequent_values.size == 1 else avg_probs.detach().argmax().tolist()
189
+
190
+ """
191
+ Chunks input context and claim - context is chunked into 350 tokens
192
+ - claim is chunked into sentences
193
+ - using stride for overflowing tokens
194
+ """
195
+ def chunk_sent_input(self, context, claim, max_length = 512, chunk_size = 350, chunk_claim_size = 150):
196
+ assert chunk_size <= max_length, "Chunk size {} cannot be greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
197
+ chunk_claim_size = max_length - chunk_size if chunk_claim_size is None else chunk_claim_size
198
+ assert chunk_size + chunk_claim_size <= max_length, "Chunk size {} and Chunk claim size {} cannot be together greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
199
+ return_chunked_inputs = {}
200
+ context_chunks = self.chunk_text(context, chunk_size=chunk_size, overflowing_tokens_stride = 25, first_special_token=[0])
201
+ claim_chunks = self.chunk_sentences(claim, chunk_size=chunk_claim_size,overflowing_tokens_stride=int(chunk_claim_size/3), first_special_token=[2])
202
+ for claim_chunk in claim_chunks:
203
+ for context_chunk in context_chunks:
204
+ inputs,attention =self.fill_with_pad_tokens(context_chunk,claim_chunk )
205
+ return_chunked_inputs["input_ids"] = return_chunked_inputs.get("input_ids",[]) + [inputs]
206
+ return_chunked_inputs["attention_mask"] = return_chunked_inputs.get("attention_mask",[]) + [attention]
207
+ return_chunked_inputs["n_claims"] = len(claim_chunks)
208
+ return_chunked_inputs["n_contexts"] = len(context_chunks)
209
+ return return_chunked_inputs
210
+
211
+ """
212
+ According to paper - chunk the text into smaller parts (350tokens + claim_tokens) when the tokenized inputs exceed the max_length
213
+ returns chunked input
214
+ """
215
+ def chunk_inputs(self, context, claim, max_length = 512, chunk_size = 512, first_fit_within_max_length = True):
216
+ assert chunk_size <= max_length, "Chunk size {} cannot be greater than max size {}".format(chunk_size, max_length)
217
+
218
+ tokenized_claim = self.tokenizer(claim, return_length=True)
219
+ tokenized_claim["input_ids"][0] = 2 # </s> token according to pair tokenization where the separator of the context and claim is </s></s>
220
+ tokenized_context = self.tokenizer(context, return_length = True)
221
+ assert tokenized_claim["length"][0] < max_length*4/5, "Create chunks of claim sentences. Claim is too long {} which is more than 4/5 from {}.".format(tokenized_claim["length"][0], max_length)
222
+
223
+ # set chunk size to incorporate the claim size as well
224
+ chunk_size = min(max_length, chunk_size + tokenized_claim["length"][0])
225
+
226
+ first_check_max_size = max_length if first_fit_within_max_length else chunk_size
227
+
228
+ if tokenized_claim["length"][0] + tokenized_context["length"][0] <= first_check_max_size: #if it fits within max_length
229
+ input_ids, attention_mask = self.fill_with_pad_tokens(tokenized_context["input_ids"],tokenized_claim["input_ids"])
230
+ return {"input_ids" : [input_ids], "attention_mask" : [attention_mask]}
231
+ else: # make chunks
232
+ return_chunked_inputs = {}
233
+ current_chunk = {}
234
+ for sentence in sent_tokenize(context, language="czech"):
235
+ tok_sent = self.tokenizer(sentence, return_length=True)
236
+ if len(current_chunk.get("input_ids",[0])) + tok_sent["length"][0] - 1 + tokenized_claim["length"][0] <= chunk_size:
237
+ current_chunk["input_ids"] = current_chunk.get("input_ids",[0]) + tok_sent["input_ids"][1:-1]
238
+ else:
239
+ return_chunked_inputs = self._update_chunked_inputs(tokenized_claim, current_chunk, return_chunked_inputs, max_length, tok_sent)
240
+ current_chunk["input_ids"] = [0] + tok_sent["input_ids"][1:-1]
241
+ if current_chunk != {}: # add the rest
242
+ return_chunked_inputs = self._update_chunked_inputs(tokenized_claim, current_chunk, return_chunked_inputs, max_length)
243
+ current_chunk = {}
244
+ return return_chunked_inputs
245
+
246
+ """
247
+ Chunks input context and claim - context is chunked into 350 tokens
248
+ - claim is chunked into 150 tokens
249
+ - using stride for overflowing tokens
250
+ """
251
+ def chunk_input_deprecated(self, context, claim, max_length = 512, chunk_size = 350, chunk_claim_size = 150):
252
+ assert chunk_size <= max_length, "Chunk size {} cannot be greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
253
+ chunk_claim_size = max_length - chunk_size if chunk_claim_size is None else chunk_claim_size
254
+ assert chunk_size + chunk_claim_size <= max_length, "Chunk size {} and Chunk claim size {} cannot be together greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
255
+ return_chunked_inputs = {}
256
+ context_chunks = self.chunk_text(context, chunk_size=chunk_size, overflowing_tokens_stride = 25, first_special_token=[0])
257
+ claim_chunks = self.chunk_text(claim, chunk_size=chunk_claim_size,overflowing_tokens_stride=int(chunk_claim_size/3), first_special_token=[2])
258
+ for claim_chunk in claim_chunks:
259
+ for context_chunk in context_chunks:
260
+ inputs,attention =self.fill_with_pad_tokens(context_chunk,claim_chunk )
261
+ return_chunked_inputs["input_ids"] = return_chunked_inputs.get("input_ids",[]) + [inputs]
262
+ return_chunked_inputs["attention_mask"] = return_chunked_inputs.get("attention_mask",[]) + [attention]
263
+ return_chunked_inputs["n_claims"] = len(claim_chunks)
264
+ return_chunked_inputs["n_contexts"] = len(context_chunks)
265
+ return return_chunked_inputs
266
+
267
+
268
+ """
269
+ Chunk texts into blocks of chunk_size tokens
270
+
271
+ """
272
+ def chunk_text(self, text, chunk_size = 350, overflowing_tokens_stride = 25, language="czech", first_special_token = [0]):
273
+ sentences = sent_tokenize(text, language=language)
274
+ tokenized = self.tokenizer(sentences if sentences != [] else [""], return_length=True)
275
+ chunks = []
276
+ chunk, current_chunk_size = ([], 0)
277
+ for i, length in enumerate(tokenized["length"]):
278
+
279
+ # WRAP THE TOKENIZED SENTNECE INTO LIST TO HANDLE OVERFLOWING TOKENS EASILY
280
+ # Case when length of one sentence is longer than the chunk size - split the sentence into chunks of chunk size
281
+ if length > chunk_size:
282
+ splits = [first_special_token + tokenized["input_ids"][i][max(1,cs):min(cs + chunk_size - 2, length - 1)] + [2] for cs in range(0, length , chunk_size-(2+overflowing_tokens_stride))]
283
+ # Case when lenght of sequence is equal or smaller than the chunk size - only continue
284
+ else:
285
+ splits = [first_special_token + tokenized["input_ids"][i][1:]]
286
+
287
+ # Go through sentence or splits of sentence
288
+ for subsentence in splits:
289
+ up_length = len(subsentence) - 2
290
+
291
+ # Case when the current chunk = 0
292
+ if current_chunk_size == 0:
293
+ current_chunk_size = up_length + 2 # First include <s> and </s> tokens
294
+ chunk = subsentence[:-1]
295
+ # Case when the current chunk + length of new subsentence <= chunk_size - only add
296
+ elif current_chunk_size + up_length <= chunk_size:
297
+ current_chunk_size += up_length
298
+ chunk += subsentence[1:-1]
299
+ # Case when the current chunk + length of new subsentence > chunk_size - create chunk
300
+ else:
301
+ chunks += [chunk + [2]]
302
+ current_chunk_size = up_length + 2 # First include <s> and </s> tokens
303
+ chunk = subsentence[:-1]
304
+ #Case when the loop ended but the current chunk isnt saved in the chunks
305
+ if chunk != []:
306
+ chunks += [chunk + [2]]
307
+ # lengths = [len(ch) for ch in chunks]
308
+ # print("Lenght in tokens of ",len(lengths)," chunks (AVG=",np.mean(lengths),",MAX=",np.max(lengths),",MIN=", np.min(lengths),")")
309
+ return chunks
310
+
311
+ """
312
+ Chunks text into sentences using nlt.sent_tokenize
313
+ """
314
+ def chunk_sentences(self, text, chunk_size, overflowing_tokens_stride = 0, language="czech", sentence_window = 2, first_special_token = [2]):
315
+ sentences = sent_tokenize(text, language=language)
316
+ tokenized = self.tokenizer(sentences if sentences != [] else [""], return_length=True)
317
+ chunks = []
318
+ current_chunk = []
319
+ for i, length in enumerate(tokenized["length"]):
320
+ # WRAP THE TOKENIZED SENTNECE INTO LIST TO HANDLE OVERFLOWING TOKENS EASILY
321
+ # Case when length of one sentence is longer than the chunk size - split the sentence into chunks of chunk size
322
+ if length > chunk_size:
323
+ splits = [first_special_token + tokenized["input_ids"][i][max(1,cs):min(cs + chunk_size - 2, length - 1)] + [2] for cs in range(0, length , chunk_size-(2+overflowing_tokens_stride))]
324
+ # Case when lenght of sequence is equal or smaller than the chunk size - only continue
325
+ else:
326
+ splits = [first_special_token + tokenized["input_ids"][i][1:]]
327
+
328
+ #Go through sentence or parts of sentence and create chunks
329
+ for split in splits:
330
+ chunks += [split]
331
+ # if len(current_chunk) == sentence_window:
332
+ # chunks += [first_special_token + [item for row in current_chunk for item in row] + [2]]
333
+ # current_chunk = current_chunk[1:] + [split[1:-1]]
334
+ # else:
335
+ # current_chunk += [split[1:-1]]
336
+
337
+ # if chunks == []:
338
+ # chunks += [first_special_token + [item for row in current_chunk for item in row] + [2]]
339
+ return chunks
340
+
341
+ """
342
+ join context and claim tokens as input_ids and create attention_mask
343
+ """
344
+ def fill_with_pad_tokens(self, first, second, max_length=512, pad_token = 1):
345
+ return first + second + [pad_token]*max(max_length-len(first)-len(second),0), [1]*(len(first)+len(second)) + [0]*max(max_length-len(first)-len(second),0)
346
+
347
+
348
+ def _update_chunked_inputs(self, tokenized_claim, current_chunk, return_chunked_inputs, max_length, tok_sent = {"input_ids" : []}):
349
+ # truncate if there is a long sentence (rare occurrences)
350
+ if len(current_chunk.get("input_ids",[0])) + tokenized_claim["length"][0] >= max_length:
351
+ chunk = current_chunk["input_ids"].copy()[:max_length-tokenized_claim["length"][0]-1] + [2]
352
+ elif not current_chunk.get("input_ids",False):
353
+ chunk = tok_sent["input_ids"][: max_length - tokenized_claim["length"][0] -1] + [2]
354
+ else:
355
+ chunk = current_chunk["input_ids"].copy() + [2] # add </s> end of sentence
356
+ claim_ids = tokenized_claim["input_ids"].copy()
357
+ inputs, attention = self.fill_with_pad_tokens(chunk,claim_ids )
358
+ return_chunked_inputs["input_ids"] = return_chunked_inputs.get("input_ids",[]) + [inputs]
359
+ return_chunked_inputs["attention_mask"] = return_chunked_inputs.get("attention_mask",[]) + [attention]
360
+ return return_chunked_inputs
361
+
362
+ @classmethod
363
+ def get_encoder_attr_name(cls, model):
364
+ """
365
+ The encoder transformer is named differently in each model "architecture".
366
+ This method lets us get the name of the encoder attribute
367
+ """
368
+ model_class_name = model.__class__.__name__
369
+ if model_class_name.startswith("XLMRoberta"):
370
+ return "roberta"
371
+ else:
372
+ raise KeyError(f"Add support for new model {model_class_name}")
373
+
374
+
375
+ @classmethod
376
+ def from_pretrained(
377
+ cls,
378
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
379
+ model_name : str = "xlm-roberta-large",
380
+ *model_args,
381
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
382
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
383
+ ignore_mismatched_sizes: bool = False,
384
+ force_download: bool = False,
385
+ local_files_only: bool = False,
386
+ token: Optional[Union[str, bool]] = None,
387
+ revision: str = "main",
388
+ use_safetensors: bool = None,
389
+ **kwargs,
390
+ ):
391
+ # Check if the required model directories exist then load it from file
392
+ if all(os.path.exists(os.path.join(pretrained_model_name_or_path, model_dir)) for model_dir in [cls._3way_class_model, cls._regression_model, cls._binary_class_model]):
393
+ # assert all(
394
+
395
+ # for model_dir in [cls._3way_class_model, cls._regression_model, cls._binary_class_model]
396
+ # ), "Error: Required model directories not found!"
397
+
398
+ # Disable the warning about newly initialized weights
399
+ transformers.logging.set_verbosity_error()
400
+
401
+ shared_encoder = None
402
+ taskmodels_dict = {}
403
+ for path_name in [cls._regression_model, cls._binary_class_model, cls._3way_class_model]:
404
+ task_name = path_name.split("_")[0]
405
+
406
+ # Load the configuration for the task-specific model
407
+ task_config = transformers.XLMRobertaConfig.from_json_file("{}/{}/config.json".format(pretrained_model_name_or_path,path_name))
408
+ # Create the task-specific model
409
+ model = transformers.XLMRobertaForSequenceClassification.from_pretrained(model_name, config=task_config,*model_args,**kwargs)
410
+ # Load the weights for the task-specific model
411
+ model.load_state_dict(torch.load("{}/{}/pytorch_model.bin".format(pretrained_model_name_or_path,path_name), map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
412
+ # Set the shared encoder to the model's encoder
413
+ if shared_encoder is None:
414
+ shared_encoder = getattr(model, AlignScoreCS.get_encoder_attr_name(model))
415
+ else:
416
+ setattr(model, AlignScoreCS.get_encoder_attr_name(model), shared_encoder)
417
+ taskmodels_dict[task_name] = model
418
+
419
+ # Create the AlignScoreCS with the shared encoder and loaded task-specific models
420
+ alignScoreCS = AlignScoreCS(encoder=shared_encoder, taskmodels_dict=taskmodels_dict, model_name=model_name)
421
+ #Try load the model from huggingface hub
422
+ else:
423
+ shared_encoder = None
424
+ taskmodels_dict = {}
425
+ for model_dir in [cls._regression_model, cls._binary_class_model, cls._3way_class_model]:
426
+ task_name = model_dir.split("_")[0]
427
+ config = transformers.XLMRobertaConfig.from_pretrained(f"{pretrained_model_name_or_path}", subfolder=model_dir)
428
+ model = transformers.XLMRobertaForSequenceClassification.from_pretrained(f"{pretrained_model_name_or_path}",config=config, subfolder=model_dir)
429
+ if shared_encoder is None:
430
+ shared_encoder = getattr(model, AlignScoreCS.get_encoder_attr_name(model))
431
+ else:
432
+ setattr(model, AlignScoreCS.get_encoder_attr_name(model), shared_encoder)
433
+ taskmodels_dict[task_name] = model
434
+ alignScoreCS = AlignScoreCS(encoder=shared_encoder, taskmodels_dict=taskmodels_dict, model_name=model_name)
435
+
436
+ return alignScoreCS
437
+
438
+
439
+ def save_pretrained(
440
+ self,
441
+ save_directory: Union[str, os.PathLike],
442
+ is_main_process: bool = True,
443
+ state_dict: Optional[dict] = None,
444
+ save_function: Callable = torch.save,
445
+ push_to_hub: bool = False,
446
+ max_shard_size: Union[int, str] = "10GB",
447
+ safe_serialization: bool = False,
448
+ variant: Optional[str] = None,
449
+ token: Optional[Union[str, bool]] = None,
450
+ save_peft_format: bool = True,
451
+ **kwargs,
452
+ ):
453
+ for task_name, model_type in self.taskmodels_dict.items():
454
+ model_type.save_pretrained(save_directory = Path(save_directory,task_name+"_model"),
455
+ is_main_process = is_main_process,
456
+ state_dict = state_dict,
457
+ save_function = save_function,
458
+ push_to_hub = push_to_hub,
459
+ max_shard_size = max_shard_size,
460
+ safe_serialization = safe_serialization,
461
+ variant = variant,
462
+ token = token,
463
+ save_peft_format = save_peft_format,
464
+ **kwargs)
465
+
466
+ # This piece of code is copied from AlignScore github repository
467
+ # if you want to use different nlg_eval_mode you have to fix errors on your own
468
+ class InferenceHandler:
469
+ def __init__(self, model, tokenizer, device = "cuda"):
470
+ self.model = model
471
+ self.device = device
472
+ self.tokenizer = tokenizer
473
+ self.model.to(self.device)
474
+ self.model.eval()
475
+ self.batch_size = 32
476
+ self.nlg_eval_mode = "nli_sp"
477
+ self.verbose = False
478
+ self.task_name = "3way"
479
+ self.softmax = nn.Softmax(dim=-1)
480
+
481
+ def nlg_eval(self, premise, hypo):
482
+ if isinstance(premise, str) and isinstance(hypo, str):
483
+ premise = [premise]
484
+ hypo = [hypo]
485
+ return self.inference_example_batch(premise, hypo)
486
+
487
+ def inference_example_batch(self, premise: list, hypo: list):
488
+ """
489
+ inference a example,
490
+ premise: list
491
+ hypo: list
492
+ using self.inference to batch the process
493
+
494
+ SummaC Style aggregation
495
+ """
496
+ self.disable_progress_bar_in_inference = True
497
+ assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!"
498
+
499
+ out_score = []
500
+ for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating", total=len(premise), disable=(not self.verbose)):
501
+ out_score.append(self.inference_per_example(one_pre, one_hypo))
502
+
503
+ return torch.tensor(out_score)
504
+
505
+ def inference_per_example(self, premise:str, hypo: str):
506
+ """
507
+ inference a example,
508
+ premise: string
509
+ hypo: string
510
+ using self.inference to batch the process
511
+ """
512
+ def chunks(lst, n):
513
+ """Yield successive n-sized chunks from lst."""
514
+ for i in range(0, len(lst), n):
515
+ yield ' '.join(lst[i:i + n])
516
+
517
+ premise_sents = sent_tokenize(premise)
518
+ premise_sents = premise_sents or ['']
519
+
520
+ n_chunk = len(premise.strip().split()) // 350 + 1
521
+ n_chunk = max(len(premise_sents) // n_chunk, 1)
522
+ premise_sents = [each for each in chunks(premise_sents, n_chunk)]
523
+
524
+ hypo_sents = sent_tokenize(hypo)
525
+
526
+ premise_sent_mat = []
527
+ hypo_sents_mat = []
528
+ for i in range(len(premise_sents)):
529
+ for j in range(len(hypo_sents)):
530
+ premise_sent_mat.append(premise_sents[i])
531
+ hypo_sents_mat.append(hypo_sents[j])
532
+
533
+ if self.nlg_eval_mode is not None:
534
+ if self.nlg_eval_mode == 'nli_sp':
535
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat)[:,1] ### use NLI head OR ALIGN head
536
+ output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
537
+
538
+ return output_score
539
+
540
+
541
+ output_score = self.inference(premise_sent_mat, hypo_sents_mat) ### use NLI head OR ALIGN head
542
+ output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
543
+
544
+ return output_score
545
+
546
+ def inference(self, premise, hypo, task_name = None):
547
+ """
548
+ inference a list of premise and hypo
549
+
550
+ Standard aggregation
551
+ """
552
+ task_name = self.task_name if task_name is None else task_name
553
+ if isinstance(premise, str) and isinstance(hypo, str):
554
+ premise = [premise]
555
+ hypo = [hypo]
556
+
557
+ batch = self.batch_tokenize(premise, hypo)
558
+ output_score = []
559
+
560
+ for mini_batch in tqdm(batch, desc="Evaluating", disable=not self.verbose or self.disable_progress_bar_in_inference):
561
+ mini_batch = mini_batch.to(self.device)
562
+ with torch.no_grad():
563
+ model_output = self.model.forward(task_name=task_name, **mini_batch)
564
+ model_output = model_output.logits
565
+ if task_name == "re":
566
+ model_output = model_output.cpu()
567
+ else:
568
+ model_output = self.softmax(model_output).cpu()
569
+ output_score.append(model_output[:,:])
570
+
571
+ output_score = torch.cat(output_score)
572
+
573
+ if self.nlg_eval_mode is not None:
574
+ if self.nlg_eval_mode == 'nli':
575
+ output_score_nli = output_score[:,1]
576
+ return output_score_nli
577
+ elif self.nlg_eval_mode == 'bin':
578
+ return output_score
579
+ elif self.nlg_eval_mode == 'reg':
580
+ return output_score
581
+ else:
582
+ ValueError("unrecognized nlg eval mode")
583
+
584
+
585
+ return output_score
586
+
587
+ def batch_tokenize(self, premise, hypo):
588
+ """
589
+ input premise and hypos are lists
590
+ """
591
+ assert isinstance(premise, list) and isinstance(hypo, list)
592
+ assert len(premise) == len(hypo), "premise and hypo should be in the same length."
593
+
594
+ batch = []
595
+ for mini_batch_pre, mini_batch_hypo in zip(self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)):
596
+ try:
597
+ mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation='only_first', padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
598
+ except:
599
+ print('text_b too long...')
600
+ mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation=True, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
601
+ batch.append(mini_batch)
602
+
603
+ return batch
604
+
605
+ def chunks(self, lst, n):
606
+ """Yield successive n-sized chunks from lst."""
607
+ for i in range(0, len(lst), n):
608
+ yield lst[i:i + n]
609
+
610
+
611
+
612
+ if __name__ == "__main__":
613
+ alignScore = AlignScoreCS.from_pretrained("krotima1/AlignScoreCS")
614
+ alignScore.to("cuda" if torch.cuda.is_available() else "cpu")
615
+
616
+ print("Tomáš miluje Zuzku!", "|", "Tomáš miluje Petru!",alignScore.score("Tomáš miluje Zuzku!", "Tomáš miluje Petru."))
617
+ print("Tomáš miluje Zuzku!", "|", "Tomáš miluje Zuzku!",alignScore.score("Tomáš miluje Zuzku!", "Tomáš miluje Zuzku!"))
618
+ print("Tomáš miluje Zuzku.", "|", "Zuzka miluje Tomáše.",alignScore.score("Tomáš miluje Zuzku!", "Zuzka miluje Tomáše."))
619
+ print("Tomáš miluje Zuzku.", "|", "Zuzka nemiluje Tomáše.",alignScore.score("Tomáš miluje Zuzku!", "Zuzka nemiluje Tomáše."))
620
+ print("Tomáš miluje Zuzku.", "|", "Tomáš nemiluje Zuzku.",alignScore.score("Tomáš miluje Zuzku!", "Tomáš nemiluje Zuzku."))
621
+ print("Dva chlapi se perou.", "|", "Je tu bitka.",alignScore.score("Dva chlapi se perou.", "Je tu bitka."))
622
+ print("Dva chlapi se perou.", "|", "Je tu láska.",alignScore.score("Dva chlapi se perou.", "Je tu láska."))
623
+ print("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta. \n Kdo nechal vystavět katedrálu?", "|", "Byl to Karel.",alignScore.score("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.\nKdo nechal vystavět katedrálu?", "Byl to Karel."))
624
+ print("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta. \n Kdo nechal vystavět katedrálu?", "|", "Byl to Vít.",alignScore.score("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.\nKdo nechal vystavět katedrálu?", "Byl to Vít."))
625
+ print("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta. \n Kdo nechal vystavět katedrálu?", "|", "Byla to katedrála.",alignScore.score("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.\nKdo nechal vystavět katedrálu?", "Byla to katedrála."))
626
+ print("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "|", "Je Otec.",alignScore.score("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "Je Otec."))
627
+ print("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "|", "Je Otec vlasti.",alignScore.score("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "Je Otec vlasti."))
628
+ print("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "|", "Je katedrála svatého Víta.",alignScore.score("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "Je katedrála svatého Víta."))
629
+ print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka utekla vklovi.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka utekla vklovi."))
630
+ print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka neutekla vklovi.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka neutekla vklovi."))
631
+ print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Vlk snědl Karkulku.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Vlk snědl karkulku."))
632
+ print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Vlk nesnědl Karkulku.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Vlk nesnědl karkulku."))
633
+ print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka snědla vlka.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka snědla vlka."))
634
+ print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka dala vlkovi jablko.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka dala vlkovi jablko."))