akkiisfrommars commited on
Commit
e226163
·
verified ·
1 Parent(s): b7345fe

Upload 7 files

Browse files
Files changed (7) hide show
  1. LICENSE +201 -0
  2. README.md +241 -3
  3. config.json +25 -0
  4. example_usage.py +57 -0
  5. modeling_cosmicfish.py +290 -0
  6. tokenizer_config.json +11 -0
  7. vocab_info.json +5 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2026 Mistyoz AI Private Limited
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,241 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - text-generation
5
+ - language-model
6
+ - causal-lm
7
+ - cosmicfish
8
+ - 300m
9
+ - transformer
10
+ - rope
11
+ - gqa
12
+ - swiglu
13
+ - rmsnorm
14
+ language: en
15
+ datasets:
16
+ - CosmicSet-2.0
17
+ - akkiisfrommars/TreeCorpusCleanedmodel
18
+ model_type: CosmicFish
19
+ pipeline_tag: text-generation
20
+ ---
21
+
22
+ # CosmicFish-300M
23
+
24
+ A 300M parameter language model with modern architecture improvements developed by Mistyoz AI.
25
+
26
+ ## Quick Start
27
+
28
+ **The easiest way to chat with CosmicFish is using our chat.py script:**
29
+
30
+ ```bash
31
+ # Download the chat script from this repository
32
+ wget https://huggingface.co/MistyozAI/CosmicFish-300M/resolve/main/chat.py
33
+
34
+ # Install dependencies
35
+ pip install transformers huggingface-hub termcolor safetensors
36
+
37
+ # Run the chat interface (automatically downloads model)
38
+ python chat.py
39
+ ```
40
+
41
+ The `chat.py` script handles all model loading, generation, and provides the best chat experience with live streaming, repetition penalty, and conversation commands.
42
+
43
+ ## Model Details
44
+
45
+ - **Parameters**: 369M
46
+ - **Architecture**: CosmicFish (RoPE, GQA, SwiGLU, RMSNorm)
47
+ - **Context Length**: 2048 tokens
48
+ - **Vocabulary**: 50,257 tokens
49
+ - **Training Data**: CosmicSet 2.0
50
+ - **Developer**: Mistyoz AI
51
+ - **Repository**: MistyozAI/CosmicFish-300M
52
+ - **Format**: Safetensors
53
+
54
+ ## Usage
55
+
56
+ ### Installation
57
+
58
+ ```bash
59
+ pip install transformers huggingface-hub termcolor safetensors torch
60
+ ```
61
+
62
+ ### Quick Chat Interface
63
+
64
+ ```python
65
+ from transformers import GPT2Tokenizer
66
+ from huggingface_hub import snapshot_download
67
+ from safetensors.torch import load_file
68
+ import torch
69
+ import json
70
+ import os
71
+
72
+ # Download model from Hugging Face Hub
73
+ cache_dir = snapshot_download(repo_id="MistyozAI/CosmicFish-300M")
74
+
75
+ # Load tokenizer
76
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
77
+
78
+ # Load config
79
+ with open(os.path.join(cache_dir, "config.json")) as f:
80
+ config_dict = json.load(f)
81
+
82
+ # Load model weights from safetensors
83
+ state_dict = load_file(os.path.join(cache_dir, "model.safetensors"))
84
+
85
+ # Note: Full model class available in the repository
86
+ print("Model downloaded and ready for use!")
87
+ ```
88
+
89
+ ### Advanced Generation with Repetition Penalty
90
+
91
+ ```python
92
+ def generate_with_repetition_penalty(model, tokenizer, prompt, max_tokens=100, temperature=0.5, penalty=1.2):
93
+ input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
94
+ generated = input_ids.clone()
95
+
96
+ for _ in range(max_tokens):
97
+ with torch.no_grad():
98
+ logits, _ = model(generated)
99
+
100
+ next_token_logits = logits[:, -1, :] / temperature
101
+
102
+ # Apply repetition penalty
103
+ if penalty > 1.0:
104
+ for token_id in set(generated[0].tolist()):
105
+ if next_token_logits[0, token_id] > 0:
106
+ next_token_logits[0, token_id] /= penalty
107
+ else:
108
+ next_token_logits[0, token_id] *= penalty
109
+
110
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
111
+ next_token = torch.multinomial(probs, num_samples=1)
112
+
113
+ if next_token.item() == tokenizer.eos_token_id:
114
+ break
115
+
116
+ generated = torch.cat([generated, next_token], dim=1)
117
+
118
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
119
+ ```
120
+
121
+ ### Loading Model with Safetensors
122
+
123
+ ```python
124
+ from safetensors.torch import load_file
125
+ from modeling_cosmicfish import CosmicFish, CosmicConfig
126
+ import json
127
+
128
+ def load_cosmicfish_model(model_path):
129
+ # Load config
130
+ with open(os.path.join(model_path, "config.json")) as f:
131
+ config_dict = json.load(f)
132
+
133
+ # Create model config
134
+ config = CosmicConfig(
135
+ vocab_size=config_dict["vocab_size"],
136
+ block_size=config_dict["block_size"],
137
+ n_layer=config_dict["n_layer"],
138
+ n_head=config_dict["n_head"],
139
+ n_embd=config_dict["n_embd"],
140
+ bias=config_dict["bias"],
141
+ dropout=0.0,
142
+ use_rotary=config_dict["use_rotary"],
143
+ use_swiglu=config_dict["use_swiglu"],
144
+ use_gqa=config_dict["use_gqa"],
145
+ n_query_groups=config_dict["n_query_groups"]
146
+ )
147
+
148
+ # Create model
149
+ model = CosmicFish(config)
150
+
151
+ # Load weights from safetensors (secure format)
152
+ state_dict = load_file(os.path.join(model_path, "model.safetensors"))
153
+
154
+ # Handle weight sharing (lm_head.weight shares with transformer.wte.weight)
155
+ if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
156
+ state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
157
+
158
+ model.load_state_dict(state_dict)
159
+ model.eval()
160
+
161
+ return model
162
+ ```
163
+
164
+ ### Chat Interface
165
+
166
+ ```python
167
+ def chat_with_model():
168
+ conversation = []
169
+
170
+ while True:
171
+ user_input = input("You: ")
172
+ if user_input.lower() in ['quit', 'exit']:
173
+ break
174
+
175
+ context = "Below is a conversation between a human and an AI assistant.\n\n"
176
+ for human, ai in conversation:
177
+ context += f"Human: {human}\nAssistant: {ai}\n\n"
178
+ context += f"Human: {user_input}\nAssistant:"
179
+
180
+ # Generate response with repetition penalty
181
+ response = generate_with_repetition_penalty(
182
+ model, tokenizer, context,
183
+ max_tokens=150, temperature=0.7, penalty=1.2
184
+ )
185
+
186
+ # Extract just the assistant's response
187
+ response = response.split("Assistant:")[-1].split('\n')[0].strip()
188
+ print(f"CosmicFish: {response}")
189
+
190
+ conversation.append((user_input, response))
191
+
192
+ chat_with_model()
193
+ ```
194
+
195
+ ## Architecture
196
+
197
+ CosmicFish uses several modern improvements over standard transformers:
198
+
199
+ - **RoPE (Rotary Position Embeddings)**: Better position encoding than absolute positions
200
+ - **GQA (Grouped-Query Attention)**: Reduces memory usage with 4 query groups
201
+ - **SwiGLU**: More effective activation function than ReLU/GELU
202
+ - **RMSNorm**: Simpler, more stable normalization than LayerNorm
203
+
204
+ ## Training
205
+
206
+ - **Dataset**: CosmicSet 2.0
207
+ - **Sequence Length**: 2048 tokens
208
+ - **Training Steps**: ~130K iterations
209
+ - **Hardware**: Nvidia RTX Pro 6000 x1
210
+
211
+ ## Performance
212
+
213
+ - **Speed**: Varies by hardware (not benchmarked)
214
+ - **Memory**: ~1GB RAM
215
+ - **File Size**: 738.6MB
216
+ - **Loading**: Fast and secure with safetensors
217
+
218
+ ## Model Format
219
+
220
+ This model uses **safetensors** format for:
221
+ - **Security**: Safe loading without arbitrary code execution
222
+ - **Performance**: Faster loading compared to pickle-based formats
223
+ - **Memory efficiency**: Zero-copy loading when possible
224
+ - **Cross-platform compatibility**: Works consistently across different environments
225
+
226
+ ## Limitations
227
+
228
+ - May produce less accurate responses
229
+ - 2048 token context limit
230
+ - English only
231
+ - Training data cutoff applies
232
+ - May generate incorrect information
233
+ - Cannot browse internet or access real-time data
234
+
235
+ ## License
236
+
237
+ Apache 2.0 - see LICENSE file.
238
+
239
+ ## Credit
240
+
241
+ If you use CosmicFish-300M, please credit Mistyoz AI.
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "cosmicfish",
3
+ "architectures": [
4
+ "CosmicFish"
5
+ ],
6
+ "vocab_size": 50257,
7
+ "n_embd": 960,
8
+ "n_layer": 24,
9
+ "n_head": 24,
10
+ "block_size": 2048,
11
+ "bias": true,
12
+ "dropout": 0.1,
13
+ "eps": 1e-06,
14
+ "use_rotary": true,
15
+ "use_swiglu": true,
16
+ "use_gqa": true,
17
+ "use_qk_norm": false,
18
+ "n_query_groups": 4,
19
+ "torch_dtype": "float16",
20
+ "transformers_version": "4.36.0",
21
+ "use_cache": true,
22
+ "pad_token_id": 50256,
23
+ "bos_token_id": 50256,
24
+ "eos_token_id": 50256
25
+ }
example_usage.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage of CosmicFish model (using safetensors)
3
+ """
4
+ import torch
5
+ from transformers import GPT2Tokenizer
6
+ from modeling_cosmicfish import CosmicFish, CosmicConfig
7
+ from safetensors.torch import load_file
8
+ import json
9
+
10
+ def load_cosmicfish(model_dir):
11
+ """Load CosmicFish model and tokenizer"""
12
+ # Load config
13
+ with open(f"{model_dir}/config.json", "r") as f:
14
+ config_dict = json.load(f)
15
+
16
+ # Create CosmicConfig
17
+ config = CosmicConfig(
18
+ vocab_size=config_dict["vocab_size"],
19
+ block_size=config_dict["block_size"],
20
+ n_layer=config_dict["n_layer"],
21
+ n_head=config_dict["n_head"],
22
+ n_embd=config_dict["n_embd"],
23
+ bias=config_dict["bias"],
24
+ dropout=0.0, # Set to 0 for inference
25
+ use_rotary=config_dict["use_rotary"],
26
+ use_swiglu=config_dict["use_swiglu"],
27
+ use_gqa=config_dict["use_gqa"],
28
+ n_query_groups=config_dict["n_query_groups"],
29
+ use_qk_norm=config_dict["use_qk_norm"]
30
+ )
31
+
32
+ # Create model
33
+ model = CosmicFish(config)
34
+
35
+ # Load weights from safetensors (safer and faster)
36
+ state_dict = load_file(f"{model_dir}/model.safetensors")
37
+
38
+ # Handle weight sharing: lm_head.weight shares with transformer.wte.weight
39
+ if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict:
40
+ print("Weight sharing detected: tying lm_head.weight to transformer.wte.weight")
41
+ state_dict['lm_head.weight'] = state_dict['transformer.wte.weight']
42
+
43
+ model.load_state_dict(state_dict)
44
+ model.eval()
45
+
46
+ # Load tokenizer
47
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
48
+
49
+ return model, tokenizer
50
+
51
+ # Example usage:
52
+ # model, tokenizer = load_cosmicfish("./")
53
+ # input_text = "The future of AI is"
54
+ # inputs = tokenizer.encode(input_text, return_tensors="pt")
55
+ # outputs = model.generate(inputs, max_length=50, temperature=0.7, do_sample=True)
56
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+ # print(response)
modeling_cosmicfish.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class CosmicConfig:
8
+ """Configuration class for CosmicFish."""
9
+
10
+ def __init__(self,
11
+ vocab_size=50257,
12
+ block_size=2048,
13
+ n_layer=24,
14
+ n_head=16,
15
+ n_embd=960,
16
+ bias=True,
17
+ dropout=0.0, # Always 0 for inference
18
+ n_query_groups=4,
19
+ eps=1e-6,
20
+ use_rotary=True,
21
+ use_swiglu=True,
22
+ use_qk_norm=False,
23
+ use_gqa=True):
24
+ self.vocab_size = vocab_size
25
+ self.block_size = block_size
26
+ self.n_layer = n_layer
27
+ self.n_head = n_head
28
+ self.n_embd = n_embd
29
+ self.bias = bias
30
+ self.dropout = dropout
31
+ self.eps = eps
32
+ self.use_rotary = use_rotary
33
+ self.use_swiglu = use_swiglu
34
+ self.use_qk_norm = use_qk_norm
35
+ self.use_gqa = use_gqa
36
+ self.n_query_groups = n_query_groups if use_gqa else n_head
37
+ # Ensure n_head is divisible by n_query_groups
38
+ assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups"
39
+
40
+
41
+ class RMSNorm(nn.Module):
42
+ """Root Mean Square Normalization"""
43
+
44
+ def __init__(self, dim, eps=1e-6):
45
+ super().__init__()
46
+ self.eps = eps
47
+ self.weight = nn.Parameter(torch.ones(dim))
48
+
49
+ def forward(self, x):
50
+ rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
51
+ return self.weight * (x / rms)
52
+
53
+
54
+ def precompute_freqs_cis(dim, end, theta=10000.0):
55
+ """Precompute the frequency tensor for complex exponentials (cis)"""
56
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
57
+ t = torch.arange(end, device=freqs.device)
58
+ freqs = torch.outer(t, freqs)
59
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
60
+ return freqs_cis
61
+
62
+
63
+ def apply_rotary_emb(xq, xk, freqs_cis):
64
+ """Apply rotary embeddings to input tensors"""
65
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
66
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
67
+
68
+ seq_len = xq_.size(2)
69
+ if freqs_cis.size(0) < seq_len:
70
+ raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}")
71
+
72
+ freqs_cis_seq = freqs_cis[:seq_len]
73
+ xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
74
+ xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3)
75
+
76
+ return xq_out.type_as(xq), xk_out.type_as(xk)
77
+
78
+
79
+ class GroupedQueryAttention(nn.Module):
80
+ """Grouped Query Attention (GQA) implementation"""
81
+
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ assert config.n_embd % config.n_head == 0
85
+
86
+ head_dim = config.n_embd // config.n_head
87
+ self.head_dim = head_dim
88
+ self.n_head = config.n_head
89
+ self.n_embd = config.n_embd
90
+ self.n_query_groups = config.n_query_groups
91
+
92
+ self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head
93
+ qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim
94
+
95
+ self.c_attn = nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias)
96
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
97
+
98
+ # Flash attention support
99
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
100
+ if not self.flash:
101
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
102
+ .view(1, 1, config.block_size, config.block_size))
103
+
104
+ # Query-key normalization
105
+ self.qk_norm = getattr(config, 'use_qk_norm', False)
106
+ if self.qk_norm:
107
+ self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
108
+ self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6))
109
+
110
+ def forward(self, x, freqs_cis=None):
111
+ B, T, C = x.size()
112
+ qkv = self.c_attn(x)
113
+ head_dim = C // self.n_head
114
+
115
+ q_size = self.n_head * head_dim
116
+ k_size = self.kv_heads * head_dim
117
+ v_size = self.kv_heads * head_dim
118
+
119
+ q, k, v = qkv.split([q_size, k_size, v_size], dim=2)
120
+
121
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
122
+ k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
123
+ v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2)
124
+
125
+ # Repeat k and v if needed for GQA
126
+ if self.kv_heads < self.n_head:
127
+ repeats = self.n_head // self.kv_heads
128
+ k = k.repeat_interleave(repeats, dim=1)
129
+ v = v.repeat_interleave(repeats, dim=1)
130
+
131
+ # Apply rotary embeddings
132
+ if freqs_cis is not None:
133
+ q, k = apply_rotary_emb(q, k, freqs_cis)
134
+
135
+ # Apply query-key normalization
136
+ if self.qk_norm:
137
+ q = self.q_norm(q)
138
+ k = self.k_norm(k)
139
+
140
+ # Compute attention
141
+ if self.flash:
142
+ y = torch.nn.functional.scaled_dot_product_attention(
143
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
144
+ )
145
+ else:
146
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
147
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
148
+ att = F.softmax(att, dim=-1)
149
+ y = att @ v
150
+
151
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
152
+ y = self.c_proj(y)
153
+ return y
154
+
155
+
156
+ class Block(nn.Module):
157
+ """Transformer block"""
158
+
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ self.ln_1 = RMSNorm(config.n_embd, eps=config.eps)
162
+ self.ln_2 = RMSNorm(config.n_embd, eps=config.eps)
163
+ self.attn = GroupedQueryAttention(config)
164
+
165
+ # MLP implementation based on configuration
166
+ if config.use_swiglu:
167
+ # SwiGLU MLP
168
+ self.mlp = nn.ModuleDict(dict(
169
+ gate=nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
170
+ up=nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
171
+ down=nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
172
+ act=nn.SiLU(),
173
+ ))
174
+ m = self.mlp
175
+ self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x))
176
+ else:
177
+ # Traditional MLP
178
+ self.mlp = nn.ModuleDict(dict(
179
+ c_fc=nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
180
+ c_proj=nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
181
+ act=nn.GELU(),
182
+ ))
183
+ m = self.mlp
184
+ self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
185
+
186
+ def forward(self, x, freqs_cis=None):
187
+ x = x + self.attn(self.ln_1(x), freqs_cis)
188
+ x = x + self.mlpf(self.ln_2(x))
189
+ return x
190
+
191
+
192
+ class CosmicFish(nn.Module):
193
+ """
194
+ CosmicFish model for inference only.
195
+ Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm
196
+ """
197
+
198
+ def __init__(self, config):
199
+ super().__init__()
200
+ self.config = config
201
+
202
+ self.transformer = nn.ModuleDict(dict(
203
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
204
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
205
+ ln_f=RMSNorm(config.n_embd, eps=config.eps),
206
+ ))
207
+
208
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
209
+
210
+ # Share weights between embedding and output
211
+ self.transformer.wte.weight = self.lm_head.weight
212
+
213
+ # Precompute rotary embedding frequencies
214
+ if config.use_rotary:
215
+ head_dim = config.n_embd // config.n_head
216
+ self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size)
217
+ else:
218
+ self.freqs_cis = None
219
+ self.transformer.wpe = nn.Embedding(config.block_size, config.n_embd)
220
+
221
+ def get_num_params(self, non_embedding=True):
222
+ """Return the number of parameters in the model."""
223
+ n_params = sum(p.numel() for p in self.parameters())
224
+ if non_embedding and hasattr(self.transformer, 'wpe'):
225
+ n_params -= self.transformer.wpe.weight.numel()
226
+ return n_params
227
+
228
+ def forward(self, idx, targets=None):
229
+ """Forward pass through the model."""
230
+ device = idx.device
231
+ b, t = idx.size()
232
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
233
+
234
+ # Get token embeddings
235
+ tok_emb = self.transformer.wte(idx)
236
+
237
+ # Handle positional embeddings
238
+ if self.config.use_rotary:
239
+ x = tok_emb
240
+ freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None
241
+ else:
242
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
243
+ pos_emb = self.transformer.wpe(pos)
244
+ x = tok_emb + pos_emb
245
+ freqs_cis = None
246
+
247
+ # Apply transformer blocks
248
+ for block in self.transformer.h:
249
+ x = block(x, freqs_cis)
250
+
251
+ # Apply final normalization
252
+ x = self.transformer.ln_f(x)
253
+
254
+ # Calculate outputs
255
+ if targets is not None:
256
+ logits = self.lm_head(x)
257
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
258
+ else:
259
+ # For inference, only compute logits for the last token
260
+ logits = self.lm_head(x[:, [-1], :])
261
+ loss = None
262
+
263
+ return logits, loss
264
+
265
+ @torch.no_grad()
266
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
267
+ """
268
+ Generate text by sampling from the model, token by token.
269
+ """
270
+ for _ in range(max_new_tokens):
271
+ # Crop sequence to block size if needed
272
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
273
+
274
+ # Forward pass
275
+ logits, _ = self(idx_cond)
276
+ logits = logits[:, -1, :] / temperature
277
+
278
+ # Apply top-k sampling
279
+ if top_k is not None:
280
+ v, _ = torch.topk(logits, top_k)
281
+ logits[logits < v[:, [-1]]] = -float('Inf')
282
+
283
+ # Sample next token
284
+ probs = F.softmax(logits, dim=-1)
285
+ idx_next = torch.multinomial(probs, num_samples=1)
286
+
287
+ # Append to sequence
288
+ idx = torch.cat((idx, idx_next), dim=1)
289
+
290
+ return idx
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "GPT2Tokenizer",
3
+ "vocab_size": 50257,
4
+ "model_max_length": 2048,
5
+ "bos_token": "<|endoftext|>",
6
+ "eos_token": "<|endoftext|>",
7
+ "unk_token": "<|endoftext|>",
8
+ "pad_token": "<|endoftext|>",
9
+ "add_prefix_space": false,
10
+ "do_lower_case": false
11
+ }
vocab_info.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "note": "This model uses GPT-2 tokenizer. Please use: tokenizer = GPT2Tokenizer.from_pretrained('gpt2')",
3
+ "vocab_size": 50257,
4
+ "encoding": "gpt2"
5
+ }