Update modeling_llama.py
Browse files- modeling_llama.py +1 -1
modeling_llama.py
CHANGED
|
@@ -457,7 +457,7 @@ class LlamaAttention(nn.Module):
|
|
| 457 |
if attn_weights.shape[2]>576:
|
| 458 |
# print("loading ... ")
|
| 459 |
#print(value_states.shape)
|
| 460 |
-
self.ae_v.load_state_dict(torch.load("weights_480/"+"
|
| 461 |
value_states_v = value_states[:,:,35:35+576,:]
|
| 462 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
| 463 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
|
|
|
| 457 |
if attn_weights.shape[2]>576:
|
| 458 |
# print("loading ... ")
|
| 459 |
#print(value_states.shape)
|
| 460 |
+
self.ae_v.load_state_dict(torch.load("weights_480/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
|
| 461 |
value_states_v = value_states[:,:,35:35+576,:]
|
| 462 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
| 463 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|