Update dnaflash.py
Browse files- dnaflash.py +1 -1
dnaflash.py
CHANGED
|
@@ -475,7 +475,7 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
| 475 |
)
|
| 476 |
hidden_states = outputs["hidden_states"]
|
| 477 |
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
|
| 478 |
-
mean_pooled = torch.sum(
|
| 479 |
logits = self.score(mean_pooled)
|
| 480 |
|
| 481 |
loss = None
|
|
|
|
| 475 |
)
|
| 476 |
hidden_states = outputs["hidden_states"]
|
| 477 |
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
|
| 478 |
+
mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
|
| 479 |
logits = self.score(mean_pooled)
|
| 480 |
|
| 481 |
loss = None
|