Update dnaflash.py
Browse files- dnaflash.py +2 -2
dnaflash.py
CHANGED
|
@@ -408,7 +408,7 @@ class FLASHTransformerForPretrained(PreTrainedModel):
|
|
| 408 |
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
|
| 409 |
)
|
| 410 |
|
| 411 |
-
def forward(self,
|
| 412 |
-
logits, x = self.model(input_ids, mask=
|
| 413 |
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
|
| 414 |
|
|
|
|
| 408 |
reduce_group_non_causal_attn=config.reduce_group_non_causal_attn
|
| 409 |
)
|
| 410 |
|
| 411 |
+
def forward(self,inputs):
|
| 412 |
+
logits, x = self.model(inputs["input_ids"], mask=inputs["attention_mask"])
|
| 413 |
return MaskedLMOutput(logits=logits, hidden_states=x, loss=None, attentions=None)
|
| 414 |
|