fix mean_pool reference
Browse files- semscore.py +2 -2
semscore.py
CHANGED
|
@@ -122,8 +122,8 @@ class SemScore(evaluate.Metric):
|
|
| 122 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
| 123 |
model_output_refs = self.model(**encoded_refs.to(device))
|
| 124 |
model_output_preds = self.model(**encoded_preds.to(device))
|
| 125 |
-
batch_pooled_refs = mean_pooling(model_output_refs, encoded_refs['attention_mask'])
|
| 126 |
-
batch_pooled_preds = mean_pooling(model_output_preds, encoded_preds['attention_mask'])
|
| 127 |
pooled_refs.append(batch_pooled_refs)
|
| 128 |
pooled_preds.append(batch_pooled_preds)
|
| 129 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
|
|
|
| 122 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
| 123 |
model_output_refs = self.model(**encoded_refs.to(device))
|
| 124 |
model_output_preds = self.model(**encoded_preds.to(device))
|
| 125 |
+
batch_pooled_refs = self.mean_pooling(model_output_refs, encoded_refs['attention_mask'])
|
| 126 |
+
batch_pooled_preds = self.mean_pooling(model_output_preds, encoded_preds['attention_mask'])
|
| 127 |
pooled_refs.append(batch_pooled_refs)
|
| 128 |
pooled_preds.append(batch_pooled_preds)
|
| 129 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|