| import torch | |
| from transformers import LlamaForCausalLM | |
| from .configuration_pruned_llama import LlamaPrunedConfig | |
| import torch.nn as nn | |
| class LlamaPrunedForCausalLM(LlamaForCausalLM): | |
| config_class = LlamaPrunedConfig | |
| def __init__(self, config: LlamaPrunedConfig): | |
| super().__init__(config) | |
| for i in range(32): | |
| self.model.layers[i].self_attn.hidden_size = 2048 | |
| self.model.layers[i].self_attn.q_proj = nn.Linear(4096, 1024, bias=False) | |
| self.model.layers[i].self_attn.k_proj = nn.Linear(4096, 256, bias=False) | |
| self.model.layers[i].self_attn.v_proj = nn.Linear(4096, 256, bias=False) | |
| self.model.layers[i].self_attn.o_proj = nn.Linear(1024, 4096, bias=False) | |
| self.model.layers[i].mlp.gate_proj = nn.Linear(4096, 2007, bias=False) | |
| self.model.layers[i].mlp.up_proj = nn.Linear(4096, 2007, bias=False) | |
| self.model.layers[i].mlp.down_proj = nn.Linear(2007, 4096, bias=False) | |
| for layer in self.model.layers: | |
| layer.self_attn.num_heads = layer.self_attn.q_proj.weight.data.shape[0] // layer.self_attn.head_dim | |
| layer.self_attn.num_key_value_heads = layer.self_attn.k_proj.weight.data.shape[ | |
| 0] // layer.self_attn.head_dim | |