ziadrone commited on
Commit
1408be7
·
verified ·
1 Parent(s): 4840367

Upload SHIVIK-M3 FP32 (2.43B params, 28 layers, 200K vocab)

Browse files
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ShivikM3Model"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_shivikM3.ShivikM3Config",
7
+ "AutoModelForCausalLM": "modeling_shivikM3.ShivikM3Model"
8
+ },
9
+ "dtype": "float32",
10
+ "hidden_size": 2048,
11
+ "intermediate_size": 7168,
12
+ "kv_head_split_layer": 14,
13
+ "model_type": "shivik-m3",
14
+ "num_attention_heads": 32,
15
+ "num_hidden_layers": 28,
16
+ "num_kv_heads": 8,
17
+ "num_kv_heads_high": 32,
18
+ "rms_norm_eps": 1e-05,
19
+ "transformers_version": "4.57.3",
20
+ "vocab_size": 200018
21
+ }
configuration_shivikM3.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ShivikM3Config(PretrainedConfig):
4
+ model_type = "shivik-m3"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=200018,
9
+ hidden_size=2048,
10
+ num_hidden_layers=28,
11
+ num_attention_heads=32,
12
+ intermediate_size=7168,
13
+ kv_head_split_layer=14,
14
+ num_kv_heads=8,
15
+ num_kv_heads_high=32,
16
+ tie_word_embeddings=True,
17
+ rms_norm_eps=1e-5,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.vocab_size = vocab_size
22
+ self.hidden_size = hidden_size
23
+ self.num_hidden_layers = num_hidden_layers
24
+ self.num_attention_heads = num_attention_heads
25
+ self.intermediate_size = intermediate_size
26
+ self.kv_head_split_layer = kv_head_split_layer
27
+ self.num_kv_heads = num_kv_heads
28
+ self.num_kv_heads_high = num_kv_heads_high
29
+ self.tie_word_embeddings = tie_word_embeddings
30
+ self.rms_norm_eps = rms_norm_eps
generation_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token_id": 0, "eos_token_id": 0, "pad_token_id": 0}
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6675682aa98bfe37cc356d779b696bbd8ca769b7649148f0405bbdd6d4f3a968
3
+ size 4985878640
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76b23ca3982f2e0bf05d6210ebb21127937ec6defe039435b65d7df6c25a37a1
3
+ size 4750937512
model.safetensors.index.json ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 2434197504,
4
+ "total_size": 9736790016
5
+ },
6
+ "weight_map": {
7
+ "embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "layers.0.attn.k_proj.weight": "model-00001-of-00002.safetensors",
9
+ "layers.0.attn.o_proj.weight": "model-00001-of-00002.safetensors",
10
+ "layers.0.attn.q_proj.weight": "model-00001-of-00002.safetensors",
11
+ "layers.0.attn.v_proj.weight": "model-00001-of-00002.safetensors",
12
+ "layers.0.attn_norm.weight": "model-00001-of-00002.safetensors",
13
+ "layers.0.mlp.fc.weight": "model-00001-of-00002.safetensors",
14
+ "layers.0.mlp.gate.weight": "model-00001-of-00002.safetensors",
15
+ "layers.0.mlp.proj.weight": "model-00001-of-00002.safetensors",
16
+ "layers.0.mlp_norm.weight": "model-00001-of-00002.safetensors",
17
+ "layers.1.attn.k_proj.weight": "model-00001-of-00002.safetensors",
18
+ "layers.1.attn.o_proj.weight": "model-00001-of-00002.safetensors",
19
+ "layers.1.attn.q_proj.weight": "model-00001-of-00002.safetensors",
20
+ "layers.1.attn.v_proj.weight": "model-00001-of-00002.safetensors",
21
+ "layers.1.attn_norm.weight": "model-00001-of-00002.safetensors",
22
+ "layers.1.mlp.fc.weight": "model-00001-of-00002.safetensors",
23
+ "layers.1.mlp.gate.weight": "model-00001-of-00002.safetensors",
24
+ "layers.1.mlp.proj.weight": "model-00001-of-00002.safetensors",
25
+ "layers.1.mlp_norm.weight": "model-00001-of-00002.safetensors",
26
+ "layers.10.attn.k_proj.weight": "model-00001-of-00002.safetensors",
27
+ "layers.10.attn.o_proj.weight": "model-00001-of-00002.safetensors",
28
+ "layers.10.attn.q_proj.weight": "model-00001-of-00002.safetensors",
29
+ "layers.10.attn.v_proj.weight": "model-00001-of-00002.safetensors",
30
+ "layers.10.attn_norm.weight": "model-00001-of-00002.safetensors",
31
+ "layers.10.mlp.fc.weight": "model-00001-of-00002.safetensors",
32
+ "layers.10.mlp.gate.weight": "model-00001-of-00002.safetensors",
33
+ "layers.10.mlp.proj.weight": "model-00001-of-00002.safetensors",
34
+ "layers.10.mlp_norm.weight": "model-00001-of-00002.safetensors",
35
+ "layers.11.attn.k_proj.weight": "model-00001-of-00002.safetensors",
36
+ "layers.11.attn.o_proj.weight": "model-00001-of-00002.safetensors",
37
+ "layers.11.attn.q_proj.weight": "model-00001-of-00002.safetensors",
38
+ "layers.11.attn.v_proj.weight": "model-00001-of-00002.safetensors",
39
+ "layers.11.attn_norm.weight": "model-00001-of-00002.safetensors",
40
+ "layers.11.mlp.fc.weight": "model-00001-of-00002.safetensors",
41
+ "layers.11.mlp.gate.weight": "model-00001-of-00002.safetensors",
42
+ "layers.11.mlp.proj.weight": "model-00001-of-00002.safetensors",
43
+ "layers.11.mlp_norm.weight": "model-00001-of-00002.safetensors",
44
+ "layers.12.attn.k_proj.weight": "model-00001-of-00002.safetensors",
45
+ "layers.12.attn.o_proj.weight": "model-00001-of-00002.safetensors",
46
+ "layers.12.attn.q_proj.weight": "model-00001-of-00002.safetensors",
47
+ "layers.12.attn.v_proj.weight": "model-00001-of-00002.safetensors",
48
+ "layers.12.attn_norm.weight": "model-00001-of-00002.safetensors",
49
+ "layers.12.mlp.fc.weight": "model-00001-of-00002.safetensors",
50
+ "layers.12.mlp.gate.weight": "model-00001-of-00002.safetensors",
51
+ "layers.12.mlp.proj.weight": "model-00001-of-00002.safetensors",
52
+ "layers.12.mlp_norm.weight": "model-00001-of-00002.safetensors",
53
+ "layers.13.attn.k_proj.weight": "model-00001-of-00002.safetensors",
54
+ "layers.13.attn.o_proj.weight": "model-00001-of-00002.safetensors",
55
+ "layers.13.attn.q_proj.weight": "model-00001-of-00002.safetensors",
56
+ "layers.13.attn.v_proj.weight": "model-00001-of-00002.safetensors",
57
+ "layers.13.attn_norm.weight": "model-00001-of-00002.safetensors",
58
+ "layers.13.mlp.fc.weight": "model-00001-of-00002.safetensors",
59
+ "layers.13.mlp.gate.weight": "model-00001-of-00002.safetensors",
60
+ "layers.13.mlp.proj.weight": "model-00001-of-00002.safetensors",
61
+ "layers.13.mlp_norm.weight": "model-00001-of-00002.safetensors",
62
+ "layers.14.attn.k_proj.weight": "model-00001-of-00002.safetensors",
63
+ "layers.14.attn.o_proj.weight": "model-00001-of-00002.safetensors",
64
+ "layers.14.attn.q_proj.weight": "model-00001-of-00002.safetensors",
65
+ "layers.14.attn.v_proj.weight": "model-00001-of-00002.safetensors",
66
+ "layers.14.attn_norm.weight": "model-00001-of-00002.safetensors",
67
+ "layers.14.mlp.fc.weight": "model-00001-of-00002.safetensors",
68
+ "layers.14.mlp.gate.weight": "model-00001-of-00002.safetensors",
69
+ "layers.14.mlp.proj.weight": "model-00001-of-00002.safetensors",
70
+ "layers.14.mlp_norm.weight": "model-00001-of-00002.safetensors",
71
+ "layers.15.attn.k_proj.weight": "model-00001-of-00002.safetensors",
72
+ "layers.15.attn.o_proj.weight": "model-00002-of-00002.safetensors",
73
+ "layers.15.attn.q_proj.weight": "model-00001-of-00002.safetensors",
74
+ "layers.15.attn.v_proj.weight": "model-00001-of-00002.safetensors",
75
+ "layers.15.attn_norm.weight": "model-00001-of-00002.safetensors",
76
+ "layers.15.mlp.fc.weight": "model-00002-of-00002.safetensors",
77
+ "layers.15.mlp.gate.weight": "model-00002-of-00002.safetensors",
78
+ "layers.15.mlp.proj.weight": "model-00002-of-00002.safetensors",
79
+ "layers.15.mlp_norm.weight": "model-00001-of-00002.safetensors",
80
+ "layers.16.attn.k_proj.weight": "model-00002-of-00002.safetensors",
81
+ "layers.16.attn.o_proj.weight": "model-00002-of-00002.safetensors",
82
+ "layers.16.attn.q_proj.weight": "model-00002-of-00002.safetensors",
83
+ "layers.16.attn.v_proj.weight": "model-00002-of-00002.safetensors",
84
+ "layers.16.attn_norm.weight": "model-00002-of-00002.safetensors",
85
+ "layers.16.mlp.fc.weight": "model-00002-of-00002.safetensors",
86
+ "layers.16.mlp.gate.weight": "model-00002-of-00002.safetensors",
87
+ "layers.16.mlp.proj.weight": "model-00002-of-00002.safetensors",
88
+ "layers.16.mlp_norm.weight": "model-00002-of-00002.safetensors",
89
+ "layers.17.attn.k_proj.weight": "model-00002-of-00002.safetensors",
90
+ "layers.17.attn.o_proj.weight": "model-00002-of-00002.safetensors",
91
+ "layers.17.attn.q_proj.weight": "model-00002-of-00002.safetensors",
92
+ "layers.17.attn.v_proj.weight": "model-00002-of-00002.safetensors",
93
+ "layers.17.attn_norm.weight": "model-00002-of-00002.safetensors",
94
+ "layers.17.mlp.fc.weight": "model-00002-of-00002.safetensors",
95
+ "layers.17.mlp.gate.weight": "model-00002-of-00002.safetensors",
96
+ "layers.17.mlp.proj.weight": "model-00002-of-00002.safetensors",
97
+ "layers.17.mlp_norm.weight": "model-00002-of-00002.safetensors",
98
+ "layers.18.attn.k_proj.weight": "model-00002-of-00002.safetensors",
99
+ "layers.18.attn.o_proj.weight": "model-00002-of-00002.safetensors",
100
+ "layers.18.attn.q_proj.weight": "model-00002-of-00002.safetensors",
101
+ "layers.18.attn.v_proj.weight": "model-00002-of-00002.safetensors",
102
+ "layers.18.attn_norm.weight": "model-00002-of-00002.safetensors",
103
+ "layers.18.mlp.fc.weight": "model-00002-of-00002.safetensors",
104
+ "layers.18.mlp.gate.weight": "model-00002-of-00002.safetensors",
105
+ "layers.18.mlp.proj.weight": "model-00002-of-00002.safetensors",
106
+ "layers.18.mlp_norm.weight": "model-00002-of-00002.safetensors",
107
+ "layers.19.attn.k_proj.weight": "model-00002-of-00002.safetensors",
108
+ "layers.19.attn.o_proj.weight": "model-00002-of-00002.safetensors",
109
+ "layers.19.attn.q_proj.weight": "model-00002-of-00002.safetensors",
110
+ "layers.19.attn.v_proj.weight": "model-00002-of-00002.safetensors",
111
+ "layers.19.attn_norm.weight": "model-00002-of-00002.safetensors",
112
+ "layers.19.mlp.fc.weight": "model-00002-of-00002.safetensors",
113
+ "layers.19.mlp.gate.weight": "model-00002-of-00002.safetensors",
114
+ "layers.19.mlp.proj.weight": "model-00002-of-00002.safetensors",
115
+ "layers.19.mlp_norm.weight": "model-00002-of-00002.safetensors",
116
+ "layers.2.attn.k_proj.weight": "model-00001-of-00002.safetensors",
117
+ "layers.2.attn.o_proj.weight": "model-00001-of-00002.safetensors",
118
+ "layers.2.attn.q_proj.weight": "model-00001-of-00002.safetensors",
119
+ "layers.2.attn.v_proj.weight": "model-00001-of-00002.safetensors",
120
+ "layers.2.attn_norm.weight": "model-00001-of-00002.safetensors",
121
+ "layers.2.mlp.fc.weight": "model-00001-of-00002.safetensors",
122
+ "layers.2.mlp.gate.weight": "model-00001-of-00002.safetensors",
123
+ "layers.2.mlp.proj.weight": "model-00001-of-00002.safetensors",
124
+ "layers.2.mlp_norm.weight": "model-00001-of-00002.safetensors",
125
+ "layers.20.attn.k_proj.weight": "model-00002-of-00002.safetensors",
126
+ "layers.20.attn.o_proj.weight": "model-00002-of-00002.safetensors",
127
+ "layers.20.attn.q_proj.weight": "model-00002-of-00002.safetensors",
128
+ "layers.20.attn.v_proj.weight": "model-00002-of-00002.safetensors",
129
+ "layers.20.attn_norm.weight": "model-00002-of-00002.safetensors",
130
+ "layers.20.mlp.fc.weight": "model-00002-of-00002.safetensors",
131
+ "layers.20.mlp.gate.weight": "model-00002-of-00002.safetensors",
132
+ "layers.20.mlp.proj.weight": "model-00002-of-00002.safetensors",
133
+ "layers.20.mlp_norm.weight": "model-00002-of-00002.safetensors",
134
+ "layers.21.attn.k_proj.weight": "model-00002-of-00002.safetensors",
135
+ "layers.21.attn.o_proj.weight": "model-00002-of-00002.safetensors",
136
+ "layers.21.attn.q_proj.weight": "model-00002-of-00002.safetensors",
137
+ "layers.21.attn.v_proj.weight": "model-00002-of-00002.safetensors",
138
+ "layers.21.attn_norm.weight": "model-00002-of-00002.safetensors",
139
+ "layers.21.mlp.fc.weight": "model-00002-of-00002.safetensors",
140
+ "layers.21.mlp.gate.weight": "model-00002-of-00002.safetensors",
141
+ "layers.21.mlp.proj.weight": "model-00002-of-00002.safetensors",
142
+ "layers.21.mlp_norm.weight": "model-00002-of-00002.safetensors",
143
+ "layers.22.attn.k_proj.weight": "model-00002-of-00002.safetensors",
144
+ "layers.22.attn.o_proj.weight": "model-00002-of-00002.safetensors",
145
+ "layers.22.attn.q_proj.weight": "model-00002-of-00002.safetensors",
146
+ "layers.22.attn.v_proj.weight": "model-00002-of-00002.safetensors",
147
+ "layers.22.attn_norm.weight": "model-00002-of-00002.safetensors",
148
+ "layers.22.mlp.fc.weight": "model-00002-of-00002.safetensors",
149
+ "layers.22.mlp.gate.weight": "model-00002-of-00002.safetensors",
150
+ "layers.22.mlp.proj.weight": "model-00002-of-00002.safetensors",
151
+ "layers.22.mlp_norm.weight": "model-00002-of-00002.safetensors",
152
+ "layers.23.attn.k_proj.weight": "model-00002-of-00002.safetensors",
153
+ "layers.23.attn.o_proj.weight": "model-00002-of-00002.safetensors",
154
+ "layers.23.attn.q_proj.weight": "model-00002-of-00002.safetensors",
155
+ "layers.23.attn.v_proj.weight": "model-00002-of-00002.safetensors",
156
+ "layers.23.attn_norm.weight": "model-00002-of-00002.safetensors",
157
+ "layers.23.mlp.fc.weight": "model-00002-of-00002.safetensors",
158
+ "layers.23.mlp.gate.weight": "model-00002-of-00002.safetensors",
159
+ "layers.23.mlp.proj.weight": "model-00002-of-00002.safetensors",
160
+ "layers.23.mlp_norm.weight": "model-00002-of-00002.safetensors",
161
+ "layers.24.attn.k_proj.weight": "model-00002-of-00002.safetensors",
162
+ "layers.24.attn.o_proj.weight": "model-00002-of-00002.safetensors",
163
+ "layers.24.attn.q_proj.weight": "model-00002-of-00002.safetensors",
164
+ "layers.24.attn.v_proj.weight": "model-00002-of-00002.safetensors",
165
+ "layers.24.attn_norm.weight": "model-00002-of-00002.safetensors",
166
+ "layers.24.mlp.fc.weight": "model-00002-of-00002.safetensors",
167
+ "layers.24.mlp.gate.weight": "model-00002-of-00002.safetensors",
168
+ "layers.24.mlp.proj.weight": "model-00002-of-00002.safetensors",
169
+ "layers.24.mlp_norm.weight": "model-00002-of-00002.safetensors",
170
+ "layers.25.attn.k_proj.weight": "model-00002-of-00002.safetensors",
171
+ "layers.25.attn.o_proj.weight": "model-00002-of-00002.safetensors",
172
+ "layers.25.attn.q_proj.weight": "model-00002-of-00002.safetensors",
173
+ "layers.25.attn.v_proj.weight": "model-00002-of-00002.safetensors",
174
+ "layers.25.attn_norm.weight": "model-00002-of-00002.safetensors",
175
+ "layers.25.mlp.fc.weight": "model-00002-of-00002.safetensors",
176
+ "layers.25.mlp.gate.weight": "model-00002-of-00002.safetensors",
177
+ "layers.25.mlp.proj.weight": "model-00002-of-00002.safetensors",
178
+ "layers.25.mlp_norm.weight": "model-00002-of-00002.safetensors",
179
+ "layers.26.attn.k_proj.weight": "model-00002-of-00002.safetensors",
180
+ "layers.26.attn.o_proj.weight": "model-00002-of-00002.safetensors",
181
+ "layers.26.attn.q_proj.weight": "model-00002-of-00002.safetensors",
182
+ "layers.26.attn.v_proj.weight": "model-00002-of-00002.safetensors",
183
+ "layers.26.attn_norm.weight": "model-00002-of-00002.safetensors",
184
+ "layers.26.mlp.fc.weight": "model-00002-of-00002.safetensors",
185
+ "layers.26.mlp.gate.weight": "model-00002-of-00002.safetensors",
186
+ "layers.26.mlp.proj.weight": "model-00002-of-00002.safetensors",
187
+ "layers.26.mlp_norm.weight": "model-00002-of-00002.safetensors",
188
+ "layers.27.attn.k_proj.weight": "model-00002-of-00002.safetensors",
189
+ "layers.27.attn.o_proj.weight": "model-00002-of-00002.safetensors",
190
+ "layers.27.attn.q_proj.weight": "model-00002-of-00002.safetensors",
191
+ "layers.27.attn.v_proj.weight": "model-00002-of-00002.safetensors",
192
+ "layers.27.attn_norm.weight": "model-00002-of-00002.safetensors",
193
+ "layers.27.mlp.fc.weight": "model-00002-of-00002.safetensors",
194
+ "layers.27.mlp.gate.weight": "model-00002-of-00002.safetensors",
195
+ "layers.27.mlp.proj.weight": "model-00002-of-00002.safetensors",
196
+ "layers.27.mlp_norm.weight": "model-00002-of-00002.safetensors",
197
+ "layers.3.attn.k_proj.weight": "model-00001-of-00002.safetensors",
198
+ "layers.3.attn.o_proj.weight": "model-00001-of-00002.safetensors",
199
+ "layers.3.attn.q_proj.weight": "model-00001-of-00002.safetensors",
200
+ "layers.3.attn.v_proj.weight": "model-00001-of-00002.safetensors",
201
+ "layers.3.attn_norm.weight": "model-00001-of-00002.safetensors",
202
+ "layers.3.mlp.fc.weight": "model-00001-of-00002.safetensors",
203
+ "layers.3.mlp.gate.weight": "model-00001-of-00002.safetensors",
204
+ "layers.3.mlp.proj.weight": "model-00001-of-00002.safetensors",
205
+ "layers.3.mlp_norm.weight": "model-00001-of-00002.safetensors",
206
+ "layers.4.attn.k_proj.weight": "model-00001-of-00002.safetensors",
207
+ "layers.4.attn.o_proj.weight": "model-00001-of-00002.safetensors",
208
+ "layers.4.attn.q_proj.weight": "model-00001-of-00002.safetensors",
209
+ "layers.4.attn.v_proj.weight": "model-00001-of-00002.safetensors",
210
+ "layers.4.attn_norm.weight": "model-00001-of-00002.safetensors",
211
+ "layers.4.mlp.fc.weight": "model-00001-of-00002.safetensors",
212
+ "layers.4.mlp.gate.weight": "model-00001-of-00002.safetensors",
213
+ "layers.4.mlp.proj.weight": "model-00001-of-00002.safetensors",
214
+ "layers.4.mlp_norm.weight": "model-00001-of-00002.safetensors",
215
+ "layers.5.attn.k_proj.weight": "model-00001-of-00002.safetensors",
216
+ "layers.5.attn.o_proj.weight": "model-00001-of-00002.safetensors",
217
+ "layers.5.attn.q_proj.weight": "model-00001-of-00002.safetensors",
218
+ "layers.5.attn.v_proj.weight": "model-00001-of-00002.safetensors",
219
+ "layers.5.attn_norm.weight": "model-00001-of-00002.safetensors",
220
+ "layers.5.mlp.fc.weight": "model-00001-of-00002.safetensors",
221
+ "layers.5.mlp.gate.weight": "model-00001-of-00002.safetensors",
222
+ "layers.5.mlp.proj.weight": "model-00001-of-00002.safetensors",
223
+ "layers.5.mlp_norm.weight": "model-00001-of-00002.safetensors",
224
+ "layers.6.attn.k_proj.weight": "model-00001-of-00002.safetensors",
225
+ "layers.6.attn.o_proj.weight": "model-00001-of-00002.safetensors",
226
+ "layers.6.attn.q_proj.weight": "model-00001-of-00002.safetensors",
227
+ "layers.6.attn.v_proj.weight": "model-00001-of-00002.safetensors",
228
+ "layers.6.attn_norm.weight": "model-00001-of-00002.safetensors",
229
+ "layers.6.mlp.fc.weight": "model-00001-of-00002.safetensors",
230
+ "layers.6.mlp.gate.weight": "model-00001-of-00002.safetensors",
231
+ "layers.6.mlp.proj.weight": "model-00001-of-00002.safetensors",
232
+ "layers.6.mlp_norm.weight": "model-00001-of-00002.safetensors",
233
+ "layers.7.attn.k_proj.weight": "model-00001-of-00002.safetensors",
234
+ "layers.7.attn.o_proj.weight": "model-00001-of-00002.safetensors",
235
+ "layers.7.attn.q_proj.weight": "model-00001-of-00002.safetensors",
236
+ "layers.7.attn.v_proj.weight": "model-00001-of-00002.safetensors",
237
+ "layers.7.attn_norm.weight": "model-00001-of-00002.safetensors",
238
+ "layers.7.mlp.fc.weight": "model-00001-of-00002.safetensors",
239
+ "layers.7.mlp.gate.weight": "model-00001-of-00002.safetensors",
240
+ "layers.7.mlp.proj.weight": "model-00001-of-00002.safetensors",
241
+ "layers.7.mlp_norm.weight": "model-00001-of-00002.safetensors",
242
+ "layers.8.attn.k_proj.weight": "model-00001-of-00002.safetensors",
243
+ "layers.8.attn.o_proj.weight": "model-00001-of-00002.safetensors",
244
+ "layers.8.attn.q_proj.weight": "model-00001-of-00002.safetensors",
245
+ "layers.8.attn.v_proj.weight": "model-00001-of-00002.safetensors",
246
+ "layers.8.attn_norm.weight": "model-00001-of-00002.safetensors",
247
+ "layers.8.mlp.fc.weight": "model-00001-of-00002.safetensors",
248
+ "layers.8.mlp.gate.weight": "model-00001-of-00002.safetensors",
249
+ "layers.8.mlp.proj.weight": "model-00001-of-00002.safetensors",
250
+ "layers.8.mlp_norm.weight": "model-00001-of-00002.safetensors",
251
+ "layers.9.attn.k_proj.weight": "model-00001-of-00002.safetensors",
252
+ "layers.9.attn.o_proj.weight": "model-00001-of-00002.safetensors",
253
+ "layers.9.attn.q_proj.weight": "model-00001-of-00002.safetensors",
254
+ "layers.9.attn.v_proj.weight": "model-00001-of-00002.safetensors",
255
+ "layers.9.attn_norm.weight": "model-00001-of-00002.safetensors",
256
+ "layers.9.mlp.fc.weight": "model-00001-of-00002.safetensors",
257
+ "layers.9.mlp.gate.weight": "model-00001-of-00002.safetensors",
258
+ "layers.9.mlp.proj.weight": "model-00001-of-00002.safetensors",
259
+ "layers.9.mlp_norm.weight": "model-00001-of-00002.safetensors",
260
+ "lm_head.weight": "model-00002-of-00002.safetensors",
261
+ "norm.weight": "model-00002-of-00002.safetensors"
262
+ }
263
+ }
modeling_shivikM3.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from configuration_shivikM3 import ShivikM3Config
8
+
9
+ class RMSNorm(nn.Module):
10
+ def __init__(self, d, eps=1e-5):
11
+ super().__init__()
12
+ self.weight = nn.Parameter(torch.ones(d))
13
+ self.eps = eps
14
+
15
+ def forward(self, x):
16
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
17
+
18
+ def rotate_half(x):
19
+ x1, x2 = x.chunk(2, dim=-1)
20
+ return torch.cat([-x2, x1], dim=-1)
21
+
22
+ def apply_rope(q, k, cos, sin):
23
+ q = (q * cos) + (rotate_half(q) * sin)
24
+ k = (k * cos) + (rotate_half(k) * sin)
25
+ return q, k
26
+
27
+ class ShivikM3Attention(nn.Module):
28
+ def __init__(self, c, idx):
29
+ super().__init__()
30
+ self.num_q = c.num_attention_heads
31
+ self.head_dim = c.hidden_size // self.num_q
32
+ self.num_kv = c.num_kv_heads if idx < c.kv_head_split_layer else c.num_kv_heads_high
33
+ self.q_proj = nn.Linear(c.hidden_size, self.num_q * self.head_dim, bias=False)
34
+ self.k_proj = nn.Linear(c.hidden_size, self.num_kv * self.head_dim, bias=False)
35
+ self.v_proj = nn.Linear(c.hidden_size, self.num_kv * self.head_dim, bias=False)
36
+ self.o_proj = nn.Linear(c.hidden_size, c.hidden_size, bias=False)
37
+
38
+ def forward(self, x, cos, sin, mask=None):
39
+ B, T, C = x.size()
40
+ q = self.q_proj(x).view(B, T, self.num_q, self.head_dim).transpose(1, 2)
41
+ k = self.k_proj(x).view(B, T, self.num_kv, self.head_dim).transpose(1, 2)
42
+ v = self.v_proj(x).view(B, T, self.num_kv, self.head_dim).transpose(1, 2)
43
+
44
+ if self.num_kv != self.num_q:
45
+ k = k.repeat_interleave(self.num_q // self.num_kv, dim=1)
46
+ v = v.repeat_interleave(self.num_q // self.num_kv, dim=1)
47
+
48
+ q, k = apply_rope(q, k, cos, sin)
49
+
50
+ attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
51
+ if mask is not None:
52
+ attn = attn + mask
53
+ attn = F.softmax(attn, dim=-1)
54
+ out = attn @ v
55
+
56
+ return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
57
+
58
+ class ShivikM3MLP(nn.Module):
59
+ def __init__(self, c):
60
+ super().__init__()
61
+ self.gate = nn.Linear(c.hidden_size, c.intermediate_size, bias=False)
62
+ self.fc = nn.Linear(c.hidden_size, c.intermediate_size, bias=False)
63
+ self.proj = nn.Linear(c.intermediate_size, c.hidden_size, bias=False)
64
+
65
+ def forward(self, x):
66
+ return self.proj(F.silu(self.gate(x)) * self.fc(x))
67
+
68
+ class ShivikM3Block(nn.Module):
69
+ def __init__(self, c, idx):
70
+ super().__init__()
71
+ self.attn_norm = RMSNorm(c.hidden_size, c.rms_norm_eps)
72
+ self.mlp_norm = RMSNorm(c.hidden_size, c.rms_norm_eps)
73
+ self.attn = ShivikM3Attention(c, idx)
74
+ self.mlp = ShivikM3MLP(c)
75
+
76
+ def forward(self, x, cos, sin, mask=None):
77
+ h = x + self.attn(self.attn_norm(x), cos, sin, mask)
78
+ h = h + self.mlp(self.mlp_norm(h))
79
+ return h
80
+
81
+ class ShivikM3Model(PreTrainedModel):
82
+ config_class = ShivikM3Config
83
+ supports_gradient_checkpointing = True
84
+
85
+ def __init__(self, c):
86
+ super().__init__(c)
87
+ self.config = c
88
+ self.gradient_checkpointing = False
89
+ self.embed_tokens = nn.Embedding(c.vocab_size, c.hidden_size)
90
+ self.layers = nn.ModuleList([ShivikM3Block(c, i) for i in range(c.num_hidden_layers)])
91
+ self.norm = RMSNorm(c.hidden_size, c.rms_norm_eps)
92
+ self.lm_head = nn.Linear(c.hidden_size, c.vocab_size, bias=False)
93
+ self.post_init()
94
+
95
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
96
+ B, T = input_ids.size()
97
+ x = self.embed_tokens(input_ids)
98
+
99
+ # Create RoPE - match model dtype
100
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
101
+ device = x.device
102
+ dtype = x.dtype
103
+
104
+ pos = torch.arange(T, device=device, dtype=dtype)
105
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device, dtype=dtype) / head_dim))
106
+ freqs = torch.outer(pos, inv_freq)
107
+ freqs = torch.cat((freqs, freqs), dim=-1)
108
+ cos = freqs.cos().view(1, 1, T, head_dim)
109
+ sin = freqs.sin().view(1, 1, T, head_dim)
110
+
111
+ # Create causal mask - match model dtype
112
+ mask = torch.triu(
113
+ torch.full((T, T), float('-inf'), device=device, dtype=dtype),
114
+ diagonal=1
115
+ ).view(1, 1, T, T)
116
+
117
+ if attention_mask is not None:
118
+ mask = mask + (1.0 - attention_mask[:, None, None, :]) * torch.finfo(dtype).min
119
+
120
+ for block in self.layers:
121
+ if self.gradient_checkpointing and self.training:
122
+ x = torch.utils.checkpoint.checkpoint(block, x, cos, sin, mask, use_reentrant=False)
123
+ else:
124
+ x = block(x, cos, sin, mask)
125
+
126
+ x = self.norm(x)
127
+ logits = self.lm_head(x)
128
+
129
+ loss = None
130
+ if labels is not None:
131
+ shift_logits = logits[..., :-1, :].contiguous()
132
+ shift_labels = labels[..., 1:].contiguous()
133
+ loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
134
+
135
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
special_tokens_map.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "unk_token": "<unk>",
3
+ "pad_token": "<pad>",
4
+ "bos_token": "<think>",
5
+ "eos_token": "</think>",
6
+ "additional_special_tokens": [
7
+ "<answer>",
8
+ "</answer>",
9
+ "<context>",
10
+ "</context>",
11
+ "<end>"
12
+ ]
13
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "wordlevel",
3
+ "unk_token": "<unk>",
4
+ "pad_token": "<pad>",
5
+ "bos_token": "<think>",
6
+ "eos_token": "</think>",
7
+ "special_tokens": [
8
+ "<unk>",
9
+ "<pad>",
10
+ "<think>",
11
+ "</think>",
12
+ "<answer>",
13
+ "</answer>",
14
+ "<context>",
15
+ "</context>",
16
+ "<end>"
17
+ ]
18
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff