aynetdia commited on
Commit
0b695c3
·
1 Parent(s): e07ce7c

fix mean_pool reference

Browse files
Files changed (1) hide show
  1. semscore.py +2 -2
semscore.py CHANGED
@@ -122,8 +122,8 @@ class SemScore(evaluate.Metric):
122
  encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
123
  model_output_refs = self.model(**encoded_refs.to(device))
124
  model_output_preds = self.model(**encoded_preds.to(device))
125
- batch_pooled_refs = mean_pooling(model_output_refs, encoded_refs['attention_mask'])
126
- batch_pooled_preds = mean_pooling(model_output_preds, encoded_preds['attention_mask'])
127
  pooled_refs.append(batch_pooled_refs)
128
  pooled_preds.append(batch_pooled_preds)
129
  pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
 
122
  encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
123
  model_output_refs = self.model(**encoded_refs.to(device))
124
  model_output_preds = self.model(**encoded_preds.to(device))
125
+ batch_pooled_refs = self.mean_pooling(model_output_refs, encoded_refs['attention_mask'])
126
+ batch_pooled_preds = self.mean_pooling(model_output_preds, encoded_preds['attention_mask'])
127
  pooled_refs.append(batch_pooled_refs)
128
  pooled_preds.append(batch_pooled_preds)
129
  pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)