dogtooth commited on
Commit
1659bfa
·
verified ·
1 Parent(s): cf2c76f

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -30,23 +30,25 @@ trained with a knowledge cutoff of **January 2017**, from the
30
  ## Usage
31
 
32
  ```python
33
- from transformers import AutoModelForCausalLM
34
  import torch
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
37
  "dogtooth/open-lm-1b-201701",
38
  dtype=torch.bfloat16,
39
  device_map="auto",
 
40
  )
 
41
  ```
42
 
43
  ## Conversion Notes
44
 
45
- - Converted from the original Open LM `.pt` checkpoint to HuggingFace `LlamaForCausalLM` format.
 
 
46
  - Architecture dimensions are auto-detected from checkpoint weights.
47
- - The original model uses **QK norm** (RMSNorm on Q and K projections), which is not natively
48
- supported by HF Llama. QK norm weights are dropped during conversion. For exact numerical
49
- equivalence, use the [open_lm](https://github.com/mlfoundations/open_lm) library.
50
 
51
  ## Citation
52
 
 
30
  ## Usage
31
 
32
  ```python
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
  import torch
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
37
  "dogtooth/open-lm-1b-201701",
38
  dtype=torch.bfloat16,
39
  device_map="auto",
40
+ trust_remote_code=True,
41
  )
42
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
43
  ```
44
 
45
  ## Conversion Notes
46
 
47
+ - Converted from the original Open LM `.pt` checkpoint to a custom `OpenLMForCausalLM` format.
48
+ - Uses **LayerNorm** (not RMSNorm) to match the original Open LM training.
49
+ - Includes **QK norm** (LayerNorm on Q and K projections before attention).
50
  - Architecture dimensions are auto-detected from checkpoint weights.
51
+ - Requires `trust_remote_code=True` when loading.
 
 
52
 
53
  ## Citation
54
 
config.json CHANGED
@@ -1,9 +1,13 @@
1
  {
2
  "architectures": [
3
- "LlamaForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
 
 
 
7
  "bos_token_id": 1,
8
  "dtype": "float32",
9
  "eos_token_id": 2,
@@ -14,11 +18,12 @@
14
  "intermediate_size": 5632,
15
  "max_position_embeddings": 2048,
16
  "mlp_bias": false,
17
- "model_type": "llama",
18
  "num_attention_heads": 16,
19
  "num_hidden_layers": 24,
20
  "num_key_value_heads": 16,
21
  "pretraining_tp": 1,
 
22
  "rms_norm_eps": 1e-05,
23
  "rope_scaling": null,
24
  "rope_theta": 10000.0,
 
1
  {
2
  "architectures": [
3
+ "OpenLMForCausalLM"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "modeling_open_lm_hf.OpenLMConfig",
9
+ "AutoModelForCausalLM": "modeling_open_lm_hf.OpenLMForCausalLM"
10
+ },
11
  "bos_token_id": 1,
12
  "dtype": "float32",
13
  "eos_token_id": 2,
 
18
  "intermediate_size": 5632,
19
  "max_position_embeddings": 2048,
20
  "mlp_bias": false,
21
+ "model_type": "open_lm",
22
  "num_attention_heads": 16,
23
  "num_hidden_layers": 24,
24
  "num_key_value_heads": 16,
25
  "pretraining_tp": 1,
26
+ "qk_norm": true,
27
  "rms_norm_eps": 1e-05,
28
  "rope_scaling": null,
29
  "rope_theta": 10000.0,
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:382e04d8fb9e3dd2d511c2bc0d56216bfe53763e46d9ab13de9d2aaaf7014aa6
3
- size 4985313960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:013ff505ba55eaa060105291399de41b923fd62cc5ee28c7f27c1c72c21762bd
3
+ size 4985679304
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d01d139603008cf5eb4f1df2c1276eecb4925885b8c956888801e33c06025072
3
- size 773891960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d78d215b3f01a75a49518123fc5aa2322bd94e4c56549234ec584bcaaaa3e1eb
3
+ size 773925168
model.safetensors.index.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "metadata": {
3
- "total_parameters": 1439795200,
4
- "total_size": 5759180800
5
  },
6
  "weight_map": {
7
  "lm_head.weight": "model-00002-of-00002.safetensors",
@@ -11,8 +11,10 @@
11
  "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
12
  "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
13
  "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
14
  "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
15
  "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
16
  "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
17
  "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
18
  "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -20,8 +22,10 @@
20
  "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
21
  "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
22
  "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
23
  "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
24
  "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
25
  "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
26
  "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
27
  "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -29,8 +33,10 @@
29
  "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
30
  "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
31
  "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
32
  "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
33
  "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
34
  "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
35
  "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
36
  "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -38,8 +44,10 @@
38
  "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
39
  "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
40
  "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
41
  "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
42
  "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
43
  "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
44
  "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
45
  "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -47,8 +55,10 @@
47
  "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
48
  "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
49
  "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
50
  "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
51
  "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
52
  "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
53
  "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
54
  "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -56,8 +66,10 @@
56
  "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
57
  "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
58
  "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
59
  "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
60
  "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
61
  "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
62
  "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
63
  "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -65,8 +77,10 @@
65
  "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
66
  "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
67
  "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
68
  "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
69
  "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
70
  "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
71
  "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
72
  "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -74,8 +88,10 @@
74
  "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
75
  "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
76
  "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
77
  "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
78
  "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
79
  "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
80
  "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
81
  "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -83,8 +99,10 @@
83
  "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
84
  "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
85
  "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
86
  "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
87
  "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
88
  "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
89
  "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
90
  "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -92,8 +110,10 @@
92
  "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
93
  "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
94
  "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
95
  "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
96
  "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
97
  "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
98
  "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
99
  "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -101,8 +121,10 @@
101
  "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
102
  "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
103
  "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
104
  "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
105
  "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
106
  "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
107
  "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
108
  "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -110,8 +132,10 @@
110
  "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
111
  "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
112
  "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
113
  "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
114
  "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
115
  "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
116
  "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
117
  "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -119,8 +143,10 @@
119
  "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
120
  "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
121
  "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
122
  "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
123
  "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
124
  "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
125
  "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
126
  "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -128,8 +154,10 @@
128
  "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
129
  "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
130
  "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
131
  "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
132
  "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
133
  "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
134
  "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
135
  "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -137,8 +165,10 @@
137
  "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
138
  "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
139
  "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
140
  "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
141
  "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
142
  "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
143
  "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
144
  "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
@@ -146,8 +176,10 @@
146
  "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
147
  "model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
148
  "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
 
149
  "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
150
  "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
 
151
  "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
152
  "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
153
  "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
@@ -155,8 +187,10 @@
155
  "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
156
  "model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
157
  "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
 
158
  "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
159
  "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
 
160
  "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
161
  "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
162
  "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -164,8 +198,10 @@
164
  "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
165
  "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
166
  "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
167
  "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
168
  "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
169
  "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
170
  "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
171
  "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -173,8 +209,10 @@
173
  "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
174
  "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
175
  "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
176
  "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
177
  "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
178
  "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
179
  "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
180
  "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -182,8 +220,10 @@
182
  "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
183
  "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
184
  "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
185
  "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
186
  "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
187
  "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
188
  "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
189
  "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -191,8 +231,10 @@
191
  "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
192
  "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
193
  "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
194
  "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
195
  "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
196
  "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
197
  "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
198
  "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -200,8 +242,10 @@
200
  "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
201
  "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
202
  "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
203
  "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
204
  "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
205
  "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
206
  "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
207
  "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -209,8 +253,10 @@
209
  "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
210
  "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
211
  "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
212
  "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
213
  "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
214
  "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
215
  "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
216
  "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -218,8 +264,10 @@
218
  "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
219
  "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
220
  "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
 
221
  "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
222
  "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
223
  "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
224
  "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
225
  "model.norm.weight": "model-00002-of-00002.safetensors"
 
1
  {
2
  "metadata": {
3
+ "total_parameters": 1439893504,
4
+ "total_size": 5759574016
5
  },
6
  "weight_map": {
7
  "lm_head.weight": "model-00002-of-00002.safetensors",
 
11
  "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
12
  "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
13
  "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
15
  "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
16
  "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
18
  "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
19
  "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
20
  "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
22
  "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
23
  "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
24
  "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
26
  "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
27
  "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.1.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
29
  "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
30
  "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
31
  "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
33
  "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
34
  "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
35
  "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.10.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
37
  "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
38
  "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.10.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
40
  "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
41
  "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
42
  "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
44
  "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
45
  "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
46
  "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.11.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
48
  "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
49
  "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.11.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
51
  "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
52
  "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
53
  "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
55
  "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
56
  "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
57
  "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.12.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
59
  "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
60
  "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.12.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
62
  "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
63
  "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
64
  "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
66
  "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
67
  "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
68
  "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.13.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
70
  "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
71
  "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.13.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
73
  "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
74
  "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
75
  "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
77
  "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
78
  "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
79
  "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.14.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
81
  "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
82
  "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.14.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
84
  "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
85
  "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
86
  "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
88
  "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
89
  "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
90
  "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.15.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
92
  "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
93
  "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.15.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
95
  "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
96
  "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
97
  "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
99
  "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
100
  "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
101
  "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
102
+ "model.layers.16.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
103
  "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
104
  "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.16.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
106
  "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
107
  "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
108
  "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
110
  "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
111
  "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
112
  "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.17.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
114
  "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
115
  "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.17.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
117
  "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
118
  "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
119
  "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
121
  "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
122
  "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
123
  "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.18.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
125
  "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
126
  "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.18.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
128
  "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
129
  "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
130
  "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
132
  "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
133
  "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
134
  "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.19.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
136
  "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
137
  "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.19.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
139
  "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
140
  "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
141
  "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
143
  "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
144
  "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
145
  "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.2.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
147
  "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
148
  "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.2.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
150
  "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
151
  "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
152
  "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
154
  "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
155
  "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
156
  "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.20.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
158
  "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
159
  "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
160
+ "model.layers.20.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
161
  "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
162
  "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
163
  "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
165
  "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
166
  "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
167
  "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.21.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
169
  "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
170
  "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.21.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
172
  "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
173
  "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
174
  "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
 
176
  "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
177
  "model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
178
  "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
179
+ "model.layers.22.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
180
  "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
181
  "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
182
+ "model.layers.22.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
183
  "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
184
  "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
185
  "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
 
187
  "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
188
  "model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
189
  "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
190
+ "model.layers.23.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
191
  "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
192
  "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
193
+ "model.layers.23.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
194
  "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
195
  "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
196
  "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
198
  "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
199
  "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
200
  "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
201
+ "model.layers.3.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
202
  "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
203
  "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
204
+ "model.layers.3.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
205
  "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
206
  "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
207
  "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
209
  "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
210
  "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
211
  "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
212
+ "model.layers.4.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
213
  "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
214
  "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
215
+ "model.layers.4.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
216
  "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
217
  "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
218
  "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
220
  "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
221
  "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
222
  "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.5.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
224
  "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
225
  "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
226
+ "model.layers.5.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
227
  "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
228
  "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
229
  "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
231
  "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
232
  "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
233
  "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.6.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
235
  "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
236
  "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
237
+ "model.layers.6.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
238
  "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
239
  "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
240
  "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
242
  "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
243
  "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
244
  "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.7.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
246
  "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
247
  "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
248
+ "model.layers.7.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
249
  "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
250
  "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
251
  "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
253
  "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
254
  "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
255
  "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
256
+ "model.layers.8.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
257
  "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
258
  "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.8.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
260
  "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
261
  "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
262
  "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
264
  "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
265
  "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
266
  "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.9.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
268
  "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
269
  "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
270
+ "model.layers.9.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
271
  "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
272
  "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
273
  "model.norm.weight": "model-00002-of-00002.safetensors"
modeling_open_lm_hf.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom HuggingFace model for Open LM checkpoints.
3
+
4
+ Open LM uses LayerNorm (not RMSNorm) and QK norm, which standard
5
+ LlamaForCausalLM does not support. This module provides:
6
+ - OpenLMConfig: LlamaConfig subclass with qk_norm flag
7
+ - OpenLMForCausalLM: LlamaForCausalLM subclass with LayerNorm + QK norm
8
+
9
+ Usage:
10
+ model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True)
11
+ """
12
+
13
+ from typing import Callable, Optional
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from transformers import LlamaConfig, LlamaForCausalLM
18
+ from transformers.models.llama.modeling_llama import (
19
+ ALL_ATTENTION_FUNCTIONS,
20
+ LlamaAttention,
21
+ LlamaRMSNorm,
22
+ apply_rotary_pos_emb,
23
+ eager_attention_forward,
24
+ )
25
+
26
+ try:
27
+ from typing import Unpack
28
+ from transformers.utils.generic import TransformersKwargs
29
+ except ImportError:
30
+ pass
31
+
32
+ from transformers.cache_utils import Cache
33
+
34
+
35
+ class OpenLMConfig(LlamaConfig):
36
+ model_type = "open_lm"
37
+
38
+ def __init__(self, qk_norm: bool = True, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.qk_norm = qk_norm
41
+
42
+
43
+ class OpenLMAttention(LlamaAttention):
44
+ """LlamaAttention with QK norm applied before reshape (matching Open LM)."""
45
+
46
+ def __init__(self, config: OpenLMConfig, layer_idx: int):
47
+ super().__init__(config, layer_idx)
48
+ if getattr(config, "qk_norm", False):
49
+ self.q_norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=False)
50
+ self.k_norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=False)
51
+ else:
52
+ self.q_norm = nn.Identity()
53
+ self.k_norm = nn.Identity()
54
+
55
+ def forward(
56
+ self,
57
+ hidden_states: torch.Tensor,
58
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
59
+ attention_mask: Optional[torch.Tensor],
60
+ past_key_values: Optional[Cache] = None,
61
+ cache_position: Optional[torch.LongTensor] = None,
62
+ **kwargs,
63
+ ) -> tuple[torch.Tensor, torch.Tensor]:
64
+ input_shape = hidden_states.shape[:-1]
65
+ hidden_shape = (*input_shape, -1, self.head_dim)
66
+
67
+ # QK norm applied to flat projected vectors BEFORE reshape (matches Open LM)
68
+ query_states = self.q_norm(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2)
69
+ key_states = self.k_norm(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2)
70
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
71
+
72
+ cos, sin = position_embeddings
73
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
74
+
75
+ if past_key_values is not None:
76
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
77
+ key_states, value_states = past_key_values.update(
78
+ key_states, value_states, self.layer_idx, cache_kwargs
79
+ )
80
+
81
+ attention_interface: Callable = eager_attention_forward
82
+ if self.config._attn_implementation != "eager":
83
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
84
+
85
+ attn_output, attn_weights = attention_interface(
86
+ self,
87
+ query_states,
88
+ key_states,
89
+ value_states,
90
+ attention_mask,
91
+ dropout=0.0 if not self.training else self.attention_dropout,
92
+ scaling=self.scaling,
93
+ **kwargs,
94
+ )
95
+
96
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
97
+ attn_output = self.o_proj(attn_output)
98
+ return attn_output, attn_weights
99
+
100
+
101
+ class OpenLMForCausalLM(LlamaForCausalLM):
102
+ """LlamaForCausalLM with LayerNorm (instead of RMSNorm) and QK norm support."""
103
+
104
+ config_class = OpenLMConfig
105
+
106
+ def __init__(self, config: OpenLMConfig):
107
+ super().__init__(config)
108
+
109
+ # Replace all LlamaRMSNorm with nn.LayerNorm(bias=False)
110
+ eps = config.rms_norm_eps
111
+ hidden_size = config.hidden_size
112
+
113
+ self.model.norm = nn.LayerNorm(hidden_size, eps=eps, bias=False)
114
+ for layer in self.model.layers:
115
+ layer.input_layernorm = nn.LayerNorm(hidden_size, eps=eps, bias=False)
116
+ layer.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=eps, bias=False)
117
+
118
+ # Replace attention module with QK norm version
119
+ layer.self_attn = OpenLMAttention(config, layer.self_attn.layer_idx)
120
+
121
+ # Re-run post_init to tie weights etc.
122
+ self.post_init()