wangleiofficial commited on
Commit
14c7bbd
·
verified ·
1 Parent(s): 9b747e5

Update dnaflash.py

Browse files
Files changed (1) hide show
  1. dnaflash.py +1 -1
dnaflash.py CHANGED
@@ -474,7 +474,7 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
474
  return_dict=return_dict,
475
  )
476
  hidden_states = outputs["hidden_states"]
477
- input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
478
  mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
479
  logits = self.score(mean_pooled)
480
 
 
474
  return_dict=return_dict,
475
  )
476
  hidden_states = outputs["hidden_states"]
477
+ input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1) # 维度匹配
478
  mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
479
  logits = self.score(mean_pooled)
480