add support for last token pooling
Browse files- semscore.py +20 -3
semscore.py
CHANGED
|
@@ -87,6 +87,7 @@ class SemScore(evaluate.Metric):
|
|
| 87 |
# Load model and tokenizer from HuggingFace Hub
|
| 88 |
self.model = AutoModel.from_pretrained(checkpoint)
|
| 89 |
self.model.eval()
|
|
|
|
| 90 |
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
| 91 |
|
| 92 |
@staticmethod
|
|
@@ -95,6 +96,16 @@ class SemScore(evaluate.Metric):
|
|
| 95 |
token_embeddings = model_output[0]
|
| 96 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 97 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def _compute(
|
| 100 |
self,
|
|
@@ -102,10 +113,12 @@ class SemScore(evaluate.Metric):
|
|
| 102 |
references,
|
| 103 |
batch_size=32,
|
| 104 |
device=None,
|
|
|
|
| 105 |
):
|
| 106 |
"""Returns the scores"""
|
| 107 |
|
| 108 |
assert len(predictions) == len(references), "predictions and references should have the same length."
|
|
|
|
| 109 |
if device is not None:
|
| 110 |
if "cuda" in device:
|
| 111 |
assert torch.cuda.is_available()
|
|
@@ -123,8 +136,12 @@ class SemScore(evaluate.Metric):
|
|
| 123 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
| 124 |
model_output_refs = self.model(**encoded_refs.to(device))
|
| 125 |
model_output_preds = self.model(**encoded_preds.to(device))
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
pooled_refs.append(batch_pooled_refs)
|
| 129 |
pooled_preds.append(batch_pooled_preds)
|
| 130 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
|
@@ -136,4 +153,4 @@ class SemScore(evaluate.Metric):
|
|
| 136 |
return {
|
| 137 |
"semscore": round(semscore.item(), 2),
|
| 138 |
"similarities": similarities.tolist()
|
| 139 |
-
}
|
|
|
|
| 87 |
# Load model and tokenizer from HuggingFace Hub
|
| 88 |
self.model = AutoModel.from_pretrained(checkpoint)
|
| 89 |
self.model.eval()
|
| 90 |
+
padding_side = "left" if self
|
| 91 |
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
| 92 |
|
| 93 |
@staticmethod
|
|
|
|
| 96 |
token_embeddings = model_output[0]
|
| 97 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 98 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _last_token_pooling(last_hidden_states, attention_mask):
|
| 102 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 103 |
+
if left_padding:
|
| 104 |
+
return last_hidden_states[:, -1]
|
| 105 |
+
else:
|
| 106 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 107 |
+
batch_size = last_hidden_states.shape[0]
|
| 108 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 109 |
|
| 110 |
def _compute(
|
| 111 |
self,
|
|
|
|
| 113 |
references,
|
| 114 |
batch_size=32,
|
| 115 |
device=None,
|
| 116 |
+
pooling="mean"
|
| 117 |
):
|
| 118 |
"""Returns the scores"""
|
| 119 |
|
| 120 |
assert len(predictions) == len(references), "predictions and references should have the same length."
|
| 121 |
+
assert pooling in ["mean", "last"]
|
| 122 |
if device is not None:
|
| 123 |
if "cuda" in device:
|
| 124 |
assert torch.cuda.is_available()
|
|
|
|
| 136 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
| 137 |
model_output_refs = self.model(**encoded_refs.to(device))
|
| 138 |
model_output_preds = self.model(**encoded_preds.to(device))
|
| 139 |
+
if pooling == "mean":
|
| 140 |
+
batch_pooled_refs = self._mean_pooling(model_output_refs, encoded_refs['attention_mask'])
|
| 141 |
+
batch_pooled_preds = self._mean_pooling(model_output_preds, encoded_preds['attention_mask'])
|
| 142 |
+
elif pooling == "last":
|
| 143 |
+
batch_pooled_refs = self._last_token_pooling(model_output_refs, encoded_refs['attention_mask'])
|
| 144 |
+
batch_pooled_preds = self._last_token_pooling(model_output_preds, encoded_preds['attention_mask'])
|
| 145 |
pooled_refs.append(batch_pooled_refs)
|
| 146 |
pooled_preds.append(batch_pooled_preds)
|
| 147 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
|
|
|
| 153 |
return {
|
| 154 |
"semscore": round(semscore.item(), 2),
|
| 155 |
"similarities": similarities.tolist()
|
| 156 |
+
}
|