wangleiofficial commited on
Commit
ef5bad5
·
verified ·
1 Parent(s): cc2b29e

Update dnaflash.py

Browse files
Files changed (1) hide show
  1. dnaflash.py +2 -2
dnaflash.py CHANGED
@@ -408,7 +408,7 @@ class FLASHTransformerForPretrained(PreTrainedModel):
408
  reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
409
  )
410
 
411
- def forward(self, input_ids, mask=None):
412
- logits, x = self.model(input_ids, mask=mask)
413
  return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
414
 
 
408
  reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
409
  )
410
 
411
+ def forward(self,inputs):
412
+ logits, x = self.model(inputs["input_ids"], mask=inputs["attention_mask"])
413
  return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
414