Update modeling.py
Browse files- modeling.py +3 -0
modeling.py
CHANGED
|
@@ -185,6 +185,9 @@ class RecombinationTransformerForCausalLM(PreTrainedModel):
|
|
| 185 |
if past:
|
| 186 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 187 |
|
|
|
|
|
|
|
|
|
|
| 188 |
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
| 189 |
|
| 190 |
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
|
|
|
|
| 185 |
if past:
|
| 186 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 187 |
|
| 188 |
+
if attention_mask is None:
|
| 189 |
+
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 190 |
+
|
| 191 |
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
| 192 |
|
| 193 |
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
|