wangleiofficial commited on
Commit
72ed94e
·
verified ·
1 Parent(s): 82dc5f9
Files changed (1) hide show
  1. dnaflash.py +24 -3
dnaflash.py CHANGED
@@ -416,8 +416,18 @@ class FLASHTransformerForPretrained(PreTrainedModel):
416
  reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
417
  )
418
 
419
- def forward(self,inputs):
420
- logits, x = self.model(inputs["input_ids"], mask=inputs["attention_mask"])
 
 
 
 
 
 
 
 
 
 
421
  return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
422
 
423
  class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
@@ -438,7 +448,12 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
438
  def forward(
439
  self,
440
  input_ids: Optional[torch.LongTensor] = None,
 
 
 
441
  labels: Optional[torch.LongTensor] = None,
 
 
442
  return_dict: Optional[bool] = None,
443
  ) -> Union[Tuple, SequenceClassifierOutput]:
444
  r"""
@@ -450,7 +465,13 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
450
 
451
  # 获取基模型输出
452
  outputs = super().forward(
453
- input_ids
 
 
 
 
 
 
454
  )
455
  hidden_states = outputs["hidden_states"]
456
  input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配
 
416
  reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
417
  )
418
 
419
+ def forward(
420
+ self,
421
+ input_ids: torch.LongTensor = None,
422
+ attention_mask: Optional[torch.Tensor] = None,
423
+ position_ids: Optional[torch.LongTensor] = None,
424
+ inputs_embeds: Optional[torch.FloatTensor] = None,
425
+ labels: Optional[torch.LongTensor] = None,
426
+ output_attentions: Optional[bool] = None,
427
+ output_hidden_states: Optional[bool] = None,
428
+ return_dict: Optional[bool] = None
429
+ )->Union[Tuple, MaskedLMOutput]:
430
+ logits, x = self.model(input_ids, mask=attention_mask)
431
  return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
432
 
433
  class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
 
448
  def forward(
449
  self,
450
  input_ids: Optional[torch.LongTensor] = None,
451
+ attention_mask: Optional[torch.Tensor] = None,
452
+ position_ids: Optional[torch.LongTensor] = None,
453
+ inputs_embeds: Optional[torch.FloatTensor] = None,
454
  labels: Optional[torch.LongTensor] = None,
455
+ output_attentions: Optional[bool] = None,
456
+ output_hidden_states: Optional[bool] = None,
457
  return_dict: Optional[bool] = None,
458
  ) -> Union[Tuple, SequenceClassifierOutput]:
459
  r"""
 
465
 
466
  # 获取基模型输出
467
  outputs = super().forward(
468
+ input_ids,
469
+ attention_mask=attention_mask,
470
+ position_ids=position_ids,
471
+ inputs_embeds=inputs_embeds,
472
+ output_attentions=output_attentions,
473
+ output_hidden_states=output_hidden_states,
474
+ return_dict=return_dict,
475
  )
476
  hidden_states = outputs["hidden_states"]
477
  input_mask_expanded = input_ids["attention_mask"].unsqueeze(-1).expand(hidden_states.size()) # 维度匹配