krotima1
commited on
Commit
·
6885f5c
1
Parent(s):
3dd0723
Add AlignScore.py class of transformer model - easy to use
Browse files- AlignScoreCS.py +634 -0
AlignScoreCS.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch
|
| 9 |
+
# This include should be add when using different AlignScoreFunction methods instead of score()
|
| 10 |
+
# from nltk.tokenize import sent_tokenize
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
class AlignScoreCS(transformers.XLMRobertaModel):
|
| 14 |
+
"""
|
| 15 |
+
ALIGNSCORE class
|
| 16 |
+
|
| 17 |
+
Description:
|
| 18 |
+
Model ALIGNSCORECS has been trained according the paper for 3 days on 4GPUs AMD NVIDIA.
|
| 19 |
+
(3 epochs, 1e-5 learning rate, 1e-6 AdamWeps, batchsize 32, WarmupRatio 0.06, 0.1 WeighDecay)
|
| 20 |
+
- XLMROBERTA-base model with 3 classification HEAD {regression,binary,3way} using shared encoder
|
| 21 |
+
|
| 22 |
+
USAGE: AlignScore.py
|
| 23 |
+
- from_pretrained - loads the model, usage as transformers.model
|
| 24 |
+
- .score(context, claim) - function
|
| 25 |
+
- returns probs of the ALIGNED class using 3way class head as in the paper.
|
| 26 |
+
|
| 27 |
+
alignScoreCS = AlignScoreCS.from_pretrained("/mnt/data/factcheck/AlignScore-data/AAmodel/MTLModel/mo
|
| 28 |
+
alignScoreCS.score(context,claim)
|
| 29 |
+
|
| 30 |
+
If you want to try different classification head use parameter:
|
| 31 |
+
- task_name = "re" : regression head
|
| 32 |
+
- task_name = "bin" : binary classification head
|
| 33 |
+
- task_name = "3way" : 3way classification head
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
_regression_model = "re_model"
|
| 37 |
+
_binary_class_model = "bin_model"
|
| 38 |
+
_3way_class_model = "3way_model"
|
| 39 |
+
|
| 40 |
+
def __init__(self, encoder, taskmodels_dict, model_name= "xlm-roberta-large", **kwargs):
|
| 41 |
+
super().__init__(transformers.XLMRobertaConfig(), **kwargs)
|
| 42 |
+
self.encoder = encoder
|
| 43 |
+
self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)
|
| 44 |
+
self.tokenizer = None
|
| 45 |
+
self.model_name = model_name
|
| 46 |
+
self.inferencer = None
|
| 47 |
+
|
| 48 |
+
def init_inferencer(self, device = "cuda"):
|
| 49 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
|
| 50 |
+
self.inferencer = self.InferenceHandler(self, self.tokenizer, device)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
Score: scores the context and claim with Aligned probabitlity of 3way classification head
|
| 56 |
+
- using paper code inferencer from ALignScore
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
def score(self, context, claim, **kwargs):
|
| 60 |
+
if self.inferencer is None:
|
| 61 |
+
self.init_inferencer()
|
| 62 |
+
scores = self.inferencer.nlg_eval(context, claim)
|
| 63 |
+
return scores
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
Score: scores the context and claim with ALIGNED probability (wrt task_name ["re" | "bin" | "3way"])
|
| 67 |
+
|
| 68 |
+
Returns the probability of the ALIGNED CLASS between context text and claim text
|
| 69 |
+
- chunks text by 350 tokens and splits claim into sentences
|
| 70 |
+
- using 3way classification head
|
| 71 |
+
"""
|
| 72 |
+
def score_sentences(self, context :str, claim :str, task_name = "3way", batch_size = 2, return_all_outputs = False, **kwargs):
|
| 73 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
|
| 74 |
+
chunked_inputs = self.chunk_sent_input(context,claim, chunk_size=350,chunk_claim_size=150)
|
| 75 |
+
nclaims, ncontexts = (chunked_inputs["n_claims"],chunked_inputs["n_contexts"])
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
chunked_inputs = {key : torch.tensor(item).to(self.device) for key, item in chunked_inputs.items() if not key.startswith("n_")}
|
| 78 |
+
chunked_outputs = {}
|
| 79 |
+
for i in range(0,len(chunked_inputs["input_ids"]),batch_size):
|
| 80 |
+
tmp = self.forward(task_name = task_name,**{"input_ids":chunked_inputs["input_ids"][i:i+batch_size],"attention_mask" :chunked_inputs["attention_mask"][i:i+batch_size]}, **kwargs)
|
| 81 |
+
for k, item in tmp.items():
|
| 82 |
+
chunked_outputs[k] = chunked_outputs.get(k, []) + [item]
|
| 83 |
+
logits = torch.vstack(chunked_outputs["logits"]).cpu()
|
| 84 |
+
outputs = {"score" : self.alignscore_input(logits,nclaims=nclaims,ncontexts=ncontexts, task_name=task_name)}
|
| 85 |
+
outputs["outputs"] = chunked_outputs
|
| 86 |
+
return torch.tensor([outputs["score"]]) if not return_all_outputs else outputs
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
Score: scores the context and claim with ALIGNED probability (wrt task_name ["re" | "bin" | "3way"])
|
| 91 |
+
|
| 92 |
+
Returns the probability of the ALIGNED CLASS between context text and claim text
|
| 93 |
+
- chunks text into 350 tolens and chunks claim into 150 tokens
|
| 94 |
+
- using 3way classification head
|
| 95 |
+
"""
|
| 96 |
+
def score_chunks(self, context :str, claim :str, task_name = "3way", batch_size = 2, return_all_outputs = False, **kwargs):
|
| 97 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
|
| 98 |
+
chunked_inputs = self.chunk_inputs(context,claim, chunk_size=350)
|
| 99 |
+
chunked_inputs = {key : torch.tensor(item).to(self.device) for key, item in chunked_inputs.items()}
|
| 100 |
+
chunked_outputs = self.forward(task_name = task_name, **chunked_inputs, **kwargs)
|
| 101 |
+
outputs = {"score" : self.alignscore_input_deprecated(chunked_outputs.logits.cpu(), task_name=task_name)}
|
| 102 |
+
outputs["outputs"] = chunked_outputs
|
| 103 |
+
return outputs["score"] if not return_all_outputs else outputs
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
Classify: classify the context and claim to the class label given the task_name ["re" | "bin" | "3way"]
|
| 107 |
+
|
| 108 |
+
Returns the class of {Neutral, contradict, aligned} between context text and claim text
|
| 109 |
+
- using 3way classification head
|
| 110 |
+
"""
|
| 111 |
+
def classify(self, context :str, claim :str, task_name = "3way", return_all_outputs = False, **kwargs):
|
| 112 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
|
| 113 |
+
chunked_inputs = self.chunk_inputs(context,claim, chunk_size=350)
|
| 114 |
+
chunked_inputs = {key : torch.tensor(item).to(self.device) for key, item in chunked_inputs.items()}
|
| 115 |
+
chunked_outputs = self.forward(task_name = task_name, **chunked_inputs, **kwargs)
|
| 116 |
+
outputs = {"class" : self.get_system_label(chunked_outputs.logits.cpu(), task_name=task_name)}
|
| 117 |
+
outputs["outputs"] = chunked_outputs
|
| 118 |
+
return outputs["class"] if not return_all_outputs else outputs
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def score_truncated(self, context :str, claim :str, task_name = "3way", return_all_outputs = False, **kwargs):
|
| 122 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) if not self.tokenizer else self.tokenizer
|
| 123 |
+
tokenized_inputs = self.tokenizer(list(zip([context], [claim])), padding = "max_length", truncation = True, max_length = 512, return_tensors="pt")
|
| 124 |
+
tokenized_inputs = {key : torch.tensor(item).to(self.device) for key, item in tokenized_inputs.items()}
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
model_outputs = self.forward(task_name=task_name, **tokenized_inputs, **kwargs)
|
| 127 |
+
outputs = {"score" : self.alignscore_input(model_outputs["logits"].cpu(),nclaims=1, ncontexts=1, task_name=task_name)}
|
| 128 |
+
outputs["outputs"] = model_outputs
|
| 129 |
+
return torch.tensor([outputs["score"]]) if not return_all_outputs else outputs
|
| 130 |
+
|
| 131 |
+
def forward(self, task_name = "3way", **kwargs):
|
| 132 |
+
return self.taskmodels_dict[task_name](**kwargs)
|
| 133 |
+
|
| 134 |
+
def __call__(self, task_name, **kwargs):
|
| 135 |
+
return self.taskmodels_dict[task_name](**kwargs)
|
| 136 |
+
|
| 137 |
+
"""
|
| 138 |
+
Get the probability of the ALIGNED label from input
|
| 139 |
+
"""
|
| 140 |
+
def alignscore_input(self, chunked_logits, nclaims, ncontexts, task_name = "3way"):
|
| 141 |
+
if task_name == "re":
|
| 142 |
+
ouptuts = chunked_logits.detach()
|
| 143 |
+
# Reshape the tensor to separate each block of n rows
|
| 144 |
+
reshaped_tensor = ouptuts.view(nclaims, ncontexts)
|
| 145 |
+
|
| 146 |
+
# Extract the maximum values from the first column (index 0) within each block of n rows
|
| 147 |
+
max_values, _ = reshaped_tensor.max(dim=1)
|
| 148 |
+
|
| 149 |
+
# Calculate the mean of the max values for each block of n rows
|
| 150 |
+
mean_of_maxes = torch.mean(max_values, dim=0)
|
| 151 |
+
return mean_of_maxes.tolist()
|
| 152 |
+
else:
|
| 153 |
+
nlabels = {"3way" : 3, "re" : 1, "2way" : 2}[task_name]
|
| 154 |
+
ouptuts = chunked_logits.softmax(1).detach()
|
| 155 |
+
# Reshape the tensor to separate each block of n rows
|
| 156 |
+
reshaped_tensor = ouptuts.view(nclaims, ncontexts, nlabels)
|
| 157 |
+
|
| 158 |
+
# Extract the maximum values from the first column (index 0) within each block of n rows
|
| 159 |
+
max_values, _ = torch.max(reshaped_tensor[:, :, 1], dim=1)
|
| 160 |
+
|
| 161 |
+
# Calculate the mean of the max values for each block of n rows
|
| 162 |
+
mean_of_maxes = torch.mean(max_values, dim=0)
|
| 163 |
+
return mean_of_maxes.tolist()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def alignscore_input_deprecated(self, chunked_logits, task_name = "3way"):
|
| 167 |
+
if task_name == "re":
|
| 168 |
+
return chunked_logits.detach().amax(0).tolist()
|
| 169 |
+
else:
|
| 170 |
+
return chunked_logits.softmax(1).detach()[:, 1].amax(0).tolist() # return max probability over the ALIGNED class
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
get the label from the input
|
| 175 |
+
"""
|
| 176 |
+
def get_system_label(self, chunked_logits, task_name):
|
| 177 |
+
if task_name == "re":
|
| 178 |
+
return (chunked_logits.sum(0) / chunked_logits.size()[0]).detach().tolist()
|
| 179 |
+
else:
|
| 180 |
+
avg_probs = chunked_logits.softmax(1).sum(0) / chunked_logits.size()[0]
|
| 181 |
+
numpy_array = chunked_logits.softmax(1).argmax(1).detach().numpy()
|
| 182 |
+
# Calculate the frequencies of each value
|
| 183 |
+
unique_values, counts = np.unique(numpy_array, return_counts=True)
|
| 184 |
+
# Find the maximum count
|
| 185 |
+
max_count = np.max(counts)
|
| 186 |
+
# Find all values with the maximum count
|
| 187 |
+
most_frequent_values = unique_values[counts == max_count]
|
| 188 |
+
return most_frequent_values[0] if most_frequent_values.size == 1 else avg_probs.detach().argmax().tolist()
|
| 189 |
+
|
| 190 |
+
"""
|
| 191 |
+
Chunks input context and claim - context is chunked into 350 tokens
|
| 192 |
+
- claim is chunked into sentences
|
| 193 |
+
- using stride for overflowing tokens
|
| 194 |
+
"""
|
| 195 |
+
def chunk_sent_input(self, context, claim, max_length = 512, chunk_size = 350, chunk_claim_size = 150):
|
| 196 |
+
assert chunk_size <= max_length, "Chunk size {} cannot be greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
|
| 197 |
+
chunk_claim_size = max_length - chunk_size if chunk_claim_size is None else chunk_claim_size
|
| 198 |
+
assert chunk_size + chunk_claim_size <= max_length, "Chunk size {} and Chunk claim size {} cannot be together greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
|
| 199 |
+
return_chunked_inputs = {}
|
| 200 |
+
context_chunks = self.chunk_text(context, chunk_size=chunk_size, overflowing_tokens_stride = 25, first_special_token=[0])
|
| 201 |
+
claim_chunks = self.chunk_sentences(claim, chunk_size=chunk_claim_size,overflowing_tokens_stride=int(chunk_claim_size/3), first_special_token=[2])
|
| 202 |
+
for claim_chunk in claim_chunks:
|
| 203 |
+
for context_chunk in context_chunks:
|
| 204 |
+
inputs,attention =self.fill_with_pad_tokens(context_chunk,claim_chunk )
|
| 205 |
+
return_chunked_inputs["input_ids"] = return_chunked_inputs.get("input_ids",[]) + [inputs]
|
| 206 |
+
return_chunked_inputs["attention_mask"] = return_chunked_inputs.get("attention_mask",[]) + [attention]
|
| 207 |
+
return_chunked_inputs["n_claims"] = len(claim_chunks)
|
| 208 |
+
return_chunked_inputs["n_contexts"] = len(context_chunks)
|
| 209 |
+
return return_chunked_inputs
|
| 210 |
+
|
| 211 |
+
"""
|
| 212 |
+
According to paper - chunk the text into smaller parts (350tokens + claim_tokens) when the tokenized inputs exceed the max_length
|
| 213 |
+
returns chunked input
|
| 214 |
+
"""
|
| 215 |
+
def chunk_inputs(self, context, claim, max_length = 512, chunk_size = 512, first_fit_within_max_length = True):
|
| 216 |
+
assert chunk_size <= max_length, "Chunk size {} cannot be greater than max size {}".format(chunk_size, max_length)
|
| 217 |
+
|
| 218 |
+
tokenized_claim = self.tokenizer(claim, return_length=True)
|
| 219 |
+
tokenized_claim["input_ids"][0] = 2 # </s> token according to pair tokenization where the separator of the context and claim is </s></s>
|
| 220 |
+
tokenized_context = self.tokenizer(context, return_length = True)
|
| 221 |
+
assert tokenized_claim["length"][0] < max_length*4/5, "Create chunks of claim sentences. Claim is too long {} which is more than 4/5 from {}.".format(tokenized_claim["length"][0], max_length)
|
| 222 |
+
|
| 223 |
+
# set chunk size to incorporate the claim size as well
|
| 224 |
+
chunk_size = min(max_length, chunk_size + tokenized_claim["length"][0])
|
| 225 |
+
|
| 226 |
+
first_check_max_size = max_length if first_fit_within_max_length else chunk_size
|
| 227 |
+
|
| 228 |
+
if tokenized_claim["length"][0] + tokenized_context["length"][0] <= first_check_max_size: #if it fits within max_length
|
| 229 |
+
input_ids, attention_mask = self.fill_with_pad_tokens(tokenized_context["input_ids"],tokenized_claim["input_ids"])
|
| 230 |
+
return {"input_ids" : [input_ids], "attention_mask" : [attention_mask]}
|
| 231 |
+
else: # make chunks
|
| 232 |
+
return_chunked_inputs = {}
|
| 233 |
+
current_chunk = {}
|
| 234 |
+
for sentence in sent_tokenize(context, language="czech"):
|
| 235 |
+
tok_sent = self.tokenizer(sentence, return_length=True)
|
| 236 |
+
if len(current_chunk.get("input_ids",[0])) + tok_sent["length"][0] - 1 + tokenized_claim["length"][0] <= chunk_size:
|
| 237 |
+
current_chunk["input_ids"] = current_chunk.get("input_ids",[0]) + tok_sent["input_ids"][1:-1]
|
| 238 |
+
else:
|
| 239 |
+
return_chunked_inputs = self._update_chunked_inputs(tokenized_claim, current_chunk, return_chunked_inputs, max_length, tok_sent)
|
| 240 |
+
current_chunk["input_ids"] = [0] + tok_sent["input_ids"][1:-1]
|
| 241 |
+
if current_chunk != {}: # add the rest
|
| 242 |
+
return_chunked_inputs = self._update_chunked_inputs(tokenized_claim, current_chunk, return_chunked_inputs, max_length)
|
| 243 |
+
current_chunk = {}
|
| 244 |
+
return return_chunked_inputs
|
| 245 |
+
|
| 246 |
+
"""
|
| 247 |
+
Chunks input context and claim - context is chunked into 350 tokens
|
| 248 |
+
- claim is chunked into 150 tokens
|
| 249 |
+
- using stride for overflowing tokens
|
| 250 |
+
"""
|
| 251 |
+
def chunk_input_deprecated(self, context, claim, max_length = 512, chunk_size = 350, chunk_claim_size = 150):
|
| 252 |
+
assert chunk_size <= max_length, "Chunk size {} cannot be greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
|
| 253 |
+
chunk_claim_size = max_length - chunk_size if chunk_claim_size is None else chunk_claim_size
|
| 254 |
+
assert chunk_size + chunk_claim_size <= max_length, "Chunk size {} and Chunk claim size {} cannot be together greater than max size {}".format(chunk_size, chunk_claim_size, max_length)
|
| 255 |
+
return_chunked_inputs = {}
|
| 256 |
+
context_chunks = self.chunk_text(context, chunk_size=chunk_size, overflowing_tokens_stride = 25, first_special_token=[0])
|
| 257 |
+
claim_chunks = self.chunk_text(claim, chunk_size=chunk_claim_size,overflowing_tokens_stride=int(chunk_claim_size/3), first_special_token=[2])
|
| 258 |
+
for claim_chunk in claim_chunks:
|
| 259 |
+
for context_chunk in context_chunks:
|
| 260 |
+
inputs,attention =self.fill_with_pad_tokens(context_chunk,claim_chunk )
|
| 261 |
+
return_chunked_inputs["input_ids"] = return_chunked_inputs.get("input_ids",[]) + [inputs]
|
| 262 |
+
return_chunked_inputs["attention_mask"] = return_chunked_inputs.get("attention_mask",[]) + [attention]
|
| 263 |
+
return_chunked_inputs["n_claims"] = len(claim_chunks)
|
| 264 |
+
return_chunked_inputs["n_contexts"] = len(context_chunks)
|
| 265 |
+
return return_chunked_inputs
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
"""
|
| 269 |
+
Chunk texts into blocks of chunk_size tokens
|
| 270 |
+
|
| 271 |
+
"""
|
| 272 |
+
def chunk_text(self, text, chunk_size = 350, overflowing_tokens_stride = 25, language="czech", first_special_token = [0]):
|
| 273 |
+
sentences = sent_tokenize(text, language=language)
|
| 274 |
+
tokenized = self.tokenizer(sentences if sentences != [] else [""], return_length=True)
|
| 275 |
+
chunks = []
|
| 276 |
+
chunk, current_chunk_size = ([], 0)
|
| 277 |
+
for i, length in enumerate(tokenized["length"]):
|
| 278 |
+
|
| 279 |
+
# WRAP THE TOKENIZED SENTNECE INTO LIST TO HANDLE OVERFLOWING TOKENS EASILY
|
| 280 |
+
# Case when length of one sentence is longer than the chunk size - split the sentence into chunks of chunk size
|
| 281 |
+
if length > chunk_size:
|
| 282 |
+
splits = [first_special_token + tokenized["input_ids"][i][max(1,cs):min(cs + chunk_size - 2, length - 1)] + [2] for cs in range(0, length , chunk_size-(2+overflowing_tokens_stride))]
|
| 283 |
+
# Case when lenght of sequence is equal or smaller than the chunk size - only continue
|
| 284 |
+
else:
|
| 285 |
+
splits = [first_special_token + tokenized["input_ids"][i][1:]]
|
| 286 |
+
|
| 287 |
+
# Go through sentence or splits of sentence
|
| 288 |
+
for subsentence in splits:
|
| 289 |
+
up_length = len(subsentence) - 2
|
| 290 |
+
|
| 291 |
+
# Case when the current chunk = 0
|
| 292 |
+
if current_chunk_size == 0:
|
| 293 |
+
current_chunk_size = up_length + 2 # First include <s> and </s> tokens
|
| 294 |
+
chunk = subsentence[:-1]
|
| 295 |
+
# Case when the current chunk + length of new subsentence <= chunk_size - only add
|
| 296 |
+
elif current_chunk_size + up_length <= chunk_size:
|
| 297 |
+
current_chunk_size += up_length
|
| 298 |
+
chunk += subsentence[1:-1]
|
| 299 |
+
# Case when the current chunk + length of new subsentence > chunk_size - create chunk
|
| 300 |
+
else:
|
| 301 |
+
chunks += [chunk + [2]]
|
| 302 |
+
current_chunk_size = up_length + 2 # First include <s> and </s> tokens
|
| 303 |
+
chunk = subsentence[:-1]
|
| 304 |
+
#Case when the loop ended but the current chunk isnt saved in the chunks
|
| 305 |
+
if chunk != []:
|
| 306 |
+
chunks += [chunk + [2]]
|
| 307 |
+
# lengths = [len(ch) for ch in chunks]
|
| 308 |
+
# print("Lenght in tokens of ",len(lengths)," chunks (AVG=",np.mean(lengths),",MAX=",np.max(lengths),",MIN=", np.min(lengths),")")
|
| 309 |
+
return chunks
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
Chunks text into sentences using nlt.sent_tokenize
|
| 313 |
+
"""
|
| 314 |
+
def chunk_sentences(self, text, chunk_size, overflowing_tokens_stride = 0, language="czech", sentence_window = 2, first_special_token = [2]):
|
| 315 |
+
sentences = sent_tokenize(text, language=language)
|
| 316 |
+
tokenized = self.tokenizer(sentences if sentences != [] else [""], return_length=True)
|
| 317 |
+
chunks = []
|
| 318 |
+
current_chunk = []
|
| 319 |
+
for i, length in enumerate(tokenized["length"]):
|
| 320 |
+
# WRAP THE TOKENIZED SENTNECE INTO LIST TO HANDLE OVERFLOWING TOKENS EASILY
|
| 321 |
+
# Case when length of one sentence is longer than the chunk size - split the sentence into chunks of chunk size
|
| 322 |
+
if length > chunk_size:
|
| 323 |
+
splits = [first_special_token + tokenized["input_ids"][i][max(1,cs):min(cs + chunk_size - 2, length - 1)] + [2] for cs in range(0, length , chunk_size-(2+overflowing_tokens_stride))]
|
| 324 |
+
# Case when lenght of sequence is equal or smaller than the chunk size - only continue
|
| 325 |
+
else:
|
| 326 |
+
splits = [first_special_token + tokenized["input_ids"][i][1:]]
|
| 327 |
+
|
| 328 |
+
#Go through sentence or parts of sentence and create chunks
|
| 329 |
+
for split in splits:
|
| 330 |
+
chunks += [split]
|
| 331 |
+
# if len(current_chunk) == sentence_window:
|
| 332 |
+
# chunks += [first_special_token + [item for row in current_chunk for item in row] + [2]]
|
| 333 |
+
# current_chunk = current_chunk[1:] + [split[1:-1]]
|
| 334 |
+
# else:
|
| 335 |
+
# current_chunk += [split[1:-1]]
|
| 336 |
+
|
| 337 |
+
# if chunks == []:
|
| 338 |
+
# chunks += [first_special_token + [item for row in current_chunk for item in row] + [2]]
|
| 339 |
+
return chunks
|
| 340 |
+
|
| 341 |
+
"""
|
| 342 |
+
join context and claim tokens as input_ids and create attention_mask
|
| 343 |
+
"""
|
| 344 |
+
def fill_with_pad_tokens(self, first, second, max_length=512, pad_token = 1):
|
| 345 |
+
return first + second + [pad_token]*max(max_length-len(first)-len(second),0), [1]*(len(first)+len(second)) + [0]*max(max_length-len(first)-len(second),0)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _update_chunked_inputs(self, tokenized_claim, current_chunk, return_chunked_inputs, max_length, tok_sent = {"input_ids" : []}):
|
| 349 |
+
# truncate if there is a long sentence (rare occurrences)
|
| 350 |
+
if len(current_chunk.get("input_ids",[0])) + tokenized_claim["length"][0] >= max_length:
|
| 351 |
+
chunk = current_chunk["input_ids"].copy()[:max_length-tokenized_claim["length"][0]-1] + [2]
|
| 352 |
+
elif not current_chunk.get("input_ids",False):
|
| 353 |
+
chunk = tok_sent["input_ids"][: max_length - tokenized_claim["length"][0] -1] + [2]
|
| 354 |
+
else:
|
| 355 |
+
chunk = current_chunk["input_ids"].copy() + [2] # add </s> end of sentence
|
| 356 |
+
claim_ids = tokenized_claim["input_ids"].copy()
|
| 357 |
+
inputs, attention = self.fill_with_pad_tokens(chunk,claim_ids )
|
| 358 |
+
return_chunked_inputs["input_ids"] = return_chunked_inputs.get("input_ids",[]) + [inputs]
|
| 359 |
+
return_chunked_inputs["attention_mask"] = return_chunked_inputs.get("attention_mask",[]) + [attention]
|
| 360 |
+
return return_chunked_inputs
|
| 361 |
+
|
| 362 |
+
@classmethod
|
| 363 |
+
def get_encoder_attr_name(cls, model):
|
| 364 |
+
"""
|
| 365 |
+
The encoder transformer is named differently in each model "architecture".
|
| 366 |
+
This method lets us get the name of the encoder attribute
|
| 367 |
+
"""
|
| 368 |
+
model_class_name = model.__class__.__name__
|
| 369 |
+
if model_class_name.startswith("XLMRoberta"):
|
| 370 |
+
return "roberta"
|
| 371 |
+
else:
|
| 372 |
+
raise KeyError(f"Add support for new model {model_class_name}")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@classmethod
|
| 376 |
+
def from_pretrained(
|
| 377 |
+
cls,
|
| 378 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 379 |
+
model_name : str = "xlm-roberta-large",
|
| 380 |
+
*model_args,
|
| 381 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 382 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 383 |
+
ignore_mismatched_sizes: bool = False,
|
| 384 |
+
force_download: bool = False,
|
| 385 |
+
local_files_only: bool = False,
|
| 386 |
+
token: Optional[Union[str, bool]] = None,
|
| 387 |
+
revision: str = "main",
|
| 388 |
+
use_safetensors: bool = None,
|
| 389 |
+
**kwargs,
|
| 390 |
+
):
|
| 391 |
+
# Check if the required model directories exist then load it from file
|
| 392 |
+
if all(os.path.exists(os.path.join(pretrained_model_name_or_path, model_dir)) for model_dir in [cls._3way_class_model, cls._regression_model, cls._binary_class_model]):
|
| 393 |
+
# assert all(
|
| 394 |
+
|
| 395 |
+
# for model_dir in [cls._3way_class_model, cls._regression_model, cls._binary_class_model]
|
| 396 |
+
# ), "Error: Required model directories not found!"
|
| 397 |
+
|
| 398 |
+
# Disable the warning about newly initialized weights
|
| 399 |
+
transformers.logging.set_verbosity_error()
|
| 400 |
+
|
| 401 |
+
shared_encoder = None
|
| 402 |
+
taskmodels_dict = {}
|
| 403 |
+
for path_name in [cls._regression_model, cls._binary_class_model, cls._3way_class_model]:
|
| 404 |
+
task_name = path_name.split("_")[0]
|
| 405 |
+
|
| 406 |
+
# Load the configuration for the task-specific model
|
| 407 |
+
task_config = transformers.XLMRobertaConfig.from_json_file("{}/{}/config.json".format(pretrained_model_name_or_path,path_name))
|
| 408 |
+
# Create the task-specific model
|
| 409 |
+
model = transformers.XLMRobertaForSequenceClassification.from_pretrained(model_name, config=task_config,*model_args,**kwargs)
|
| 410 |
+
# Load the weights for the task-specific model
|
| 411 |
+
model.load_state_dict(torch.load("{}/{}/pytorch_model.bin".format(pretrained_model_name_or_path,path_name), map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
| 412 |
+
# Set the shared encoder to the model's encoder
|
| 413 |
+
if shared_encoder is None:
|
| 414 |
+
shared_encoder = getattr(model, AlignScoreCS.get_encoder_attr_name(model))
|
| 415 |
+
else:
|
| 416 |
+
setattr(model, AlignScoreCS.get_encoder_attr_name(model), shared_encoder)
|
| 417 |
+
taskmodels_dict[task_name] = model
|
| 418 |
+
|
| 419 |
+
# Create the AlignScoreCS with the shared encoder and loaded task-specific models
|
| 420 |
+
alignScoreCS = AlignScoreCS(encoder=shared_encoder, taskmodels_dict=taskmodels_dict, model_name=model_name)
|
| 421 |
+
#Try load the model from huggingface hub
|
| 422 |
+
else:
|
| 423 |
+
shared_encoder = None
|
| 424 |
+
taskmodels_dict = {}
|
| 425 |
+
for model_dir in [cls._regression_model, cls._binary_class_model, cls._3way_class_model]:
|
| 426 |
+
task_name = model_dir.split("_")[0]
|
| 427 |
+
config = transformers.XLMRobertaConfig.from_pretrained(f"{pretrained_model_name_or_path}", subfolder=model_dir)
|
| 428 |
+
model = transformers.XLMRobertaForSequenceClassification.from_pretrained(f"{pretrained_model_name_or_path}",config=config, subfolder=model_dir)
|
| 429 |
+
if shared_encoder is None:
|
| 430 |
+
shared_encoder = getattr(model, AlignScoreCS.get_encoder_attr_name(model))
|
| 431 |
+
else:
|
| 432 |
+
setattr(model, AlignScoreCS.get_encoder_attr_name(model), shared_encoder)
|
| 433 |
+
taskmodels_dict[task_name] = model
|
| 434 |
+
alignScoreCS = AlignScoreCS(encoder=shared_encoder, taskmodels_dict=taskmodels_dict, model_name=model_name)
|
| 435 |
+
|
| 436 |
+
return alignScoreCS
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def save_pretrained(
|
| 440 |
+
self,
|
| 441 |
+
save_directory: Union[str, os.PathLike],
|
| 442 |
+
is_main_process: bool = True,
|
| 443 |
+
state_dict: Optional[dict] = None,
|
| 444 |
+
save_function: Callable = torch.save,
|
| 445 |
+
push_to_hub: bool = False,
|
| 446 |
+
max_shard_size: Union[int, str] = "10GB",
|
| 447 |
+
safe_serialization: bool = False,
|
| 448 |
+
variant: Optional[str] = None,
|
| 449 |
+
token: Optional[Union[str, bool]] = None,
|
| 450 |
+
save_peft_format: bool = True,
|
| 451 |
+
**kwargs,
|
| 452 |
+
):
|
| 453 |
+
for task_name, model_type in self.taskmodels_dict.items():
|
| 454 |
+
model_type.save_pretrained(save_directory = Path(save_directory,task_name+"_model"),
|
| 455 |
+
is_main_process = is_main_process,
|
| 456 |
+
state_dict = state_dict,
|
| 457 |
+
save_function = save_function,
|
| 458 |
+
push_to_hub = push_to_hub,
|
| 459 |
+
max_shard_size = max_shard_size,
|
| 460 |
+
safe_serialization = safe_serialization,
|
| 461 |
+
variant = variant,
|
| 462 |
+
token = token,
|
| 463 |
+
save_peft_format = save_peft_format,
|
| 464 |
+
**kwargs)
|
| 465 |
+
|
| 466 |
+
# This piece of code is copied from AlignScore github repository
|
| 467 |
+
# if you want to use different nlg_eval_mode you have to fix errors on your own
|
| 468 |
+
class InferenceHandler:
|
| 469 |
+
def __init__(self, model, tokenizer, device = "cuda"):
|
| 470 |
+
self.model = model
|
| 471 |
+
self.device = device
|
| 472 |
+
self.tokenizer = tokenizer
|
| 473 |
+
self.model.to(self.device)
|
| 474 |
+
self.model.eval()
|
| 475 |
+
self.batch_size = 32
|
| 476 |
+
self.nlg_eval_mode = "nli_sp"
|
| 477 |
+
self.verbose = False
|
| 478 |
+
self.task_name = "3way"
|
| 479 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 480 |
+
|
| 481 |
+
def nlg_eval(self, premise, hypo):
|
| 482 |
+
if isinstance(premise, str) and isinstance(hypo, str):
|
| 483 |
+
premise = [premise]
|
| 484 |
+
hypo = [hypo]
|
| 485 |
+
return self.inference_example_batch(premise, hypo)
|
| 486 |
+
|
| 487 |
+
def inference_example_batch(self, premise: list, hypo: list):
|
| 488 |
+
"""
|
| 489 |
+
inference a example,
|
| 490 |
+
premise: list
|
| 491 |
+
hypo: list
|
| 492 |
+
using self.inference to batch the process
|
| 493 |
+
|
| 494 |
+
SummaC Style aggregation
|
| 495 |
+
"""
|
| 496 |
+
self.disable_progress_bar_in_inference = True
|
| 497 |
+
assert len(premise) == len(hypo), "Premise must has the same length with Hypothesis!"
|
| 498 |
+
|
| 499 |
+
out_score = []
|
| 500 |
+
for one_pre, one_hypo in tqdm(zip(premise, hypo), desc="Evaluating", total=len(premise), disable=(not self.verbose)):
|
| 501 |
+
out_score.append(self.inference_per_example(one_pre, one_hypo))
|
| 502 |
+
|
| 503 |
+
return torch.tensor(out_score)
|
| 504 |
+
|
| 505 |
+
def inference_per_example(self, premise:str, hypo: str):
|
| 506 |
+
"""
|
| 507 |
+
inference a example,
|
| 508 |
+
premise: string
|
| 509 |
+
hypo: string
|
| 510 |
+
using self.inference to batch the process
|
| 511 |
+
"""
|
| 512 |
+
def chunks(lst, n):
|
| 513 |
+
"""Yield successive n-sized chunks from lst."""
|
| 514 |
+
for i in range(0, len(lst), n):
|
| 515 |
+
yield ' '.join(lst[i:i + n])
|
| 516 |
+
|
| 517 |
+
premise_sents = sent_tokenize(premise)
|
| 518 |
+
premise_sents = premise_sents or ['']
|
| 519 |
+
|
| 520 |
+
n_chunk = len(premise.strip().split()) // 350 + 1
|
| 521 |
+
n_chunk = max(len(premise_sents) // n_chunk, 1)
|
| 522 |
+
premise_sents = [each for each in chunks(premise_sents, n_chunk)]
|
| 523 |
+
|
| 524 |
+
hypo_sents = sent_tokenize(hypo)
|
| 525 |
+
|
| 526 |
+
premise_sent_mat = []
|
| 527 |
+
hypo_sents_mat = []
|
| 528 |
+
for i in range(len(premise_sents)):
|
| 529 |
+
for j in range(len(hypo_sents)):
|
| 530 |
+
premise_sent_mat.append(premise_sents[i])
|
| 531 |
+
hypo_sents_mat.append(hypo_sents[j])
|
| 532 |
+
|
| 533 |
+
if self.nlg_eval_mode is not None:
|
| 534 |
+
if self.nlg_eval_mode == 'nli_sp':
|
| 535 |
+
output_score = self.inference(premise_sent_mat, hypo_sents_mat)[:,1] ### use NLI head OR ALIGN head
|
| 536 |
+
output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
|
| 537 |
+
|
| 538 |
+
return output_score
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
output_score = self.inference(premise_sent_mat, hypo_sents_mat) ### use NLI head OR ALIGN head
|
| 542 |
+
output_score = output_score.view(len(premise_sents), len(hypo_sents)).max(dim=0).values.mean().item() ### sum or mean depends on the task/aspect
|
| 543 |
+
|
| 544 |
+
return output_score
|
| 545 |
+
|
| 546 |
+
def inference(self, premise, hypo, task_name = None):
|
| 547 |
+
"""
|
| 548 |
+
inference a list of premise and hypo
|
| 549 |
+
|
| 550 |
+
Standard aggregation
|
| 551 |
+
"""
|
| 552 |
+
task_name = self.task_name if task_name is None else task_name
|
| 553 |
+
if isinstance(premise, str) and isinstance(hypo, str):
|
| 554 |
+
premise = [premise]
|
| 555 |
+
hypo = [hypo]
|
| 556 |
+
|
| 557 |
+
batch = self.batch_tokenize(premise, hypo)
|
| 558 |
+
output_score = []
|
| 559 |
+
|
| 560 |
+
for mini_batch in tqdm(batch, desc="Evaluating", disable=not self.verbose or self.disable_progress_bar_in_inference):
|
| 561 |
+
mini_batch = mini_batch.to(self.device)
|
| 562 |
+
with torch.no_grad():
|
| 563 |
+
model_output = self.model.forward(task_name=task_name, **mini_batch)
|
| 564 |
+
model_output = model_output.logits
|
| 565 |
+
if task_name == "re":
|
| 566 |
+
model_output = model_output.cpu()
|
| 567 |
+
else:
|
| 568 |
+
model_output = self.softmax(model_output).cpu()
|
| 569 |
+
output_score.append(model_output[:,:])
|
| 570 |
+
|
| 571 |
+
output_score = torch.cat(output_score)
|
| 572 |
+
|
| 573 |
+
if self.nlg_eval_mode is not None:
|
| 574 |
+
if self.nlg_eval_mode == 'nli':
|
| 575 |
+
output_score_nli = output_score[:,1]
|
| 576 |
+
return output_score_nli
|
| 577 |
+
elif self.nlg_eval_mode == 'bin':
|
| 578 |
+
return output_score
|
| 579 |
+
elif self.nlg_eval_mode == 'reg':
|
| 580 |
+
return output_score
|
| 581 |
+
else:
|
| 582 |
+
ValueError("unrecognized nlg eval mode")
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
return output_score
|
| 586 |
+
|
| 587 |
+
def batch_tokenize(self, premise, hypo):
|
| 588 |
+
"""
|
| 589 |
+
input premise and hypos are lists
|
| 590 |
+
"""
|
| 591 |
+
assert isinstance(premise, list) and isinstance(hypo, list)
|
| 592 |
+
assert len(premise) == len(hypo), "premise and hypo should be in the same length."
|
| 593 |
+
|
| 594 |
+
batch = []
|
| 595 |
+
for mini_batch_pre, mini_batch_hypo in zip(self.chunks(premise, self.batch_size), self.chunks(hypo, self.batch_size)):
|
| 596 |
+
try:
|
| 597 |
+
mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation='only_first', padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
| 598 |
+
except:
|
| 599 |
+
print('text_b too long...')
|
| 600 |
+
mini_batch = self.tokenizer(mini_batch_pre, mini_batch_hypo, truncation=True, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
| 601 |
+
batch.append(mini_batch)
|
| 602 |
+
|
| 603 |
+
return batch
|
| 604 |
+
|
| 605 |
+
def chunks(self, lst, n):
|
| 606 |
+
"""Yield successive n-sized chunks from lst."""
|
| 607 |
+
for i in range(0, len(lst), n):
|
| 608 |
+
yield lst[i:i + n]
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
if __name__ == "__main__":
|
| 613 |
+
alignScore = AlignScoreCS.from_pretrained("krotima1/AlignScoreCS")
|
| 614 |
+
alignScore.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 615 |
+
|
| 616 |
+
print("Tomáš miluje Zuzku!", "|", "Tomáš miluje Petru!",alignScore.score("Tomáš miluje Zuzku!", "Tomáš miluje Petru."))
|
| 617 |
+
print("Tomáš miluje Zuzku!", "|", "Tomáš miluje Zuzku!",alignScore.score("Tomáš miluje Zuzku!", "Tomáš miluje Zuzku!"))
|
| 618 |
+
print("Tomáš miluje Zuzku.", "|", "Zuzka miluje Tomáše.",alignScore.score("Tomáš miluje Zuzku!", "Zuzka miluje Tomáše."))
|
| 619 |
+
print("Tomáš miluje Zuzku.", "|", "Zuzka nemiluje Tomáše.",alignScore.score("Tomáš miluje Zuzku!", "Zuzka nemiluje Tomáše."))
|
| 620 |
+
print("Tomáš miluje Zuzku.", "|", "Tomáš nemiluje Zuzku.",alignScore.score("Tomáš miluje Zuzku!", "Tomáš nemiluje Zuzku."))
|
| 621 |
+
print("Dva chlapi se perou.", "|", "Je tu bitka.",alignScore.score("Dva chlapi se perou.", "Je tu bitka."))
|
| 622 |
+
print("Dva chlapi se perou.", "|", "Je tu láska.",alignScore.score("Dva chlapi se perou.", "Je tu láska."))
|
| 623 |
+
print("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta. \n Kdo nechal vystavět katedrálu?", "|", "Byl to Karel.",alignScore.score("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.\nKdo nechal vystavět katedrálu?", "Byl to Karel."))
|
| 624 |
+
print("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta. \n Kdo nechal vystavět katedrálu?", "|", "Byl to Vít.",alignScore.score("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.\nKdo nechal vystavět katedrálu?", "Byl to Vít."))
|
| 625 |
+
print("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta. \n Kdo nechal vystavět katedrálu?", "|", "Byla to katedrála.",alignScore.score("Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.\nKdo nechal vystavět katedrálu?", "Byla to katedrála."))
|
| 626 |
+
print("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "|", "Je Otec.",alignScore.score("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "Je Otec."))
|
| 627 |
+
print("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "|", "Je Otec vlasti.",alignScore.score("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "Je Otec vlasti."))
|
| 628 |
+
print("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "|", "Je katedrála svatého Víta.",alignScore.score("Kdo je Karel IV.? Karel IV. je Otec vlasti. Nechal postavit katedrálu svatého Víta.", "Je katedrála svatého Víta."))
|
| 629 |
+
print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka utekla vklovi.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka utekla vklovi."))
|
| 630 |
+
print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka neutekla vklovi.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka neutekla vklovi."))
|
| 631 |
+
print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Vlk snědl Karkulku.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Vlk snědl karkulku."))
|
| 632 |
+
print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Vlk nesnědl Karkulku.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Vlk nesnědl karkulku."))
|
| 633 |
+
print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka snědla vlka.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka snědla vlka."))
|
| 634 |
+
print("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "|", "Karkulka dala vlkovi jablko.",alignScore.score("Karkulka šla do lesa. V lese potkala vlka. Vlk ji zkoušel sníst, ale Karkulka se nedala a Vlkovi utekla!", "Karkulka dala vlkovi jablko."))
|