GPUburnout commited on
Commit
36bc78f
·
1 Parent(s): 514a6e1

feat: add app code, configs, and tokenizers

Browse files
README.md CHANGED
@@ -1,13 +1,21 @@
1
  ---
2
- title: Gpuburnout Models
3
- emoji: 📈
4
- colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
- pinned: false
10
- short_description: Compare language models trained from scratch
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: GPUburnout Models
3
+ emoji: 🔥
4
+ colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.12.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: mit
11
  ---
12
 
13
+ # GPUburnout Models Interactive Demo
14
+
15
+ Compare language models trained from scratch across two seasons:
16
+
17
+ - **Tiny Shakespeare** (3.2M params) — Character-level, trained on Shakespeare
18
+ - **GPT-2 Small** (134M params) — BPE tokenizer, trained on 2.8B tokens
19
+ - **Llama 1B** (1.04B params) — Llama architecture, trained on 30B tokens for $175
20
+
21
+ Built by [Jun Park](https://gpuburnout.com/about/) | [Read the blog](https://gpuburnout.com) | [GitHub](https://github.com/GPUburnout)
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPUburnout Models — Unified Demo
3
+ Compare models trained from scratch: Tiny (3.2M) → GPT-2 (134M) → Llama (1B)
4
+ """
5
+
6
+ import gc
7
+ import json
8
+ import os
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ # Add models directory to path
16
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "models"))
17
+
18
+ # ── Model Registry ──────────────────────────────────────────────────────────
19
+
20
+ MODELS = {
21
+ "Tiny Shakespeare (3.2M)": {
22
+ "path": "checkpoints/tiny",
23
+ "arch": "s1",
24
+ "description": "Character-level model trained on Shakespeare. The very first step.",
25
+ "examples": ["ROMEO:", "JULIET:", "To be, or not to be", "First Citizen:"],
26
+ },
27
+ "GPT-2 Small (134M)": {
28
+ "path": "checkpoints/gpt2_small",
29
+ "arch": "s1",
30
+ "description": "Season 1 final model. BPE tokenizer, 2.8B tokens, 12 layers.",
31
+ "examples": [
32
+ "The capital of France is",
33
+ "Explain machine learning in simple terms.",
34
+ "def fibonacci(n):",
35
+ "The meaning of life is",
36
+ ],
37
+ },
38
+ "Llama 1B (1.04B)": {
39
+ "path": "checkpoints/llama_1b",
40
+ "arch": "s2",
41
+ "description": "Season 2. Llama architecture, 30B tokens, $175 total. Final loss 2.494.",
42
+ "examples": [
43
+ "The capital of France is",
44
+ "In a shocking discovery, scientists found that",
45
+ "def fibonacci(n):",
46
+ "Once upon a time, in a land far away,",
47
+ ],
48
+ },
49
+ }
50
+
51
+ # ── Current model state (one at a time) ─────────────────────────────────────
52
+
53
+ current = {"name": None, "model": None, "tokenizer": None, "config": None}
54
+
55
+
56
+ def unload_current():
57
+ """Free the currently loaded model from memory."""
58
+ if current["model"] is not None:
59
+ del current["model"]
60
+ current["model"] = None
61
+ current["tokenizer"] = None
62
+ current["config"] = None
63
+ current["name"] = None
64
+ gc.collect()
65
+ if torch.cuda.is_available():
66
+ torch.cuda.empty_cache()
67
+
68
+
69
+ def load_model(model_name):
70
+ """Load a model by name, unloading the previous one first."""
71
+ if current["name"] == model_name and current["model"] is not None:
72
+ return current["model"], current["tokenizer"], current["config"]
73
+
74
+ unload_current()
75
+
76
+ info = MODELS[model_name]
77
+ model_dir = info["path"]
78
+ config_path = os.path.join(model_dir, "config.json")
79
+
80
+ if not os.path.exists(config_path):
81
+ raise FileNotFoundError(f"Model not found: {model_dir}")
82
+
83
+ with open(config_path) as f:
84
+ config = json.load(f)
85
+
86
+ if info["arch"] == "s1":
87
+ model, tokenizer = _load_s1(model_dir, config)
88
+ else:
89
+ model, tokenizer = _load_s2(model_dir, config)
90
+
91
+ current["name"] = model_name
92
+ current["model"] = model
93
+ current["tokenizer"] = tokenizer
94
+ current["config"] = config
95
+ return model, tokenizer, config
96
+
97
+
98
+ def _load_s1(model_dir, config):
99
+ """Load Season 1 GPT-2 style model."""
100
+ from s1_model import TransformerLanguageModel
101
+
102
+ model = TransformerLanguageModel(
103
+ vocab_size=config["vocab_size"],
104
+ embed_dim=config["embed_dim"],
105
+ num_heads=config["num_heads"],
106
+ num_layers=config["num_layers"],
107
+ ff_dim=config["ff_dim"],
108
+ max_seq_len=config["max_seq_len"],
109
+ dropout=0.0,
110
+ )
111
+ weights_path = os.path.join(model_dir, "pytorch_model.bin")
112
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
113
+ model.eval()
114
+
115
+ # Load tokenizer
116
+ tokenizer_type = config.get("tokenizer_type", "character")
117
+ tokenizer_path = os.path.join(model_dir, "tokenizer.json")
118
+
119
+ if tokenizer_type == "bpe":
120
+ from s1_tokenizer_bpe import BPETokenizer
121
+ tokenizer = BPETokenizer()
122
+ tokenizer.load(tokenizer_path)
123
+ else:
124
+ from s1_tokenizer_char import CharacterTokenizer
125
+ tokenizer = CharacterTokenizer()
126
+ tokenizer.load(tokenizer_path)
127
+
128
+ return model, tokenizer
129
+
130
+
131
+ def _load_s2(model_dir, config):
132
+ """Load Season 2 Llama style model."""
133
+ from s2_model import LlamaModel, ModelConfig
134
+
135
+ model_config = ModelConfig(
136
+ vocab_size=config.get("vocab_size", 32005),
137
+ d_model=config.get("d_model", 2048),
138
+ n_layers=config.get("n_layers", 16),
139
+ n_heads=config.get("n_heads", 32),
140
+ n_kv_heads=config.get("n_kv_heads", 8),
141
+ d_ff=config.get("d_ff", 8192),
142
+ max_seq_len=config.get("max_seq_len", 2048),
143
+ )
144
+
145
+ model = LlamaModel(model_config).to("cpu")
146
+ weights_path = os.path.join(model_dir, "pytorch_model.bin")
147
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
148
+ model.load_state_dict(state_dict)
149
+ model.eval()
150
+
151
+ # S2 uses HuggingFace tokenizers library
152
+ from tokenizers import Tokenizer
153
+ tokenizer = Tokenizer.from_file("tokenizer/bpe_tokenizer.json")
154
+
155
+ return model, tokenizer
156
+
157
+
158
+ # ── Generation ──────────────────────────────────────────────────────────────
159
+
160
+ def generate_s1(model, tokenizer, config, prompt, max_tokens, temperature, top_k):
161
+ """Generate text with S1 (GPT-2) model."""
162
+ tokens = tokenizer.encode(prompt)
163
+ if not tokens:
164
+ return "Could not encode prompt."
165
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
166
+ max_seq_len = config.get("max_seq_len", 256)
167
+
168
+ with torch.no_grad():
169
+ for _ in range(max_tokens):
170
+ inp = tokens[:, -max_seq_len:] if tokens.size(1) > max_seq_len else tokens
171
+ logits = model(inp)[:, -1, :] / temperature
172
+ if top_k > 0:
173
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
174
+ logits[logits < v[:, [-1]]] = float("-inf")
175
+ probs = F.softmax(logits, dim=-1)
176
+ next_token = torch.multinomial(probs, num_samples=1)
177
+ tokens = torch.cat([tokens, next_token], dim=1)
178
+
179
+ return tokenizer.decode(tokens[0].tolist())
180
+
181
+
182
+ def generate_s2(model, tokenizer, prompt, max_tokens, temperature, top_k):
183
+ """Generate text with S2 (Llama) model."""
184
+ encoded = tokenizer.encode(prompt)
185
+ input_ids = torch.tensor([encoded.ids], dtype=torch.long)
186
+
187
+ with torch.no_grad():
188
+ output_ids = model.generate(
189
+ input_ids,
190
+ max_new_tokens=max_tokens,
191
+ temperature=temperature,
192
+ top_k=top_k if top_k > 0 else None,
193
+ )
194
+
195
+ return tokenizer.decode(output_ids[0].tolist())
196
+
197
+
198
+ def generate_text(model_name, prompt, max_tokens, temperature, top_k):
199
+ """Main generation entry point."""
200
+ if not prompt.strip():
201
+ return "Please enter a prompt."
202
+
203
+ try:
204
+ model, tokenizer, config = load_model(model_name)
205
+ except FileNotFoundError as e:
206
+ return f"Error: {e}"
207
+
208
+ info = MODELS[model_name]
209
+ if info["arch"] == "s1":
210
+ return generate_s1(model, tokenizer, config, prompt, int(max_tokens), temperature, int(top_k))
211
+ else:
212
+ return generate_s2(model, tokenizer, prompt, int(max_tokens), temperature, int(top_k))
213
+
214
+
215
+ def get_status(model_name):
216
+ """Return status string for the selected model."""
217
+ info = MODELS[model_name]
218
+ loaded = "Loaded" if current["name"] == model_name else "Not loaded (will load on generate)"
219
+ return f"**{model_name}** — {info['description']}\n\nStatus: {loaded}"
220
+
221
+
222
+ def update_examples(model_name):
223
+ """Return example prompts for the selected model."""
224
+ return gr.update(samples=[[ex] for ex in MODELS[model_name]["examples"]])
225
+
226
+
227
+ # ── Custom CSS ──────────────────────────────────────────────────────────────
228
+
229
+ CUSTOM_CSS = """
230
+ .gradio-container {
231
+ max-width: 900px !important;
232
+ margin: auto;
233
+ }
234
+ .header-text {
235
+ text-align: center;
236
+ margin-bottom: 0.5em;
237
+ }
238
+ .header-text h1 {
239
+ color: #22d3ee;
240
+ font-family: 'Courier New', monospace;
241
+ }
242
+ .header-text a {
243
+ color: #f59e0b;
244
+ }
245
+ .model-info {
246
+ font-family: 'Courier New', monospace;
247
+ font-size: 0.85em;
248
+ padding: 10px;
249
+ border-radius: 8px;
250
+ background: rgba(34, 211, 238, 0.05);
251
+ border: 1px solid rgba(34, 211, 238, 0.15);
252
+ }
253
+ """
254
+
255
+ # ── Gradio UI ───────────────────────────────────────────────────────────────
256
+
257
+ with gr.Blocks(
258
+ title="GPUburnout Models",
259
+ theme=gr.themes.Base(
260
+ primary_hue="cyan",
261
+ neutral_hue="gray",
262
+ font=gr.themes.GoogleFont("JetBrains Mono"),
263
+ ),
264
+ css=CUSTOM_CSS,
265
+ ) as demo:
266
+
267
+ gr.HTML("""
268
+ <div class="header-text">
269
+ <h1>GPUburnout Models</h1>
270
+ <p>Compare language models I trained from scratch — from 3.2M to 1 billion parameters.</p>
271
+ <p>
272
+ <a href="https://gpuburnout.com" target="_blank">Read the blog</a> ·
273
+ <a href="https://github.com/GPUburnout" target="_blank">GitHub</a> ·
274
+ <a href="https://gpuburnout.com/about/" target="_blank">About</a>
275
+ </p>
276
+ </div>
277
+ """)
278
+
279
+ with gr.Row():
280
+ with gr.Column(scale=1):
281
+ model_selector = gr.Dropdown(
282
+ choices=list(MODELS.keys()),
283
+ value="Llama 1B (1.04B)",
284
+ label="Select Model",
285
+ )
286
+
287
+ model_status = gr.Markdown(elem_classes=["model-info"])
288
+
289
+ prompt = gr.Textbox(
290
+ label="Prompt",
291
+ placeholder="Type something...",
292
+ lines=2,
293
+ value="The capital of France is",
294
+ )
295
+
296
+ with gr.Row():
297
+ max_tokens = gr.Slider(50, 300, value=100, step=25, label="Max tokens")
298
+ temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature")
299
+
300
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-K")
301
+
302
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
303
+
304
+ with gr.Column(scale=1):
305
+ output = gr.Textbox(label="Output", lines=15, show_copy_button=True)
306
+
307
+ examples = gr.Examples(
308
+ examples=[["The capital of France is"], ["def fibonacci(n):"]],
309
+ inputs=prompt,
310
+ label="Example prompts",
311
+ )
312
+
313
+ # Events
314
+ demo.load(get_status, inputs=model_selector, outputs=model_status)
315
+ model_selector.change(get_status, inputs=model_selector, outputs=model_status)
316
+ model_selector.change(update_examples, inputs=model_selector, outputs=examples.dataset)
317
+
318
+ generate_btn.click(
319
+ generate_text,
320
+ inputs=[model_selector, prompt, max_tokens, temperature, top_k],
321
+ outputs=output,
322
+ )
323
+ prompt.submit(
324
+ generate_text,
325
+ inputs=[model_selector, prompt, max_tokens, temperature, top_k],
326
+ outputs=output,
327
+ )
328
+
329
+ if __name__ == "__main__":
330
+ demo.launch()
checkpoints/gpt2_small/NOTE.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ This model is from checkpoint_epoch_7 (not the final model).
2
+
3
+ Training was still in progress - this represents ~70% through training.
4
+ Final model would be checkpoint_epoch_10.
checkpoints/gpt2_small/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "embed_dim": 768,
4
+ "num_heads": 12,
5
+ "num_layers": 12,
6
+ "ff_dim": 3072,
7
+ "max_seq_len": 512,
8
+ "dropout": 0.1,
9
+ "model_type": "TransformerLanguageModel",
10
+ "architecture": "gpt2_small",
11
+ "total_parameters": 134601216,
12
+ "tokenizer_type": "bpe"
13
+ }
checkpoints/gpt2_small/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/llama_1b/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_size": "1B",
3
+ "vocab_size": 32005,
4
+ "d_model": 2048,
5
+ "n_layers": 16,
6
+ "n_heads": 32,
7
+ "n_kv_heads": 8,
8
+ "d_ff": 8192,
9
+ "max_seq_len": 2048,
10
+ "total_parameters": 1040000000,
11
+ "tokenizer_type": "bpe"
12
+ }
checkpoints/llama_1b/metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "step": 90000,
3
+ "loss": 2.494209110736847,
4
+ "tokens_processed": 11796480000,
5
+ "best_val_loss": 2.539955945014954,
6
+ "phase_complete": true,
7
+ "source_file": "milestone_step_00090000.pt",
8
+ "export_date": "2026-03-03 02:24:18",
9
+ "model_weights_file": "pytorch_model.bin",
10
+ "model_weights_gb": 4.15,
11
+ "optimizer_state_file": "optimizer_state.bin",
12
+ "optimizer_state_gb": 8.31,
13
+ "original_checkpoint_gb": 12.46
14
+ }
checkpoints/tiny/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 65,
3
+ "embed_dim": 256,
4
+ "num_heads": 4,
5
+ "num_layers": 4,
6
+ "ff_dim": 1024,
7
+ "max_seq_len": 256,
8
+ "dropout": 0.1,
9
+ "total_parameters": 3258368,
10
+ "tokenizer_type": "character",
11
+ "model_name": "tiny_shakespeare",
12
+ "description": "Phase 1 model trained on Shakespeare text (character-level)"
13
+ }
checkpoints/tiny/tokenizer.json ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "character",
3
+ "vocab_size": 65,
4
+ "char_to_idx": {
5
+ "\n": 0,
6
+ " ": 1,
7
+ "!": 2,
8
+ "$": 3,
9
+ "&": 4,
10
+ "'": 5,
11
+ ",": 6,
12
+ "-": 7,
13
+ ".": 8,
14
+ "3": 9,
15
+ ":": 10,
16
+ ";": 11,
17
+ "?": 12,
18
+ "A": 13,
19
+ "B": 14,
20
+ "C": 15,
21
+ "D": 16,
22
+ "E": 17,
23
+ "F": 18,
24
+ "G": 19,
25
+ "H": 20,
26
+ "I": 21,
27
+ "J": 22,
28
+ "K": 23,
29
+ "L": 24,
30
+ "M": 25,
31
+ "N": 26,
32
+ "O": 27,
33
+ "P": 28,
34
+ "Q": 29,
35
+ "R": 30,
36
+ "S": 31,
37
+ "T": 32,
38
+ "U": 33,
39
+ "V": 34,
40
+ "W": 35,
41
+ "X": 36,
42
+ "Y": 37,
43
+ "Z": 38,
44
+ "a": 39,
45
+ "b": 40,
46
+ "c": 41,
47
+ "d": 42,
48
+ "e": 43,
49
+ "f": 44,
50
+ "g": 45,
51
+ "h": 46,
52
+ "i": 47,
53
+ "j": 48,
54
+ "k": 49,
55
+ "l": 50,
56
+ "m": 51,
57
+ "n": 52,
58
+ "o": 53,
59
+ "p": 54,
60
+ "q": 55,
61
+ "r": 56,
62
+ "s": 57,
63
+ "t": 58,
64
+ "u": 59,
65
+ "v": 60,
66
+ "w": 61,
67
+ "x": 62,
68
+ "y": 63,
69
+ "z": 64
70
+ },
71
+ "idx_to_char": {
72
+ "0": "\n",
73
+ "1": " ",
74
+ "2": "!",
75
+ "3": "$",
76
+ "4": "&",
77
+ "5": "'",
78
+ "6": ",",
79
+ "7": "-",
80
+ "8": ".",
81
+ "9": "3",
82
+ "10": ":",
83
+ "11": ";",
84
+ "12": "?",
85
+ "13": "A",
86
+ "14": "B",
87
+ "15": "C",
88
+ "16": "D",
89
+ "17": "E",
90
+ "18": "F",
91
+ "19": "G",
92
+ "20": "H",
93
+ "21": "I",
94
+ "22": "J",
95
+ "23": "K",
96
+ "24": "L",
97
+ "25": "M",
98
+ "26": "N",
99
+ "27": "O",
100
+ "28": "P",
101
+ "29": "Q",
102
+ "30": "R",
103
+ "31": "S",
104
+ "32": "T",
105
+ "33": "U",
106
+ "34": "V",
107
+ "35": "W",
108
+ "36": "X",
109
+ "37": "Y",
110
+ "38": "Z",
111
+ "39": "a",
112
+ "40": "b",
113
+ "41": "c",
114
+ "42": "d",
115
+ "43": "e",
116
+ "44": "f",
117
+ "45": "g",
118
+ "46": "h",
119
+ "47": "i",
120
+ "48": "j",
121
+ "49": "k",
122
+ "50": "l",
123
+ "51": "m",
124
+ "52": "n",
125
+ "53": "o",
126
+ "54": "p",
127
+ "55": "q",
128
+ "56": "r",
129
+ "57": "s",
130
+ "58": "t",
131
+ "59": "u",
132
+ "60": "v",
133
+ "61": "w",
134
+ "62": "x",
135
+ "63": "y",
136
+ "64": "z"
137
+ }
138
+ }
models/__init__.py ADDED
File without changes
models/s1_model.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer Language Model Architecture
3
+ Modern architecture (GPT-style) scalable from tiny to large
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import json
10
+ import os
11
+ import math
12
+
13
+
14
+ class MultiHeadAttention(nn.Module):
15
+ """Multi-head self-attention mechanism with Flash Attention support"""
16
+
17
+ def __init__(self, embed_dim, num_heads, dropout=0.1):
18
+ super().__init__()
19
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
20
+
21
+ self.embed_dim = embed_dim
22
+ self.num_heads = num_heads
23
+ self.head_dim = embed_dim // num_heads
24
+ self.dropout_p = dropout
25
+
26
+ # Q, K, V projections
27
+ self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
28
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
29
+
30
+ # Check if Flash Attention is available (PyTorch 2.0+)
31
+ self.use_flash = hasattr(F, 'scaled_dot_product_attention')
32
+
33
+ # Fallback dropout for non-flash path
34
+ self.dropout = nn.Dropout(dropout)
35
+
36
+ def forward(self, x, mask=None):
37
+ batch_size, seq_len, embed_dim = x.shape
38
+
39
+ # Compute Q, K, V
40
+ qkv = self.qkv(x) # (batch, seq, 3*embed_dim)
41
+ qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
42
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq, head_dim)
43
+ q, k, v = qkv[0], qkv[1], qkv[2]
44
+
45
+ if self.use_flash:
46
+ # Use PyTorch's scaled_dot_product_attention (Flash Attention when available)
47
+ # This is 1.5-2x faster and more memory efficient
48
+ dropout_p = self.dropout_p if self.training else 0.0
49
+ out = F.scaled_dot_product_attention(
50
+ q, k, v,
51
+ attn_mask=None, # We use is_causal instead
52
+ dropout_p=dropout_p,
53
+ is_causal=True # Causal mask for autoregressive generation
54
+ )
55
+ else:
56
+ # Fallback to manual attention for older PyTorch versions
57
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
58
+
59
+ # Apply causal mask (for autoregressive generation)
60
+ if mask is not None:
61
+ scores = scores.masked_fill(mask == 0, float('-inf'))
62
+
63
+ # Attention weights
64
+ attn = F.softmax(scores, dim=-1)
65
+ attn = self.dropout(attn)
66
+
67
+ # Apply attention to values
68
+ out = torch.matmul(attn, v)
69
+
70
+ # Reshape: (batch, heads, seq, head_dim) -> (batch, seq, embed_dim)
71
+ out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim)
72
+
73
+ # Output projection
74
+ out = self.out_proj(out)
75
+ return out
76
+
77
+
78
+ class FeedForward(nn.Module):
79
+ """Position-wise feed-forward network"""
80
+
81
+ def __init__(self, embed_dim, ff_dim, dropout=0.1):
82
+ super().__init__()
83
+ self.fc1 = nn.Linear(embed_dim, ff_dim)
84
+ self.fc2 = nn.Linear(ff_dim, embed_dim)
85
+ self.dropout = nn.Dropout(dropout)
86
+
87
+ def forward(self, x):
88
+ x = F.gelu(self.fc1(x))
89
+ x = self.dropout(x)
90
+ x = self.fc2(x)
91
+ return x
92
+
93
+
94
+ class TransformerBlock(nn.Module):
95
+ """Single Transformer block (attention + feed-forward)"""
96
+
97
+ def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
98
+ super().__init__()
99
+
100
+ self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
101
+ self.feed_forward = FeedForward(embed_dim, ff_dim, dropout)
102
+
103
+ self.norm1 = nn.LayerNorm(embed_dim)
104
+ self.norm2 = nn.LayerNorm(embed_dim)
105
+
106
+ self.dropout = nn.Dropout(dropout)
107
+
108
+ def forward(self, x, mask=None):
109
+ # Self-attention with residual connection
110
+ attn_out = self.attention(self.norm1(x), mask)
111
+ x = x + self.dropout(attn_out)
112
+
113
+ # Feed-forward with residual connection
114
+ ff_out = self.feed_forward(self.norm2(x))
115
+ x = x + self.dropout(ff_out)
116
+
117
+ return x
118
+
119
+
120
+ class TransformerLanguageModel(nn.Module):
121
+ """
122
+ GPT-style Transformer Language Model
123
+ Scalable from tiny (CPU) to large (GPU cluster)
124
+ """
125
+
126
+ def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=4,
127
+ ff_dim=None, max_seq_len=256, dropout=0.1):
128
+ """
129
+ Initialize Transformer model
130
+
131
+ Args:
132
+ vocab_size: Number of tokens in vocabulary
133
+ embed_dim: Embedding dimension (must be divisible by num_heads)
134
+ num_heads: Number of attention heads
135
+ num_layers: Number of Transformer blocks
136
+ ff_dim: Feed-forward dimension (default: 4 * embed_dim)
137
+ max_seq_len: Maximum sequence length
138
+ dropout: Dropout probability
139
+ """
140
+ super().__init__()
141
+
142
+ if ff_dim is None:
143
+ ff_dim = 4 * embed_dim
144
+
145
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
146
+
147
+ self.vocab_size = vocab_size
148
+ self.embed_dim = embed_dim
149
+ self.num_heads = num_heads
150
+ self.num_layers = num_layers
151
+ self.ff_dim = ff_dim
152
+ self.max_seq_len = max_seq_len
153
+ self.dropout = dropout
154
+
155
+ # Token embeddings
156
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
157
+
158
+ # Positional embeddings (learned)
159
+ self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)
160
+
161
+ # Transformer blocks
162
+ self.blocks = nn.ModuleList([
163
+ TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
164
+ for _ in range(num_layers)
165
+ ])
166
+
167
+ # Final layer norm
168
+ self.ln_f = nn.LayerNorm(embed_dim)
169
+
170
+ # Output projection
171
+ self.head = nn.Linear(embed_dim, vocab_size, bias=False)
172
+
173
+ # Dropout
174
+ self.dropout_layer = nn.Dropout(dropout)
175
+
176
+ # Initialize weights
177
+ self._init_weights()
178
+
179
+ # Create causal mask
180
+ self.register_buffer("causal_mask", self._create_causal_mask(max_seq_len))
181
+
182
+ def _init_weights(self):
183
+ """Initialize weights"""
184
+ for module in self.modules():
185
+ if isinstance(module, nn.Linear):
186
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
187
+ if module.bias is not None:
188
+ torch.nn.init.zeros_(module.bias)
189
+ elif isinstance(module, nn.Embedding):
190
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
191
+
192
+ def _create_causal_mask(self, seq_len):
193
+ """Create causal mask for autoregressive generation"""
194
+ mask = torch.tril(torch.ones(seq_len, seq_len))
195
+ mask = mask.view(1, 1, seq_len, seq_len)
196
+ return mask
197
+
198
+ def forward(self, x):
199
+ """
200
+ Forward pass
201
+
202
+ Args:
203
+ x: Input tensor of shape (batch_size, seq_len)
204
+
205
+ Returns:
206
+ logits: Output logits of shape (batch_size, seq_len, vocab_size)
207
+ """
208
+ batch_size, seq_len = x.shape
209
+ device = x.device
210
+
211
+ # Token embeddings
212
+ token_emb = self.token_embedding(x) # (batch, seq_len, embed_dim)
213
+
214
+ # Positional embeddings
215
+ positions = torch.arange(seq_len, device=device).unsqueeze(0)
216
+ pos_emb = self.positional_embedding(positions) # (1, seq_len, embed_dim)
217
+
218
+ # Combine embeddings
219
+ x = self.dropout_layer(token_emb + pos_emb)
220
+
221
+ # Get causal mask for this sequence length
222
+ mask = self.causal_mask[:, :, :seq_len, :seq_len]
223
+
224
+ # Apply Transformer blocks
225
+ for block in self.blocks:
226
+ x = block(x, mask)
227
+
228
+ # Final layer norm
229
+ x = self.ln_f(x)
230
+
231
+ # Output logits
232
+ logits = self.head(x) # (batch, seq_len, vocab_size)
233
+
234
+ return logits
235
+
236
+ def count_parameters(self):
237
+ """Count trainable parameters"""
238
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
239
+
240
+ def get_config(self):
241
+ """Get model configuration"""
242
+ return {
243
+ 'model_type': 'Transformer',
244
+ 'architecture': 'GPT-style (decoder-only)',
245
+ 'vocab_size': self.vocab_size,
246
+ 'embed_dim': self.embed_dim,
247
+ 'num_heads': self.num_heads,
248
+ 'num_layers': self.num_layers,
249
+ 'ff_dim': self.ff_dim,
250
+ 'max_seq_len': self.max_seq_len,
251
+ 'dropout': self.dropout,
252
+ 'total_parameters': self.count_parameters()
253
+ }
254
+
255
+ def save_config(self, filepath='models/model_config.json'):
256
+ """Save model configuration"""
257
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
258
+
259
+ config = self.get_config()
260
+ with open(filepath, 'w') as f:
261
+ json.dump(config, f, indent=2)
262
+
263
+ print(f"Model config saved to: {filepath}")
264
+ return filepath
265
+
266
+
267
+ def create_tiny_transformer(vocab_size):
268
+ """Create a tiny Transformer (fastest on CPU)"""
269
+ return TransformerLanguageModel(
270
+ vocab_size=vocab_size,
271
+ embed_dim=128,
272
+ num_heads=4,
273
+ num_layers=2,
274
+ max_seq_len=128,
275
+ dropout=0.1
276
+ )
277
+
278
+
279
+ def create_small_transformer(vocab_size):
280
+ """Create a small Transformer (recommended for first run)"""
281
+ return TransformerLanguageModel(
282
+ vocab_size=vocab_size,
283
+ embed_dim=256,
284
+ num_heads=4,
285
+ num_layers=4,
286
+ max_seq_len=256,
287
+ dropout=0.1
288
+ )
289
+
290
+
291
+ def create_medium_transformer(vocab_size):
292
+ """Create a medium Transformer (GPU recommended)"""
293
+ return TransformerLanguageModel(
294
+ vocab_size=vocab_size,
295
+ embed_dim=512,
296
+ num_heads=8,
297
+ num_layers=6,
298
+ max_seq_len=512,
299
+ dropout=0.1
300
+ )
301
+
302
+
303
+ def create_large_transformer(vocab_size):
304
+ """Create a large Transformer (GPU cluster)"""
305
+ return TransformerLanguageModel(
306
+ vocab_size=vocab_size,
307
+ embed_dim=1024,
308
+ num_heads=16,
309
+ num_layers=12,
310
+ max_seq_len=1024,
311
+ dropout=0.1
312
+ )
313
+
314
+
315
+ def main():
316
+ """Test model creation"""
317
+ print("\n" + "="*80)
318
+ print("TRANSFORMER MODEL ARCHITECTURE")
319
+ print("="*80)
320
+
321
+ # Load tokenizer to get vocab size
322
+ tokenizer_path = 'models/tokenizer.json'
323
+ if not os.path.exists(tokenizer_path):
324
+ print(f"\nError: Tokenizer not found at {tokenizer_path}")
325
+ print("Please run tokenizer.py first.")
326
+ return
327
+
328
+ with open(tokenizer_path, 'r') as f:
329
+ tokenizer_data = json.load(f)
330
+ vocab_size = tokenizer_data['vocab_size']
331
+
332
+ print(f"\nVocabulary size: {vocab_size}")
333
+ print("Architecture: GPT-style Transformer (decoder-only)")
334
+
335
+ # Create models of different sizes
336
+ print("\n" + "-"*80)
337
+ print("TINY TRANSFORMER (fastest on CPU)")
338
+ print("-"*80)
339
+ tiny_model = create_tiny_transformer(vocab_size)
340
+ print(f"Parameters: {tiny_model.count_parameters():,}")
341
+ print(f"Embed dim: {tiny_model.embed_dim}")
342
+ print(f"Attention heads: {tiny_model.num_heads}")
343
+ print(f"Layers: {tiny_model.num_layers}")
344
+ print(f"Context length: {tiny_model.max_seq_len}")
345
+
346
+ print("\n" + "-"*80)
347
+ print("SMALL TRANSFORMER (recommended for first run)")
348
+ print("-"*80)
349
+ small_model = create_small_transformer(vocab_size)
350
+ print(f"Parameters: {small_model.count_parameters():,}")
351
+ print(f"Embed dim: {small_model.embed_dim}")
352
+ print(f"Attention heads: {small_model.num_heads}")
353
+ print(f"Layers: {small_model.num_layers}")
354
+ print(f"Context length: {small_model.max_seq_len}")
355
+
356
+ print("\n" + "-"*80)
357
+ print("MEDIUM TRANSFORMER (GPU recommended)")
358
+ print("-"*80)
359
+ medium_model = create_medium_transformer(vocab_size)
360
+ print(f"Parameters: {medium_model.count_parameters():,}")
361
+ print(f"Embed dim: {medium_model.embed_dim}")
362
+ print(f"Attention heads: {medium_model.num_heads}")
363
+ print(f"Layers: {medium_model.num_layers}")
364
+ print(f"Context length: {medium_model.max_seq_len}")
365
+
366
+ # Use small model for our tiny LM
367
+ print("\n" + "="*80)
368
+ print("SELECTED MODEL: SMALL TRANSFORMER")
369
+ print("="*80)
370
+ print("Good balance for CPU training with modern architecture")
371
+ model = small_model
372
+
373
+ # Test forward pass
374
+ print("\nTesting forward pass...")
375
+ batch_size = 4
376
+ seq_len = 32
377
+ dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len))
378
+
379
+ with torch.no_grad():
380
+ logits = model(dummy_input)
381
+
382
+ print(f"Input shape: {dummy_input.shape}")
383
+ print(f"Output shape: {logits.shape}")
384
+ print(f"Expected: (batch={batch_size}, seq_len={seq_len}, vocab={vocab_size})")
385
+ assert logits.shape == (batch_size, seq_len, vocab_size), "Shape mismatch!"
386
+ print("Forward pass test passed!")
387
+
388
+ # Save configuration
389
+ model.save_config()
390
+
391
+ print("\n" + "="*80)
392
+ print("MODEL CREATION COMPLETE")
393
+ print("="*80)
394
+ print(f"\nModel ready for training!")
395
+ print(f"Architecture: {model.get_config()['model_type']}")
396
+ print(f"Total parameters: {model.count_parameters():,}")
397
+ print(f"Configuration saved to: models/model_config.json")
398
+ print(f"\nNext step: Implement the training loop")
399
+ print("="*80 + "\n")
400
+
401
+
402
+ if __name__ == "__main__":
403
+ main()
models/s1_tokenizer_bpe.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE Tokenizer Wrapper
3
+ =====================
4
+ Wraps HuggingFace `tokenizers` library to provide the same interface
5
+ as CharacterTokenizer. Uses byte-level BPE (GPT-2 style).
6
+
7
+ Requires: pip install tokenizers
8
+ """
9
+
10
+ import json
11
+ import os
12
+
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import BPE
15
+ from tokenizers.trainers import BpeTrainer
16
+ from tokenizers.pre_tokenizers import ByteLevel
17
+ from tokenizers.decoders import ByteLevel as ByteLevelDecoder
18
+
19
+
20
+ class BPETokenizer:
21
+ """Byte-level BPE tokenizer compatible with CharacterTokenizer interface."""
22
+
23
+ def __init__(self):
24
+ self.tokenizer = None
25
+ self._vocab_size = 0
26
+
27
+ def build_vocab_from_file(self, filepath, vocab_size=32000,
28
+ min_frequency=2, chunk_size=None):
29
+ """Train BPE tokenizer on a text file.
30
+
31
+ Args:
32
+ filepath: Path to text file
33
+ vocab_size: Target vocabulary size (default: 32000)
34
+ min_frequency: Minimum token frequency (default: 2)
35
+ chunk_size: Unused, kept for interface compatibility
36
+ """
37
+ tokenizer = Tokenizer(BPE(unk_token="<|unk|>"))
38
+ tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
39
+ tokenizer.decoder = ByteLevelDecoder()
40
+
41
+ trainer = BpeTrainer(
42
+ vocab_size=vocab_size,
43
+ min_frequency=min_frequency,
44
+ special_tokens=["<|endoftext|>", "<|pad|>", "<|unk|>"],
45
+ show_progress=True
46
+ )
47
+
48
+ file_size = os.path.getsize(filepath)
49
+ print(f"\nTraining BPE tokenizer on: {filepath}")
50
+ print(f"File size: {file_size / (1024**3):.2f} GB")
51
+ print(f"Target vocab size: {vocab_size:,}")
52
+ print(f"Min frequency: {min_frequency}")
53
+
54
+ tokenizer.train(files=[filepath], trainer=trainer)
55
+
56
+ self.tokenizer = tokenizer
57
+ self._vocab_size = tokenizer.get_vocab_size()
58
+
59
+ print(f"\nBPE vocabulary built: {self._vocab_size:,} tokens")
60
+ # Show some sample tokens
61
+ vocab = tokenizer.get_vocab()
62
+ sample = sorted(vocab.items(), key=lambda x: x[1])[:20]
63
+ sample_str = ', '.join(f"'{k}'" for k, v in sample)
64
+ print(f"Sample tokens: {sample_str}")
65
+
66
+ return self._vocab_size
67
+
68
+ def encode(self, text):
69
+ """Encode text to list of token IDs.
70
+
71
+ Args:
72
+ text: Input string
73
+
74
+ Returns:
75
+ List of integer token IDs
76
+ """
77
+ if self.tokenizer is None:
78
+ raise ValueError("Tokenizer not initialized. "
79
+ "Call build_vocab_from_file() or load() first.")
80
+ return self.tokenizer.encode(text).ids
81
+
82
+ def decode(self, tokens):
83
+ """Decode token IDs back to text.
84
+
85
+ Args:
86
+ tokens: List of integer token IDs
87
+
88
+ Returns:
89
+ Decoded string
90
+ """
91
+ if self.tokenizer is None:
92
+ raise ValueError("Tokenizer not initialized.")
93
+ return self.tokenizer.decode(tokens)
94
+
95
+ @property
96
+ def vocab_size(self):
97
+ """Number of tokens in vocabulary."""
98
+ return self._vocab_size
99
+
100
+ def save(self, filepath):
101
+ """Save tokenizer to a JSON file.
102
+
103
+ Args:
104
+ filepath: Path to save tokenizer (e.g. 'bpe_tokenizer.json')
105
+ """
106
+ if self.tokenizer is None:
107
+ raise ValueError("Tokenizer not initialized.")
108
+
109
+ self.tokenizer.save(filepath)
110
+ print(f"\nBPE tokenizer saved to: {filepath}")
111
+
112
+ def load(self, filepath):
113
+ """Load tokenizer from a JSON file.
114
+
115
+ Args:
116
+ filepath: Path to tokenizer JSON file
117
+ """
118
+ if not os.path.exists(filepath):
119
+ raise FileNotFoundError(f"Tokenizer file not found: {filepath}")
120
+
121
+ self.tokenizer = Tokenizer.from_file(filepath)
122
+ self._vocab_size = self.tokenizer.get_vocab_size()
123
+
124
+ print(f"BPE tokenizer loaded: {self._vocab_size:,} tokens")
125
+ return self
models/s1_tokenizer_char.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tokenizer for Language Model
3
+ Converts text to numbers (tokens) and back
4
+ """
5
+
6
+ import json
7
+ import os
8
+
9
+
10
+ class CharacterTokenizer:
11
+ """Simple character-level tokenizer for tiny language models"""
12
+
13
+ def __init__(self):
14
+ """Initialize tokenizer"""
15
+ self.char_to_idx = {}
16
+ self.idx_to_char = {}
17
+ self.vocab_size = 0
18
+
19
+ def build_vocab(self, text):
20
+ """Build vocabulary from text"""
21
+ print("\nBuilding character vocabulary...")
22
+
23
+ # Get unique characters and sort them
24
+ chars = sorted(set(text))
25
+ self.vocab_size = len(chars)
26
+
27
+ # Create mappings
28
+ self.char_to_idx = {ch: i for i, ch in enumerate(chars)}
29
+ self.idx_to_char = {i: ch for i, ch in enumerate(chars)}
30
+
31
+ print(f"Vocabulary size: {self.vocab_size} characters")
32
+ print(f"Characters: {''.join(chars[:50])}" + ("..." if len(chars) > 50 else ""))
33
+
34
+ return self.vocab_size
35
+
36
+ def build_vocab_from_file(self, filepath, chunk_size=100*1024*1024):
37
+ """Build vocabulary from a large file using streaming (memory-efficient)
38
+
39
+ Args:
40
+ filepath: Path to text file
41
+ chunk_size: Size of chunks to read (default: 100MB)
42
+ """
43
+ print(f"\nBuilding character vocabulary from file: {filepath}")
44
+ print(f"Chunk size: {chunk_size / (1024*1024):.0f}MB")
45
+
46
+ # Get file size
47
+ file_size = os.path.getsize(filepath)
48
+ file_size_gb = file_size / (1024**3)
49
+ print(f"File size: {file_size_gb:.2f} GB")
50
+
51
+ # Collect unique characters by reading file in chunks
52
+ unique_chars = set()
53
+ total_read = 0
54
+
55
+ with open(filepath, 'r', encoding='utf-8') as f:
56
+ while True:
57
+ chunk = f.read(chunk_size)
58
+ if not chunk:
59
+ break
60
+
61
+ # Add unique characters from this chunk
62
+ unique_chars.update(chunk)
63
+ total_read += len(chunk)
64
+
65
+ # Progress update (calculate based on character count)
66
+ progress_pct = (total_read / (file_size / 1.5)) * 100 # Approximate chars from bytes
67
+ if progress_pct <= 100:
68
+ print(f" Progress: {progress_pct:.1f}% | Unique chars found: {len(unique_chars)}", end='\r')
69
+
70
+ print() # New line after progress
71
+
72
+ # Sort characters and build mappings
73
+ chars = sorted(unique_chars)
74
+ self.vocab_size = len(chars)
75
+
76
+ # Create mappings
77
+ self.char_to_idx = {ch: i for i, ch in enumerate(chars)}
78
+ self.idx_to_char = {i: ch for i, ch in enumerate(chars)}
79
+
80
+ print(f"\nVocabulary size: {self.vocab_size} characters")
81
+ print(f"Sample characters: {''.join(chars[:50])}" + ("..." if len(chars) > 50 else ""))
82
+
83
+ return self.vocab_size
84
+
85
+ def encode(self, text):
86
+ """Convert text to list of token IDs"""
87
+ return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx]
88
+
89
+ def decode(self, tokens):
90
+ """Convert list of token IDs back to text"""
91
+ return ''.join([self.idx_to_char[idx] for idx in tokens if idx in self.idx_to_char])
92
+
93
+ def save(self, filepath='models/tokenizer.json'):
94
+ """Save tokenizer to JSON file"""
95
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
96
+
97
+ tokenizer_data = {
98
+ 'type': 'character',
99
+ 'vocab_size': self.vocab_size,
100
+ 'char_to_idx': self.char_to_idx,
101
+ 'idx_to_char': {str(k): v for k, v in self.idx_to_char.items()}
102
+ }
103
+
104
+ with open(filepath, 'w', encoding='utf-8') as f:
105
+ json.dump(tokenizer_data, f, indent=2, ensure_ascii=False)
106
+
107
+ print(f"\nTokenizer saved to: {filepath}")
108
+ return filepath
109
+
110
+ def load(self, filepath='models/tokenizer.json'):
111
+ """Load tokenizer from JSON file"""
112
+ with open(filepath, 'r', encoding='utf-8') as f:
113
+ tokenizer_data = json.load(f)
114
+
115
+ self.vocab_size = tokenizer_data['vocab_size']
116
+ self.char_to_idx = tokenizer_data['char_to_idx']
117
+ self.idx_to_char = {int(k): v for k, v in tokenizer_data['idx_to_char'].items()}
118
+
119
+ print(f"\nTokenizer loaded from: {filepath}")
120
+ print(f"Vocabulary size: {self.vocab_size}")
121
+ return self
122
+
123
+ def get_stats(self):
124
+ """Print tokenizer statistics"""
125
+ print("\n" + "="*80)
126
+ print("TOKENIZER STATISTICS")
127
+ print("="*80)
128
+ print(f"Type: Character-level")
129
+ print(f"Vocabulary size: {self.vocab_size}")
130
+ print(f"Sample characters: {list(self.char_to_idx.keys())[:20]}")
131
+ print("="*80)
132
+
133
+
134
+ def main():
135
+ """Main function to build and test tokenizer"""
136
+ print("\n" + "="*80)
137
+ print("TOKENIZER BUILDER")
138
+ print("="*80)
139
+
140
+ # Load dataset
141
+ dataset_file = 'data/tiny_shakespeare.txt'
142
+ if not os.path.exists(dataset_file):
143
+ print(f"\nError: Dataset not found at {dataset_file}")
144
+ print("Please run dataset_loader.py first.")
145
+ return
146
+
147
+ print(f"\nLoading text from: {dataset_file}")
148
+ with open(dataset_file, 'r', encoding='utf-8') as f:
149
+ text = f.read()
150
+
151
+ print(f"Loaded {len(text):,} characters")
152
+
153
+ # Build tokenizer
154
+ tokenizer = CharacterTokenizer()
155
+ tokenizer.build_vocab(text)
156
+
157
+ # Test tokenizer
158
+ print("\n" + "="*80)
159
+ print("TESTING TOKENIZER")
160
+ print("="*80)
161
+
162
+ test_text = "Hello, World!"
163
+ print(f"\nOriginal text: {test_text}")
164
+
165
+ encoded = tokenizer.encode(test_text)
166
+ print(f"Encoded: {encoded}")
167
+
168
+ decoded = tokenizer.decode(encoded)
169
+ print(f"Decoded: {decoded}")
170
+
171
+ if test_text == decoded:
172
+ print("Test passed!")
173
+ else:
174
+ print("Test failed!")
175
+
176
+ # Test with Shakespeare sample
177
+ shakespeare_sample = text[:100]
178
+ print(f"\nShakespeare sample: {shakespeare_sample}")
179
+ encoded_sample = tokenizer.encode(shakespeare_sample)
180
+ print(f"Encoded (first 20 tokens): {encoded_sample[:20]}")
181
+ decoded_sample = tokenizer.decode(encoded_sample)
182
+ assert shakespeare_sample == decoded_sample, "Encoding/decoding mismatch!"
183
+ print("Shakespeare encoding test passed!")
184
+
185
+ # Show statistics
186
+ tokenizer.get_stats()
187
+
188
+ # Save tokenizer
189
+ tokenizer.save()
190
+
191
+ print("\n" + "="*80)
192
+ print("TOKENIZER BUILD COMPLETE")
193
+ print("="*80)
194
+ print(f"\nTokenizer ready for model training!")
195
+ print(f"Vocabulary size: {tokenizer.vocab_size}")
196
+ print(f"Saved to: models/tokenizer.json")
197
+ print(f"\nNext step: Build the model architecture")
198
+ print("="*80 + "\n")
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()
models/s2_model.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Llama-Style Transformer Model
3
+ =============================
4
+ Modern transformer architecture with all Tier 1 and Tier 2 optimizations:
5
+
6
+ Architecture (Tier 1):
7
+ - RMSNorm (faster than LayerNorm, no mean calculation)
8
+ - RoPE (Rotary Position Embedding, better length generalization)
9
+ - SwiGLU activation (gated FFN, consistently outperforms GELU)
10
+ - Pre-norm (apply norm before attention/FFN, more stable training)
11
+
12
+ Optimizations (Tier 2):
13
+ - GQA (Grouped Query Attention, fewer KV heads = faster + less memory)
14
+ - Weight tying (share embedding and output projection)
15
+ - Flash Attention via F.scaled_dot_product_attention
16
+ - Gradient checkpointing support (trade compute for memory)
17
+
18
+ Compatible with:
19
+ - liger-kernel (fused RMSNorm, SwiGLU, RoPE, cross-entropy)
20
+ - bf16/fp16 mixed precision training
21
+ - torch.compile for additional speedups
22
+
23
+ Model Sizes:
24
+ - tiny: ~15M params (for testing)
25
+ - small: ~125M params
26
+ - medium: ~350M params
27
+ - large: ~760M params
28
+ - 1B: ~1.1B params (Llama 3.2 1B style)
29
+ """
30
+
31
+ import math
32
+ from dataclasses import dataclass
33
+ from typing import Optional, Tuple
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+
40
+ # ============================================================================
41
+ # Model Configuration
42
+ # ============================================================================
43
+
44
+ @dataclass
45
+ class ModelConfig:
46
+ """Configuration for Llama-style transformer model."""
47
+
48
+ # Model architecture
49
+ vocab_size: int = 32000
50
+ d_model: int = 2048 # Hidden dimension
51
+ n_layers: int = 16 # Number of transformer blocks
52
+ n_heads: int = 32 # Number of attention heads
53
+ n_kv_heads: int = 8 # Number of KV heads (for GQA)
54
+ d_ff: int = None # FFN intermediate dim (default: 8/3 * d_model)
55
+
56
+ # Sequence
57
+ max_seq_len: int = 2048 # Maximum sequence length
58
+
59
+ # RoPE
60
+ rope_theta: float = 500000.0 # RoPE base frequency
61
+
62
+ # Regularization
63
+ dropout: float = 0.0 # Dropout (0 for pretraining)
64
+
65
+ # Options
66
+ tie_weights: bool = True # Tie embedding and output weights
67
+ use_flash_attn: bool = True # Use Flash Attention (SDPA)
68
+
69
+ def __post_init__(self):
70
+ # SwiGLU uses 8/3 * d_model for FFN, rounded to multiple of 256
71
+ if self.d_ff is None:
72
+ self.d_ff = int(8 / 3 * self.d_model)
73
+ self.d_ff = ((self.d_ff + 255) // 256) * 256
74
+
75
+ # Validate GQA configuration
76
+ assert self.n_heads % self.n_kv_heads == 0, \
77
+ f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
78
+
79
+ self.n_kv_groups = self.n_heads // self.n_kv_heads
80
+ self.head_dim = self.d_model // self.n_heads
81
+
82
+
83
+ # Predefined model configurations
84
+ MODEL_CONFIGS = {
85
+ "tiny": ModelConfig(
86
+ d_model=256,
87
+ n_layers=6,
88
+ n_heads=8,
89
+ n_kv_heads=4,
90
+ max_seq_len=1024,
91
+ ),
92
+ "small": ModelConfig(
93
+ d_model=768,
94
+ n_layers=12,
95
+ n_heads=12,
96
+ n_kv_heads=4,
97
+ max_seq_len=2048,
98
+ ),
99
+ "medium": ModelConfig(
100
+ d_model=1024,
101
+ n_layers=16,
102
+ n_heads=16,
103
+ n_kv_heads=4,
104
+ max_seq_len=2048,
105
+ ),
106
+ "large": ModelConfig(
107
+ d_model=1536,
108
+ n_layers=20,
109
+ n_heads=24,
110
+ n_kv_heads=8,
111
+ max_seq_len=2048,
112
+ ),
113
+ "1B": ModelConfig(
114
+ d_model=2048,
115
+ n_layers=16,
116
+ n_heads=32,
117
+ n_kv_heads=8,
118
+ d_ff=8192, # Llama 3.2 1B uses 4x hidden, not 8/3x
119
+ max_seq_len=2048,
120
+ ),
121
+ }
122
+
123
+
124
+ def get_model_config(size: str, **overrides) -> ModelConfig:
125
+ """Get a predefined model configuration with optional overrides."""
126
+ if size not in MODEL_CONFIGS:
127
+ raise ValueError(f"Unknown model size: {size}. Choose from: {list(MODEL_CONFIGS.keys())}")
128
+
129
+ config = MODEL_CONFIGS[size]
130
+
131
+ # Apply overrides
132
+ for key, value in overrides.items():
133
+ if hasattr(config, key):
134
+ setattr(config, key, value)
135
+ else:
136
+ raise ValueError(f"Unknown config parameter: {key}")
137
+
138
+ # Recompute derived values
139
+ config.__post_init__()
140
+ return config
141
+
142
+
143
+ # ============================================================================
144
+ # RMSNorm (Tier 1)
145
+ # ============================================================================
146
+
147
+ class RMSNorm(nn.Module):
148
+ """
149
+ Root Mean Square Layer Normalization.
150
+
151
+ Simpler and faster than LayerNorm - skips the mean calculation.
152
+ Used in Llama, Mistral, and other modern LLMs.
153
+
154
+ Can be replaced with liger_kernel.transformers.LigerRMSNorm for
155
+ additional speedup via kernel fusion.
156
+ """
157
+
158
+ def __init__(self, dim: int, eps: float = 1e-6):
159
+ super().__init__()
160
+ self.eps = eps
161
+ self.weight = nn.Parameter(torch.ones(dim))
162
+
163
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
164
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ output = self._norm(x.float()).type_as(x)
168
+ return output * self.weight
169
+
170
+
171
+ # ============================================================================
172
+ # Rotary Position Embedding (RoPE) (Tier 1)
173
+ # ============================================================================
174
+
175
+ def precompute_rope_freqs(
176
+ dim: int,
177
+ max_seq_len: int,
178
+ theta: float = 10000.0,
179
+ device: torch.device = None,
180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ """
182
+ Precompute the cos and sin frequencies for RoPE.
183
+
184
+ Args:
185
+ dim: Head dimension (d_model // n_heads)
186
+ max_seq_len: Maximum sequence length
187
+ theta: Base frequency (Llama 3 uses 500000)
188
+ device: Target device
189
+
190
+ Returns:
191
+ cos, sin tensors of shape (max_seq_len, dim)
192
+ """
193
+ # Compute inverse frequencies
194
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
195
+
196
+ # Create position indices
197
+ t = torch.arange(max_seq_len, device=device)
198
+
199
+ # Outer product: (seq_len,) x (dim/2,) -> (seq_len, dim/2)
200
+ freqs = torch.outer(t, freqs)
201
+
202
+ # Compute cos and sin, then interleave to get (seq_len, dim)
203
+ cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
204
+ sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
205
+
206
+ return cos, sin
207
+
208
+
209
+ def apply_rotary_emb(
210
+ x: torch.Tensor,
211
+ cos: torch.Tensor,
212
+ sin: torch.Tensor,
213
+ ) -> torch.Tensor:
214
+ """
215
+ Apply rotary position embedding to input tensor.
216
+
217
+ Args:
218
+ x: Input tensor of shape (batch, n_heads, seq_len, head_dim)
219
+ cos: Cosine frequencies of shape (seq_len, head_dim)
220
+ sin: Sine frequencies of shape (seq_len, head_dim)
221
+
222
+ Returns:
223
+ Tensor with rotary embedding applied
224
+ """
225
+ # Get sequence length from input
226
+ seq_len = x.size(2)
227
+ cos = cos[:seq_len]
228
+ sin = sin[:seq_len]
229
+
230
+ # Reshape for broadcasting: (1, 1, seq_len, head_dim)
231
+ cos = cos.unsqueeze(0).unsqueeze(0)
232
+ sin = sin.unsqueeze(0).unsqueeze(0)
233
+
234
+ # Rotate pairs: [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...]
235
+ x_rot = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1)
236
+ x_rot = x_rot.reshape(x.shape)
237
+
238
+ # Apply rotation
239
+ return x * cos + x_rot * sin
240
+
241
+
242
+ # ============================================================================
243
+ # Grouped Query Attention (GQA) with Flash Attention (Tier 1 + Tier 2)
244
+ # ============================================================================
245
+
246
+ class Attention(nn.Module):
247
+ """
248
+ Multi-head attention with Grouped Query Attention (GQA) and Flash Attention.
249
+
250
+ GQA uses fewer key-value heads than query heads, reducing memory and
251
+ compute while maintaining quality. For example, with 32 query heads and
252
+ 8 KV heads, each KV head is shared by 4 query heads.
253
+
254
+ Flash Attention is used via PyTorch's scaled_dot_product_attention,
255
+ which provides O(N) memory complexity instead of O(N^2).
256
+ """
257
+
258
+ def __init__(self, config: ModelConfig):
259
+ super().__init__()
260
+ self.config = config
261
+
262
+ self.n_heads = config.n_heads
263
+ self.n_kv_heads = config.n_kv_heads
264
+ self.n_kv_groups = config.n_kv_groups
265
+ self.head_dim = config.head_dim
266
+
267
+ # Query projection: full heads
268
+ self.wq = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False)
269
+
270
+ # Key and Value projections: fewer heads for GQA
271
+ self.wk = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False)
272
+ self.wv = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False)
273
+
274
+ # Output projection
275
+ self.wo = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
276
+
277
+ self.dropout = nn.Dropout(config.dropout)
278
+ self.use_flash_attn = config.use_flash_attn
279
+
280
+ def forward(
281
+ self,
282
+ x: torch.Tensor,
283
+ cos: torch.Tensor,
284
+ sin: torch.Tensor,
285
+ mask: Optional[torch.Tensor] = None,
286
+ ) -> torch.Tensor:
287
+ batch_size, seq_len, _ = x.shape
288
+
289
+ # Project to Q, K, V
290
+ q = self.wq(x) # (B, T, n_heads * head_dim)
291
+ k = self.wk(x) # (B, T, n_kv_heads * head_dim)
292
+ v = self.wv(x) # (B, T, n_kv_heads * head_dim)
293
+
294
+ # Reshape to (B, n_heads, T, head_dim)
295
+ q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
296
+ k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
297
+ v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
298
+
299
+ # Apply RoPE to Q and K
300
+ q = apply_rotary_emb(q, cos, sin)
301
+ k = apply_rotary_emb(k, cos, sin)
302
+
303
+ # Expand KV heads for GQA: (B, n_kv_heads, T, head_dim) -> (B, n_heads, T, head_dim)
304
+ if self.n_kv_groups > 1:
305
+ k = k.repeat_interleave(self.n_kv_groups, dim=1)
306
+ v = v.repeat_interleave(self.n_kv_groups, dim=1)
307
+
308
+ # Attention
309
+ if self.use_flash_attn:
310
+ # Use PyTorch's optimized SDPA (Flash Attention when available)
311
+ attn_out = F.scaled_dot_product_attention(
312
+ q, k, v,
313
+ attn_mask=mask,
314
+ dropout_p=self.dropout.p if self.training else 0.0,
315
+ is_causal=mask is None, # Use causal mask if no explicit mask
316
+ )
317
+ else:
318
+ # Manual attention (for debugging or when SDPA unavailable)
319
+ scale = 1.0 / math.sqrt(self.head_dim)
320
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
321
+
322
+ if mask is not None:
323
+ attn_weights = attn_weights + mask
324
+ else:
325
+ # Causal mask
326
+ causal_mask = torch.triu(
327
+ torch.full((seq_len, seq_len), float('-inf'), device=x.device),
328
+ diagonal=1
329
+ )
330
+ attn_weights = attn_weights + causal_mask
331
+
332
+ attn_weights = F.softmax(attn_weights, dim=-1)
333
+ attn_weights = self.dropout(attn_weights)
334
+ attn_out = torch.matmul(attn_weights, v)
335
+
336
+ # Reshape back: (B, n_heads, T, head_dim) -> (B, T, d_model)
337
+ attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
338
+
339
+ return self.wo(attn_out)
340
+
341
+
342
+ # ============================================================================
343
+ # SwiGLU Feed-Forward Network (Tier 1)
344
+ # ============================================================================
345
+
346
+ class FeedForward(nn.Module):
347
+ """
348
+ SwiGLU Feed-Forward Network.
349
+
350
+ Replaces the standard GELU FFN with a gated linear unit using SiLU activation.
351
+ Uses 3 weight matrices (gate, up, down) instead of 2.
352
+
353
+ SwiGLU(x) = (x * W_gate * SiLU) * (x * W_up) * W_down
354
+
355
+ Consistently outperforms GELU at the same compute budget.
356
+ Can be replaced with liger_kernel.transformers.LigerSwiGLUMLP for fusion.
357
+ """
358
+
359
+ def __init__(self, config: ModelConfig):
360
+ super().__init__()
361
+
362
+ hidden_dim = config.d_ff
363
+
364
+ # Gate and up projections (can be fused)
365
+ self.w_gate = nn.Linear(config.d_model, hidden_dim, bias=False)
366
+ self.w_up = nn.Linear(config.d_model, hidden_dim, bias=False)
367
+
368
+ # Down projection
369
+ self.w_down = nn.Linear(hidden_dim, config.d_model, bias=False)
370
+
371
+ self.dropout = nn.Dropout(config.dropout)
372
+
373
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
374
+ # SwiGLU: SiLU(gate) * up, then project down
375
+ return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
376
+
377
+
378
+ # ============================================================================
379
+ # Transformer Block (Pre-norm)
380
+ # ============================================================================
381
+
382
+ class TransformerBlock(nn.Module):
383
+ """
384
+ Single transformer block with pre-norm architecture.
385
+
386
+ Pre-norm applies normalization BEFORE attention/FFN (not after),
387
+ which provides more stable gradients at scale.
388
+
389
+ Structure:
390
+ x = x + Attention(RMSNorm(x))
391
+ x = x + FFN(RMSNorm(x))
392
+ """
393
+
394
+ def __init__(self, config: ModelConfig, layer_idx: int):
395
+ super().__init__()
396
+ self.layer_idx = layer_idx
397
+
398
+ # Pre-norm layers
399
+ self.attn_norm = RMSNorm(config.d_model)
400
+ self.ffn_norm = RMSNorm(config.d_model)
401
+
402
+ # Attention and FFN
403
+ self.attn = Attention(config)
404
+ self.ffn = FeedForward(config)
405
+
406
+ def forward(
407
+ self,
408
+ x: torch.Tensor,
409
+ cos: torch.Tensor,
410
+ sin: torch.Tensor,
411
+ mask: Optional[torch.Tensor] = None,
412
+ ) -> torch.Tensor:
413
+ # Pre-norm attention with residual
414
+ x = x + self.attn(self.attn_norm(x), cos, sin, mask)
415
+
416
+ # Pre-norm FFN with residual
417
+ x = x + self.ffn(self.ffn_norm(x))
418
+
419
+ return x
420
+
421
+
422
+ # ============================================================================
423
+ # Complete Llama Model
424
+ # ============================================================================
425
+
426
+ class LlamaModel(nn.Module):
427
+ """
428
+ Complete Llama-style transformer model for language modeling.
429
+
430
+ Features:
431
+ - RMSNorm, RoPE, SwiGLU, GQA (Tier 1)
432
+ - Weight tying, Flash Attention (Tier 2)
433
+ - Gradient checkpointing support
434
+ - Compatible with liger-kernel fused ops
435
+
436
+ Usage:
437
+ config = get_model_config("1B", vocab_size=32000)
438
+ model = LlamaModel(config)
439
+
440
+ # Enable gradient checkpointing for memory savings
441
+ model.gradient_checkpointing_enable()
442
+
443
+ # Forward pass
444
+ logits = model(input_ids)
445
+ loss = model(input_ids, targets=targets)
446
+ """
447
+
448
+ def __init__(self, config: ModelConfig):
449
+ super().__init__()
450
+ self.config = config
451
+
452
+ # Token embedding
453
+ self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
454
+
455
+ # Transformer blocks
456
+ self.layers = nn.ModuleList([
457
+ TransformerBlock(config, layer_idx=i)
458
+ for i in range(config.n_layers)
459
+ ])
460
+
461
+ # Final normalization
462
+ self.norm = RMSNorm(config.d_model)
463
+
464
+ # Output projection (language model head)
465
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
466
+
467
+ # Weight tying: share embedding and output weights
468
+ if config.tie_weights:
469
+ self.lm_head.weight = self.tok_emb.weight
470
+
471
+ # Precompute RoPE frequencies
472
+ self.register_buffer(
473
+ "rope_cos",
474
+ torch.zeros(config.max_seq_len, config.head_dim),
475
+ persistent=False
476
+ )
477
+ self.register_buffer(
478
+ "rope_sin",
479
+ torch.zeros(config.max_seq_len, config.head_dim),
480
+ persistent=False
481
+ )
482
+
483
+ # Gradient checkpointing flag
484
+ self._gradient_checkpointing = False
485
+
486
+ # Initialize weights
487
+ self.apply(self._init_weights)
488
+
489
+ # Apply special initialization for output projection
490
+ self._init_output_weights()
491
+
492
+ def _init_weights(self, module: nn.Module):
493
+ """Initialize weights using Llama-style initialization."""
494
+ if isinstance(module, nn.Linear):
495
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
496
+ if module.bias is not None:
497
+ torch.nn.init.zeros_(module.bias)
498
+ elif isinstance(module, nn.Embedding):
499
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
500
+
501
+ def _init_output_weights(self):
502
+ """Apply scaled initialization to output projections for stability."""
503
+ # Scale down residual projections by 1/sqrt(2*n_layers)
504
+ scale = (2 * self.config.n_layers) ** -0.5
505
+ for layer in self.layers:
506
+ torch.nn.init.normal_(layer.attn.wo.weight, mean=0.0, std=0.02 * scale)
507
+ torch.nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale)
508
+
509
+ def _init_rope(self, device: torch.device):
510
+ """Initialize RoPE frequencies on the correct device."""
511
+ cos, sin = precompute_rope_freqs(
512
+ dim=self.config.head_dim,
513
+ max_seq_len=self.config.max_seq_len,
514
+ theta=self.config.rope_theta,
515
+ device=device,
516
+ )
517
+ self.rope_cos = cos
518
+ self.rope_sin = sin
519
+
520
+ def gradient_checkpointing_enable(self):
521
+ """Enable gradient checkpointing for memory-efficient training."""
522
+ self._gradient_checkpointing = True
523
+
524
+ def gradient_checkpointing_disable(self):
525
+ """Disable gradient checkpointing."""
526
+ self._gradient_checkpointing = False
527
+
528
+ def forward(
529
+ self,
530
+ input_ids: torch.Tensor,
531
+ targets: Optional[torch.Tensor] = None,
532
+ mask: Optional[torch.Tensor] = None,
533
+ ) -> torch.Tensor:
534
+ """
535
+ Forward pass.
536
+
537
+ Args:
538
+ input_ids: Token IDs of shape (batch_size, seq_len)
539
+ targets: Optional target IDs for loss computation
540
+ mask: Optional attention mask
541
+
542
+ Returns:
543
+ If targets provided: scalar loss
544
+ Otherwise: logits of shape (batch_size, seq_len, vocab_size)
545
+ """
546
+ batch_size, seq_len = input_ids.shape
547
+ device = input_ids.device
548
+
549
+ # Initialize RoPE on first forward pass (ensures correct device)
550
+ if self.rope_cos.device != device or self.rope_cos.sum() == 0:
551
+ self._init_rope(device)
552
+
553
+ # Token embeddings
554
+ x = self.tok_emb(input_ids)
555
+
556
+ # Get RoPE frequencies for this sequence length
557
+ cos = self.rope_cos[:seq_len]
558
+ sin = self.rope_sin[:seq_len]
559
+
560
+ # Transformer blocks
561
+ for layer in self.layers:
562
+ if self._gradient_checkpointing and self.training:
563
+ x = torch.utils.checkpoint.checkpoint(
564
+ layer, x, cos, sin, mask,
565
+ use_reentrant=False
566
+ )
567
+ else:
568
+ x = layer(x, cos, sin, mask)
569
+
570
+ # Final norm
571
+ x = self.norm(x)
572
+
573
+ # Compute logits
574
+ logits = self.lm_head(x)
575
+
576
+ # Compute loss if targets provided
577
+ if targets is not None:
578
+ # NOTE: No shift here — the DataLoader already provides
579
+ # pre-shifted targets (x = tokens[:-1], y = tokens[1:]),
580
+ # so logits[k] should predict targets[k] directly.
581
+ loss = F.cross_entropy(
582
+ logits.view(-1, self.config.vocab_size),
583
+ targets.view(-1),
584
+ ignore_index=-100, # Ignore padding
585
+ )
586
+ return loss
587
+
588
+ return logits
589
+
590
+ @torch.no_grad()
591
+ def generate(
592
+ self,
593
+ input_ids: torch.Tensor,
594
+ max_new_tokens: int = 100,
595
+ temperature: float = 1.0,
596
+ top_k: Optional[int] = None,
597
+ top_p: Optional[float] = None,
598
+ ) -> torch.Tensor:
599
+ """
600
+ Generate tokens autoregressively.
601
+
602
+ Args:
603
+ input_ids: Starting token IDs (batch_size, seq_len)
604
+ max_new_tokens: Maximum number of tokens to generate
605
+ temperature: Sampling temperature (1.0 = neutral)
606
+ top_k: If set, only sample from top k tokens
607
+ top_p: If set, use nucleus sampling with this probability mass
608
+
609
+ Returns:
610
+ Generated token IDs (batch_size, seq_len + max_new_tokens)
611
+ """
612
+ self.eval()
613
+
614
+ for _ in range(max_new_tokens):
615
+ # Crop to max_seq_len if needed
616
+ idx_cond = input_ids if input_ids.size(1) <= self.config.max_seq_len else \
617
+ input_ids[:, -self.config.max_seq_len:]
618
+
619
+ # Forward pass
620
+ logits = self(idx_cond)
621
+
622
+ # Get logits for last position
623
+ logits = logits[:, -1, :] / temperature
624
+
625
+ # Apply top-k filtering
626
+ if top_k is not None:
627
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
628
+ logits[logits < v[:, [-1]]] = float('-inf')
629
+
630
+ # Apply top-p (nucleus) filtering
631
+ if top_p is not None:
632
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
633
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
634
+
635
+ # Remove tokens with cumulative probability above threshold
636
+ sorted_indices_to_remove = cumulative_probs > top_p
637
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
638
+ sorted_indices_to_remove[..., 0] = 0
639
+
640
+ indices_to_remove = sorted_indices_to_remove.scatter(
641
+ 1, sorted_indices, sorted_indices_to_remove
642
+ )
643
+ logits[indices_to_remove] = float('-inf')
644
+
645
+ # Sample
646
+ probs = F.softmax(logits, dim=-1)
647
+ next_token = torch.multinomial(probs, num_samples=1)
648
+
649
+ # Append
650
+ input_ids = torch.cat([input_ids, next_token], dim=1)
651
+
652
+ return input_ids
653
+
654
+ def count_parameters(self, trainable_only: bool = True) -> int:
655
+ """Count model parameters."""
656
+ if trainable_only:
657
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
658
+ return sum(p.numel() for p in self.parameters())
659
+
660
+ def estimate_flops(self, seq_len: int, batch_size: int = 1) -> int:
661
+ """
662
+ Estimate FLOPs for a forward pass.
663
+
664
+ Uses the approximation: FLOPs ≈ 2 * params * tokens
665
+ (multiply-add counts as 2 ops)
666
+ """
667
+ params = self.count_parameters(trainable_only=False)
668
+ tokens = batch_size * seq_len
669
+ return 2 * params * tokens
670
+
671
+
672
+ # ============================================================================
673
+ # Utility Functions
674
+ # ============================================================================
675
+
676
+ def create_model(
677
+ size: str = "1B",
678
+ vocab_size: int = 32000,
679
+ max_seq_len: int = 2048,
680
+ **kwargs
681
+ ) -> LlamaModel:
682
+ """
683
+ Create a Llama model with the specified configuration.
684
+
685
+ Args:
686
+ size: Model size ("tiny", "small", "medium", "large", "1B")
687
+ vocab_size: Vocabulary size
688
+ max_seq_len: Maximum sequence length
689
+ **kwargs: Additional config overrides
690
+
691
+ Returns:
692
+ Initialized LlamaModel
693
+ """
694
+ config = get_model_config(
695
+ size,
696
+ vocab_size=vocab_size,
697
+ max_seq_len=max_seq_len,
698
+ **kwargs
699
+ )
700
+ return LlamaModel(config)
701
+
702
+
703
+ def print_model_summary(model: LlamaModel):
704
+ """Print a summary of the model architecture."""
705
+ config = model.config
706
+ params = model.count_parameters()
707
+
708
+ print("\n" + "=" * 60)
709
+ print("LLAMA MODEL SUMMARY")
710
+ print("=" * 60)
711
+ print(f"\nArchitecture:")
712
+ print(f" Hidden dim: {config.d_model}")
713
+ print(f" Layers: {config.n_layers}")
714
+ print(f" Attention heads: {config.n_heads}")
715
+ print(f" KV heads (GQA): {config.n_kv_heads}")
716
+ print(f" Head dim: {config.head_dim}")
717
+ print(f" FFN dim: {config.d_ff}")
718
+ print(f" Vocab size: {config.vocab_size}")
719
+ print(f" Max seq len: {config.max_seq_len}")
720
+
721
+ print(f"\nOptimizations:")
722
+ print(f" RMSNorm: Yes")
723
+ print(f" RoPE: Yes (theta={config.rope_theta})")
724
+ print(f" SwiGLU: Yes")
725
+ print(f" GQA: Yes ({config.n_heads}/{config.n_kv_heads} = {config.n_kv_groups}x)")
726
+ print(f" Weight tying: {config.tie_weights}")
727
+ print(f" Flash Attention: {config.use_flash_attn}")
728
+
729
+ print(f"\nParameters:")
730
+ print(f" Total: {params:,}")
731
+ print(f" Size: ~{params / 1e9:.2f}B" if params > 1e9 else f" Size: ~{params / 1e6:.0f}M")
732
+
733
+ # Estimate memory
734
+ param_bytes = params * 4 # fp32
735
+ print(f" FP32 memory: ~{param_bytes / 1e9:.2f} GB")
736
+ print(f" BF16 memory: ~{param_bytes / 2 / 1e9:.2f} GB")
737
+
738
+ print("=" * 60 + "\n")
739
+
740
+
741
+ # ============================================================================
742
+ # Main (for testing)
743
+ # ============================================================================
744
+
745
+ if __name__ == "__main__":
746
+ # Test model creation
747
+ print("Testing Llama model creation...\n")
748
+
749
+ for size in ["tiny", "small", "medium", "large", "1B"]:
750
+ model = create_model(size)
751
+ params = model.count_parameters()
752
+ print(f"{size:8s}: {params:>12,} parameters ({params/1e6:>7.1f}M)")
753
+
754
+ print("\n" + "-" * 60)
755
+
756
+ # Detailed summary for 1B
757
+ model = create_model("1B")
758
+ print_model_summary(model)
759
+
760
+ # Test forward pass
761
+ print("Testing forward pass...")
762
+ device = "cuda" if torch.cuda.is_available() else "cpu"
763
+ model = model.to(device)
764
+
765
+ batch_size = 2
766
+ seq_len = 128
767
+ input_ids = torch.randint(0, 32000, (batch_size, seq_len), device=device)
768
+
769
+ # Forward without targets (returns logits)
770
+ logits = model(input_ids)
771
+ print(f"Logits shape: {logits.shape}")
772
+
773
+ # Forward with targets (returns loss)
774
+ targets = torch.randint(0, 32000, (batch_size, seq_len), device=device)
775
+ loss = model(input_ids, targets=targets)
776
+ print(f"Loss: {loss.item():.4f}")
777
+
778
+ # Test gradient checkpointing
779
+ print("\nTesting gradient checkpointing...")
780
+ model.gradient_checkpointing_enable()
781
+ loss = model(input_ids, targets=targets)
782
+ loss.backward()
783
+ print(f"Gradient checkpointing loss: {loss.item():.4f}")
784
+
785
+ print("\nAll tests passed!")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0.0
2
+ tokenizers>=0.13.0
3
+ gradio>=4.0.0
4
+ numpy
tokenizer/bpe_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff