zwt123home123 commited on
Commit
97d4b28
·
verified ·
1 Parent(s): 469195c

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +34 -11
modeling_llama.py CHANGED
@@ -318,7 +318,7 @@ class GroupedAutoEncoder(nn.Module):
318
  def forward(self, x):
319
  # Split input into groups
320
  group_inputs = torch.split(x, self.group_input_dim, dim=2)
321
- # import pdb; pdb.set_trace()
322
  # Apply group-wise encoding
323
  encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]
324
 
@@ -370,8 +370,8 @@ class LlamaAttention(nn.Module):
370
  input_dim = 5120
371
  hidden_dim = 320
372
  num_groups = 40
373
- # self.ae_v = AutoEncoder(input_dim, hidden_dim)#.cuda()
374
- self.ae_v = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups)# .cuda()
375
  #self.ae_v.eval()
376
 
377
  def _init_rope(self):
@@ -446,7 +446,16 @@ class LlamaAttention(nn.Module):
446
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
447
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449
-
 
 
 
 
 
 
 
 
 
450
  kv_seq_len = key_states.shape[-2]
451
  if past_key_value is not None:
452
  if self.layer_idx is None:
@@ -458,14 +467,15 @@ class LlamaAttention(nn.Module):
458
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
459
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
460
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
461
-
462
  if past_key_value is not None:
463
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
464
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
465
 
466
  key_states = repeat_kv(key_states, self.num_key_value_groups)
467
  value_states = repeat_kv(value_states, self.num_key_value_groups)
468
-
469
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
470
 
471
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@@ -485,11 +495,19 @@ class LlamaAttention(nn.Module):
485
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
486
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
487
 
488
- # import pdb; pdb.set_trace()
489
- if attn_weights.shape[2]>576:
490
- # print("loading ... ")
491
- #print(value_states.shape)
 
492
  self.ae_v.load_state_dict(torch.load("weights_group_320/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
 
 
 
 
 
 
 
493
  value_states_v = value_states[:,:,35:35+576,:]
494
  value_states_v = value_states_v.permute(0, 2, 1, 3)
495
  value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
@@ -498,7 +516,12 @@ class LlamaAttention(nn.Module):
498
  value_states_v = value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1], 40, 128)
499
  value_states_v = value_states_v.permute(0, 2, 1, 3)
500
  value_states[:,:,35:35+576,:] = value_states_v
 
 
 
501
 
 
 
502
  attn_output = torch.matmul(attn_weights, value_states)
503
 
504
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -1480,4 +1503,4 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1480
  past_key_values=transformer_outputs.past_key_values,
1481
  hidden_states=transformer_outputs.hidden_states,
1482
  attentions=transformer_outputs.attentions,
1483
- )
 
318
  def forward(self, x):
319
  # Split input into groups
320
  group_inputs = torch.split(x, self.group_input_dim, dim=2)
321
+
322
  # Apply group-wise encoding
323
  encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]
324
 
 
370
  input_dim = 5120
371
  hidden_dim = 320
372
  num_groups = 40
373
+ self.ae_v = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups)
374
+ self.load_ae_v = True
375
  #self.ae_v.eval()
376
 
377
  def _init_rope(self):
 
446
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
447
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449
+
450
+
451
+ # import pdb; pdb.set_trace()
452
+
453
+ if value_states.shape[2]>576:
454
+ reuse = True
455
+ value_states_ = value_states.clone()
456
+ else:
457
+ reuse = False
458
+
459
  kv_seq_len = key_states.shape[-2]
460
  if past_key_value is not None:
461
  if self.layer_idx is None:
 
467
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
468
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
470
+
471
  if past_key_value is not None:
472
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
473
+ # print(value_states.shape)
474
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
475
 
476
  key_states = repeat_kv(key_states, self.num_key_value_groups)
477
  value_states = repeat_kv(value_states, self.num_key_value_groups)
478
+
479
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
480
 
481
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
495
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
496
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
497
 
498
+
499
+ #if self.layer_idx==5:
500
+ # print(value_states[0,0,256,:])
501
+
502
+ if self.load_ae_v:
503
  self.ae_v.load_state_dict(torch.load("weights_group_320/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
504
+ self.load_ae_v = False
505
+ else:
506
+ pass
507
+
508
+ #if self.layer_idx==5:
509
+ # print(value_states.shape)
510
+ if value_states.shape[2]>576:
511
  value_states_v = value_states[:,:,35:35+576,:]
512
  value_states_v = value_states_v.permute(0, 2, 1, 3)
513
  value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
 
516
  value_states_v = value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1], 40, 128)
517
  value_states_v = value_states_v.permute(0, 2, 1, 3)
518
  value_states[:,:,35:35+576,:] = value_states_v
519
+
520
+ # if reuse:
521
+ # value_states = value_states_
522
 
523
+ #if self.layer_idx==5:
524
+ # print(value_states[0,0,256,:])
525
  attn_output = torch.matmul(attn_weights, value_states)
526
 
527
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
1503
  past_key_values=transformer_outputs.past_key_values,
1504
  hidden_states=transformer_outputs.hidden_states,
1505
  attentions=transformer_outputs.attentions,
1506
+ )