Update modeling_internlm2.py
Browse files- modeling_internlm2.py +1 -1
modeling_internlm2.py
CHANGED
|
@@ -1088,7 +1088,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 1088 |
def prepare_inputs_for_generation(
|
| 1089 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1090 |
):
|
| 1091 |
-
if past_key_values is not None:
|
| 1092 |
past_length = past_key_values[0][0].shape[2]
|
| 1093 |
|
| 1094 |
# Some generation methods already pass only the last input ID
|
|
|
|
| 1088 |
def prepare_inputs_for_generation(
|
| 1089 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1090 |
):
|
| 1091 |
+
if past_key_values is not None and past_key_values[0][0] is not None:
|
| 1092 |
past_length = past_key_values[0][0].shape[2]
|
| 1093 |
|
| 1094 |
# Some generation methods already pass only the last input ID
|