IlyasMoutawwakil HF Staff commited on
Commit
ee4aed2
·
verified ·
1 Parent(s): 8ab1e26

Update modeling_internlm2.py

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +4 -4
modeling_internlm2.py CHANGED
@@ -360,12 +360,12 @@ class InternLM2Attention(nn.Module):
360
  value_states = value_states.transpose(1, 2)
361
 
362
  kv_seq_len = key_states.shape[-2]
363
- if past_key_value is not None:
364
  kv_seq_len += past_key_value[0].shape[-2]
365
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
366
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
367
 
368
- if past_key_value is not None:
369
  # reuse k, v, self_attention
370
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
371
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
@@ -462,14 +462,14 @@ class InternLM2FlashAttention2(InternLM2Attention):
462
  value_states = value_states.transpose(1, 2)
463
 
464
  kv_seq_len = key_states.shape[-2]
465
- if past_key_value is not None:
466
  kv_seq_len += past_key_value[0].shape[-2]
467
 
468
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
 
470
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
471
 
472
- if past_key_value is not None:
473
  # reuse k, v, self_attention
474
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
475
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
360
  value_states = value_states.transpose(1, 2)
361
 
362
  kv_seq_len = key_states.shape[-2]
363
+ if past_key_value is not None and past_key_value[0] is not None:
364
  kv_seq_len += past_key_value[0].shape[-2]
365
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
366
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
367
 
368
+ if past_key_value is not None and past_key_value[0] is not None:
369
  # reuse k, v, self_attention
370
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
371
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
462
  value_states = value_states.transpose(1, 2)
463
 
464
  kv_seq_len = key_states.shape[-2]
465
+ if past_key_value is not None and past_key_value[0] is not None:
466
  kv_seq_len += past_key_value[0].shape[-2]
467
 
468
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
 
470
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
471
 
472
+ if past_key_value is not None and past_key_value[0] is not None:
473
  # reuse k, v, self_attention
474
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
475
  value_states = torch.cat([past_key_value[1], value_states], dim=2)