fix param err
Browse files- dnaflash.py +1 -1
dnaflash.py
CHANGED
|
@@ -450,7 +450,7 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
| 450 |
|
| 451 |
# 获取基模型输出
|
| 452 |
outputs = super().forward(
|
| 453 |
-
input_ids
|
| 454 |
)
|
| 455 |
hidden_states = outputs["hidden_states"]
|
| 456 |
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
|
|
|
|
| 450 |
|
| 451 |
# 获取基模型输出
|
| 452 |
outputs = super().forward(
|
| 453 |
+
input_ids
|
| 454 |
)
|
| 455 |
hidden_states = outputs["hidden_states"]
|
| 456 |
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
|