fix param
Browse files- dnaflash.py +24 -3
dnaflash.py
CHANGED
|
@@ -416,8 +416,18 @@ class FLASHTransformerForPretrained(PreTrainedModel):
|
|
| 416 |
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
|
| 417 |
)
|
| 418 |
|
| 419 |
-
def forward(
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
|
| 422 |
|
| 423 |
class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
@@ -438,7 +448,12 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
| 438 |
def forward(
|
| 439 |
self,
|
| 440 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
|
|
|
|
|
| 441 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
|
|
|
| 442 |
return_dict: Optional[bool] = None,
|
| 443 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 444 |
r"""
|
|
@@ -450,7 +465,13 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
| 450 |
|
| 451 |
# 获取基模型输出
|
| 452 |
outputs = super().forward(
|
| 453 |
-
input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
)
|
| 455 |
hidden_states = outputs["hidden_states"]
|
| 456 |
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
|
|
|
|
| 416 |
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
|
| 417 |
)
|
| 418 |
|
| 419 |
+
def forward(
|
| 420 |
+
self,
|
| 421 |
+
input_ids: torch.LongTensor = None,
|
| 422 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 423 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 424 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 425 |
+
labels: Optional[torch.LongTensor] = None,
|
| 426 |
+
output_attentions: Optional[bool] = None,
|
| 427 |
+
output_hidden_states: Optional[bool] = None,
|
| 428 |
+
return_dict: Optional[bool] = None
|
| 429 |
+
)->Union[Tuple, MaskedLMOutput]:
|
| 430 |
+
logits, x = self.model(input_ids, mask=attention_mask)
|
| 431 |
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
|
| 432 |
|
| 433 |
class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
|
|
| 448 |
def forward(
|
| 449 |
self,
|
| 450 |
input_ids: Optional[torch.LongTensor] = None,
|
| 451 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 452 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 453 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 454 |
labels: Optional[torch.LongTensor] = None,
|
| 455 |
+
output_attentions: Optional[bool] = None,
|
| 456 |
+
output_hidden_states: Optional[bool] = None,
|
| 457 |
return_dict: Optional[bool] = None,
|
| 458 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 459 |
r"""
|
|
|
|
| 465 |
|
| 466 |
# 获取基模型输出
|
| 467 |
outputs = super().forward(
|
| 468 |
+
input_ids,
|
| 469 |
+
attention_mask=attention_mask,
|
| 470 |
+
position_ids=position_ids,
|
| 471 |
+
inputs_embeds=inputs_embeds,
|
| 472 |
+
output_attentions=output_attentions,
|
| 473 |
+
output_hidden_states=output_hidden_states,
|
| 474 |
+
return_dict=return_dict,
|
| 475 |
)
|
| 476 |
hidden_states = outputs["hidden_states"]
|
| 477 |
input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
|