Update modeling_internlm2.py
Browse files- modeling_internlm2.py +4 -4
modeling_internlm2.py
CHANGED
|
@@ -360,12 +360,12 @@ class InternLM2Attention(nn.Module):
|
|
| 360 |
value_states = value_states.transpose(1, 2)
|
| 361 |
|
| 362 |
kv_seq_len = key_states.shape[-2]
|
| 363 |
-
if past_key_value is not None:
|
| 364 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 365 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 366 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 367 |
|
| 368 |
-
if past_key_value is not None:
|
| 369 |
# reuse k, v, self_attention
|
| 370 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 371 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
@@ -462,14 +462,14 @@ class InternLM2FlashAttention2(InternLM2Attention):
|
|
| 462 |
value_states = value_states.transpose(1, 2)
|
| 463 |
|
| 464 |
kv_seq_len = key_states.shape[-2]
|
| 465 |
-
if past_key_value is not None:
|
| 466 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 467 |
|
| 468 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 469 |
|
| 470 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 471 |
|
| 472 |
-
if past_key_value is not None:
|
| 473 |
# reuse k, v, self_attention
|
| 474 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 475 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
| 360 |
value_states = value_states.transpose(1, 2)
|
| 361 |
|
| 362 |
kv_seq_len = key_states.shape[-2]
|
| 363 |
+
if past_key_value is not None and past_key_value[0] is not None:
|
| 364 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 365 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 366 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 367 |
|
| 368 |
+
if past_key_value is not None and past_key_value[0] is not None:
|
| 369 |
# reuse k, v, self_attention
|
| 370 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 371 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
| 462 |
value_states = value_states.transpose(1, 2)
|
| 463 |
|
| 464 |
kv_seq_len = key_states.shape[-2]
|
| 465 |
+
if past_key_value is not None and past_key_value[0] is not None:
|
| 466 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 467 |
|
| 468 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 469 |
|
| 470 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 471 |
|
| 472 |
+
if past_key_value is not None and past_key_value[0] is not None:
|
| 473 |
# reuse k, v, self_attention
|
| 474 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 475 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|