michaelbzhu commited on
Commit
6bd3c24
·
verified ·
1 Parent(s): ff7fd3c

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MBZTestModelForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.MBZTestConfig",
7
+ "AutoModelForCausalLM": "modeling.MBZTestModelForCausalLM"
8
+ },
9
+ "d_head": 128,
10
+ "d_model": 4096,
11
+ "dtype": "float32",
12
+ "model_type": "mbz-test",
13
+ "n_heads": 32,
14
+ "n_layers": 36,
15
+ "n_vocab": 50257,
16
+ "transformers_version": "4.56.0"
17
+ }
configuration.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MBZTestConfig(PretrainedConfig):
4
+ model_type = 'mbz-test'
5
+
6
+ def __init__(
7
+ self,
8
+ n_layers=36,
9
+ d_model=4096,
10
+ n_heads=32,
11
+ n_vocab=50257,
12
+ d_head=128,
13
+ **kwargs
14
+ ):
15
+ self.n_layers = n_layers
16
+ self.d_model = d_model
17
+ self.n_heads = n_heads
18
+ self.n_vocab = n_vocab
19
+ self.d_head = d_head
20
+
21
+ super().__init__(**kwargs)
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b083a85825b7c07cb1e158bc0e48c47843dfb05b743bdc21ada4d264792aab1f
3
+ size 4984738880
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2a3509e808b4ae2e75fa8cc38b3d01cbe6cce37aae283c9190a8be905e50b4f
3
+ size 4966750152
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bc96857ce845ba7f3586a1243dcc26bafdc349e3800f2282d633490e5955f7e
3
+ size 4832532264
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a274ca806522e62ae56b99f44dd9e1e60360e2e8a02e46ca11123e7ecbd4a713
3
+ size 4832532264
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba9ee4bfa6bc20ba52c6471cf3713eee9eea912565b0329a23aa5cd6b52cd94e
3
+ size 4832532264
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6e3084031f7745e1932e4afe4dfd384a44b9c56868e05c1b4d5fbf085c76473
3
+ size 4832532264
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ab5dea9b901e40d43399e2ea89cabca3fd40eb4ad29ec9dc67074ccf3c7b96
3
+ size 1360614556
model.safetensors.index.json ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 7660549201,
4
+ "total_size": 30642196804
5
+ },
6
+ "weight_map": {
7
+ "blocks.0.attn.Wk.weight": "model-00001-of-00007.safetensors",
8
+ "blocks.0.attn.Wo.weight": "model-00001-of-00007.safetensors",
9
+ "blocks.0.attn.Wq.weight": "model-00001-of-00007.safetensors",
10
+ "blocks.0.attn.Wv.weight": "model-00001-of-00007.safetensors",
11
+ "blocks.0.mlp.0.bias": "model-00001-of-00007.safetensors",
12
+ "blocks.0.mlp.0.weight": "model-00001-of-00007.safetensors",
13
+ "blocks.0.mlp.2.bias": "model-00001-of-00007.safetensors",
14
+ "blocks.0.mlp.2.weight": "model-00001-of-00007.safetensors",
15
+ "blocks.0.norm1.weight": "model-00001-of-00007.safetensors",
16
+ "blocks.0.norm2.weight": "model-00001-of-00007.safetensors",
17
+ "blocks.1.attn.Wk.weight": "model-00001-of-00007.safetensors",
18
+ "blocks.1.attn.Wo.weight": "model-00001-of-00007.safetensors",
19
+ "blocks.1.attn.Wq.weight": "model-00001-of-00007.safetensors",
20
+ "blocks.1.attn.Wv.weight": "model-00001-of-00007.safetensors",
21
+ "blocks.1.mlp.0.bias": "model-00001-of-00007.safetensors",
22
+ "blocks.1.mlp.0.weight": "model-00001-of-00007.safetensors",
23
+ "blocks.1.mlp.2.bias": "model-00001-of-00007.safetensors",
24
+ "blocks.1.mlp.2.weight": "model-00001-of-00007.safetensors",
25
+ "blocks.1.norm1.weight": "model-00001-of-00007.safetensors",
26
+ "blocks.1.norm2.weight": "model-00001-of-00007.safetensors",
27
+ "blocks.10.attn.Wk.weight": "model-00002-of-00007.safetensors",
28
+ "blocks.10.attn.Wo.weight": "model-00002-of-00007.safetensors",
29
+ "blocks.10.attn.Wq.weight": "model-00002-of-00007.safetensors",
30
+ "blocks.10.attn.Wv.weight": "model-00002-of-00007.safetensors",
31
+ "blocks.10.mlp.0.bias": "model-00002-of-00007.safetensors",
32
+ "blocks.10.mlp.0.weight": "model-00002-of-00007.safetensors",
33
+ "blocks.10.mlp.2.bias": "model-00002-of-00007.safetensors",
34
+ "blocks.10.mlp.2.weight": "model-00002-of-00007.safetensors",
35
+ "blocks.10.norm1.weight": "model-00002-of-00007.safetensors",
36
+ "blocks.10.norm2.weight": "model-00002-of-00007.safetensors",
37
+ "blocks.11.attn.Wk.weight": "model-00002-of-00007.safetensors",
38
+ "blocks.11.attn.Wo.weight": "model-00002-of-00007.safetensors",
39
+ "blocks.11.attn.Wq.weight": "model-00002-of-00007.safetensors",
40
+ "blocks.11.attn.Wv.weight": "model-00002-of-00007.safetensors",
41
+ "blocks.11.mlp.0.bias": "model-00003-of-00007.safetensors",
42
+ "blocks.11.mlp.0.weight": "model-00003-of-00007.safetensors",
43
+ "blocks.11.mlp.2.bias": "model-00003-of-00007.safetensors",
44
+ "blocks.11.mlp.2.weight": "model-00003-of-00007.safetensors",
45
+ "blocks.11.norm1.weight": "model-00003-of-00007.safetensors",
46
+ "blocks.11.norm2.weight": "model-00003-of-00007.safetensors",
47
+ "blocks.12.attn.Wk.weight": "model-00003-of-00007.safetensors",
48
+ "blocks.12.attn.Wo.weight": "model-00003-of-00007.safetensors",
49
+ "blocks.12.attn.Wq.weight": "model-00003-of-00007.safetensors",
50
+ "blocks.12.attn.Wv.weight": "model-00003-of-00007.safetensors",
51
+ "blocks.12.mlp.0.bias": "model-00003-of-00007.safetensors",
52
+ "blocks.12.mlp.0.weight": "model-00003-of-00007.safetensors",
53
+ "blocks.12.mlp.2.bias": "model-00003-of-00007.safetensors",
54
+ "blocks.12.mlp.2.weight": "model-00003-of-00007.safetensors",
55
+ "blocks.12.norm1.weight": "model-00003-of-00007.safetensors",
56
+ "blocks.12.norm2.weight": "model-00003-of-00007.safetensors",
57
+ "blocks.13.attn.Wk.weight": "model-00003-of-00007.safetensors",
58
+ "blocks.13.attn.Wo.weight": "model-00003-of-00007.safetensors",
59
+ "blocks.13.attn.Wq.weight": "model-00003-of-00007.safetensors",
60
+ "blocks.13.attn.Wv.weight": "model-00003-of-00007.safetensors",
61
+ "blocks.13.mlp.0.bias": "model-00003-of-00007.safetensors",
62
+ "blocks.13.mlp.0.weight": "model-00003-of-00007.safetensors",
63
+ "blocks.13.mlp.2.bias": "model-00003-of-00007.safetensors",
64
+ "blocks.13.mlp.2.weight": "model-00003-of-00007.safetensors",
65
+ "blocks.13.norm1.weight": "model-00003-of-00007.safetensors",
66
+ "blocks.13.norm2.weight": "model-00003-of-00007.safetensors",
67
+ "blocks.14.attn.Wk.weight": "model-00003-of-00007.safetensors",
68
+ "blocks.14.attn.Wo.weight": "model-00003-of-00007.safetensors",
69
+ "blocks.14.attn.Wq.weight": "model-00003-of-00007.safetensors",
70
+ "blocks.14.attn.Wv.weight": "model-00003-of-00007.safetensors",
71
+ "blocks.14.mlp.0.bias": "model-00003-of-00007.safetensors",
72
+ "blocks.14.mlp.0.weight": "model-00003-of-00007.safetensors",
73
+ "blocks.14.mlp.2.bias": "model-00003-of-00007.safetensors",
74
+ "blocks.14.mlp.2.weight": "model-00003-of-00007.safetensors",
75
+ "blocks.14.norm1.weight": "model-00003-of-00007.safetensors",
76
+ "blocks.14.norm2.weight": "model-00003-of-00007.safetensors",
77
+ "blocks.15.attn.Wk.weight": "model-00003-of-00007.safetensors",
78
+ "blocks.15.attn.Wo.weight": "model-00003-of-00007.safetensors",
79
+ "blocks.15.attn.Wq.weight": "model-00003-of-00007.safetensors",
80
+ "blocks.15.attn.Wv.weight": "model-00003-of-00007.safetensors",
81
+ "blocks.15.mlp.0.bias": "model-00003-of-00007.safetensors",
82
+ "blocks.15.mlp.0.weight": "model-00003-of-00007.safetensors",
83
+ "blocks.15.mlp.2.bias": "model-00003-of-00007.safetensors",
84
+ "blocks.15.mlp.2.weight": "model-00003-of-00007.safetensors",
85
+ "blocks.15.norm1.weight": "model-00003-of-00007.safetensors",
86
+ "blocks.15.norm2.weight": "model-00003-of-00007.safetensors",
87
+ "blocks.16.attn.Wk.weight": "model-00003-of-00007.safetensors",
88
+ "blocks.16.attn.Wo.weight": "model-00003-of-00007.safetensors",
89
+ "blocks.16.attn.Wq.weight": "model-00003-of-00007.safetensors",
90
+ "blocks.16.attn.Wv.weight": "model-00003-of-00007.safetensors",
91
+ "blocks.16.mlp.0.bias": "model-00003-of-00007.safetensors",
92
+ "blocks.16.mlp.0.weight": "model-00003-of-00007.safetensors",
93
+ "blocks.16.mlp.2.bias": "model-00003-of-00007.safetensors",
94
+ "blocks.16.mlp.2.weight": "model-00003-of-00007.safetensors",
95
+ "blocks.16.norm1.weight": "model-00003-of-00007.safetensors",
96
+ "blocks.16.norm2.weight": "model-00003-of-00007.safetensors",
97
+ "blocks.17.attn.Wk.weight": "model-00003-of-00007.safetensors",
98
+ "blocks.17.attn.Wo.weight": "model-00003-of-00007.safetensors",
99
+ "blocks.17.attn.Wq.weight": "model-00003-of-00007.safetensors",
100
+ "blocks.17.attn.Wv.weight": "model-00003-of-00007.safetensors",
101
+ "blocks.17.mlp.0.bias": "model-00004-of-00007.safetensors",
102
+ "blocks.17.mlp.0.weight": "model-00004-of-00007.safetensors",
103
+ "blocks.17.mlp.2.bias": "model-00004-of-00007.safetensors",
104
+ "blocks.17.mlp.2.weight": "model-00004-of-00007.safetensors",
105
+ "blocks.17.norm1.weight": "model-00004-of-00007.safetensors",
106
+ "blocks.17.norm2.weight": "model-00004-of-00007.safetensors",
107
+ "blocks.18.attn.Wk.weight": "model-00004-of-00007.safetensors",
108
+ "blocks.18.attn.Wo.weight": "model-00004-of-00007.safetensors",
109
+ "blocks.18.attn.Wq.weight": "model-00004-of-00007.safetensors",
110
+ "blocks.18.attn.Wv.weight": "model-00004-of-00007.safetensors",
111
+ "blocks.18.mlp.0.bias": "model-00004-of-00007.safetensors",
112
+ "blocks.18.mlp.0.weight": "model-00004-of-00007.safetensors",
113
+ "blocks.18.mlp.2.bias": "model-00004-of-00007.safetensors",
114
+ "blocks.18.mlp.2.weight": "model-00004-of-00007.safetensors",
115
+ "blocks.18.norm1.weight": "model-00004-of-00007.safetensors",
116
+ "blocks.18.norm2.weight": "model-00004-of-00007.safetensors",
117
+ "blocks.19.attn.Wk.weight": "model-00004-of-00007.safetensors",
118
+ "blocks.19.attn.Wo.weight": "model-00004-of-00007.safetensors",
119
+ "blocks.19.attn.Wq.weight": "model-00004-of-00007.safetensors",
120
+ "blocks.19.attn.Wv.weight": "model-00004-of-00007.safetensors",
121
+ "blocks.19.mlp.0.bias": "model-00004-of-00007.safetensors",
122
+ "blocks.19.mlp.0.weight": "model-00004-of-00007.safetensors",
123
+ "blocks.19.mlp.2.bias": "model-00004-of-00007.safetensors",
124
+ "blocks.19.mlp.2.weight": "model-00004-of-00007.safetensors",
125
+ "blocks.19.norm1.weight": "model-00004-of-00007.safetensors",
126
+ "blocks.19.norm2.weight": "model-00004-of-00007.safetensors",
127
+ "blocks.2.attn.Wk.weight": "model-00001-of-00007.safetensors",
128
+ "blocks.2.attn.Wo.weight": "model-00001-of-00007.safetensors",
129
+ "blocks.2.attn.Wq.weight": "model-00001-of-00007.safetensors",
130
+ "blocks.2.attn.Wv.weight": "model-00001-of-00007.safetensors",
131
+ "blocks.2.mlp.0.bias": "model-00001-of-00007.safetensors",
132
+ "blocks.2.mlp.0.weight": "model-00001-of-00007.safetensors",
133
+ "blocks.2.mlp.2.bias": "model-00001-of-00007.safetensors",
134
+ "blocks.2.mlp.2.weight": "model-00001-of-00007.safetensors",
135
+ "blocks.2.norm1.weight": "model-00001-of-00007.safetensors",
136
+ "blocks.2.norm2.weight": "model-00001-of-00007.safetensors",
137
+ "blocks.20.attn.Wk.weight": "model-00004-of-00007.safetensors",
138
+ "blocks.20.attn.Wo.weight": "model-00004-of-00007.safetensors",
139
+ "blocks.20.attn.Wq.weight": "model-00004-of-00007.safetensors",
140
+ "blocks.20.attn.Wv.weight": "model-00004-of-00007.safetensors",
141
+ "blocks.20.mlp.0.bias": "model-00004-of-00007.safetensors",
142
+ "blocks.20.mlp.0.weight": "model-00004-of-00007.safetensors",
143
+ "blocks.20.mlp.2.bias": "model-00004-of-00007.safetensors",
144
+ "blocks.20.mlp.2.weight": "model-00004-of-00007.safetensors",
145
+ "blocks.20.norm1.weight": "model-00004-of-00007.safetensors",
146
+ "blocks.20.norm2.weight": "model-00004-of-00007.safetensors",
147
+ "blocks.21.attn.Wk.weight": "model-00004-of-00007.safetensors",
148
+ "blocks.21.attn.Wo.weight": "model-00004-of-00007.safetensors",
149
+ "blocks.21.attn.Wq.weight": "model-00004-of-00007.safetensors",
150
+ "blocks.21.attn.Wv.weight": "model-00004-of-00007.safetensors",
151
+ "blocks.21.mlp.0.bias": "model-00004-of-00007.safetensors",
152
+ "blocks.21.mlp.0.weight": "model-00004-of-00007.safetensors",
153
+ "blocks.21.mlp.2.bias": "model-00004-of-00007.safetensors",
154
+ "blocks.21.mlp.2.weight": "model-00004-of-00007.safetensors",
155
+ "blocks.21.norm1.weight": "model-00004-of-00007.safetensors",
156
+ "blocks.21.norm2.weight": "model-00004-of-00007.safetensors",
157
+ "blocks.22.attn.Wk.weight": "model-00004-of-00007.safetensors",
158
+ "blocks.22.attn.Wo.weight": "model-00004-of-00007.safetensors",
159
+ "blocks.22.attn.Wq.weight": "model-00004-of-00007.safetensors",
160
+ "blocks.22.attn.Wv.weight": "model-00004-of-00007.safetensors",
161
+ "blocks.22.mlp.0.bias": "model-00004-of-00007.safetensors",
162
+ "blocks.22.mlp.0.weight": "model-00004-of-00007.safetensors",
163
+ "blocks.22.mlp.2.bias": "model-00004-of-00007.safetensors",
164
+ "blocks.22.mlp.2.weight": "model-00004-of-00007.safetensors",
165
+ "blocks.22.norm1.weight": "model-00004-of-00007.safetensors",
166
+ "blocks.22.norm2.weight": "model-00004-of-00007.safetensors",
167
+ "blocks.23.attn.Wk.weight": "model-00004-of-00007.safetensors",
168
+ "blocks.23.attn.Wo.weight": "model-00004-of-00007.safetensors",
169
+ "blocks.23.attn.Wq.weight": "model-00004-of-00007.safetensors",
170
+ "blocks.23.attn.Wv.weight": "model-00004-of-00007.safetensors",
171
+ "blocks.23.mlp.0.bias": "model-00005-of-00007.safetensors",
172
+ "blocks.23.mlp.0.weight": "model-00005-of-00007.safetensors",
173
+ "blocks.23.mlp.2.bias": "model-00005-of-00007.safetensors",
174
+ "blocks.23.mlp.2.weight": "model-00005-of-00007.safetensors",
175
+ "blocks.23.norm1.weight": "model-00005-of-00007.safetensors",
176
+ "blocks.23.norm2.weight": "model-00005-of-00007.safetensors",
177
+ "blocks.24.attn.Wk.weight": "model-00005-of-00007.safetensors",
178
+ "blocks.24.attn.Wo.weight": "model-00005-of-00007.safetensors",
179
+ "blocks.24.attn.Wq.weight": "model-00005-of-00007.safetensors",
180
+ "blocks.24.attn.Wv.weight": "model-00005-of-00007.safetensors",
181
+ "blocks.24.mlp.0.bias": "model-00005-of-00007.safetensors",
182
+ "blocks.24.mlp.0.weight": "model-00005-of-00007.safetensors",
183
+ "blocks.24.mlp.2.bias": "model-00005-of-00007.safetensors",
184
+ "blocks.24.mlp.2.weight": "model-00005-of-00007.safetensors",
185
+ "blocks.24.norm1.weight": "model-00005-of-00007.safetensors",
186
+ "blocks.24.norm2.weight": "model-00005-of-00007.safetensors",
187
+ "blocks.25.attn.Wk.weight": "model-00005-of-00007.safetensors",
188
+ "blocks.25.attn.Wo.weight": "model-00005-of-00007.safetensors",
189
+ "blocks.25.attn.Wq.weight": "model-00005-of-00007.safetensors",
190
+ "blocks.25.attn.Wv.weight": "model-00005-of-00007.safetensors",
191
+ "blocks.25.mlp.0.bias": "model-00005-of-00007.safetensors",
192
+ "blocks.25.mlp.0.weight": "model-00005-of-00007.safetensors",
193
+ "blocks.25.mlp.2.bias": "model-00005-of-00007.safetensors",
194
+ "blocks.25.mlp.2.weight": "model-00005-of-00007.safetensors",
195
+ "blocks.25.norm1.weight": "model-00005-of-00007.safetensors",
196
+ "blocks.25.norm2.weight": "model-00005-of-00007.safetensors",
197
+ "blocks.26.attn.Wk.weight": "model-00005-of-00007.safetensors",
198
+ "blocks.26.attn.Wo.weight": "model-00005-of-00007.safetensors",
199
+ "blocks.26.attn.Wq.weight": "model-00005-of-00007.safetensors",
200
+ "blocks.26.attn.Wv.weight": "model-00005-of-00007.safetensors",
201
+ "blocks.26.mlp.0.bias": "model-00005-of-00007.safetensors",
202
+ "blocks.26.mlp.0.weight": "model-00005-of-00007.safetensors",
203
+ "blocks.26.mlp.2.bias": "model-00005-of-00007.safetensors",
204
+ "blocks.26.mlp.2.weight": "model-00005-of-00007.safetensors",
205
+ "blocks.26.norm1.weight": "model-00005-of-00007.safetensors",
206
+ "blocks.26.norm2.weight": "model-00005-of-00007.safetensors",
207
+ "blocks.27.attn.Wk.weight": "model-00005-of-00007.safetensors",
208
+ "blocks.27.attn.Wo.weight": "model-00005-of-00007.safetensors",
209
+ "blocks.27.attn.Wq.weight": "model-00005-of-00007.safetensors",
210
+ "blocks.27.attn.Wv.weight": "model-00005-of-00007.safetensors",
211
+ "blocks.27.mlp.0.bias": "model-00005-of-00007.safetensors",
212
+ "blocks.27.mlp.0.weight": "model-00005-of-00007.safetensors",
213
+ "blocks.27.mlp.2.bias": "model-00005-of-00007.safetensors",
214
+ "blocks.27.mlp.2.weight": "model-00005-of-00007.safetensors",
215
+ "blocks.27.norm1.weight": "model-00005-of-00007.safetensors",
216
+ "blocks.27.norm2.weight": "model-00005-of-00007.safetensors",
217
+ "blocks.28.attn.Wk.weight": "model-00005-of-00007.safetensors",
218
+ "blocks.28.attn.Wo.weight": "model-00005-of-00007.safetensors",
219
+ "blocks.28.attn.Wq.weight": "model-00005-of-00007.safetensors",
220
+ "blocks.28.attn.Wv.weight": "model-00005-of-00007.safetensors",
221
+ "blocks.28.mlp.0.bias": "model-00005-of-00007.safetensors",
222
+ "blocks.28.mlp.0.weight": "model-00005-of-00007.safetensors",
223
+ "blocks.28.mlp.2.bias": "model-00005-of-00007.safetensors",
224
+ "blocks.28.mlp.2.weight": "model-00005-of-00007.safetensors",
225
+ "blocks.28.norm1.weight": "model-00005-of-00007.safetensors",
226
+ "blocks.28.norm2.weight": "model-00005-of-00007.safetensors",
227
+ "blocks.29.attn.Wk.weight": "model-00005-of-00007.safetensors",
228
+ "blocks.29.attn.Wo.weight": "model-00005-of-00007.safetensors",
229
+ "blocks.29.attn.Wq.weight": "model-00005-of-00007.safetensors",
230
+ "blocks.29.attn.Wv.weight": "model-00005-of-00007.safetensors",
231
+ "blocks.29.mlp.0.bias": "model-00006-of-00007.safetensors",
232
+ "blocks.29.mlp.0.weight": "model-00006-of-00007.safetensors",
233
+ "blocks.29.mlp.2.bias": "model-00006-of-00007.safetensors",
234
+ "blocks.29.mlp.2.weight": "model-00006-of-00007.safetensors",
235
+ "blocks.29.norm1.weight": "model-00006-of-00007.safetensors",
236
+ "blocks.29.norm2.weight": "model-00006-of-00007.safetensors",
237
+ "blocks.3.attn.Wk.weight": "model-00001-of-00007.safetensors",
238
+ "blocks.3.attn.Wo.weight": "model-00001-of-00007.safetensors",
239
+ "blocks.3.attn.Wq.weight": "model-00001-of-00007.safetensors",
240
+ "blocks.3.attn.Wv.weight": "model-00001-of-00007.safetensors",
241
+ "blocks.3.mlp.0.bias": "model-00001-of-00007.safetensors",
242
+ "blocks.3.mlp.0.weight": "model-00001-of-00007.safetensors",
243
+ "blocks.3.mlp.2.bias": "model-00001-of-00007.safetensors",
244
+ "blocks.3.mlp.2.weight": "model-00001-of-00007.safetensors",
245
+ "blocks.3.norm1.weight": "model-00001-of-00007.safetensors",
246
+ "blocks.3.norm2.weight": "model-00001-of-00007.safetensors",
247
+ "blocks.30.attn.Wk.weight": "model-00006-of-00007.safetensors",
248
+ "blocks.30.attn.Wo.weight": "model-00006-of-00007.safetensors",
249
+ "blocks.30.attn.Wq.weight": "model-00006-of-00007.safetensors",
250
+ "blocks.30.attn.Wv.weight": "model-00006-of-00007.safetensors",
251
+ "blocks.30.mlp.0.bias": "model-00006-of-00007.safetensors",
252
+ "blocks.30.mlp.0.weight": "model-00006-of-00007.safetensors",
253
+ "blocks.30.mlp.2.bias": "model-00006-of-00007.safetensors",
254
+ "blocks.30.mlp.2.weight": "model-00006-of-00007.safetensors",
255
+ "blocks.30.norm1.weight": "model-00006-of-00007.safetensors",
256
+ "blocks.30.norm2.weight": "model-00006-of-00007.safetensors",
257
+ "blocks.31.attn.Wk.weight": "model-00006-of-00007.safetensors",
258
+ "blocks.31.attn.Wo.weight": "model-00006-of-00007.safetensors",
259
+ "blocks.31.attn.Wq.weight": "model-00006-of-00007.safetensors",
260
+ "blocks.31.attn.Wv.weight": "model-00006-of-00007.safetensors",
261
+ "blocks.31.mlp.0.bias": "model-00006-of-00007.safetensors",
262
+ "blocks.31.mlp.0.weight": "model-00006-of-00007.safetensors",
263
+ "blocks.31.mlp.2.bias": "model-00006-of-00007.safetensors",
264
+ "blocks.31.mlp.2.weight": "model-00006-of-00007.safetensors",
265
+ "blocks.31.norm1.weight": "model-00006-of-00007.safetensors",
266
+ "blocks.31.norm2.weight": "model-00006-of-00007.safetensors",
267
+ "blocks.32.attn.Wk.weight": "model-00006-of-00007.safetensors",
268
+ "blocks.32.attn.Wo.weight": "model-00006-of-00007.safetensors",
269
+ "blocks.32.attn.Wq.weight": "model-00006-of-00007.safetensors",
270
+ "blocks.32.attn.Wv.weight": "model-00006-of-00007.safetensors",
271
+ "blocks.32.mlp.0.bias": "model-00006-of-00007.safetensors",
272
+ "blocks.32.mlp.0.weight": "model-00006-of-00007.safetensors",
273
+ "blocks.32.mlp.2.bias": "model-00006-of-00007.safetensors",
274
+ "blocks.32.mlp.2.weight": "model-00006-of-00007.safetensors",
275
+ "blocks.32.norm1.weight": "model-00006-of-00007.safetensors",
276
+ "blocks.32.norm2.weight": "model-00006-of-00007.safetensors",
277
+ "blocks.33.attn.Wk.weight": "model-00006-of-00007.safetensors",
278
+ "blocks.33.attn.Wo.weight": "model-00006-of-00007.safetensors",
279
+ "blocks.33.attn.Wq.weight": "model-00006-of-00007.safetensors",
280
+ "blocks.33.attn.Wv.weight": "model-00006-of-00007.safetensors",
281
+ "blocks.33.mlp.0.bias": "model-00006-of-00007.safetensors",
282
+ "blocks.33.mlp.0.weight": "model-00006-of-00007.safetensors",
283
+ "blocks.33.mlp.2.bias": "model-00006-of-00007.safetensors",
284
+ "blocks.33.mlp.2.weight": "model-00006-of-00007.safetensors",
285
+ "blocks.33.norm1.weight": "model-00006-of-00007.safetensors",
286
+ "blocks.33.norm2.weight": "model-00006-of-00007.safetensors",
287
+ "blocks.34.attn.Wk.weight": "model-00006-of-00007.safetensors",
288
+ "blocks.34.attn.Wo.weight": "model-00006-of-00007.safetensors",
289
+ "blocks.34.attn.Wq.weight": "model-00006-of-00007.safetensors",
290
+ "blocks.34.attn.Wv.weight": "model-00006-of-00007.safetensors",
291
+ "blocks.34.mlp.0.bias": "model-00006-of-00007.safetensors",
292
+ "blocks.34.mlp.0.weight": "model-00006-of-00007.safetensors",
293
+ "blocks.34.mlp.2.bias": "model-00006-of-00007.safetensors",
294
+ "blocks.34.mlp.2.weight": "model-00006-of-00007.safetensors",
295
+ "blocks.34.norm1.weight": "model-00006-of-00007.safetensors",
296
+ "blocks.34.norm2.weight": "model-00006-of-00007.safetensors",
297
+ "blocks.35.attn.Wk.weight": "model-00006-of-00007.safetensors",
298
+ "blocks.35.attn.Wo.weight": "model-00006-of-00007.safetensors",
299
+ "blocks.35.attn.Wq.weight": "model-00006-of-00007.safetensors",
300
+ "blocks.35.attn.Wv.weight": "model-00006-of-00007.safetensors",
301
+ "blocks.35.mlp.0.bias": "model-00007-of-00007.safetensors",
302
+ "blocks.35.mlp.0.weight": "model-00007-of-00007.safetensors",
303
+ "blocks.35.mlp.2.bias": "model-00007-of-00007.safetensors",
304
+ "blocks.35.mlp.2.weight": "model-00007-of-00007.safetensors",
305
+ "blocks.35.norm1.weight": "model-00007-of-00007.safetensors",
306
+ "blocks.35.norm2.weight": "model-00007-of-00007.safetensors",
307
+ "blocks.4.attn.Wk.weight": "model-00001-of-00007.safetensors",
308
+ "blocks.4.attn.Wo.weight": "model-00001-of-00007.safetensors",
309
+ "blocks.4.attn.Wq.weight": "model-00001-of-00007.safetensors",
310
+ "blocks.4.attn.Wv.weight": "model-00001-of-00007.safetensors",
311
+ "blocks.4.mlp.0.bias": "model-00001-of-00007.safetensors",
312
+ "blocks.4.mlp.0.weight": "model-00001-of-00007.safetensors",
313
+ "blocks.4.mlp.2.bias": "model-00001-of-00007.safetensors",
314
+ "blocks.4.mlp.2.weight": "model-00001-of-00007.safetensors",
315
+ "blocks.4.norm1.weight": "model-00001-of-00007.safetensors",
316
+ "blocks.4.norm2.weight": "model-00001-of-00007.safetensors",
317
+ "blocks.5.attn.Wk.weight": "model-00001-of-00007.safetensors",
318
+ "blocks.5.attn.Wo.weight": "model-00002-of-00007.safetensors",
319
+ "blocks.5.attn.Wq.weight": "model-00001-of-00007.safetensors",
320
+ "blocks.5.attn.Wv.weight": "model-00002-of-00007.safetensors",
321
+ "blocks.5.mlp.0.bias": "model-00002-of-00007.safetensors",
322
+ "blocks.5.mlp.0.weight": "model-00002-of-00007.safetensors",
323
+ "blocks.5.mlp.2.bias": "model-00002-of-00007.safetensors",
324
+ "blocks.5.mlp.2.weight": "model-00002-of-00007.safetensors",
325
+ "blocks.5.norm1.weight": "model-00002-of-00007.safetensors",
326
+ "blocks.5.norm2.weight": "model-00002-of-00007.safetensors",
327
+ "blocks.6.attn.Wk.weight": "model-00002-of-00007.safetensors",
328
+ "blocks.6.attn.Wo.weight": "model-00002-of-00007.safetensors",
329
+ "blocks.6.attn.Wq.weight": "model-00002-of-00007.safetensors",
330
+ "blocks.6.attn.Wv.weight": "model-00002-of-00007.safetensors",
331
+ "blocks.6.mlp.0.bias": "model-00002-of-00007.safetensors",
332
+ "blocks.6.mlp.0.weight": "model-00002-of-00007.safetensors",
333
+ "blocks.6.mlp.2.bias": "model-00002-of-00007.safetensors",
334
+ "blocks.6.mlp.2.weight": "model-00002-of-00007.safetensors",
335
+ "blocks.6.norm1.weight": "model-00002-of-00007.safetensors",
336
+ "blocks.6.norm2.weight": "model-00002-of-00007.safetensors",
337
+ "blocks.7.attn.Wk.weight": "model-00002-of-00007.safetensors",
338
+ "blocks.7.attn.Wo.weight": "model-00002-of-00007.safetensors",
339
+ "blocks.7.attn.Wq.weight": "model-00002-of-00007.safetensors",
340
+ "blocks.7.attn.Wv.weight": "model-00002-of-00007.safetensors",
341
+ "blocks.7.mlp.0.bias": "model-00002-of-00007.safetensors",
342
+ "blocks.7.mlp.0.weight": "model-00002-of-00007.safetensors",
343
+ "blocks.7.mlp.2.bias": "model-00002-of-00007.safetensors",
344
+ "blocks.7.mlp.2.weight": "model-00002-of-00007.safetensors",
345
+ "blocks.7.norm1.weight": "model-00002-of-00007.safetensors",
346
+ "blocks.7.norm2.weight": "model-00002-of-00007.safetensors",
347
+ "blocks.8.attn.Wk.weight": "model-00002-of-00007.safetensors",
348
+ "blocks.8.attn.Wo.weight": "model-00002-of-00007.safetensors",
349
+ "blocks.8.attn.Wq.weight": "model-00002-of-00007.safetensors",
350
+ "blocks.8.attn.Wv.weight": "model-00002-of-00007.safetensors",
351
+ "blocks.8.mlp.0.bias": "model-00002-of-00007.safetensors",
352
+ "blocks.8.mlp.0.weight": "model-00002-of-00007.safetensors",
353
+ "blocks.8.mlp.2.bias": "model-00002-of-00007.safetensors",
354
+ "blocks.8.mlp.2.weight": "model-00002-of-00007.safetensors",
355
+ "blocks.8.norm1.weight": "model-00002-of-00007.safetensors",
356
+ "blocks.8.norm2.weight": "model-00002-of-00007.safetensors",
357
+ "blocks.9.attn.Wk.weight": "model-00002-of-00007.safetensors",
358
+ "blocks.9.attn.Wo.weight": "model-00002-of-00007.safetensors",
359
+ "blocks.9.attn.Wq.weight": "model-00002-of-00007.safetensors",
360
+ "blocks.9.attn.Wv.weight": "model-00002-of-00007.safetensors",
361
+ "blocks.9.mlp.0.bias": "model-00002-of-00007.safetensors",
362
+ "blocks.9.mlp.0.weight": "model-00002-of-00007.safetensors",
363
+ "blocks.9.mlp.2.bias": "model-00002-of-00007.safetensors",
364
+ "blocks.9.mlp.2.weight": "model-00002-of-00007.safetensors",
365
+ "blocks.9.norm1.weight": "model-00002-of-00007.safetensors",
366
+ "blocks.9.norm2.weight": "model-00002-of-00007.safetensors",
367
+ "embed.weight": "model-00001-of-00007.safetensors",
368
+ "norm.weight": "model-00007-of-00007.safetensors",
369
+ "out_head.bias": "model-00007-of-00007.safetensors",
370
+ "out_head.weight": "model-00007-of-00007.safetensors"
371
+ }
372
+ }
modeling.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.attention import sdpa_kernel, SDPBackend
5
+ from transformers import PreTrainedModel
6
+ from .configuration import MBZTestConfig
7
+ from transformers.modeling_outputs import CausalLMOutput
8
+
9
+ class RotaryPositionalEncoding(nn.Module):
10
+ """
11
+ Rotary Position Embeddings (RoPE) - efficient implementation
12
+ """
13
+ def __init__(self, d_head, max_seq_len=8192, base=10000.0):
14
+ super().__init__()
15
+ self.d_head = d_head
16
+ self.max_seq_len = max_seq_len
17
+ self.base = base
18
+
19
+ # Precompute inverse frequencies
20
+ inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
21
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
22
+
23
+ # Precompute cos and sin for maximum sequence length
24
+ self._precompute_freqs(max_seq_len)
25
+
26
+ def _precompute_freqs(self, seq_len):
27
+ """Precompute cos and sin values for positions"""
28
+ t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
29
+ freqs = torch.outer(t, self.inv_freq) # (seq_len, d_head/2)
30
+
31
+ # Create cos and sin embeddings
32
+ freqs_cos = torch.cos(freqs)
33
+ freqs_sin = torch.sin(freqs)
34
+
35
+ # Interleave to match the dimension (seq_len, d_head)
36
+ self.register_buffer('freqs_cos', freqs_cos.repeat_interleave(2, dim=-1), persistent=False)
37
+ self.register_buffer('freqs_sin', freqs_sin.repeat_interleave(2, dim=-1), persistent=False)
38
+
39
+ def rotate_half(self, x):
40
+ """Rotate half the hidden dims of the input"""
41
+ x1 = x[..., ::2]
42
+ x2 = x[..., 1::2]
43
+ return torch.stack([-x2, x1], dim=-1).flatten(-2)
44
+
45
+ def forward(self, q, k, start_pos=0):
46
+ """
47
+ Apply rotary embeddings to query and key tensors
48
+ Args:
49
+ q: (batch_size, n_heads, seq_len, d_head)
50
+ k: (batch_size, n_heads, seq_len, d_head)
51
+ start_pos: starting position for caching scenarios
52
+ Returns:
53
+ q_rot, k_rot with rotary embeddings applied
54
+ """
55
+ seq_len = q.shape[2]
56
+
57
+ # Get the precomputed frequencies for this sequence length
58
+ freqs_cos = self.freqs_cos[start_pos:start_pos + seq_len]
59
+ freqs_sin = self.freqs_sin[start_pos:start_pos + seq_len]
60
+
61
+ # Apply rotary embeddings
62
+ q_rot = q * freqs_cos + self.rotate_half(q) * freqs_sin
63
+ k_rot = k * freqs_cos + self.rotate_half(k) * freqs_sin
64
+
65
+ return q_rot, k_rot
66
+
67
+ class Attention(nn.Module):
68
+ def __init__(self, d_model, n_heads, d_head):
69
+ super().__init__()
70
+ self.d_model = d_model
71
+ self.n_heads = n_heads
72
+ self.d_head = d_head
73
+
74
+ self.Wq = nn.Linear(d_model, n_heads * d_head, bias=False)
75
+ self.Wk = nn.Linear(d_model, n_heads * d_head, bias=False)
76
+ self.Wv = nn.Linear(d_model, n_heads * d_head, bias=False)
77
+ self.Wo = nn.Linear(n_heads * d_head, d_model, bias=False)
78
+
79
+ # Initialize RoPE
80
+ self.rope = RotaryPositionalEncoding(d_head)
81
+
82
+ def forward(self, x):
83
+ # x is shape batch_size, seq_len, d_model
84
+ batch_size, seq_len, d_model = x.shape
85
+ q = self.Wq(x) # q is shape batch_size, seq_len, n_heads * d_head
86
+ k = self.Wk(x)
87
+ v = self.Wv(x)
88
+
89
+ # reshape to batch_size, n_heads, seq_len, d_head
90
+ q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
91
+ k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
92
+ v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
93
+
94
+ q, k = self.rope(q, k)
95
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # ensure use flash attention
96
+ a = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)# a is (batch_size, n_heads, seq_len, d_head)
97
+ a = a.transpose(1,2) # change a to (batch_size, seq_len, n_heads, d_head)
98
+ a = a.reshape(batch_size, seq_len, self.n_heads * self.d_head)
99
+ out = self.Wo(a) # out is (batch_size, seq_len, d_model)
100
+ return out
101
+
102
+ class TransformerBlock(nn.Module):
103
+ def __init__(self, d_model, n_heads, d_head):
104
+ super().__init__()
105
+ self.d_model = d_model
106
+ self.n_heads = n_heads
107
+ self.d_head = d_head
108
+
109
+ self.attn = Attention(d_model, n_heads, d_head)
110
+ self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model))
111
+
112
+ self.norm1 = nn.RMSNorm(d_model)
113
+ self.norm2 = nn.RMSNorm(d_model)
114
+
115
+ def forward(self, x):
116
+ x = self.attn(self.norm1(x)) + x
117
+ x = self.mlp(self.norm2(x)) + x
118
+ return x
119
+
120
+ class MBZTestModelForCausalLM(PreTrainedModel):
121
+ config_class = MBZTestConfig
122
+
123
+ def __init__(self, config):
124
+ super().__init__(config)
125
+ d_model = config.d_model
126
+ n_heads = config.n_heads
127
+ d_head = config.d_head
128
+ n_vocab = config.n_vocab
129
+ n_layers = config.n_layers
130
+
131
+ self.d_model = d_model
132
+ self.n_heads = n_heads
133
+ self.d_head = d_head
134
+ self.n_vocab = n_vocab
135
+
136
+ self.embed = nn.Embedding(n_vocab, d_model)
137
+
138
+ self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_head) for _ in range(n_layers)])
139
+
140
+ self.norm = nn.RMSNorm(d_model)
141
+ self.out_head = nn.Linear(d_model, n_vocab)
142
+
143
+ def forward(self, x):
144
+ x = self.embed(x)
145
+ for block in self.blocks:
146
+ x = block(x)
147
+ x = self.out_head(self.norm(x))
148
+ return CausalLMOutput(logits=x)