spachava commited on
Commit
8b9ffa2
·
1 Parent(s): 7d1f047

Upload CompressedLlamaForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CompressedLlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_compressed_llama.CompressedLlamaConfig",
8
+ "AutoModelForCausalLM": "modeling_compressed_llama.CompressedLlamaForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 3200,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 8640,
16
+ "max_position_embeddings": 2048,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 26,
20
+ "num_key_value_heads": 32,
21
+ "pad_token_id": 0,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-06,
24
+ "rope_scaling": null,
25
+ "rope_theta": 10000.0,
26
+ "share_layers": "none",
27
+ "tie_word_embeddings": false,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.35.2",
30
+ "use_cache": true,
31
+ "vocab_size": 32000
32
+ }
configuration_compressed_llama.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig
2
+ from typing import List, Union
3
+
4
+
5
+ class CompressedLlamaConfig(LlamaConfig):
6
+
7
+ def __init__(
8
+ self,
9
+ share_layers: Union[List[List[int]], str] = "none",
10
+ **kwargs,
11
+ ):
12
+ if isinstance(share_layers, str) and share_layers not in ["none", "all"]:
13
+ raise ValueError(f"`share_layers` must be 'none' or all', got {share_layers}.")
14
+ if isinstance(share_layers, list):
15
+ already_shared = []
16
+ # check all elements are of type list
17
+ for shared_layer in share_layers:
18
+ if not isinstance(shared_layer, list):
19
+ raise ValueError(f"`share_layers` must be contain a list of list of ints, got {share_layers}.")
20
+ for layer in shared_layer:
21
+ if not isinstance(layer, int):
22
+ raise ValueError(f"`share_layers` must be contain a list of list of ints, got {share_layers}.")
23
+ if layer in already_shared:
24
+ raise ValueError(f"you can only share a lyaer once, got {share_layers}.")
25
+ already_shared.append(layer)
26
+
27
+ self.share_layers = share_layers
28
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.35.2"
7
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cde221462593019af025faf5cce19f0bcf432f6106d8d4f6193ef65dde77dec
3
+ size 4993264136
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da989a50e26cc2e59c9c546b439ea6d3cedbf2090068f1159b61bfde743970ba
3
+ size 4997386488
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77e75f2f2edceb7bee4131d7e189af5f09aaa91d3f9b95b62b53811fb671293b
3
+ size 3715271008
model.safetensors.index.json ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13705894400
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00003-of-00003.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00003.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00003.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00003.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00003.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
179
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
180
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
181
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
182
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
183
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
184
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
185
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
186
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
187
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
188
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
189
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
190
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
191
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
192
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
193
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
194
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
195
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
196
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
197
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
198
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
199
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
200
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
201
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
202
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
203
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
204
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
205
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
206
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
207
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
208
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
209
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
210
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
211
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
212
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
213
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
214
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
215
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
216
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
217
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
218
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
219
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
220
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
221
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
223
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
225
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
226
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
227
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
228
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
229
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
230
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
231
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
232
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
233
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00003.safetensors",
234
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
235
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
236
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
237
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
238
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
239
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
240
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
241
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
242
+ "model.norm.weight": "model-00003-of-00003.safetensors"
243
+ }
244
+ }
modeling_compressed_llama.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
3
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4
+ from transformers.utils import logging
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ from .configuration_compressed_llama import CompressedLlamaConfig
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+ class CompressedLlamaPreTrainedModel(PreTrainedModel):
19
+ config_class = CompressedLlamaConfig
20
+ base_model_prefix = "model"
21
+ supports_gradient_checkpointing = False
22
+ _no_split_modules = ["LlamaDecoderLayer"]
23
+ _skip_keys_device_placement = "past_key_values"
24
+ _supports_flash_attn_2 = True
25
+
26
+ def _init_weights(self, module):
27
+ std = self.config.initializer_range
28
+ if isinstance(module, nn.Linear):
29
+ module.weight.data.normal_(mean=0.0, std=std)
30
+ if module.bias is not None:
31
+ module.bias.data.zero_()
32
+ elif isinstance(module, nn.Embedding):
33
+ module.weight.data.normal_(mean=0.0, std=std)
34
+ if module.padding_idx is not None:
35
+ module.weight.data[module.padding_idx].zero_()
36
+
37
+ class CompressedLlamaModel(CompressedLlamaPreTrainedModel):
38
+ """
39
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
40
+
41
+ Args:
42
+ config: LlamaConfig
43
+ """
44
+
45
+ def __init__(self, config: CompressedLlamaConfig):
46
+ super().__init__(config)
47
+ self.padding_idx = config.pad_token_id
48
+ self.vocab_size = config.vocab_size
49
+
50
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
51
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
52
+
53
+ # Now, share the MLP layers based on the config
54
+ if isinstance(config.share_layers, str):
55
+ if config.share_layers == "all":
56
+ # Share all layers with a single MLP
57
+ shared_mlp = self.layers[0].mlp
58
+ for layer in self.layers:
59
+ layer.mlp = shared_mlp
60
+
61
+ elif isinstance(config.share_layers, list):
62
+ # Share specific layers with each other
63
+ logging.critical("fine-grained layer sharing not yet supported!")
64
+ raise NotImplementedError(f"fine-grained layer sharing not yet supported, config: {config.share_layers}")
65
+
66
+ else:
67
+ # Handle unexpected types, though this shouldn't happen due to your init checks
68
+ print("Unexpected value for share_layers.")
69
+
70
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
71
+
72
+ self.gradient_checkpointing = False
73
+ # Initialize weights and apply final processing
74
+ self.post_init()
75
+
76
+ def get_input_embeddings(self):
77
+ return self.embed_tokens
78
+
79
+ def set_input_embeddings(self, value):
80
+ self.embed_tokens = value
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: torch.LongTensor = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ position_ids: Optional[torch.LongTensor] = None,
87
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ use_cache: Optional[bool] = None,
90
+ output_attentions: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ return_dict: Optional[bool] = None,
93
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
94
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
95
+ output_hidden_states = (
96
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
97
+ )
98
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
99
+
100
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
101
+
102
+ # retrieve input_ids and inputs_embeds
103
+ if input_ids is not None and inputs_embeds is not None:
104
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
105
+ elif input_ids is not None:
106
+ batch_size, seq_length = input_ids.shape[:2]
107
+ elif inputs_embeds is not None:
108
+ batch_size, seq_length = inputs_embeds.shape[:2]
109
+ else:
110
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
111
+
112
+ past_key_values_length = 0
113
+ if past_key_values is not None:
114
+ past_key_values_length = past_key_values[0][0].shape[2]
115
+
116
+ if position_ids is None:
117
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
118
+ position_ids = torch.arange(
119
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
120
+ )
121
+ position_ids = position_ids.unsqueeze(0)
122
+
123
+ if inputs_embeds is None:
124
+ inputs_embeds = self.embed_tokens(input_ids)
125
+
126
+ if getattr(self.config, "_flash_attn_2_enabled", False):
127
+ # 2d mask is passed through the layers
128
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
129
+ else:
130
+ # 4d mask is passed through the layers
131
+ attention_mask = _prepare_4d_causal_attention_mask(
132
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
133
+ )
134
+
135
+ # embed positions
136
+ hidden_states = inputs_embeds
137
+
138
+ if self.gradient_checkpointing and self.training:
139
+ if use_cache:
140
+ logger.warning_once(
141
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
142
+ )
143
+ use_cache = False
144
+
145
+ # decoder layers
146
+ all_hidden_states = () if output_hidden_states else None
147
+ all_self_attns = () if output_attentions else None
148
+ next_decoder_cache = () if use_cache else None
149
+
150
+ for idx, decoder_layer in enumerate(self.layers):
151
+ if output_hidden_states:
152
+ all_hidden_states += (hidden_states,)
153
+
154
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
155
+
156
+ if self.gradient_checkpointing and self.training:
157
+ layer_outputs = self._gradient_checkpointing_func(
158
+ decoder_layer.__call__,
159
+ hidden_states,
160
+ attention_mask,
161
+ position_ids,
162
+ past_key_value,
163
+ output_attentions,
164
+ use_cache,
165
+ )
166
+ else:
167
+ layer_outputs = decoder_layer(
168
+ hidden_states,
169
+ attention_mask=attention_mask,
170
+ position_ids=position_ids,
171
+ past_key_value=past_key_value,
172
+ output_attentions=output_attentions,
173
+ use_cache=use_cache,
174
+ )
175
+
176
+ hidden_states = layer_outputs[0]
177
+
178
+ if use_cache:
179
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
180
+
181
+ if output_attentions:
182
+ all_self_attns += (layer_outputs[1],)
183
+
184
+ hidden_states = self.norm(hidden_states)
185
+
186
+ # add hidden states from the last decoder layer
187
+ if output_hidden_states:
188
+ all_hidden_states += (hidden_states,)
189
+
190
+ next_cache = next_decoder_cache if use_cache else None
191
+ if not return_dict:
192
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
193
+ return BaseModelOutputWithPast(
194
+ last_hidden_state=hidden_states,
195
+ past_key_values=next_cache,
196
+ hidden_states=all_hidden_states,
197
+ attentions=all_self_attns,
198
+ )
199
+
200
+
201
+ class CompressedLlamaForCausalLM(CompressedLlamaPreTrainedModel):
202
+ _tied_weights_keys = ["lm_head.weight"]
203
+
204
+ def __init__(self, config):
205
+ super().__init__(config)
206
+ self.model = CompressedLlamaModel(config)
207
+ self.vocab_size = config.vocab_size
208
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
209
+
210
+ # Initialize weights and apply final processing
211
+ self.post_init()
212
+
213
+ def get_input_embeddings(self):
214
+ return self.model.embed_tokens
215
+
216
+ def set_input_embeddings(self, value):
217
+ self.model.embed_tokens = value
218
+
219
+ def get_output_embeddings(self):
220
+ return self.lm_head
221
+
222
+ def set_output_embeddings(self, new_embeddings):
223
+ self.lm_head = new_embeddings
224
+
225
+ def set_decoder(self, decoder):
226
+ self.model = decoder
227
+
228
+ def get_decoder(self):
229
+ return self.model
230
+
231
+ def forward(
232
+ self,
233
+ input_ids: torch.LongTensor = None,
234
+ attention_mask: Optional[torch.Tensor] = None,
235
+ position_ids: Optional[torch.LongTensor] = None,
236
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
237
+ inputs_embeds: Optional[torch.FloatTensor] = None,
238
+ labels: Optional[torch.LongTensor] = None,
239
+ use_cache: Optional[bool] = None,
240
+ output_attentions: Optional[bool] = None,
241
+ output_hidden_states: Optional[bool] = None,
242
+ return_dict: Optional[bool] = None,
243
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
244
+ r"""
245
+ Args:
246
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
247
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
248
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
249
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
250
+
251
+ Returns:
252
+
253
+ Example:
254
+
255
+ ```python
256
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
257
+
258
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
259
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
260
+
261
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
262
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
263
+
264
+ >>> # Generate
265
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
266
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
267
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
268
+ ```"""
269
+
270
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271
+ output_hidden_states = (
272
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273
+ )
274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
275
+
276
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
277
+ outputs = self.model(
278
+ input_ids=input_ids,
279
+ attention_mask=attention_mask,
280
+ position_ids=position_ids,
281
+ past_key_values=past_key_values,
282
+ inputs_embeds=inputs_embeds,
283
+ use_cache=use_cache,
284
+ output_attentions=output_attentions,
285
+ output_hidden_states=output_hidden_states,
286
+ return_dict=return_dict,
287
+ )
288
+
289
+ hidden_states = outputs[0]
290
+ if self.config.pretraining_tp > 1:
291
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
292
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
293
+ logits = torch.cat(logits, dim=-1)
294
+ else:
295
+ logits = self.lm_head(hidden_states)
296
+ logits = logits.float()
297
+
298
+ loss = None
299
+ if labels is not None:
300
+ # Shift so that tokens < n predict n
301
+ shift_logits = logits[..., :-1, :].contiguous()
302
+ shift_labels = labels[..., 1:].contiguous()
303
+ # Flatten the tokens
304
+ loss_fct = CrossEntropyLoss()
305
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
306
+ shift_labels = shift_labels.view(-1)
307
+ # Enable model parallelism
308
+ shift_labels = shift_labels.to(shift_logits.device)
309
+ loss = loss_fct(shift_logits, shift_labels)
310
+
311
+ if not return_dict:
312
+ output = (logits,) + outputs[1:]
313
+ return (loss,) + output if loss is not None else output
314
+
315
+ return CausalLMOutputWithPast(
316
+ loss=loss,
317
+ logits=logits,
318
+ past_key_values=outputs.past_key_values,
319
+ hidden_states=outputs.hidden_states,
320
+ attentions=outputs.attentions,
321
+ )
322
+
323
+ def prepare_inputs_for_generation(
324
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
325
+ ):
326
+ if past_key_values is not None:
327
+ past_length = past_key_values[0][0].shape[2]
328
+
329
+ # Some generation methods already pass only the last input ID
330
+ if input_ids.shape[1] > past_length:
331
+ remove_prefix_length = past_length
332
+ else:
333
+ # Default to old behavior: keep only final ID
334
+ remove_prefix_length = input_ids.shape[1] - 1
335
+
336
+ input_ids = input_ids[:, remove_prefix_length:]
337
+
338
+ position_ids = kwargs.get("position_ids", None)
339
+ if attention_mask is not None and position_ids is None:
340
+ # create position_ids on the fly for batch generation
341
+ position_ids = attention_mask.long().cumsum(-1) - 1
342
+ position_ids.masked_fill_(attention_mask == 0, 1)
343
+ if past_key_values:
344
+ position_ids = position_ids[:, -input_ids.shape[1] :]
345
+
346
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
347
+ if inputs_embeds is not None and past_key_values is None:
348
+ model_inputs = {"inputs_embeds": inputs_embeds}
349
+ else:
350
+ model_inputs = {"input_ids": input_ids}
351
+
352
+ model_inputs.update(
353
+ {
354
+ "position_ids": position_ids,
355
+ "past_key_values": past_key_values,
356
+ "use_cache": kwargs.get("use_cache"),
357
+ "attention_mask": attention_mask,
358
+ }
359
+ )
360
+ return model_inputs
361
+
362
+ @staticmethod
363
+ def _reorder_cache(past_key_values, beam_idx):
364
+ reordered_past = ()
365
+ for layer_past in past_key_values:
366
+ reordered_past += (
367
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
368
+ )
369
+ return reordered_past