Xuezha commited on
Commit
5ba2653
·
verified ·
1 Parent(s): 57c1748

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +3 -0
modeling.py CHANGED
@@ -185,6 +185,9 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
185
  if past:
186
  input_ids = input_ids[:, -1].unsqueeze(-1)
187
 
 
 
 
188
  return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
189
 
190
  def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
 
185
  if past:
186
  input_ids = input_ids[:, -1].unsqueeze(-1)
187
 
188
+ if attention_mask is None:
189
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
190
+
191
  return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
192
 
193
  def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):