Update modeling_llama.py
Browse files- modeling_llama.py +34 -11
modeling_llama.py
CHANGED
|
@@ -318,7 +318,7 @@ class GroupedAutoEncoder(nn.Module):
|
|
| 318 |
def forward(self, x):
|
| 319 |
# Split input into groups
|
| 320 |
group_inputs = torch.split(x, self.group_input_dim, dim=2)
|
| 321 |
-
|
| 322 |
# Apply group-wise encoding
|
| 323 |
encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]
|
| 324 |
|
|
@@ -370,8 +370,8 @@ class LlamaAttention(nn.Module):
|
|
| 370 |
input_dim = 5120
|
| 371 |
hidden_dim = 320
|
| 372 |
num_groups = 40
|
| 373 |
-
|
| 374 |
-
self.
|
| 375 |
#self.ae_v.eval()
|
| 376 |
|
| 377 |
def _init_rope(self):
|
|
@@ -446,7 +446,16 @@ class LlamaAttention(nn.Module):
|
|
| 446 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 447 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 448 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
kv_seq_len = key_states.shape[-2]
|
| 451 |
if past_key_value is not None:
|
| 452 |
if self.layer_idx is None:
|
|
@@ -458,14 +467,15 @@ class LlamaAttention(nn.Module):
|
|
| 458 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 459 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 460 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 461 |
-
|
| 462 |
if past_key_value is not None:
|
| 463 |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
|
|
|
| 464 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 465 |
|
| 466 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 467 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 468 |
-
|
| 469 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 470 |
|
| 471 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
@@ -485,11 +495,19 @@ class LlamaAttention(nn.Module):
|
|
| 485 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 486 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 487 |
|
| 488 |
-
|
| 489 |
-
if
|
| 490 |
-
|
| 491 |
-
|
|
|
|
| 492 |
self.ae_v.load_state_dict(torch.load("weights_group_320/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
value_states_v = value_states[:,:,35:35+576,:]
|
| 494 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
| 495 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
|
@@ -498,7 +516,12 @@ class LlamaAttention(nn.Module):
|
|
| 498 |
value_states_v = value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1], 40, 128)
|
| 499 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
| 500 |
value_states[:,:,35:35+576,:] = value_states_v
|
|
|
|
|
|
|
|
|
|
| 501 |
|
|
|
|
|
|
|
| 502 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 503 |
|
| 504 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
@@ -1480,4 +1503,4 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|
| 1480 |
past_key_values=transformer_outputs.past_key_values,
|
| 1481 |
hidden_states=transformer_outputs.hidden_states,
|
| 1482 |
attentions=transformer_outputs.attentions,
|
| 1483 |
-
)
|
|
|
|
| 318 |
def forward(self, x):
|
| 319 |
# Split input into groups
|
| 320 |
group_inputs = torch.split(x, self.group_input_dim, dim=2)
|
| 321 |
+
|
| 322 |
# Apply group-wise encoding
|
| 323 |
encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]
|
| 324 |
|
|
|
|
| 370 |
input_dim = 5120
|
| 371 |
hidden_dim = 320
|
| 372 |
num_groups = 40
|
| 373 |
+
self.ae_v = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups)
|
| 374 |
+
self.load_ae_v = True
|
| 375 |
#self.ae_v.eval()
|
| 376 |
|
| 377 |
def _init_rope(self):
|
|
|
|
| 446 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 447 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 448 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# import pdb; pdb.set_trace()
|
| 452 |
+
|
| 453 |
+
if value_states.shape[2]>576:
|
| 454 |
+
reuse = True
|
| 455 |
+
value_states_ = value_states.clone()
|
| 456 |
+
else:
|
| 457 |
+
reuse = False
|
| 458 |
+
|
| 459 |
kv_seq_len = key_states.shape[-2]
|
| 460 |
if past_key_value is not None:
|
| 461 |
if self.layer_idx is None:
|
|
|
|
| 467 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 468 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 469 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 470 |
+
|
| 471 |
if past_key_value is not None:
|
| 472 |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 473 |
+
# print(value_states.shape)
|
| 474 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 475 |
|
| 476 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 477 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 478 |
+
|
| 479 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 480 |
|
| 481 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
|
| 495 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 496 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 497 |
|
| 498 |
+
|
| 499 |
+
#if self.layer_idx==5:
|
| 500 |
+
# print(value_states[0,0,256,:])
|
| 501 |
+
|
| 502 |
+
if self.load_ae_v:
|
| 503 |
self.ae_v.load_state_dict(torch.load("weights_group_320/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
|
| 504 |
+
self.load_ae_v = False
|
| 505 |
+
else:
|
| 506 |
+
pass
|
| 507 |
+
|
| 508 |
+
#if self.layer_idx==5:
|
| 509 |
+
# print(value_states.shape)
|
| 510 |
+
if value_states.shape[2]>576:
|
| 511 |
value_states_v = value_states[:,:,35:35+576,:]
|
| 512 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
| 513 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
|
|
|
| 516 |
value_states_v = value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1], 40, 128)
|
| 517 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
| 518 |
value_states[:,:,35:35+576,:] = value_states_v
|
| 519 |
+
|
| 520 |
+
# if reuse:
|
| 521 |
+
# value_states = value_states_
|
| 522 |
|
| 523 |
+
#if self.layer_idx==5:
|
| 524 |
+
# print(value_states[0,0,256,:])
|
| 525 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 526 |
|
| 527 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
|
| 1503 |
past_key_values=transformer_outputs.past_key_values,
|
| 1504 |
hidden_states=transformer_outputs.hidden_states,
|
| 1505 |
attentions=transformer_outputs.attentions,
|
| 1506 |
+
)
|