dasdasddds commited on
Commit
93783dd
Β·
verified Β·
1 Parent(s): 057e18a

Upload 16 files

Browse files
Files changed (16) hide show
  1. .gitattributes +1 -32
  2. README.md +193 -0
  3. WEIGHTS_GO_HERE.txt +3 -0
  4. chat.py +339 -0
  5. config.json +20 -0
  6. config.py +157 -0
  7. dataset.py +269 -0
  8. model.py +513 -0
  9. requirements.txt +3 -0
  10. special_tokens_map.json +12 -0
  11. tokenizer.py +344 -0
  12. tokenizer_config.json +11 -0
  13. train.py +456 -0
  14. visual_nn_3d.py +387 -0
  15. visual_nn_nodes.py +395 -0
  16. visualize_nn.py +472 -0
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
  *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
3
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - text-generation
7
+ - from-scratch
8
+ - transformer
9
+ - gpt
10
+ - pytorch
11
+ - chatbot
12
+ pipeline_tag: text-generation
13
+ model-index:
14
+ - name: GPT-300M
15
+ results: []
16
+ ---
17
+
18
+ # GPT-300M
19
+
20
+ A **334,808,064 parameter** autoregressive transformer language model built **entirely from scratch** in PyTorch. No pretrained weights. No fine-tuning. Everything from zero.
21
+
22
+ ## Architecture
23
+
24
+ ```
25
+ Input Token IDs
26
+ ↓
27
+ Token Embedding (32,000 Γ— 1,024) β€” 32.8M params
28
+ ↓
29
+ Rotary Position Embeddings (RoPE) β€” 0 learned params
30
+ ↓
31
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
32
+ β”‚ Transformer Block Γ— 24 layers (12.6M each) β”‚
33
+ β”‚ β”‚
34
+ β”‚ RMSNorm β†’ Multi-Head Attention β†’ βŠ• Residual β”‚
35
+ β”‚ 16 heads Γ— 64d β”‚
36
+ β”‚ 4,194,304 params β”‚
37
+ β”‚ β”‚
38
+ β”‚ RMSNorm β†’ FFN (GELU) β†’ βŠ• Residual β”‚
39
+ β”‚ 1,024 β†’ 4,096 β†’ 1,024 β”‚
40
+ β”‚ 8,388,608 params β”‚
41
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
42
+ ↓
43
+ Final RMSNorm
44
+ ↓
45
+ LM Head (weight-tied with embedding) β€” 0 extra params
46
+ ↓
47
+ Softmax β†’ Next Token Probabilities
48
+ ```
49
+
50
+ ## Parameter Breakdown
51
+
52
+ | Component | Parameters | Percentage |
53
+ |---|---:|---:|
54
+ | Token Embedding | 32,768,000 | 9.8% |
55
+ | Attention Layers (Γ—24) | 100,663,296 | 30.1% |
56
+ | Feed-Forward Layers (Γ—24) | 201,326,592 | 60.1% |
57
+ | RMSNorm (Γ—24 + final) | 50,176 | 0.0% |
58
+ | LM Head | 0 (tied) | β€” |
59
+ | **TOTAL** | **334,808,064** | **100%** |
60
+
61
+ ## Model Details
62
+
63
+ | Hyperparameter | Value |
64
+ |---|---|
65
+ | Hidden dimension (d_model) | 1,024 |
66
+ | Attention heads | 16 |
67
+ | Head dimension | 64 |
68
+ | Transformer layers | 24 |
69
+ | FFN dimension (d_ff) | 4,096 |
70
+ | Vocabulary size | 32,000 |
71
+ | Max sequence length | 2,048 |
72
+ | Position encoding | RoPE (ΞΈ=10,000) |
73
+ | Activation | GELU |
74
+ | Normalization | RMSNorm (Ξ΅=1e-5) |
75
+ | Weight tying | Yes (Embed ↔ LM Head) |
76
+ | Bias | None |
77
+
78
+ ## Training Configuration
79
+
80
+ | Setting | Value |
81
+ |---|---|
82
+ | Optimizer | AdamW (β₁=0.9, Ξ²β‚‚=0.95) |
83
+ | Peak learning rate | 3e-4 |
84
+ | Min learning rate | 3e-5 |
85
+ | Schedule | Cosine decay + linear warmup |
86
+ | Warmup steps | 2,000 |
87
+ | Weight decay | 0.1 |
88
+ | Batch size | 32 Γ— 8 gradient accumulation |
89
+ | Max training steps | 600,000 |
90
+ | Precision | bfloat16 |
91
+ | Gradient clipping | 1.0 |
92
+
93
+ ## Usage
94
+
95
+ ### Loading the Model
96
+
97
+ ```python
98
+ from model import GPT300M
99
+ from config import GPT300MConfig
100
+ from tokenizer import BPETokenizer
101
+ import torch
102
+
103
+ # Load config, model, and tokenizer
104
+ config = GPT300MConfig()
105
+ model = GPT300M(config)
106
+
107
+ # Load trained weights
108
+ checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
109
+ model.load_state_dict(checkpoint)
110
+ model.eval()
111
+
112
+ # Load tokenizer
113
+ tokenizer = BPETokenizer.load("tokenizer.json")
114
+ ```
115
+
116
+ ### Chat with the Model
117
+
118
+ ```python
119
+ from chat import ChatBot
120
+
121
+ chatbot = ChatBot(model, tokenizer, config)
122
+ response = chatbot.chat("Hello! What is machine learning?")
123
+ print(response)
124
+ ```
125
+
126
+ ### Interactive Chat
127
+
128
+ ```bash
129
+ python chat.py --checkpoint pytorch_model.bin
130
+ ```
131
+
132
+ ### Training from Scratch
133
+
134
+ ```bash
135
+ # Quick test (tiny model)
136
+ python train.py --tiny
137
+
138
+ # Full 300M model
139
+ python train.py --data your_training_data.txt
140
+
141
+ # Multi-GPU
142
+ torchrun --nproc_per_node=4 train.py --data your_data.txt
143
+ ```
144
+
145
+ ## Files
146
+
147
+ | File | Description |
148
+ |---|---|
149
+ | `config.json` | Model configuration (HuggingFace format) |
150
+ | `config.py` | Python config class with all hyperparameters |
151
+ | `model.py` | Full transformer architecture (RoPE, MHA, FFN, KV-cache) |
152
+ | `tokenizer.py` | BPE tokenizer built from scratch |
153
+ | `tokenizer_config.json` | Tokenizer settings |
154
+ | `special_tokens_map.json` | Special token definitions |
155
+ | `dataset.py` | Dataset classes and data loading |
156
+ | `train.py` | Training loop (DDP, mixed precision, scheduling) |
157
+ | `chat.py` | Interactive chatbot with streaming generation |
158
+ | `visual_nn_3d.py` | 3D matplotlib architecture visualization |
159
+ | `requirements.txt` | Python dependencies |
160
+ | `pytorch_model.bin` | Trained weights *(upload after training)* |
161
+ | `tokenizer.json` | Trained tokenizer *(upload after training)* |
162
+
163
+ ## Hardware Requirements
164
+
165
+ | Config | GPU Memory | Est. Training Time |
166
+ |---|---|---|
167
+ | Tiny (debug) | ~1 GB | Minutes |
168
+ | Full 300M | ~24 GB | ~3-5 days (4Γ—A100) |
169
+
170
+ ## Key Features
171
+
172
+ - **100% from scratch** β€” no pretrained weights, no HuggingFace Transformers dependency
173
+ - **Rotary Position Embeddings** β€” better length generalization than learned positions
174
+ - **RMSNorm** β€” faster than LayerNorm, equally effective
175
+ - **Flash Attention** β€” via PyTorch 2.0 SDPA
176
+ - **KV-Cache** β€” efficient autoregressive generation
177
+ - **Weight tying** β€” saves ~33M parameters
178
+ - **Chat template** β€” built-in support for multi-turn conversations
179
+ - **torch.compile** β€” ready for PyTorch 2.0+ compilation
180
+
181
+ ## Citation
182
+
183
+ ```bibtex
184
+ @misc{gpt300m,
185
+ title={GPT-300M: A 300-Million Parameter Language Model From Scratch},
186
+ year={2025},
187
+ url={https://huggingface.co/YOUR_USERNAME/gpt-300m}
188
+ }
189
+ ```
190
+
191
+ ## License
192
+
193
+ MIT
WEIGHTS_GO_HERE.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ PLACEHOLDER - Replace this file with your trained model weights after training.
2
+ Run: python train.py --data your_data.txt
3
+ Then: torch.save(checkpoint['model_state_dict'], 'pytorch_model.bin')
chat.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M Chatbot Interface
3
+ ============================
4
+ Interactive terminal chatbot using a trained GPT-300M model.
5
+
6
+ Usage:
7
+ python chat.py --checkpoint ./checkpoints/best_model.pt
8
+
9
+ # Or with custom generation parameters:
10
+ python chat.py --checkpoint ./checkpoints/best_model.pt \
11
+ --temperature 0.8 --top_k 40 --max_tokens 256
12
+ """
13
+
14
+ import argparse
15
+ import sys
16
+ import time
17
+ from typing import List, Dict, Optional
18
+
19
+ import torch
20
+
21
+ from config import GPT300MConfig
22
+ from model import GPT300M
23
+ from tokenizer import BPETokenizer
24
+
25
+
26
+ class ChatBot:
27
+ """
28
+ Interactive chatbot powered by GPT-300M.
29
+
30
+ Maintains conversation history, handles tokenization/detokenization,
31
+ and performs autoregressive generation with KV-caching.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model: GPT300M,
37
+ tokenizer: BPETokenizer,
38
+ config: GPT300MConfig,
39
+ device: str = "auto",
40
+ ):
41
+ self.config = config
42
+ self.tokenizer = tokenizer
43
+
44
+ # Device
45
+ if device == "auto":
46
+ if torch.cuda.is_available():
47
+ self.device = "cuda"
48
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
49
+ self.device = "mps"
50
+ else:
51
+ self.device = "cpu"
52
+ else:
53
+ self.device = device
54
+
55
+ self.model = model.to(self.device)
56
+ self.model.eval()
57
+
58
+ # Conversation state
59
+ self.history: List[Dict[str, str]] = []
60
+ self.system_prompt = config.system_prompt
61
+
62
+ def set_system_prompt(self, prompt: str):
63
+ """Set the system prompt for the conversation."""
64
+ self.system_prompt = prompt
65
+
66
+ def reset(self):
67
+ """Clear conversation history."""
68
+ self.history = []
69
+ print("\n✦ Conversation reset.\n")
70
+
71
+ def chat(
72
+ self,
73
+ user_message: str,
74
+ temperature: Optional[float] = None,
75
+ top_k: Optional[int] = None,
76
+ top_p: Optional[float] = None,
77
+ max_new_tokens: Optional[int] = None,
78
+ stream: bool = True,
79
+ ) -> str:
80
+ """
81
+ Send a message and get a response.
82
+
83
+ Args:
84
+ user_message: The user's input
85
+ temperature: Override sampling temperature
86
+ top_k: Override top-k
87
+ top_p: Override top-p
88
+ max_new_tokens: Override max generation length
89
+ stream: Whether to stream tokens to stdout
90
+
91
+ Returns:
92
+ The assistant's response text
93
+ """
94
+ temp = temperature or self.config.temperature
95
+ k = top_k or self.config.top_k
96
+ p = top_p or self.config.top_p
97
+ max_tokens = max_new_tokens or self.config.max_new_tokens
98
+
99
+ # Build conversation messages
100
+ messages = []
101
+ if self.system_prompt:
102
+ messages.append({"role": "system", "content": self.system_prompt})
103
+ messages.extend(self.history)
104
+ messages.append({"role": "user", "content": user_message})
105
+
106
+ # Tokenize
107
+ input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True)
108
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
109
+
110
+ # Check sequence length
111
+ if input_tensor.size(1) > self.config.max_seq_len - max_tokens:
112
+ # Truncate history if needed
113
+ while (
114
+ len(self.history) > 0
115
+ and input_tensor.size(1) > self.config.max_seq_len - max_tokens
116
+ ):
117
+ self.history.pop(0)
118
+ messages = []
119
+ if self.system_prompt:
120
+ messages.append({"role": "system", "content": self.system_prompt})
121
+ messages.extend(self.history)
122
+ messages.append({"role": "user", "content": user_message})
123
+ input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True)
124
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
125
+
126
+ # Generate
127
+ t0 = time.time()
128
+
129
+ if stream:
130
+ response_text = self._generate_streaming(
131
+ input_tensor, max_tokens, temp, k, p
132
+ )
133
+ else:
134
+ with torch.no_grad():
135
+ output_ids = self.model.generate(
136
+ input_tensor,
137
+ max_new_tokens=max_tokens,
138
+ temperature=temp,
139
+ top_k=k,
140
+ top_p=p,
141
+ repetition_penalty=self.config.repetition_penalty,
142
+ eos_token_id=self.tokenizer.special_tokens.get("<|end|>"),
143
+ )
144
+ # Decode only the new tokens
145
+ new_ids = output_ids[0, input_tensor.size(1):].tolist()
146
+ response_text = self.tokenizer.decode(new_ids, skip_special=True)
147
+
148
+ dt = time.time() - t0
149
+ n_tokens = len(self.tokenizer.encode(response_text))
150
+
151
+ # Update history
152
+ self.history.append({"role": "user", "content": user_message})
153
+ self.history.append({"role": "assistant", "content": response_text.strip()})
154
+
155
+ if stream:
156
+ print(f"\n [{n_tokens} tokens, {dt:.1f}s, {n_tokens/dt:.1f} tok/s]")
157
+
158
+ return response_text.strip()
159
+
160
+ @torch.no_grad()
161
+ def _generate_streaming(
162
+ self,
163
+ input_ids: torch.Tensor,
164
+ max_new_tokens: int,
165
+ temperature: float,
166
+ top_k: int,
167
+ top_p: float,
168
+ ) -> str:
169
+ """Generate tokens one at a time, printing as we go."""
170
+ import torch.nn.functional as F
171
+
172
+ model = self.model
173
+ model.eval()
174
+
175
+ eos_id = self.tokenizer.special_tokens.get("<|end|>")
176
+ end_id = self.tokenizer.special_tokens.get("<eos>")
177
+
178
+ # Initial forward pass
179
+ logits, _, kv_caches = model(input_ids, use_cache=True)
180
+
181
+ generated_ids = []
182
+ buffer = b""
183
+
184
+ for step in range(max_new_tokens):
185
+ next_logits = logits[:, -1, :]
186
+
187
+ # Repetition penalty
188
+ if self.config.repetition_penalty != 1.0:
189
+ for tid in set(generated_ids):
190
+ if next_logits[0, tid] > 0:
191
+ next_logits[0, tid] /= self.config.repetition_penalty
192
+ else:
193
+ next_logits[0, tid] *= self.config.repetition_penalty
194
+
195
+ # Temperature + sampling
196
+ if temperature > 0:
197
+ next_logits = next_logits / temperature
198
+ if top_k > 0:
199
+ topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
200
+ next_logits[next_logits < topk_vals[:, -1:]] = float("-inf")
201
+ probs = F.softmax(next_logits, dim=-1)
202
+ next_token = torch.multinomial(probs, num_samples=1)
203
+ else:
204
+ next_token = next_logits.argmax(dim=-1, keepdim=True)
205
+
206
+ token_id = next_token.item()
207
+
208
+ # Check for stop tokens
209
+ if token_id in (eos_id, end_id):
210
+ break
211
+
212
+ generated_ids.append(token_id)
213
+
214
+ # Decode and print the new token
215
+ token_bytes = self.tokenizer.vocab.get(token_id, b"")
216
+ buffer += token_bytes
217
+ try:
218
+ text = buffer.decode("utf-8")
219
+ sys.stdout.write(text)
220
+ sys.stdout.flush()
221
+ buffer = b""
222
+ except UnicodeDecodeError:
223
+ pass # Wait for more bytes
224
+
225
+ # Forward with KV-cache
226
+ position_offset = input_ids.size(1) + step
227
+ logits, _, kv_caches = model(
228
+ next_token,
229
+ kv_caches=kv_caches,
230
+ use_cache=True,
231
+ position_offset=position_offset,
232
+ )
233
+
234
+ # Flush remaining buffer
235
+ if buffer:
236
+ text = buffer.decode("utf-8", errors="replace")
237
+ sys.stdout.write(text)
238
+ sys.stdout.flush()
239
+
240
+ return self.tokenizer.decode(generated_ids, skip_special=True)
241
+
242
+
243
+ def interactive_chat(chatbot: ChatBot):
244
+ """Run an interactive chat session in the terminal."""
245
+ print("=" * 60)
246
+ print(" GPT-300M Chatbot")
247
+ print(" Type 'quit' to exit, 'reset' to clear history")
248
+ print(" Type 'system: <prompt>' to set system prompt")
249
+ print("=" * 60)
250
+ print()
251
+
252
+ while True:
253
+ try:
254
+ user_input = input("You: ").strip()
255
+ except (KeyboardInterrupt, EOFError):
256
+ print("\n\nGoodbye!")
257
+ break
258
+
259
+ if not user_input:
260
+ continue
261
+
262
+ if user_input.lower() == "quit":
263
+ print("Goodbye!")
264
+ break
265
+
266
+ if user_input.lower() == "reset":
267
+ chatbot.reset()
268
+ continue
269
+
270
+ if user_input.lower().startswith("system:"):
271
+ prompt = user_input[7:].strip()
272
+ chatbot.set_system_prompt(prompt)
273
+ print(f"✦ System prompt set: {prompt}\n")
274
+ continue
275
+
276
+ print("\nAssistant: ", end="", flush=True)
277
+ chatbot.chat(user_input, stream=True)
278
+ print()
279
+
280
+
281
+ def load_model(checkpoint_path: str, device: str = "auto"):
282
+ """Load a trained model from checkpoint."""
283
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
284
+
285
+ # Reconstruct config
286
+ config = GPT300MConfig(**checkpoint["config"])
287
+
288
+ # Load model
289
+ model = GPT300M(config)
290
+ model.load_state_dict(checkpoint["model_state_dict"])
291
+
292
+ # Load tokenizer
293
+ tokenizer_path = os.path.join(
294
+ os.path.dirname(checkpoint_path), "tokenizer.json"
295
+ )
296
+ if os.path.exists(tokenizer_path):
297
+ tokenizer = BPETokenizer.load(tokenizer_path)
298
+ else:
299
+ tokenizer = BPETokenizer(vocab_size=config.vocab_size)
300
+ print("Warning: Tokenizer not found, using untrained tokenizer")
301
+
302
+ return model, tokenizer, config
303
+
304
+
305
+ # ═══════════════════════════════════════════════════════��═══════════════
306
+ # MAIN
307
+ # ═══════════════════════════════════════════════════════════════════════
308
+
309
+ if __name__ == "__main__":
310
+ import os
311
+
312
+ parser = argparse.ArgumentParser(description="GPT-300M Chatbot")
313
+ parser.add_argument("--checkpoint", type=str, default=None,
314
+ help="Path to model checkpoint")
315
+ parser.add_argument("--temperature", type=float, default=0.7)
316
+ parser.add_argument("--top_k", type=int, default=50)
317
+ parser.add_argument("--top_p", type=float, default=0.9)
318
+ parser.add_argument("--max_tokens", type=int, default=512)
319
+ parser.add_argument("--device", type=str, default="auto")
320
+ args = parser.parse_args()
321
+
322
+ if args.checkpoint and os.path.exists(args.checkpoint):
323
+ model, tokenizer, config = load_model(args.checkpoint, args.device)
324
+ else:
325
+ print("No checkpoint provided. Initializing random model for demo...")
326
+ from config import gpt_tiny
327
+ config = gpt_tiny()
328
+ model = GPT300M(config)
329
+ tokenizer = BPETokenizer(vocab_size=config.vocab_size)
330
+ # Quick train on minimal data
331
+ tokenizer.train("Hello! How are you? I am fine. " * 100)
332
+
333
+ config.temperature = args.temperature
334
+ config.top_k = args.top_k
335
+ config.top_p = args.top_p
336
+ config.max_new_tokens = args.max_tokens
337
+
338
+ chatbot = ChatBot(model, tokenizer, config, device=args.device)
339
+ interactive_chat(chatbot)
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["GPT300M"],
3
+ "model_type": "gpt-300m",
4
+ "vocab_size": 32000,
5
+ "max_position_embeddings": 2048,
6
+ "hidden_size": 1024,
7
+ "num_attention_heads": 16,
8
+ "num_hidden_layers": 24,
9
+ "intermediate_size": 4096,
10
+ "hidden_act": "gelu",
11
+ "dropout": 0.1,
12
+ "attention_dropout": 0.1,
13
+ "use_bias": false,
14
+ "tie_word_embeddings": true,
15
+ "rope_theta": 10000.0,
16
+ "rms_norm_eps": 1e-5,
17
+ "torch_dtype": "bfloat16",
18
+ "total_params": 334808064,
19
+ "total_params_trainable": 334808064
20
+ }
config.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M Configuration
3
+ ======================
4
+ A ~300 million parameter autoregressive transformer language model.
5
+ Built entirely from scratch β€” no pretrained weights, no fine-tuning.
6
+
7
+ Parameter budget breakdown:
8
+ - Token Embeddings: vocab_size Γ— d_model = 32,000 Γ— 1,024 = 32.8M
9
+ - Position Embeddings: max_seq_len Γ— d_model = 2,048 Γ— 1,024 = 2.1M
10
+ - Transformer Layers (Γ—24):
11
+ - Multi-Head Attention (Q/K/V/O): 4 Γ— d_modelΒ² = 4 Γ— 1,048,576 = 4.2M each
12
+ - Feed-Forward Network: 2 Γ— d_model Γ— d_ff = 2 Γ— 1,024 Γ— 4,096 = 8.4M each
13
+ - LayerNorms: negligible
14
+ - Per layer total: ~12.6M
15
+ - All 24 layers: ~302M
16
+ - Final LayerNorm + LM Head (tied with embeddings): ~0
17
+ ─────────────────────────────────────────────────────
18
+ TOTAL: ~337M parameters (LM head weight-tied β†’ ~304M unique)
19
+ """
20
+
21
+ from dataclasses import dataclass, field
22
+ from typing import Optional
23
+ import json
24
+ import os
25
+
26
+
27
+ @dataclass
28
+ class GPT300MConfig:
29
+ """Configuration for a ~300M parameter GPT model."""
30
+
31
+ # ── Model Architecture ──────────────────────────────────────────────
32
+ vocab_size: int = 32_000 # BPE vocabulary size
33
+ max_seq_len: int = 2_048 # Maximum sequence length (context window)
34
+ d_model: int = 1_024 # Hidden dimension / embedding size
35
+ n_heads: int = 16 # Number of attention heads
36
+ n_layers: int = 24 # Number of transformer blocks
37
+ d_ff: int = 4_096 # Feed-forward intermediate dimension
38
+ dropout: float = 0.1 # Dropout probability
39
+ bias: bool = False # Use bias in linear layers (modern GPTs skip this)
40
+ tie_weights: bool = True # Tie token embedding and LM head weights
41
+ activation: str = "gelu" # Activation function: "gelu" or "swiglu"
42
+ norm_eps: float = 1e-5 # LayerNorm epsilon
43
+ rope: bool = True # Use Rotary Position Embeddings (RoPE)
44
+ rope_theta: float = 10_000.0 # RoPE base frequency
45
+
46
+ # ── Training Hyperparameters ────────────────────────────────────────
47
+ batch_size: int = 32 # Micro-batch size per GPU
48
+ gradient_accumulation_steps: int = 8 # Effective batch = batch_size Γ— grad_accum Γ— n_gpus
49
+ learning_rate: float = 3e-4 # Peak learning rate
50
+ min_learning_rate: float = 3e-5 # Minimum LR after cosine decay
51
+ weight_decay: float = 0.1 # AdamW weight decay
52
+ beta1: float = 0.9 # Adam beta1
53
+ beta2: float = 0.95 # Adam beta2
54
+ max_grad_norm: float = 1.0 # Gradient clipping norm
55
+ warmup_steps: int = 2_000 # Linear warmup steps
56
+ max_steps: int = 600_000 # Total training steps
57
+ eval_interval: int = 1_000 # Evaluate every N steps
58
+ save_interval: int = 5_000 # Save checkpoint every N steps
59
+ log_interval: int = 10 # Log metrics every N steps
60
+
61
+ # ── Data ────────────────────────────────────────────────────────────
62
+ data_dir: str = "./data" # Directory containing tokenized .bin shards
63
+ train_split: float = 0.98 # Train/val split ratio
64
+ num_workers: int = 4 # DataLoader workers
65
+
66
+ # ── System ──────────────────────────────────────────────────────────
67
+ device: str = "auto" # "auto", "cuda", "cpu", "mps"
68
+ dtype: str = "bfloat16" # "float32", "float16", "bfloat16"
69
+ compile_model: bool = True # Use torch.compile (PyTorch 2.0+)
70
+ output_dir: str = "./checkpoints" # Where to save checkpoints
71
+ wandb_project: str = "gpt-300m" # Weights & Biases project name
72
+ wandb_run_name: Optional[str] = None
73
+ seed: int = 42
74
+
75
+ # ── Chat / Inference ────────────────────────────────────────────────
76
+ temperature: float = 0.7 # Sampling temperature
77
+ top_k: int = 50 # Top-k sampling
78
+ top_p: float = 0.9 # Nucleus sampling threshold
79
+ max_new_tokens: int = 512 # Max tokens to generate per turn
80
+ repetition_penalty: float = 1.1 # Penalize repeated tokens
81
+ chat_template: str = (
82
+ "<|system|>{system}<|end|>"
83
+ "<|user|>{user}<|end|>"
84
+ "<|assistant|>"
85
+ )
86
+ system_prompt: str = (
87
+ "You are a helpful, harmless, and honest AI assistant. "
88
+ "Respond naturally and conversationally."
89
+ )
90
+
91
+ # ── Special Token IDs (set during tokenizer init) ───────────────────
92
+ pad_token_id: int = 0
93
+ bos_token_id: int = 1
94
+ eos_token_id: int = 2
95
+
96
+ @property
97
+ def head_dim(self) -> int:
98
+ assert self.d_model % self.n_heads == 0
99
+ return self.d_model // self.n_heads
100
+
101
+ @property
102
+ def total_params_estimate(self) -> int:
103
+ emb = self.vocab_size * self.d_model
104
+ pos = self.max_seq_len * self.d_model if not self.rope else 0
105
+ attn = 4 * self.d_model * self.d_model * self.n_layers
106
+ ffn = 2 * self.d_model * self.d_ff * self.n_layers
107
+ ln = 2 * self.d_model * self.n_layers + self.d_model
108
+ tied = 0 if self.tie_weights else self.vocab_size * self.d_model
109
+ return emb + pos + attn + ffn + ln + tied
110
+
111
+ def save(self, path: str):
112
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
113
+ with open(path, "w") as f:
114
+ json.dump(self.__dict__, f, indent=2)
115
+
116
+ @classmethod
117
+ def load(cls, path: str) -> "GPT300MConfig":
118
+ with open(path) as f:
119
+ return cls(**json.load(f))
120
+
121
+ def __post_init__(self):
122
+ assert self.d_model % self.n_heads == 0, (
123
+ f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
124
+ )
125
+
126
+
127
+ # ── Preset Configs ──────────────────────────────────────────────────────
128
+
129
+ def gpt_300m() -> GPT300MConfig:
130
+ """Default 300M config."""
131
+ return GPT300MConfig()
132
+
133
+ def gpt_125m() -> GPT300MConfig:
134
+ """Smaller 125M config for testing."""
135
+ return GPT300MConfig(
136
+ d_model=768, n_heads=12, n_layers=12, d_ff=3072,
137
+ max_seq_len=1024, batch_size=64
138
+ )
139
+
140
+ def gpt_tiny() -> GPT300MConfig:
141
+ """Tiny config for debugging."""
142
+ return GPT300MConfig(
143
+ d_model=128, n_heads=4, n_layers=4, d_ff=512,
144
+ vocab_size=1000, max_seq_len=256, batch_size=8
145
+ )
146
+
147
+
148
+ if __name__ == "__main__":
149
+ cfg = gpt_300m()
150
+ print(f"GPT-300M Configuration")
151
+ print(f" Estimated parameters: {cfg.total_params_estimate:,}")
152
+ print(f" d_model: {cfg.d_model}")
153
+ print(f" n_heads: {cfg.n_heads}")
154
+ print(f" n_layers: {cfg.n_layers}")
155
+ print(f" d_ff: {cfg.d_ff}")
156
+ print(f" vocab_size: {cfg.vocab_size}")
157
+ print(f" max_seq_len: {cfg.max_seq_len}")
dataset.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset & DataLoader for GPT-300M
3
+ ==================================
4
+ Handles loading, tokenizing, and batching text data for training.
5
+
6
+ Supports two modes:
7
+ 1. Pre-tokenized binary shards (fast, for large-scale training)
8
+ 2. Raw text files (convenient, for small datasets)
9
+ """
10
+
11
+ import glob
12
+ import os
13
+ import random
14
+ from typing import List, Optional
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
19
+
20
+ from config import GPT300MConfig
21
+
22
+
23
+ class TextDataset(Dataset):
24
+ """
25
+ Simple dataset that loads raw text, tokenizes it, and creates
26
+ fixed-length training sequences.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ text: str,
32
+ tokenizer,
33
+ seq_len: int,
34
+ stride: Optional[int] = None,
35
+ ):
36
+ """
37
+ Args:
38
+ text: Raw text data
39
+ tokenizer: BPETokenizer instance
40
+ seq_len: Sequence length for training
41
+ stride: Sliding window stride (default: seq_len // 2)
42
+ """
43
+ self.seq_len = seq_len
44
+ self.stride = stride or seq_len // 2
45
+
46
+ # Tokenize the entire text
47
+ self.token_ids = tokenizer.encode(text, add_special_tokens=False)
48
+ self.token_ids = torch.tensor(self.token_ids, dtype=torch.long)
49
+
50
+ # Calculate number of sequences
51
+ self.n_sequences = max(0, (len(self.token_ids) - seq_len - 1) // self.stride + 1)
52
+
53
+ def __len__(self):
54
+ return self.n_sequences
55
+
56
+ def __getitem__(self, idx):
57
+ start = idx * self.stride
58
+ end = start + self.seq_len + 1 # +1 for target offset
59
+ chunk = self.token_ids[start:end]
60
+
61
+ x = chunk[:-1] # Input: tokens[0..seq_len-1]
62
+ y = chunk[1:] # Target: tokens[1..seq_len]
63
+ return x, y
64
+
65
+
66
+ class ChatDataset(Dataset):
67
+ """
68
+ Dataset for chat/conversation data.
69
+ Each sample is a multi-turn conversation formatted with special tokens.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ conversations: List[List[dict]],
75
+ tokenizer,
76
+ max_seq_len: int,
77
+ ):
78
+ """
79
+ Args:
80
+ conversations: List of conversations, each a list of
81
+ {"role": "user"|"assistant"|"system", "content": "..."}
82
+ tokenizer: BPETokenizer instance
83
+ max_seq_len: Maximum sequence length
84
+ """
85
+ self.max_seq_len = max_seq_len
86
+ self.samples = []
87
+
88
+ for conv in conversations:
89
+ ids = tokenizer.encode_chat(conv, add_generation_prompt=False)
90
+ ids.append(tokenizer.special_tokens["<eos>"])
91
+
92
+ # Truncate if needed
93
+ if len(ids) > max_seq_len + 1:
94
+ ids = ids[:max_seq_len + 1]
95
+
96
+ if len(ids) >= 4: # Minimum meaningful length
97
+ self.samples.append(torch.tensor(ids, dtype=torch.long))
98
+
99
+ def __len__(self):
100
+ return len(self.samples)
101
+
102
+ def __getitem__(self, idx):
103
+ tokens = self.samples[idx]
104
+ x = tokens[:-1]
105
+ y = tokens[1:]
106
+ return x, y
107
+
108
+
109
+ class ShardedDataset(IterableDataset):
110
+ """
111
+ Efficient iterable dataset that streams from pre-tokenized binary shards.
112
+ Used for large-scale training where data doesn't fit in memory.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ data_dir: str,
118
+ seq_len: int,
119
+ split: str = "train",
120
+ seed: int = 42,
121
+ ):
122
+ super().__init__()
123
+ self.seq_len = seq_len
124
+ self.seed = seed
125
+
126
+ # Find shard files
127
+ pattern = os.path.join(data_dir, f"{split}_*.bin")
128
+ self.shards = sorted(glob.glob(pattern))
129
+ if not self.shards:
130
+ raise FileNotFoundError(f"No shards found matching: {pattern}")
131
+
132
+ print(f"Found {len(self.shards)} {split} shards")
133
+
134
+ def __iter__(self):
135
+ rng = random.Random(self.seed)
136
+ shards = list(self.shards)
137
+ rng.shuffle(shards)
138
+
139
+ for shard_path in shards:
140
+ # Memory-map the shard for efficiency
141
+ data = np.memmap(shard_path, dtype=np.uint16, mode="r")
142
+ n_tokens = len(data)
143
+ n_chunks = n_tokens // (self.seq_len + 1)
144
+
145
+ # Random order within shard
146
+ indices = list(range(n_chunks))
147
+ rng.shuffle(indices)
148
+
149
+ for idx in indices:
150
+ start = idx * (self.seq_len + 1)
151
+ chunk = torch.from_numpy(
152
+ data[start : start + self.seq_len + 1].astype(np.int64)
153
+ )
154
+ x = chunk[:-1]
155
+ y = chunk[1:]
156
+ yield x, y
157
+
158
+
159
+ def collate_fn(batch, pad_id: int = 0):
160
+ """
161
+ Collate function that pads sequences to the same length within a batch.
162
+ """
163
+ xs, ys = zip(*batch)
164
+ max_len = max(x.size(0) for x in xs)
165
+
166
+ padded_x = torch.full((len(xs), max_len), pad_id, dtype=torch.long)
167
+ padded_y = torch.full((len(ys), max_len), pad_id, dtype=torch.long)
168
+
169
+ for i, (x, y) in enumerate(zip(xs, ys)):
170
+ padded_x[i, :x.size(0)] = x
171
+ padded_y[i, :y.size(0)] = y
172
+
173
+ return padded_x, padded_y
174
+
175
+
176
+ def create_dataloaders(
177
+ config: GPT300MConfig,
178
+ tokenizer,
179
+ text: Optional[str] = None,
180
+ conversations: Optional[List[List[dict]]] = None,
181
+ ) -> tuple:
182
+ """
183
+ Create train and validation DataLoaders.
184
+
185
+ Supply either `text` for raw text training or `conversations` for chat training.
186
+ """
187
+ if text is not None:
188
+ # Split into train/val
189
+ split = int(len(text) * config.train_split)
190
+ train_text = text[:split]
191
+ val_text = text[split:]
192
+
193
+ train_ds = TextDataset(train_text, tokenizer, config.max_seq_len)
194
+ val_ds = TextDataset(val_text, tokenizer, config.max_seq_len)
195
+
196
+ elif conversations is not None:
197
+ split = int(len(conversations) * config.train_split)
198
+ train_convs = conversations[:split]
199
+ val_convs = conversations[split:]
200
+
201
+ train_ds = ChatDataset(train_convs, tokenizer, config.max_seq_len)
202
+ val_ds = ChatDataset(val_convs, tokenizer, config.max_seq_len)
203
+ else:
204
+ raise ValueError("Provide either `text` or `conversations`")
205
+
206
+ train_dl = DataLoader(
207
+ train_ds,
208
+ batch_size=config.batch_size,
209
+ shuffle=True,
210
+ collate_fn=lambda b: collate_fn(b, config.pad_token_id),
211
+ num_workers=config.num_workers,
212
+ pin_memory=True,
213
+ drop_last=True,
214
+ )
215
+
216
+ val_dl = DataLoader(
217
+ val_ds,
218
+ batch_size=config.batch_size,
219
+ shuffle=False,
220
+ collate_fn=lambda b: collate_fn(b, config.pad_token_id),
221
+ num_workers=config.num_workers,
222
+ pin_memory=True,
223
+ )
224
+
225
+ return train_dl, val_dl
226
+
227
+
228
+ # ═══════════════════════════════════════════════════════════════════════
229
+ # UTILITIES: Tokenize and save to binary shards
230
+ # ═══════════════════════════════════════════════════════════════════════
231
+
232
+ def tokenize_to_shards(
233
+ text: str,
234
+ tokenizer,
235
+ output_dir: str,
236
+ shard_size: int = 100_000_000, # ~100M tokens per shard
237
+ split: str = "train",
238
+ ):
239
+ """
240
+ Tokenize text and save to binary shards for efficient loading.
241
+ """
242
+ os.makedirs(output_dir, exist_ok=True)
243
+ tokens = tokenizer.encode(text, add_special_tokens=False)
244
+
245
+ shard_idx = 0
246
+ for start in range(0, len(tokens), shard_size):
247
+ end = min(start + shard_size, len(tokens))
248
+ chunk = np.array(tokens[start:end], dtype=np.uint16)
249
+ path = os.path.join(output_dir, f"{split}_{shard_idx:04d}.bin")
250
+ chunk.tofile(path)
251
+ shard_idx += 1
252
+
253
+ print(f"Saved {shard_idx} shards ({len(tokens):,} tokens) to {output_dir}")
254
+
255
+
256
+ if __name__ == "__main__":
257
+ from tokenizer import BPETokenizer
258
+
259
+ # Quick test with synthetic data
260
+ tok = BPETokenizer(vocab_size=500)
261
+ sample_text = "Hello world! " * 1000
262
+ tok.train(sample_text)
263
+
264
+ ds = TextDataset(sample_text, tok, seq_len=64)
265
+ print(f"Dataset: {len(ds)} sequences of length 64")
266
+
267
+ x, y = ds[0]
268
+ print(f"Sample x: {x[:10]}")
269
+ print(f"Sample y: {y[:10]}")
model.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M Model Architecture
3
+ ============================
4
+ A decoder-only transformer built entirely from scratch in PyTorch.
5
+
6
+ Architecture features:
7
+ - Pre-LayerNorm transformer blocks
8
+ - Rotary Position Embeddings (RoPE)
9
+ - Multi-Head Self-Attention with causal masking
10
+ - GELU activation in feed-forward layers
11
+ - Optional weight tying (token embeddings ↔ LM head)
12
+ - KV-Cache for efficient autoregressive generation
13
+ - Flash Attention support (PyTorch 2.0+)
14
+ """
15
+
16
+ import math
17
+ from typing import Optional, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from config import GPT300MConfig
24
+
25
+
26
+ # ═══════════════════════════════════════════════════════════════════════
27
+ # ROTARY POSITION EMBEDDINGS (RoPE)
28
+ # ═══════════════════════════════════════════════════════════════════════
29
+
30
+ class RotaryEmbedding(nn.Module):
31
+ """Rotary Position Embedding (Su et al., 2021)."""
32
+
33
+ def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
34
+ super().__init__()
35
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
36
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
37
+
38
+ # Precompute cos/sin tables
39
+ t = torch.arange(max_seq_len, dtype=torch.float32)
40
+ freqs = torch.outer(t, inv_freq)
41
+ emb = torch.cat([freqs, freqs], dim=-1)
42
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
43
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
44
+
45
+ def forward(self, seq_len: int, offset: int = 0):
46
+ return (
47
+ self.cos_cached[offset : offset + seq_len],
48
+ self.sin_cached[offset : offset + seq_len],
49
+ )
50
+
51
+
52
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
53
+ """Rotate the second half of the last dimension."""
54
+ x1, x2 = x.chunk(2, dim=-1)
55
+ return torch.cat([-x2, x1], dim=-1)
56
+
57
+
58
+ def apply_rotary_emb(
59
+ q: torch.Tensor, k: torch.Tensor,
60
+ cos: torch.Tensor, sin: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ """Apply rotary embeddings to query and key tensors."""
63
+ # cos/sin shape: [seq_len, head_dim] β†’ [1, 1, seq_len, head_dim]
64
+ cos = cos.unsqueeze(0).unsqueeze(0)
65
+ sin = sin.unsqueeze(0).unsqueeze(0)
66
+ q_rot = q * cos + rotate_half(q) * sin
67
+ k_rot = k * cos + rotate_half(k) * sin
68
+ return q_rot, k_rot
69
+
70
+
71
+ # ═══════════════════════════════════════════════════════════════════════
72
+ # RMSNORM (faster alternative to LayerNorm)
73
+ # ═══════════════════════════════════════════════════════════════════════
74
+
75
+ class RMSNorm(nn.Module):
76
+ """Root Mean Square Layer Normalization."""
77
+
78
+ def __init__(self, dim: int, eps: float = 1e-5):
79
+ super().__init__()
80
+ self.eps = eps
81
+ self.weight = nn.Parameter(torch.ones(dim))
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
85
+ return (x.float() * norm).type_as(x) * self.weight
86
+
87
+
88
+ # ═══════════════════════════════════════════════════════════════════════
89
+ # MULTI-HEAD SELF-ATTENTION
90
+ # ═══════════════════════════════════════════════════════════════════════
91
+
92
+ class MultiHeadAttention(nn.Module):
93
+ """Multi-Head Self-Attention with causal masking and optional KV-cache."""
94
+
95
+ def __init__(self, config: GPT300MConfig):
96
+ super().__init__()
97
+ self.n_heads = config.n_heads
98
+ self.head_dim = config.head_dim
99
+ self.d_model = config.d_model
100
+ self.dropout = config.dropout
101
+
102
+ # Q, K, V projections (fused for efficiency)
103
+ self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
104
+ # Output projection
105
+ self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
106
+
107
+ self.attn_dropout = nn.Dropout(config.dropout)
108
+ self.resid_dropout = nn.Dropout(config.dropout)
109
+
110
+ # Check for Flash Attention support
111
+ self.flash_attn = hasattr(F, "scaled_dot_product_attention")
112
+
113
+ def forward(
114
+ self,
115
+ x: torch.Tensor,
116
+ cos: Optional[torch.Tensor] = None,
117
+ sin: Optional[torch.Tensor] = None,
118
+ mask: Optional[torch.Tensor] = None,
119
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
120
+ use_cache: bool = False,
121
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
122
+ B, T, C = x.shape
123
+
124
+ # Project to Q, K, V
125
+ qkv = self.qkv_proj(x)
126
+ q, k, v = qkv.split(self.d_model, dim=-1)
127
+
128
+ # Reshape: [B, T, n_heads, head_dim] β†’ [B, n_heads, T, head_dim]
129
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
130
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
131
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
132
+
133
+ # Apply RoPE
134
+ if cos is not None and sin is not None:
135
+ q, k = apply_rotary_emb(q, k, cos, sin)
136
+
137
+ # KV-Cache for generation
138
+ if kv_cache is not None:
139
+ k_prev, v_prev = kv_cache
140
+ k = torch.cat([k_prev, k], dim=2)
141
+ v = torch.cat([v_prev, v], dim=2)
142
+
143
+ new_cache = (k, v) if use_cache else None
144
+
145
+ # Attention
146
+ if self.flash_attn and not use_cache:
147
+ # Use PyTorch's efficient SDPA
148
+ attn_out = F.scaled_dot_product_attention(
149
+ q, k, v,
150
+ attn_mask=mask,
151
+ dropout_p=self.dropout if self.training else 0.0,
152
+ is_causal=True if mask is None else False,
153
+ )
154
+ else:
155
+ # Manual attention for compatibility / KV-cache
156
+ scale = 1.0 / math.sqrt(self.head_dim)
157
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
158
+
159
+ if mask is not None:
160
+ scores = scores.masked_fill(mask == 0, float("-inf"))
161
+ else:
162
+ # Causal mask
163
+ T_q, T_k = q.size(2), k.size(2)
164
+ causal = torch.tril(torch.ones(T_q, T_k, device=x.device, dtype=torch.bool))
165
+ # For KV-cache, the causal mask must align with key length
166
+ causal = causal[-T:, :] # last T rows
167
+ scores = scores.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf"))
168
+
169
+ attn_weights = F.softmax(scores, dim=-1)
170
+ attn_weights = self.attn_dropout(attn_weights)
171
+ attn_out = torch.matmul(attn_weights, v)
172
+
173
+ # Reshape back and project
174
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
175
+ out = self.resid_dropout(self.out_proj(attn_out))
176
+
177
+ return out, new_cache
178
+
179
+
180
+ # ═══════════════════════════════════════════════════════════════════════
181
+ # FEED-FORWARD NETWORK
182
+ # ═══════════════════════════════════════════════════════════════════════
183
+
184
+ class FeedForward(nn.Module):
185
+ """Position-wise Feed-Forward Network with GELU activation."""
186
+
187
+ def __init__(self, config: GPT300MConfig):
188
+ super().__init__()
189
+ self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
190
+ self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
191
+ self.dropout = nn.Dropout(config.dropout)
192
+
193
+ if config.activation == "gelu":
194
+ self.act = nn.GELU()
195
+ elif config.activation == "swiglu":
196
+ self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
197
+ self.act = nn.SiLU()
198
+ else:
199
+ raise ValueError(f"Unknown activation: {config.activation}")
200
+
201
+ self.use_swiglu = config.activation == "swiglu"
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ if self.use_swiglu:
205
+ return self.dropout(self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)))
206
+ else:
207
+ return self.dropout(self.down_proj(self.act(self.up_proj(x))))
208
+
209
+
210
+ # ═══════════════════════════════════════════════════════════════════════
211
+ # TRANSFORMER BLOCK
212
+ # ═══════════════════════════════════════════════════════════════════════
213
+
214
+ class TransformerBlock(nn.Module):
215
+ """Pre-norm Transformer block: LayerNorm β†’ Attention β†’ Residual β†’ LayerNorm β†’ FFN β†’ Residual."""
216
+
217
+ def __init__(self, config: GPT300MConfig, layer_idx: int):
218
+ super().__init__()
219
+ self.layer_idx = layer_idx
220
+ self.ln1 = RMSNorm(config.d_model, eps=config.norm_eps)
221
+ self.attn = MultiHeadAttention(config)
222
+ self.ln2 = RMSNorm(config.d_model, eps=config.norm_eps)
223
+ self.ffn = FeedForward(config)
224
+
225
+ def forward(
226
+ self,
227
+ x: torch.Tensor,
228
+ cos: Optional[torch.Tensor] = None,
229
+ sin: Optional[torch.Tensor] = None,
230
+ mask: Optional[torch.Tensor] = None,
231
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
232
+ use_cache: bool = False,
233
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
234
+ # Pre-norm attention with residual
235
+ residual = x
236
+ x = self.ln1(x)
237
+ attn_out, new_cache = self.attn(x, cos, sin, mask, kv_cache, use_cache)
238
+ x = residual + attn_out
239
+
240
+ # Pre-norm FFN with residual
241
+ x = x + self.ffn(self.ln2(x))
242
+
243
+ return x, new_cache
244
+
245
+
246
+ # ═══════════════════════════════════════════════════════════════════════
247
+ # GPT-300M: THE FULL MODEL
248
+ # ═══════════════════════════════════════════════════════════════════════
249
+
250
+ class GPT300M(nn.Module):
251
+ """
252
+ GPT-300M: A 300-million parameter autoregressive language model.
253
+
254
+ Architecture:
255
+ Token Embedding β†’ [Transformer Block Γ— 24] β†’ RMSNorm β†’ LM Head
256
+
257
+ Each Transformer Block:
258
+ RMSNorm β†’ Multi-Head Attention (+ RoPE) β†’ Residual
259
+ β†’ RMSNorm β†’ Feed-Forward (GELU) β†’ Residual
260
+ """
261
+
262
+ def __init__(self, config: GPT300MConfig):
263
+ super().__init__()
264
+ self.config = config
265
+
266
+ # ── Embeddings ───────────────────────────────────────────────
267
+ self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
268
+ self.drop = nn.Dropout(config.dropout)
269
+
270
+ # Rotary embeddings
271
+ if config.rope:
272
+ self.rotary = RotaryEmbedding(
273
+ config.head_dim, config.max_seq_len, config.rope_theta
274
+ )
275
+ else:
276
+ self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
277
+
278
+ # ── Transformer Blocks ───────────────────────────────────────
279
+ self.layers = nn.ModuleList([
280
+ TransformerBlock(config, layer_idx=i)
281
+ for i in range(config.n_layers)
282
+ ])
283
+
284
+ # ── Output ───────────────────────────────────────────────────
285
+ self.ln_f = RMSNorm(config.d_model, eps=config.norm_eps)
286
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
287
+
288
+ # Weight tying
289
+ if config.tie_weights:
290
+ self.lm_head.weight = self.token_emb.weight
291
+
292
+ # Initialize weights
293
+ self.apply(self._init_weights)
294
+ # Scale residual projections
295
+ for pn, p in self.named_parameters():
296
+ if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
297
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers))
298
+
299
+ def _init_weights(self, module: nn.Module):
300
+ if isinstance(module, nn.Linear):
301
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
302
+ if module.bias is not None:
303
+ nn.init.zeros_(module.bias)
304
+ elif isinstance(module, nn.Embedding):
305
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
306
+
307
+ def forward(
308
+ self,
309
+ input_ids: torch.Tensor,
310
+ targets: Optional[torch.Tensor] = None,
311
+ kv_caches: Optional[list] = None,
312
+ use_cache: bool = False,
313
+ position_offset: int = 0,
314
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list]]:
315
+ """
316
+ Forward pass.
317
+
318
+ Args:
319
+ input_ids: [B, T] token indices
320
+ targets: [B, T] target token indices for loss computation
321
+ kv_caches: List of KV-cache tuples, one per layer
322
+ use_cache: Whether to return updated KV-caches
323
+ position_offset: Offset for position embeddings (for KV-cache generation)
324
+
325
+ Returns:
326
+ logits: [B, T, vocab_size]
327
+ loss: scalar loss if targets provided, else None
328
+ new_caches: Updated KV-caches if use_cache=True
329
+ """
330
+ B, T = input_ids.shape
331
+ assert T <= self.config.max_seq_len, (
332
+ f"Sequence length {T} exceeds max {self.config.max_seq_len}"
333
+ )
334
+
335
+ # Token embeddings
336
+ x = self.token_emb(input_ids) # [B, T, d_model]
337
+
338
+ # Position information
339
+ if self.config.rope:
340
+ cos, sin = self.rotary(T, offset=position_offset)
341
+ else:
342
+ positions = torch.arange(position_offset, position_offset + T, device=input_ids.device)
343
+ x = x + self.pos_emb(positions)
344
+ cos, sin = None, None
345
+
346
+ x = self.drop(x)
347
+
348
+ # Transformer blocks
349
+ new_caches = [] if use_cache else None
350
+ for i, layer in enumerate(self.layers):
351
+ cache_i = kv_caches[i] if kv_caches is not None else None
352
+ x, new_cache = layer(x, cos, sin, kv_cache=cache_i, use_cache=use_cache)
353
+ if use_cache:
354
+ new_caches.append(new_cache)
355
+
356
+ # Final norm and LM head
357
+ x = self.ln_f(x)
358
+ logits = self.lm_head(x) # [B, T, vocab_size]
359
+
360
+ # Loss
361
+ loss = None
362
+ if targets is not None:
363
+ loss = F.cross_entropy(
364
+ logits.view(-1, self.config.vocab_size),
365
+ targets.view(-1),
366
+ ignore_index=self.config.pad_token_id,
367
+ )
368
+
369
+ return logits, loss, new_caches
370
+
371
+ @torch.no_grad()
372
+ def generate(
373
+ self,
374
+ input_ids: torch.Tensor,
375
+ max_new_tokens: int = 256,
376
+ temperature: float = 0.7,
377
+ top_k: int = 50,
378
+ top_p: float = 0.9,
379
+ repetition_penalty: float = 1.1,
380
+ eos_token_id: Optional[int] = None,
381
+ ) -> torch.Tensor:
382
+ """
383
+ Autoregressive generation with KV-cache.
384
+
385
+ Args:
386
+ input_ids: [B, T] prompt token IDs
387
+ max_new_tokens: Maximum number of tokens to generate
388
+ temperature: Sampling temperature
389
+ top_k: Top-k sampling
390
+ top_p: Nucleus sampling threshold
391
+ repetition_penalty: Penalty for repeated tokens
392
+ eos_token_id: Stop generation when this token is produced
393
+
394
+ Returns:
395
+ [B, T + max_new_tokens] generated token IDs
396
+ """
397
+ self.eval()
398
+ B, T = input_ids.shape
399
+ device = input_ids.device
400
+
401
+ # Initial forward pass to populate KV-cache
402
+ logits, _, kv_caches = self.forward(input_ids, use_cache=True)
403
+
404
+ generated = input_ids
405
+ all_token_ids = input_ids.tolist()[0] if B == 1 else []
406
+
407
+ for step in range(max_new_tokens):
408
+ # Get logits for the last token
409
+ next_logits = logits[:, -1, :] # [B, vocab_size]
410
+
411
+ # Repetition penalty
412
+ if repetition_penalty != 1.0 and B == 1:
413
+ for token_id in set(all_token_ids):
414
+ if next_logits[0, token_id] > 0:
415
+ next_logits[0, token_id] /= repetition_penalty
416
+ else:
417
+ next_logits[0, token_id] *= repetition_penalty
418
+
419
+ # Temperature
420
+ if temperature > 0:
421
+ next_logits = next_logits / temperature
422
+
423
+ # Top-k filtering
424
+ if top_k > 0:
425
+ topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
426
+ next_logits[next_logits < topk_vals[:, -1:]] = float("-inf")
427
+
428
+ # Top-p (nucleus) filtering
429
+ if top_p < 1.0:
430
+ sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
431
+ cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
432
+ sorted_mask = cumprobs - F.softmax(sorted_logits, dim=-1) >= top_p
433
+ sorted_logits[sorted_mask] = float("-inf")
434
+ next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
435
+
436
+ probs = F.softmax(next_logits, dim=-1)
437
+ next_token = torch.multinomial(probs, num_samples=1)
438
+ else:
439
+ # Greedy
440
+ next_token = next_logits.argmax(dim=-1, keepdim=True)
441
+
442
+ generated = torch.cat([generated, next_token], dim=1)
443
+
444
+ if B == 1:
445
+ all_token_ids.append(next_token.item())
446
+
447
+ # Stop on EOS
448
+ if eos_token_id is not None and next_token.item() == eos_token_id:
449
+ break
450
+
451
+ # Forward pass with KV-cache (only the new token)
452
+ position_offset = generated.size(1) - 1
453
+ logits, _, kv_caches = self.forward(
454
+ next_token,
455
+ kv_caches=kv_caches,
456
+ use_cache=True,
457
+ position_offset=position_offset,
458
+ )
459
+
460
+ return generated
461
+
462
+ def count_parameters(self, trainable_only: bool = True) -> int:
463
+ """Count model parameters."""
464
+ if trainable_only:
465
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
466
+ return sum(p.numel() for p in self.parameters())
467
+
468
+ def model_summary(self) -> str:
469
+ """Print a human-readable model summary."""
470
+ total = self.count_parameters(trainable_only=False)
471
+ trainable = self.count_parameters(trainable_only=True)
472
+ lines = [
473
+ "=" * 60,
474
+ " GPT-300M Model Summary",
475
+ "=" * 60,
476
+ f" Total parameters: {total:>15,}",
477
+ f" Trainable parameters: {trainable:>15,}",
478
+ f" d_model: {self.config.d_model:>15}",
479
+ f" n_heads: {self.config.n_heads:>15}",
480
+ f" n_layers: {self.config.n_layers:>15}",
481
+ f" d_ff: {self.config.d_ff:>15}",
482
+ f" vocab_size: {self.config.vocab_size:>15}",
483
+ f" max_seq_len: {self.config.max_seq_len:>15}",
484
+ f" RoPE: {'Yes':>15}",
485
+ f" Weight tying: {'Yes' if self.config.tie_weights else 'No':>15}",
486
+ f" Flash Attention: {'Yes' if self.layers[0].attn.flash_attn else 'No':>15}",
487
+ "=" * 60,
488
+ ]
489
+ return "\n".join(lines)
490
+
491
+
492
+ # ═══════════════════════════════════════════════════════════════════════
493
+ # QUICK TEST
494
+ # ═══════════════════════════════════════════════════════════════════════
495
+
496
+ if __name__ == "__main__":
497
+ from config import gpt_tiny
498
+
499
+ # Use tiny config for testing
500
+ cfg = gpt_tiny()
501
+ model = GPT300M(cfg)
502
+ print(model.model_summary())
503
+
504
+ # Test forward pass
505
+ x = torch.randint(0, cfg.vocab_size, (2, 32))
506
+ targets = torch.randint(0, cfg.vocab_size, (2, 32))
507
+ logits, loss, _ = model(x, targets=targets)
508
+ print(f"\nForward pass OK: logits={logits.shape}, loss={loss.item():.4f}")
509
+
510
+ # Test generation
511
+ prompt = torch.randint(0, cfg.vocab_size, (1, 8))
512
+ gen = model.generate(prompt, max_new_tokens=16, temperature=0.8)
513
+ print(f"Generation OK: {gen.shape}")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.24.0
3
+ matplotlib>=3.7.0
special_tokens_map.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": "<pad>",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "unk_token": "<unk>",
6
+ "additional_special_tokens": [
7
+ "<|system|>",
8
+ "<|user|>",
9
+ "<|assistant|>",
10
+ "<|end|>"
11
+ ]
12
+ }
tokenizer.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte-Pair Encoding (BPE) Tokenizer β€” Built From Scratch
3
+ ========================================================
4
+ A minimal but complete BPE tokenizer implementation.
5
+ Supports training from raw text, encoding/decoding, and special chat tokens.
6
+
7
+ For production use, you'd typically use SentencePiece or tiktoken,
8
+ but this demonstrates the full tokenizer pipeline.
9
+ """
10
+
11
+ import json
12
+ import os
13
+ import re
14
+ from collections import Counter
15
+ from typing import Dict, List, Optional, Tuple
16
+
17
+
18
+ class BPETokenizer:
19
+ """
20
+ Byte-Pair Encoding tokenizer with special token support.
21
+
22
+ Special tokens:
23
+ <pad> = 0 Padding token
24
+ <bos> = 1 Beginning of sequence
25
+ <eos> = 2 End of sequence
26
+ <unk> = 3 Unknown token
27
+ <|system|> = 4 System prompt delimiter
28
+ <|user|> = 5 User turn delimiter
29
+ <|assistant|> = 6 Assistant turn delimiter
30
+ <|end|> = 7 End of turn
31
+ """
32
+
33
+ SPECIAL_TOKENS = {
34
+ "<pad>": 0,
35
+ "<bos>": 1,
36
+ "<eos>": 2,
37
+ "<unk>": 3,
38
+ "<|system|>": 4,
39
+ "<|user|>": 5,
40
+ "<|assistant|>": 6,
41
+ "<|end|>": 7,
42
+ }
43
+
44
+ # Pre-tokenization regex (GPT-2 style)
45
+ PAT = re.compile(
46
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""",
47
+ re.UNICODE,
48
+ )
49
+
50
+ def __init__(self, vocab_size: int = 32_000):
51
+ self.target_vocab_size = vocab_size
52
+ self.special_tokens = dict(self.SPECIAL_TOKENS)
53
+ self.num_special = len(self.special_tokens)
54
+
55
+ # Byte-level base vocab: map each byte (0-255) to a token ID
56
+ self.byte_to_id: Dict[int, int] = {
57
+ b: b + self.num_special for b in range(256)
58
+ }
59
+ self.id_to_byte: Dict[int, int] = {v: k for k, v in self.byte_to_id.items()}
60
+
61
+ # Merge rules learned during training
62
+ self.merges: List[Tuple[int, int]] = []
63
+ self.merge_to_id: Dict[Tuple[int, int], int] = {}
64
+
65
+ # Full vocab (built after training)
66
+ self.vocab: Dict[int, bytes] = {}
67
+ self._build_vocab()
68
+
69
+ def _build_vocab(self):
70
+ """Reconstruct the full vocabulary from merges."""
71
+ self.vocab = {}
72
+ # Special tokens
73
+ for tok, idx in self.special_tokens.items():
74
+ self.vocab[idx] = tok.encode("utf-8")
75
+ # Byte-level tokens
76
+ for b in range(256):
77
+ self.vocab[self.num_special + b] = bytes([b])
78
+ # Merged tokens
79
+ for (a, b), idx in self.merge_to_id.items():
80
+ self.vocab[idx] = self.vocab[a] + self.vocab[b]
81
+
82
+ @property
83
+ def vocab_size(self) -> int:
84
+ return len(self.vocab)
85
+
86
+ # ── Training ────────────────────────────────────────────────────
87
+
88
+ def train(self, text: str, verbose: bool = True):
89
+ """
90
+ Train BPE merges from raw text.
91
+
92
+ Args:
93
+ text: Raw training text
94
+ verbose: Print progress
95
+ """
96
+ if verbose:
97
+ print(f"Training BPE tokenizer (target vocab: {self.target_vocab_size:,})...")
98
+
99
+ # Pre-tokenize into words
100
+ words = re.findall(self.PAT, text)
101
+
102
+ # Convert each word to a tuple of byte token IDs
103
+ word_freqs: Counter = Counter()
104
+ for word in words:
105
+ byte_ids = tuple(self.byte_to_id[b] for b in word.encode("utf-8"))
106
+ word_freqs[byte_ids] += 1
107
+
108
+ current_vocab_size = self.num_special + 256
109
+ num_merges = self.target_vocab_size - current_vocab_size
110
+
111
+ for i in range(num_merges):
112
+ # Count adjacent pairs
113
+ pair_counts: Counter = Counter()
114
+ for word, freq in word_freqs.items():
115
+ for j in range(len(word) - 1):
116
+ pair_counts[(word[j], word[j + 1])] += freq
117
+
118
+ if not pair_counts:
119
+ break
120
+
121
+ # Find most frequent pair
122
+ best_pair = pair_counts.most_common(1)[0][0]
123
+ new_id = current_vocab_size
124
+
125
+ # Register merge
126
+ self.merges.append(best_pair)
127
+ self.merge_to_id[best_pair] = new_id
128
+
129
+ # Apply merge to all words
130
+ new_word_freqs: Counter = Counter()
131
+ for word, freq in word_freqs.items():
132
+ new_word = self._apply_merge(word, best_pair, new_id)
133
+ new_word_freqs[new_word] += freq
134
+ word_freqs = new_word_freqs
135
+
136
+ current_vocab_size += 1
137
+
138
+ if verbose and (i + 1) % 1000 == 0:
139
+ print(f" Merge {i + 1}/{num_merges}: "
140
+ f"({best_pair[0]}, {best_pair[1]}) β†’ {new_id}, "
141
+ f"freq={pair_counts[best_pair]}")
142
+
143
+ self._build_vocab()
144
+ if verbose:
145
+ print(f"Done! Final vocab size: {self.vocab_size:,}")
146
+
147
+ @staticmethod
148
+ def _apply_merge(
149
+ word: Tuple[int, ...], pair: Tuple[int, int], new_id: int
150
+ ) -> Tuple[int, ...]:
151
+ """Apply a single merge rule to a word."""
152
+ result = []
153
+ i = 0
154
+ while i < len(word):
155
+ if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
156
+ result.append(new_id)
157
+ i += 2
158
+ else:
159
+ result.append(word[i])
160
+ i += 1
161
+ return tuple(result)
162
+
163
+ # ── Encoding ────────────────────────────────────────────────────
164
+
165
+ def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
166
+ """
167
+ Encode text to token IDs.
168
+
169
+ Args:
170
+ text: Input text
171
+ add_special_tokens: Whether to wrap with <bos>/<eos>
172
+
173
+ Returns:
174
+ List of token IDs
175
+ """
176
+ tokens = []
177
+
178
+ # Check for special tokens in the text
179
+ parts = self._split_special_tokens(text)
180
+
181
+ for part, is_special in parts:
182
+ if is_special:
183
+ tokens.append(self.special_tokens[part])
184
+ else:
185
+ # Pre-tokenize
186
+ words = re.findall(self.PAT, part)
187
+ for word in words:
188
+ # Convert to byte IDs
189
+ byte_ids = list(self.byte_to_id[b] for b in word.encode("utf-8"))
190
+ # Apply merges in order
191
+ for pair, new_id in zip(self.merges, range(self.num_special + 256, self.vocab_size)):
192
+ i = 0
193
+ while i < len(byte_ids) - 1:
194
+ if (byte_ids[i], byte_ids[i + 1]) == pair:
195
+ byte_ids[i] = new_id
196
+ del byte_ids[i + 1]
197
+ else:
198
+ i += 1
199
+ tokens.extend(byte_ids)
200
+
201
+ if add_special_tokens:
202
+ tokens = [self.special_tokens["<bos>"]] + tokens + [self.special_tokens["<eos>"]]
203
+
204
+ return tokens
205
+
206
+ def _split_special_tokens(self, text: str) -> List[Tuple[str, bool]]:
207
+ """Split text on special token boundaries."""
208
+ # Build regex to match special tokens
209
+ pattern = "|".join(re.escape(tok) for tok in self.special_tokens.keys())
210
+ if not pattern:
211
+ return [(text, False)]
212
+
213
+ parts = []
214
+ last_end = 0
215
+ for match in re.finditer(pattern, text):
216
+ if match.start() > last_end:
217
+ parts.append((text[last_end:match.start()], False))
218
+ parts.append((match.group(), True))
219
+ last_end = match.end()
220
+ if last_end < len(text):
221
+ parts.append((text[last_end:], False))
222
+ return parts
223
+
224
+ # ── Decoding ────────────────────────────────────────────────────
225
+
226
+ def decode(self, ids: List[int], skip_special: bool = True) -> str:
227
+ """
228
+ Decode token IDs to text.
229
+
230
+ Args:
231
+ ids: List of token IDs
232
+ skip_special: Whether to skip special tokens
233
+
234
+ Returns:
235
+ Decoded text string
236
+ """
237
+ byte_chunks = []
238
+ for idx in ids:
239
+ if idx in self.special_tokens.values():
240
+ if not skip_special:
241
+ # Find the special token string
242
+ for tok, tid in self.special_tokens.items():
243
+ if tid == idx:
244
+ byte_chunks.append(tok.encode("utf-8"))
245
+ break
246
+ elif idx in self.vocab:
247
+ byte_chunks.append(self.vocab[idx])
248
+ return b"".join(byte_chunks).decode("utf-8", errors="replace")
249
+
250
+ # ── Chat Formatting ─────────────────────────────────────────────
251
+
252
+ def encode_chat(
253
+ self,
254
+ messages: List[Dict[str, str]],
255
+ add_generation_prompt: bool = True,
256
+ ) -> List[int]:
257
+ """
258
+ Encode a chat conversation into token IDs.
259
+
260
+ Args:
261
+ messages: List of {"role": "system"|"user"|"assistant", "content": "..."}
262
+ add_generation_prompt: Add the assistant turn start token at the end
263
+
264
+ Returns:
265
+ List of token IDs
266
+ """
267
+ tokens = [self.special_tokens["<bos>"]]
268
+
269
+ for msg in messages:
270
+ role = msg["role"]
271
+ content = msg["content"]
272
+
273
+ if role == "system":
274
+ tokens.append(self.special_tokens["<|system|>"])
275
+ elif role == "user":
276
+ tokens.append(self.special_tokens["<|user|>"])
277
+ elif role == "assistant":
278
+ tokens.append(self.special_tokens["<|assistant|>"])
279
+
280
+ tokens.extend(self.encode(content))
281
+ tokens.append(self.special_tokens["<|end|>"])
282
+
283
+ if add_generation_prompt:
284
+ tokens.append(self.special_tokens["<|assistant|>"])
285
+
286
+ return tokens
287
+
288
+ # ── Save / Load ─────────────────────────────────────────────────
289
+
290
+ def save(self, path: str):
291
+ """Save tokenizer to JSON."""
292
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
293
+ data = {
294
+ "target_vocab_size": self.target_vocab_size,
295
+ "merges": self.merges,
296
+ }
297
+ with open(path, "w") as f:
298
+ json.dump(data, f)
299
+
300
+ @classmethod
301
+ def load(cls, path: str) -> "BPETokenizer":
302
+ """Load tokenizer from JSON."""
303
+ with open(path) as f:
304
+ data = json.load(f)
305
+ tok = cls(vocab_size=data["target_vocab_size"])
306
+ tok.merges = [tuple(m) for m in data["merges"]]
307
+ tok.merge_to_id = {
308
+ tuple(pair): idx
309
+ for idx, pair in enumerate(tok.merges, start=tok.num_special + 256)
310
+ }
311
+ tok._build_vocab()
312
+ return tok
313
+
314
+
315
+ # ═══════════════════════════════════════════════════════════════════════
316
+ # QUICK TEST
317
+ # ═══════════════════════════════════════════════════════════════════════
318
+
319
+ if __name__ == "__main__":
320
+ tok = BPETokenizer(vocab_size=500)
321
+
322
+ sample = (
323
+ "Hello, world! This is a test of the BPE tokenizer. "
324
+ "The quick brown fox jumps over the lazy dog. "
325
+ "Machine learning is fascinating and powerful. " * 20
326
+ )
327
+
328
+ tok.train(sample, verbose=True)
329
+
330
+ text = "Hello, world! Machine learning is great."
331
+ ids = tok.encode(text)
332
+ decoded = tok.decode(ids)
333
+ print(f"\nOriginal: {text}")
334
+ print(f"Token IDs: {ids[:20]}...")
335
+ print(f"Decoded: {decoded}")
336
+
337
+ # Test chat encoding
338
+ chat = [
339
+ {"role": "system", "content": "You are helpful."},
340
+ {"role": "user", "content": "Hello!"},
341
+ ]
342
+ chat_ids = tok.encode_chat(chat)
343
+ print(f"\nChat IDs: {chat_ids[:20]}...")
344
+ print(f"Chat decoded: {tok.decode(chat_ids, skip_special=False)}")
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "BPETokenizer",
3
+ "model_type": "gpt-300m",
4
+ "vocab_size": 32000,
5
+ "model_max_length": 2048,
6
+ "padding_side": "right",
7
+ "bos_token": "<bos>",
8
+ "eos_token": "<eos>",
9
+ "pad_token": "<pad>",
10
+ "unk_token": "<unk>"
11
+ }
train.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M Training Script
3
+ =========================
4
+ Full training pipeline with:
5
+ - Mixed-precision training (bf16/fp16)
6
+ - Gradient accumulation
7
+ - Cosine learning rate schedule with warmup
8
+ - Gradient clipping
9
+ - Periodic evaluation & checkpointing
10
+ - Distributed Data Parallel (DDP) support
11
+ - Weights & Biases logging
12
+ - torch.compile support
13
+
14
+ Usage:
15
+ # Single GPU
16
+ python train.py
17
+
18
+ # Multi-GPU with DDP
19
+ torchrun --nproc_per_node=4 train.py
20
+
21
+ # With custom config
22
+ python train.py --d_model 768 --n_layers 12 --batch_size 64
23
+ """
24
+
25
+ import argparse
26
+ import math
27
+ import os
28
+ import sys
29
+ import time
30
+ from contextlib import nullcontext
31
+ from typing import Optional
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.distributed as dist
36
+ from torch.nn.parallel import DistributedDataParallel as DDP
37
+
38
+ from config import GPT300MConfig, gpt_300m, gpt_tiny
39
+ from model import GPT300M
40
+ from tokenizer import BPETokenizer
41
+ from dataset import TextDataset, ChatDataset, create_dataloaders, collate_fn
42
+
43
+
44
+ # ═══════════════════════════════════════════════════════════════════════
45
+ # LEARNING RATE SCHEDULER
46
+ # ═══════════════════════════════════════════════════════════════════════
47
+
48
+ def get_lr(step: int, config: GPT300MConfig) -> float:
49
+ """Cosine decay with linear warmup."""
50
+ # Linear warmup
51
+ if step < config.warmup_steps:
52
+ return config.learning_rate * step / config.warmup_steps
53
+
54
+ # Cosine decay
55
+ if step > config.max_steps:
56
+ return config.min_learning_rate
57
+
58
+ decay_ratio = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
59
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
60
+ return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
61
+
62
+
63
+ # ═══════════════════════════════════════════════════════════════════════
64
+ # TRAINING LOOP
65
+ # ═══════════════════════════════════════════════════════════════════════
66
+
67
+ class Trainer:
68
+ """
69
+ Full-featured training loop for GPT-300M.
70
+ """
71
+
72
+ def __init__(self, config: GPT300MConfig, resume_from: Optional[str] = None):
73
+ self.config = config
74
+ self.setup_distributed()
75
+ self.setup_device()
76
+ self.setup_model()
77
+ self.setup_optimizer()
78
+ self.global_step = 0
79
+ self.best_val_loss = float("inf")
80
+
81
+ if resume_from:
82
+ self.load_checkpoint(resume_from)
83
+
84
+ def setup_distributed(self):
85
+ """Setup DDP if running in distributed mode."""
86
+ self.ddp = int(os.environ.get("RANK", -1)) != -1
87
+ if self.ddp:
88
+ dist.init_process_group(backend="nccl")
89
+ self.ddp_rank = int(os.environ["RANK"])
90
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
91
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"])
92
+ self.master_process = self.ddp_rank == 0
93
+ else:
94
+ self.ddp_rank = 0
95
+ self.ddp_local_rank = 0
96
+ self.ddp_world_size = 1
97
+ self.master_process = True
98
+
99
+ def setup_device(self):
100
+ """Configure device and mixed precision."""
101
+ cfg = self.config
102
+
103
+ if cfg.device == "auto":
104
+ if torch.cuda.is_available():
105
+ self.device = f"cuda:{self.ddp_local_rank}" if self.ddp else "cuda"
106
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
107
+ self.device = "mps"
108
+ else:
109
+ self.device = "cpu"
110
+ else:
111
+ self.device = cfg.device
112
+
113
+ # Mixed precision context
114
+ if "cuda" in self.device:
115
+ if cfg.dtype == "bfloat16" and torch.cuda.is_bf16_supported():
116
+ self.dtype = torch.bfloat16
117
+ self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
118
+ elif cfg.dtype == "float16":
119
+ self.dtype = torch.float16
120
+ self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16)
121
+ else:
122
+ self.dtype = torch.float32
123
+ self.amp_ctx = nullcontext()
124
+ self.scaler = torch.amp.GradScaler("cuda", enabled=(cfg.dtype == "float16"))
125
+ else:
126
+ self.dtype = torch.float32
127
+ self.amp_ctx = nullcontext()
128
+ self.scaler = torch.amp.GradScaler(enabled=False)
129
+
130
+ if self.master_process:
131
+ print(f"Device: {self.device}, dtype: {cfg.dtype}")
132
+
133
+ def setup_model(self):
134
+ """Initialize or load model."""
135
+ self.model = GPT300M(self.config).to(self.device)
136
+
137
+ if self.master_process:
138
+ print(self.model.model_summary())
139
+
140
+ # Compile model (PyTorch 2.0+)
141
+ if self.config.compile_model and hasattr(torch, "compile"):
142
+ if self.master_process:
143
+ print("Compiling model with torch.compile...")
144
+ self.model = torch.compile(self.model)
145
+
146
+ # Wrap in DDP
147
+ if self.ddp:
148
+ self.model = DDP(self.model, device_ids=[self.ddp_local_rank])
149
+
150
+ self.raw_model = self.model.module if self.ddp else self.model
151
+
152
+ def setup_optimizer(self):
153
+ """Configure AdamW optimizer with weight decay."""
154
+ cfg = self.config
155
+
156
+ # Separate parameters: decay vs no-decay
157
+ decay_params = []
158
+ nodecay_params = []
159
+ for name, param in self.raw_model.named_parameters():
160
+ if not param.requires_grad:
161
+ continue
162
+ if param.dim() >= 2:
163
+ decay_params.append(param)
164
+ else:
165
+ nodecay_params.append(param)
166
+
167
+ optim_groups = [
168
+ {"params": decay_params, "weight_decay": cfg.weight_decay},
169
+ {"params": nodecay_params, "weight_decay": 0.0},
170
+ ]
171
+
172
+ # Use fused AdamW if available (faster on CUDA)
173
+ use_fused = "cuda" in self.device and hasattr(torch.optim, "_multi_tensor")
174
+ self.optimizer = torch.optim.AdamW(
175
+ optim_groups,
176
+ lr=cfg.learning_rate,
177
+ betas=(cfg.beta1, cfg.beta2),
178
+ fused="cuda" in self.device,
179
+ )
180
+
181
+ if self.master_process:
182
+ n_decay = sum(p.numel() for p in decay_params)
183
+ n_nodecay = sum(p.numel() for p in nodecay_params)
184
+ print(f"Optimizer: {n_decay:,} decay params, {n_nodecay:,} no-decay params")
185
+
186
+ @torch.no_grad()
187
+ def evaluate(self, val_loader) -> float:
188
+ """Run evaluation and return average loss."""
189
+ self.model.eval()
190
+ total_loss = 0.0
191
+ n_batches = 0
192
+
193
+ for x, y in val_loader:
194
+ x, y = x.to(self.device), y.to(self.device)
195
+ with self.amp_ctx:
196
+ _, loss, _ = self.model(x, targets=y)
197
+ total_loss += loss.item()
198
+ n_batches += 1
199
+
200
+ if n_batches >= 50: # Limit eval batches
201
+ break
202
+
203
+ self.model.train()
204
+ return total_loss / max(n_batches, 1)
205
+
206
+ def save_checkpoint(self, path: Optional[str] = None):
207
+ """Save model checkpoint."""
208
+ if not self.master_process:
209
+ return
210
+
211
+ if path is None:
212
+ path = os.path.join(
213
+ self.config.output_dir,
214
+ f"checkpoint_step_{self.global_step}.pt",
215
+ )
216
+
217
+ os.makedirs(os.path.dirname(path), exist_ok=True)
218
+ checkpoint = {
219
+ "model_state_dict": self.raw_model.state_dict(),
220
+ "optimizer_state_dict": self.optimizer.state_dict(),
221
+ "config": self.config.__dict__,
222
+ "global_step": self.global_step,
223
+ "best_val_loss": self.best_val_loss,
224
+ }
225
+ torch.save(checkpoint, path)
226
+ print(f" Saved checkpoint: {path}")
227
+
228
+ def load_checkpoint(self, path: str):
229
+ """Load model checkpoint."""
230
+ checkpoint = torch.load(path, map_location=self.device)
231
+ self.raw_model.load_state_dict(checkpoint["model_state_dict"])
232
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
233
+ self.global_step = checkpoint.get("global_step", 0)
234
+ self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
235
+ if self.master_process:
236
+ print(f"Resumed from step {self.global_step}")
237
+
238
+ def train(self, train_loader, val_loader):
239
+ """
240
+ Main training loop.
241
+ """
242
+ cfg = self.config
243
+ model = self.model
244
+ optimizer = self.optimizer
245
+
246
+ model.train()
247
+ train_iter = iter(train_loader)
248
+
249
+ if self.master_process:
250
+ print(f"\n{'='*60}")
251
+ print(f" Starting training")
252
+ print(f" Effective batch size: {cfg.batch_size * cfg.gradient_accumulation_steps * self.ddp_world_size}")
253
+ print(f" Max steps: {cfg.max_steps:,}")
254
+ print(f"{'='*60}\n")
255
+
256
+ t0 = time.time()
257
+
258
+ for step in range(self.global_step, cfg.max_steps):
259
+ self.global_step = step
260
+
261
+ # Update learning rate
262
+ lr = get_lr(step, cfg)
263
+ for param_group in optimizer.param_groups:
264
+ param_group["lr"] = lr
265
+
266
+ # ── Gradient Accumulation Loop ──────────────────────────
267
+ optimizer.zero_grad(set_to_none=True)
268
+ accumulated_loss = 0.0
269
+
270
+ for micro_step in range(cfg.gradient_accumulation_steps):
271
+ # Get next batch (cycle through data)
272
+ try:
273
+ x, y = next(train_iter)
274
+ except StopIteration:
275
+ train_iter = iter(train_loader)
276
+ x, y = next(train_iter)
277
+
278
+ x, y = x.to(self.device), y.to(self.device)
279
+
280
+ # DDP sync only on last micro-step
281
+ if self.ddp:
282
+ model.require_backward_grad_sync = (
283
+ micro_step == cfg.gradient_accumulation_steps - 1
284
+ )
285
+
286
+ # Forward pass with mixed precision
287
+ with self.amp_ctx:
288
+ _, loss, _ = model(x, targets=y)
289
+ loss = loss / cfg.gradient_accumulation_steps
290
+
291
+ accumulated_loss += loss.item()
292
+
293
+ # Backward pass
294
+ self.scaler.scale(loss).backward()
295
+
296
+ # Gradient clipping
297
+ if cfg.max_grad_norm > 0:
298
+ self.scaler.unscale_(optimizer)
299
+ grad_norm = nn.utils.clip_grad_norm_(
300
+ model.parameters(), cfg.max_grad_norm
301
+ )
302
+ else:
303
+ grad_norm = 0.0
304
+
305
+ # Optimizer step
306
+ self.scaler.step(optimizer)
307
+ self.scaler.update()
308
+
309
+ # ── Logging ─────────────────────────────────────────────
310
+ if step % cfg.log_interval == 0 and self.master_process:
311
+ dt = time.time() - t0
312
+ tokens_per_sec = (
313
+ cfg.batch_size * cfg.max_seq_len
314
+ * cfg.gradient_accumulation_steps
315
+ * self.ddp_world_size
316
+ / dt
317
+ )
318
+ print(
319
+ f"step {step:>6d} | "
320
+ f"loss {accumulated_loss:.4f} | "
321
+ f"lr {lr:.2e} | "
322
+ f"grad_norm {grad_norm:.2f} | "
323
+ f"tok/s {tokens_per_sec:.0f} | "
324
+ f"dt {dt:.2f}s"
325
+ )
326
+ t0 = time.time()
327
+
328
+ # ── Evaluation ──────────────────────────────────────────
329
+ if step > 0 and step % cfg.eval_interval == 0 and self.master_process:
330
+ val_loss = self.evaluate(val_loader)
331
+ print(f" ✦ Validation loss: {val_loss:.4f}")
332
+
333
+ if val_loss < self.best_val_loss:
334
+ self.best_val_loss = val_loss
335
+ self.save_checkpoint(
336
+ os.path.join(cfg.output_dir, "best_model.pt")
337
+ )
338
+ print(f" ✦ New best! Saved best_model.pt")
339
+
340
+ # ── Checkpointing ───────────────────────────────────────
341
+ if step > 0 and step % cfg.save_interval == 0 and self.master_process:
342
+ self.save_checkpoint()
343
+
344
+ # Final save
345
+ if self.master_process:
346
+ self.save_checkpoint(
347
+ os.path.join(cfg.output_dir, "final_model.pt")
348
+ )
349
+ print("\n✦ Training complete!")
350
+
351
+ # Cleanup DDP
352
+ if self.ddp:
353
+ dist.destroy_process_group()
354
+
355
+
356
+ # ═══════════════════════════════════════════════════════════════════════
357
+ # MAIN
358
+ # ═══════════════════════════════════════════════════════════════════════
359
+
360
+ def main():
361
+ parser = argparse.ArgumentParser(description="Train GPT-300M")
362
+ parser.add_argument("--tiny", action="store_true", help="Use tiny config for debugging")
363
+ parser.add_argument("--data", type=str, default=None, help="Path to training text file")
364
+ parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
365
+ parser.add_argument("--d_model", type=int, default=None)
366
+ parser.add_argument("--n_layers", type=int, default=None)
367
+ parser.add_argument("--n_heads", type=int, default=None)
368
+ parser.add_argument("--batch_size", type=int, default=None)
369
+ parser.add_argument("--learning_rate", type=float, default=None)
370
+ parser.add_argument("--max_steps", type=int, default=None)
371
+ args = parser.parse_args()
372
+
373
+ # Config
374
+ config = gpt_tiny() if args.tiny else gpt_300m()
375
+
376
+ # Override config from CLI
377
+ for key in ["d_model", "n_layers", "n_heads", "batch_size", "learning_rate", "max_steps"]:
378
+ val = getattr(args, key, None)
379
+ if val is not None:
380
+ setattr(config, key, val)
381
+
382
+ # Seed
383
+ torch.manual_seed(config.seed)
384
+ if torch.cuda.is_available():
385
+ torch.cuda.manual_seed_all(config.seed)
386
+
387
+ # Tokenizer
388
+ tokenizer = BPETokenizer(vocab_size=config.vocab_size)
389
+
390
+ # Load data
391
+ if args.data and os.path.exists(args.data):
392
+ print(f"Loading data from {args.data}...")
393
+ with open(args.data, "r") as f:
394
+ text = f.read()
395
+ else:
396
+ # Generate synthetic data for demonstration
397
+ print("No data file provided. Generating synthetic training data...")
398
+ text = generate_synthetic_data()
399
+
400
+ # Train tokenizer on data
401
+ print("Training tokenizer...")
402
+ tokenizer.train(text, verbose=True)
403
+ tokenizer.save(os.path.join(config.output_dir, "tokenizer.json"))
404
+
405
+ # Create dataloaders
406
+ train_loader, val_loader = create_dataloaders(config, tokenizer, text=text)
407
+ print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
408
+
409
+ # Train!
410
+ trainer = Trainer(config, resume_from=args.resume)
411
+ trainer.train(train_loader, val_loader)
412
+
413
+
414
+ def generate_synthetic_data(n_samples: int = 10_000) -> str:
415
+ """Generate synthetic conversational data for demonstration."""
416
+ import random
417
+ random.seed(42)
418
+
419
+ greetings = ["Hello!", "Hi there!", "Hey!", "Good morning!", "Greetings!"]
420
+ questions = [
421
+ "What is machine learning?",
422
+ "How does gravity work?",
423
+ "What is the meaning of life?",
424
+ "Can you explain photosynthesis?",
425
+ "What are neural networks?",
426
+ "How do computers work?",
427
+ "What is quantum physics?",
428
+ "Tell me about the solar system.",
429
+ "How does the internet work?",
430
+ "What is artificial intelligence?",
431
+ ]
432
+ answers = [
433
+ "That's a great question! Machine learning is a subset of AI that enables systems to learn from data.",
434
+ "Gravity is a fundamental force that attracts objects with mass toward each other.",
435
+ "The meaning of life is a deeply philosophical question that has been debated for centuries.",
436
+ "Photosynthesis is the process by which plants convert sunlight into chemical energy.",
437
+ "Neural networks are computing systems inspired by biological neural networks in the brain.",
438
+ "Computers work by processing binary data through electronic circuits called transistors.",
439
+ "Quantum physics describes the behavior of matter and energy at the atomic scale.",
440
+ "The solar system consists of the Sun and everything that orbits around it.",
441
+ "The internet is a global network of interconnected computers that communicate using protocols.",
442
+ "Artificial intelligence is the simulation of human intelligence by computer systems.",
443
+ ]
444
+
445
+ lines = []
446
+ for _ in range(n_samples):
447
+ g = random.choice(greetings)
448
+ q = random.choice(questions)
449
+ a = random.choice(answers)
450
+ lines.append(f"User: {g} {q}\nAssistant: {a}\n")
451
+
452
+ return "\n".join(lines)
453
+
454
+
455
+ if __name__ == "__main__":
456
+ main()
visual_nn_3d.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M 3D Neural Network Visualization
3
+ ==========================================
4
+ A 3D node-and-connection neural network diagram with depth,
5
+ perspective, and accurate parameter counts.
6
+ """
7
+
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
+
11
+ import matplotlib.pyplot as plt
12
+ from mpl_toolkits.mplot3d import Axes3D
13
+ from mpl_toolkits.mplot3d.art3d import Line3DCollection
14
+ import numpy as np
15
+
16
+ # ═══════════════════════════════════════════════════════════════════════
17
+ # ACCURATE GPT-300M PARAMETERS
18
+ # ═══════════════════════════════════════════════════════════════════════
19
+
20
+ VOCAB = 32_000
21
+ D = 1_024
22
+ HEADS = 16
23
+ HEAD_D = 64
24
+ D_FF = 4_096
25
+ N_LAYERS = 24
26
+
27
+ embed_p = VOCAB * D # 32,768,000
28
+ attn_p = 4 * D * D # 4,194,304 per layer
29
+ ffn_p = 2 * D * D_FF # 8,388,608 per layer
30
+ norm_p = 2 * D # 2,048 per layer
31
+ layer_p = attn_p + ffn_p + norm_p # 12,584,960 per layer
32
+ all_layers_p = layer_p * N_LAYERS # 302,039,040
33
+ final_norm_p = D # 1,024
34
+ TOTAL = embed_p + all_layers_p + final_norm_p # 334,808,064
35
+
36
+ # Layer definitions: (name, num_display_nodes, actual_neurons, params, color_hex)
37
+ LAYERS = [
38
+ ("Input Tokens", 10, VOCAB, 0, "#4CAF50"),
39
+ ("Token Embedding", 12, D, embed_p, "#2196F3"),
40
+ ("RoPE Positions", 12, D, 0, "#00BCD4"),
41
+ ("Layer 1: Attention QKV", 14, D, attn_p * 3 // 4, "#FF9800"),
42
+ ("Layer 1: Attention Out", 12, D, attn_p * 1 // 4, "#FF9800"),
43
+ ("Layer 1: FFN Up (GELU)", 16, D_FF, ffn_p // 2, "#8BC34A"),
44
+ ("Layer 1: FFN Down", 12, D, ffn_p // 2, "#8BC34A"),
45
+ ("Layers 2–23 (Γ—22)", 14, D, layer_p * 22, "#9C27B0"),
46
+ ("Layer 24: Attention", 14, D, attn_p, "#FF5722"),
47
+ ("Layer 24: FFN", 16, D_FF, ffn_p, "#009688"),
48
+ ("Layer 24: Norm + Out", 12, D, norm_p + final_norm_p, "#E91E63"),
49
+ ("LM Head (weight-tied)", 12, VOCAB, 0, "#F44336"),
50
+ ("Output Probabilities", 1, VOCAB, 0, "#FF1744"),
51
+ ]
52
+
53
+
54
+ def hex_to_rgb(h):
55
+ h = h.lstrip("#")
56
+ return tuple(int(h[i:i+2], 16) / 255.0 for i in (0, 2, 4))
57
+
58
+
59
+ def generate_3d_network(save_path="neural_network_3d.png", elev=22, azim=-65):
60
+ """Generate a 3D neural network with nodes, connections, and parameter labels."""
61
+
62
+ fig = plt.figure(figsize=(28, 28), facecolor="#0a0e17")
63
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
64
+
65
+ # Dark theme for 3D axes
66
+ ax.set_facecolor("#0a0e17")
67
+ ax.xaxis.pane.fill = False
68
+ ax.yaxis.pane.fill = False
69
+ ax.zaxis.pane.fill = False
70
+ ax.xaxis.pane.set_edgecolor("#0a0e17")
71
+ ax.yaxis.pane.set_edgecolor("#0a0e17")
72
+ ax.zaxis.pane.set_edgecolor("#0a0e17")
73
+ ax.grid(False)
74
+ ax.set_axis_off()
75
+
76
+ ax.view_init(elev=elev, azim=azim)
77
+
78
+ n_layers = len(LAYERS)
79
+ y_positions = np.linspace(0, n_layers * 4.0, n_layers) # depth (layer position)
80
+
81
+ all_positions = [] # list of (xs, ys_unused, zs, y_layer)
82
+ running_params = 0
83
+
84
+ for i, (name, n_nodes, actual, params, color_hex) in enumerate(LAYERS):
85
+ y = y_positions[i]
86
+ running_params += params
87
+
88
+ rgb = hex_to_rgb(color_hex)
89
+
90
+ # Arrange nodes in a circle/arc for 3D effect
91
+ if n_nodes == 1:
92
+ xs = np.array([0.0])
93
+ zs = np.array([0.0])
94
+ else:
95
+ # Spread nodes along x
96
+ spread = min(n_nodes * 0.5, 7.0)
97
+ xs = np.linspace(-spread, spread, n_nodes)
98
+ # Slight arc for 3D depth perception
99
+ zs = -0.1 * (xs ** 2)
100
+
101
+ ys = np.full_like(xs, y)
102
+ all_positions.append((xs, ys, zs))
103
+
104
+ # ── Draw connections to previous layer ──────────────────
105
+ if i > 0:
106
+ prev_xs, prev_ys, prev_zs = all_positions[i - 1]
107
+
108
+ # Sample connections to avoid clutter
109
+ n_prev = len(prev_xs)
110
+ n_curr = len(xs)
111
+ step_p = max(1, n_prev // 8)
112
+ step_c = max(1, n_curr // 8)
113
+
114
+ lines = []
115
+ colors_lines = []
116
+ for pi in range(0, n_prev, step_p):
117
+ for ci in range(0, n_curr, step_c):
118
+ lines.append([
119
+ (prev_xs[pi], prev_ys[pi], prev_zs[pi]),
120
+ (xs[ci], ys[ci], zs[ci]),
121
+ ])
122
+ colors_lines.append((*rgb, 0.18))
123
+
124
+ if lines:
125
+ lc = Line3DCollection(lines, colors=colors_lines, linewidths=0.7)
126
+ ax.add_collection3d(lc)
127
+
128
+ # ── Draw nodes (spheres) ────────────────────────────────
129
+ node_size = 200 if n_nodes > 12 else 280
130
+ if n_nodes == 1:
131
+ node_size = 600
132
+
133
+ ax.scatter(
134
+ xs, ys, zs,
135
+ c=[color_hex], s=node_size,
136
+ alpha=0.95, edgecolors="white", linewidths=0.5,
137
+ depthshade=True, zorder=5,
138
+ )
139
+
140
+ # ── Glow effect (larger transparent scatter behind) ─────
141
+ ax.scatter(
142
+ xs, ys, zs,
143
+ c=[color_hex], s=node_size * 3,
144
+ alpha=0.08, edgecolors="none",
145
+ depthshade=True, zorder=4,
146
+ )
147
+
148
+ # ── Labels ──────────────────────────────────────────────
149
+ label_x = xs[-1] + 1.8 if n_nodes > 1 else 2.0
150
+ ax.text(
151
+ label_x, y, 0,
152
+ name,
153
+ fontsize=9.5, fontweight="bold",
154
+ color="#E6EDF3", fontfamily="monospace",
155
+ zorder=10,
156
+ )
157
+
158
+ # Param count
159
+ if params > 0:
160
+ if params >= 1_000_000:
161
+ ptxt = f"{params/1e6:.1f}M params"
162
+ else:
163
+ ptxt = f"{params:,} params"
164
+ ax.text(
165
+ label_x, y, -1.0,
166
+ ptxt,
167
+ fontsize=8, color=color_hex,
168
+ fontfamily="monospace", fontweight="bold",
169
+ zorder=10,
170
+ )
171
+
172
+ # Running total
173
+ if running_params > 0:
174
+ ax.text(
175
+ label_x, y, -1.8,
176
+ f"Ξ£ {running_params/1e6:.1f}M",
177
+ fontsize=6, color="#8B949E",
178
+ fontfamily="monospace",
179
+ zorder=10,
180
+ )
181
+
182
+ # Overflow indicator
183
+ if actual > n_nodes and n_nodes > 1:
184
+ ax.text(
185
+ xs[-1] + 0.5, y, zs[-1],
186
+ f"(+{actual - n_nodes:,})",
187
+ fontsize=6, color="#8B949E",
188
+ fontfamily="monospace",
189
+ zorder=10,
190
+ )
191
+
192
+ # ── Title ──────────────────────────────────────────────────────
193
+ ax.text2D(
194
+ 0.5, 0.96,
195
+ "GPT-300M β€’ 3D Neural Network Architecture",
196
+ transform=fig.transFigure,
197
+ fontsize=22, fontweight="bold", color="#E6EDF3",
198
+ ha="center", fontfamily="monospace",
199
+ )
200
+ ax.text2D(
201
+ 0.5, 0.94,
202
+ f"{TOTAL:,} parameters | {N_LAYERS} layers | {HEADS} heads | d_model={D} | d_ff={D_FF}",
203
+ transform=fig.transFigure,
204
+ fontsize=10, color="#8B949E",
205
+ ha="center", fontfamily="monospace",
206
+ )
207
+
208
+ # ── Parameter summary ──────────────────────────────────────────
209
+ summary = (
210
+ f"Parameter Breakdown:\n"
211
+ f" Embedding: {embed_p/1e6:>7.1f}M ({embed_p/TOTAL*100:.1f}%)\n"
212
+ f" Attention Γ—24: {attn_p*N_LAYERS/1e6:>7.1f}M ({attn_p*N_LAYERS/TOTAL*100:.1f}%)\n"
213
+ f" FFN Γ—24: {ffn_p*N_LAYERS/1e6:>7.1f}M ({ffn_p*N_LAYERS/TOTAL*100:.1f}%)\n"
214
+ f" Norms: {(norm_p*N_LAYERS+final_norm_p)/1e6:>7.3f}M ({(norm_p*N_LAYERS+final_norm_p)/TOTAL*100:.1f}%)\n"
215
+ f" LM Head: tied (0 extra)\n"
216
+ f" ───────────────────────\n"
217
+ f" TOTAL: {TOTAL/1e6:>7.1f}M"
218
+ )
219
+ ax.text2D(
220
+ 0.02, 0.06, summary,
221
+ transform=fig.transFigure,
222
+ fontsize=8, color="#58A6FF",
223
+ fontfamily="monospace", verticalalignment="bottom",
224
+ bbox=dict(boxstyle="round,pad=0.6", facecolor="#161B22",
225
+ edgecolor="#30363D", linewidth=1),
226
+ )
227
+
228
+ # ── Legend ──────────────────────────────────────────────────────
229
+ legend_items = [
230
+ ("#4CAF50", "Input"), ("#2196F3", "Embeddings"), ("#FF9800", "Attention"),
231
+ ("#8BC34A", "FFN"), ("#9C27B0", "Γ—22 Layers"), ("#E91E63", "Norm"),
232
+ ("#F44336", "Output"),
233
+ ]
234
+ for j, (c, l) in enumerate(legend_items):
235
+ ax.text2D(
236
+ 0.92, 0.30 - j * 0.025, f"● {l}",
237
+ transform=fig.transFigure,
238
+ fontsize=8, color=c, fontfamily="monospace",
239
+ )
240
+
241
+ # Set axis limits
242
+ all_x = np.concatenate([p[0] for p in all_positions])
243
+ all_y = np.concatenate([p[1] for p in all_positions])
244
+ all_z = np.concatenate([p[2] for p in all_positions])
245
+ margin = 4
246
+ ax.set_xlim(all_x.min() - margin, all_x.max() + margin + 8)
247
+ ax.set_ylim(all_y.min() - margin, all_y.max() + margin)
248
+ ax.set_zlim(all_z.min() - margin, all_z.max() + margin)
249
+
250
+ plt.savefig(save_path, dpi=200, bbox_inches="tight",
251
+ facecolor="#0a0e17", edgecolor="none")
252
+ print(f"Saved: {save_path}")
253
+ plt.close()
254
+
255
+
256
+ def generate_3d_single_layer(save_path="layer_3d.png", elev=18, azim=-55):
257
+ """3D view of a single transformer layer internals."""
258
+
259
+ fig = plt.figure(figsize=(22, 18), facecolor="#0a0e17")
260
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
261
+
262
+ ax.set_facecolor("#0a0e17")
263
+ ax.xaxis.pane.fill = False
264
+ ax.yaxis.pane.fill = False
265
+ ax.zaxis.pane.fill = False
266
+ ax.xaxis.pane.set_edgecolor("#0a0e17")
267
+ ax.yaxis.pane.set_edgecolor("#0a0e17")
268
+ ax.zaxis.pane.set_edgecolor("#0a0e17")
269
+ ax.grid(False)
270
+ ax.set_axis_off()
271
+ ax.view_init(elev=elev, azim=azim)
272
+
273
+ sub_layers = [
274
+ ("Input (d=1024)", 10, D, 0, "#2196F3"),
275
+ ("Query (d=1024)", 10, D, D*D, "#FF6B6B"),
276
+ ("Key (d=1024)", 10, D, D*D, "#4ECDC4"),
277
+ ("Value (d=1024)", 10, D, D*D, "#45B7D1"),
278
+ ("16 Attention Heads", 16, D, 0, "#FF9800"),
279
+ ("Attn Output (d=1024)", 10, D, D*D, "#FFA726"),
280
+ ("βŠ• Residual + RMSNorm", 10, D, D, "#E91E63"),
281
+ ("FFN Up β†’ GELU (d=4096)", 16, D_FF, D*D_FF, "#8BC34A"),
282
+ ("FFN Down (d=1024)", 10, D, D_FF*D, "#7CB342"),
283
+ ("βŠ• Residual + RMSNorm", 10, D, D, "#E91E63"),
284
+ ("Layer Output (d=1024)", 10, D, 0, "#2196F3"),
285
+ ]
286
+
287
+ n = len(sub_layers)
288
+ y_positions = np.linspace(0, n * 3, n)
289
+ all_pos = []
290
+
291
+ for i, (name, n_nodes, actual, params, chex) in enumerate(sub_layers):
292
+ y = y_positions[i]
293
+ rgb = hex_to_rgb(chex)
294
+
295
+ spread = min(n_nodes * 0.45, 5.5)
296
+ xs = np.linspace(-spread, spread, n_nodes)
297
+ zs = -0.12 * (xs ** 2)
298
+ ys = np.full_like(xs, y)
299
+ all_pos.append((xs, ys, zs))
300
+
301
+ # Connections
302
+ if i > 0:
303
+ pxs, pys, pzs = all_pos[i - 1]
304
+ sp = max(1, len(pxs) // 8)
305
+ sc = max(1, len(xs) // 8)
306
+ lines = []
307
+ cols = []
308
+ for pi in range(0, len(pxs), sp):
309
+ for ci in range(0, len(xs), sc):
310
+ lines.append([(pxs[pi], pys[pi], pzs[pi]), (xs[ci], ys[ci], zs[ci])])
311
+ cols.append((*rgb, 0.15))
312
+ if lines:
313
+ ax.add_collection3d(Line3DCollection(lines, colors=cols, linewidths=0.6))
314
+
315
+ # Nodes
316
+ sz = 130 if n_nodes > 12 else 180
317
+ ax.scatter(xs, ys, zs, c=[chex], s=sz, alpha=0.95,
318
+ edgecolors="white", linewidths=0.5, depthshade=True, zorder=5)
319
+ ax.scatter(xs, ys, zs, c=[chex], s=sz * 3, alpha=0.07,
320
+ edgecolors="none", depthshade=True, zorder=4)
321
+
322
+ # Labels
323
+ lx = xs[-1] + 1.0
324
+ ax.text(lx, y, 0, name, fontsize=9, fontweight="bold",
325
+ color="#E6EDF3", fontfamily="monospace", zorder=10)
326
+ if params > 0:
327
+ ax.text(lx, y, -0.8, f"{params:,} params",
328
+ fontsize=7, color=chex, fontfamily="monospace",
329
+ fontweight="bold", zorder=10)
330
+
331
+ if actual > n_nodes:
332
+ ax.text(xs[-1] + 0.4, y, zs[-1], f"(+{actual-n_nodes:,})",
333
+ fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10)
334
+
335
+ ax.text2D(0.5, 0.96, "Single Transformer Layer β€” 3D View",
336
+ transform=fig.transFigure, fontsize=20, fontweight="bold",
337
+ color="#E6EDF3", ha="center", fontfamily="monospace")
338
+ ax.text2D(0.5, 0.935,
339
+ f"12,584,960 params/layer Γ— 24 layers = 302,039,040 total",
340
+ transform=fig.transFigure, fontsize=10, color="#8B949E",
341
+ ha="center", fontfamily="monospace")
342
+
343
+ all_x = np.concatenate([p[0] for p in all_pos])
344
+ all_y = np.concatenate([p[1] for p in all_pos])
345
+ all_z = np.concatenate([p[2] for p in all_pos])
346
+ ax.set_xlim(all_x.min() - 2, all_x.max() + 8)
347
+ ax.set_ylim(all_y.min() - 2, all_y.max() + 2)
348
+ ax.set_zlim(all_z.min() - 2, all_z.max() + 2)
349
+
350
+ plt.savefig(save_path, dpi=200, bbox_inches="tight",
351
+ facecolor="#0a0e17", edgecolor="none")
352
+ print(f"Saved: {save_path}")
353
+ plt.close()
354
+
355
+
356
+ def generate_3d_rotating_views(base_path="viz"):
357
+ """Generate multiple angle views."""
358
+ import os
359
+ os.makedirs(base_path, exist_ok=True)
360
+
361
+ # Main dramatic angle β€” more front-facing
362
+ generate_3d_network(f"{base_path}/nn_3d_main.png", elev=12, azim=-15)
363
+
364
+ # Angled view
365
+ generate_3d_network(f"{base_path}/nn_3d_top.png", elev=35, azim=-25)
366
+
367
+ # Side angle
368
+ generate_3d_network(f"{base_path}/nn_3d_side.png", elev=8, azim=-45)
369
+
370
+ # Single layer detail
371
+ generate_3d_single_layer(f"{base_path}/nn_3d_layer.png", elev=18, azim=-55)
372
+
373
+
374
+ if __name__ == "__main__":
375
+ import os
376
+ os.makedirs("viz", exist_ok=True)
377
+
378
+ print("=" * 55)
379
+ print(" GPT-300M β€’ 3D Visualization Generator")
380
+ print("=" * 55)
381
+ print(f" Total parameters: {TOTAL:,}")
382
+ print(f" Per layer: {layer_p:,}")
383
+ print(f" Layers: {N_LAYERS}")
384
+ print("=" * 55)
385
+
386
+ generate_3d_rotating_views("viz")
387
+ print("\nAll 3D views generated!")
visual_nn_nodes.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M Visual Neural Network β€” Node & Connection Style
3
+ ==========================================================
4
+ Generates a classic neural network diagram (like the user's reference)
5
+ with nodes and connection lines, accurately showing the GPT-300M architecture
6
+ with correct parameter calculations at each layer.
7
+ """
8
+
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+ import numpy as np
15
+
16
+ # ═══════════════════════════════════════════════════════════════════════
17
+ # GPT-300M ARCHITECTURE β€” ACCURATE PARAMETER COUNTS
18
+ # ═══════════════════════════════════════════════════════════════════════
19
+
20
+ # All layer definitions with EXACT parameter counts
21
+ # Format: (layer_name, display_nodes, actual_neurons, params_in_layer, color)
22
+
23
+ VOCAB_SIZE = 32_000
24
+ D_MODEL = 1_024
25
+ N_HEADS = 16
26
+ HEAD_DIM = 64
27
+ D_FF = 4_096
28
+ N_LAYERS = 24
29
+
30
+ # Parameter calculations per component:
31
+ embed_params = VOCAB_SIZE * D_MODEL # 32,768,000
32
+ # RoPE has no learned parameters (precomputed sin/cos)
33
+ rope_params = 0
34
+
35
+ # Per transformer layer:
36
+ qkv_params = 3 * D_MODEL * D_MODEL # 3,145,728 (Q, K, V projections)
37
+ out_proj_params = D_MODEL * D_MODEL # 1,048,576 (output projection)
38
+ attn_total = qkv_params + out_proj_params # 4,194,304
39
+
40
+ ffn_up_params = D_MODEL * D_FF # 4,194,304 (up projection)
41
+ ffn_down_params = D_FF * D_MODEL # 4,194,304 (down projection)
42
+ ffn_total = ffn_up_params + ffn_down_params # 8,388,608
43
+
44
+ rmsnorm_params = D_MODEL * 2 # 2,048 (2 norms per layer)
45
+ layer_total = attn_total + ffn_total + rmsnorm_params # 12,584,960
46
+
47
+ all_layers_total = layer_total * N_LAYERS # 302,039,040
48
+
49
+ final_norm_params = D_MODEL # 1,024
50
+ # LM Head is weight-tied with embedding, so 0 extra params
51
+ lm_head_params = 0 # (tied)
52
+
53
+ TOTAL_PARAMS = embed_params + all_layers_total + final_norm_params + lm_head_params
54
+ # = 32,768,000 + 302,039,040 + 1,024 = 334,808,064
55
+ # With weight tying, unique params β‰ˆ 334,808,064
56
+
57
+ # ═══════════════════════════════════════════════════════════════════════
58
+ # LAYER DEFINITIONS FOR VISUALIZATION
59
+ # ═══════════════════════════════════════════════════════════════════════
60
+
61
+ # (name, nodes_to_display, actual_size, params_to_this_layer, color)
62
+ LAYERS = [
63
+ ("Input Tokens", 10, VOCAB_SIZE, 0, "#4CAF50"), # Green
64
+ ("Token Embedding", 10, D_MODEL, embed_params, "#2196F3"), # Blue
65
+ ("RoPE Positions", 10, D_MODEL, 0, "#00BCD4"), # Cyan
66
+
67
+ # Show 3 representative transformer layers (of 24)
68
+ ("Layer 1: Attention Q,K,V", 12, D_MODEL, qkv_params, "#FF9800"), # Orange
69
+ ("Layer 1: Attention Out", 10, D_MODEL, out_proj_params, "#FF9800"),
70
+ ("Layer 1: FFN Up", 14, D_FF, ffn_up_params, "#8BC34A"), # Light green
71
+ ("Layer 1: FFN Down", 10, D_MODEL, ffn_down_params, "#8BC34A"),
72
+
73
+ ("Layer 2–23: Γ—22 Blocks", 12, D_MODEL, layer_total * 22, "#9C27B0"), # Purple
74
+
75
+ ("Layer 24: Attention", 12, D_MODEL, attn_total, "#FF5722"), # Deep orange
76
+ ("Layer 24: FFN", 14, D_FF, ffn_total, "#009688"), # Teal
77
+ ("Layer 24: Output", 10, D_MODEL, rmsnorm_params, "#009688"),
78
+
79
+ ("Final RMSNorm", 10, D_MODEL, final_norm_params, "#E91E63"), # Pink
80
+ ("LM Head (tied)", 10, VOCAB_SIZE, lm_head_params, "#F44336"), # Red
81
+ ("Output Probabilities", 1, VOCAB_SIZE, 0, "#F44336"), # Red
82
+ ]
83
+
84
+
85
+ def draw_neural_network(save_path="neural_network.png"):
86
+ fig, ax = plt.subplots(figsize=(22, 30), facecolor="#0D1117")
87
+ ax.set_facecolor("#0D1117")
88
+
89
+ n_layers = len(LAYERS)
90
+ y_positions = np.linspace(0.92, 0.04, n_layers)
91
+
92
+ # Spacing
93
+ x_center = 0.5
94
+ max_spread = 0.38
95
+
96
+ all_node_positions = [] # Store (x_list, y) for connections
97
+
98
+ running_params = 0
99
+
100
+ for i, (name, n_display, actual_size, params, color) in enumerate(LAYERS):
101
+ y = y_positions[i]
102
+ running_params += params
103
+
104
+ # Calculate x positions for nodes
105
+ if n_display == 1:
106
+ xs = [x_center]
107
+ else:
108
+ xs = np.linspace(x_center - max_spread, x_center + max_spread, n_display)
109
+
110
+ all_node_positions.append((xs, y))
111
+
112
+ # Draw connections to previous layer
113
+ if i > 0:
114
+ prev_xs, prev_y = all_node_positions[i - 1]
115
+
116
+ # Limit connections for readability
117
+ max_connections = 200
118
+ step_curr = max(1, len(xs) // 12)
119
+ step_prev = max(1, len(prev_xs) // 12)
120
+
121
+ conn_count = 0
122
+ for px in prev_xs[::step_prev]:
123
+ for cx in xs[::step_curr]:
124
+ if conn_count > max_connections:
125
+ break
126
+ ax.plot(
127
+ [px, cx], [prev_y, y],
128
+ color=color, alpha=0.22, linewidth=0.6,
129
+ transform=ax.transAxes, zorder=1,
130
+ )
131
+ conn_count += 1
132
+
133
+ # Draw nodes
134
+ node_radius = 0.01 if n_display <= 12 else 0.008
135
+ if n_display == 1:
136
+ node_radius = 0.016
137
+
138
+ for x in xs:
139
+ circle = plt.Circle(
140
+ (x, y), node_radius,
141
+ facecolor=color, edgecolor="white",
142
+ linewidth=0.6, alpha=0.95,
143
+ transform=ax.transAxes, zorder=3,
144
+ )
145
+ ax.add_patch(circle)
146
+
147
+ # Draw "+N" indicator if actual size > displayed
148
+ if actual_size > n_display and n_display > 1:
149
+ extra = actual_size - n_display
150
+ if extra > 0:
151
+ ax.text(
152
+ xs[-1] + 0.03, y,
153
+ f"(+{extra:,})",
154
+ transform=ax.transAxes,
155
+ fontsize=7, color="#8B949E",
156
+ ha="left", va="center",
157
+ fontfamily="monospace",
158
+ )
159
+
160
+ # Layer label (left side)
161
+ ax.text(
162
+ 0.02, y,
163
+ name,
164
+ transform=ax.transAxes,
165
+ fontsize=9, fontweight="bold",
166
+ color="#E6EDF3",
167
+ ha="left", va="center",
168
+ fontfamily="monospace",
169
+ )
170
+
171
+ # Parameter count (right side)
172
+ if params > 0:
173
+ param_text = f"{params:,} params"
174
+ ax.text(
175
+ 0.98, y,
176
+ param_text,
177
+ transform=ax.transAxes,
178
+ fontsize=8,
179
+ color=color,
180
+ ha="right", va="center",
181
+ fontfamily="monospace",
182
+ fontweight="bold",
183
+ )
184
+
185
+ # Running total (far right, smaller)
186
+ if running_params > 0:
187
+ ax.text(
188
+ 0.98, y - 0.012,
189
+ f"Ξ£ {running_params / 1e6:.1f}M",
190
+ transform=ax.transAxes,
191
+ fontsize=6.5,
192
+ color="#8B949E",
193
+ ha="right", va="center",
194
+ fontfamily="monospace",
195
+ )
196
+
197
+ # ── Title ──────────────────────────────────────────────────────
198
+ ax.text(
199
+ 0.5, 0.97,
200
+ "GPT-300M Neural Network",
201
+ transform=ax.transAxes,
202
+ fontsize=24, fontweight="bold",
203
+ color="#E6EDF3", ha="center", va="center",
204
+ fontfamily="monospace",
205
+ )
206
+ ax.text(
207
+ 0.5, 0.955,
208
+ f"Total: {TOTAL_PARAMS:,} parameters β€’ {N_LAYERS} transformer layers β€’ "
209
+ f"{N_HEADS} attention heads β€’ d_model={D_MODEL}",
210
+ transform=ax.transAxes,
211
+ fontsize=9, color="#8B949E", ha="center", va="center",
212
+ fontfamily="monospace",
213
+ )
214
+
215
+ # ── Parameter Summary Box ──────────────────────────────────────
216
+ summary_y = 0.005
217
+ summary_text = (
218
+ f"β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ Parameter Summary ───────────────┐\n"
219
+ f"β”‚ Token Embedding: {embed_params:>13,} ({embed_params/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
220
+ f"β”‚ Attention (Γ—{N_LAYERS}): {attn_total*N_LAYERS:>13,} ({attn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
221
+ f"β”‚ Feed-Forward (Γ—{N_LAYERS}): {ffn_total*N_LAYERS:>13,} ({ffn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
222
+ f"β”‚ RMSNorm (Γ—{N_LAYERS}+1): {rmsnorm_params*N_LAYERS+final_norm_params:>13,} ({(rmsnorm_params*N_LAYERS+final_norm_params)/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
223
+ f"β”‚ LM Head (tied): {'0 (shared)':>13} β”‚\n"
224
+ f"β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n"
225
+ f"β”‚ TOTAL: {TOTAL_PARAMS:>13,} (100%) β”‚\n"
226
+ f"β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜"
227
+ )
228
+ ax.text(
229
+ 0.5, summary_y,
230
+ summary_text,
231
+ transform=ax.transAxes,
232
+ fontsize=8, color="#58A6FF",
233
+ ha="center", va="bottom",
234
+ fontfamily="monospace",
235
+ bbox=dict(boxstyle="round,pad=0.8", facecolor="#161B22",
236
+ edgecolor="#30363D", linewidth=1),
237
+ )
238
+
239
+ # ── Legend ──────────────────────────────────────────────────────
240
+ legend_items = [
241
+ ("#4CAF50", "Input / Tokenization"),
242
+ ("#2196F3", "Embeddings"),
243
+ ("#FF9800", "Self-Attention"),
244
+ ("#8BC34A", "Feed-Forward (GELU)"),
245
+ ("#9C27B0", "Collapsed Layers (Γ—22)"),
246
+ ("#E91E63", "Normalization"),
247
+ ("#F44336", "Output / LM Head"),
248
+ ]
249
+ for j, (c, label) in enumerate(legend_items):
250
+ lx = 0.02
251
+ ly = 0.035 - j * 0.015
252
+ circle = plt.Circle(
253
+ (lx, ly), 0.004,
254
+ facecolor=c, edgecolor="white", linewidth=0.3,
255
+ transform=ax.transAxes, zorder=5,
256
+ )
257
+ ax.add_patch(circle)
258
+ ax.text(
259
+ lx + 0.012, ly, label,
260
+ transform=ax.transAxes,
261
+ fontsize=7, color="#C9D1D9", va="center",
262
+ fontfamily="monospace",
263
+ )
264
+
265
+ ax.set_xlim(0, 1)
266
+ ax.set_ylim(0, 1)
267
+ ax.axis("off")
268
+
269
+ plt.savefig(save_path, dpi=200, bbox_inches="tight",
270
+ facecolor="#0D1117", edgecolor="none")
271
+ print(f"Saved: {save_path}")
272
+ plt.close()
273
+
274
+
275
+ # ═══════════════════════════════════════════════════════════════════════
276
+ # ALSO: A cleaner "zoomed in" single-layer view
277
+ # ═══════════════════════════════════════════════════════════════════════
278
+
279
+ def draw_single_layer_detail(save_path="layer_detail.png"):
280
+ """Draw a detailed view of one transformer layer with node connections."""
281
+ fig, ax = plt.subplots(figsize=(20, 14), facecolor="#0D1117")
282
+ ax.set_facecolor("#0D1117")
283
+
284
+ # One transformer layer breakdown:
285
+ # Input (1024) β†’ Q,K,V (3Γ—1024) β†’ Attention Heads (16Γ—64) β†’ Output Proj (1024)
286
+ # β†’ RMSNorm (1024) β†’ FFN Up (4096) β†’ GELU β†’ FFN Down (1024) β†’ Output (1024)
287
+
288
+ sub_layers = [
289
+ ("Input\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"),
290
+ ("Query\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF6B6B"),
291
+ ("Key\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#4ECDC4"),
292
+ ("Value\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#45B7D1"),
293
+ ("Attention Heads\n(16Γ—64)", 16, D_MODEL, 0, "#FF9800"),
294
+ ("Attn Output\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF9800"),
295
+ ("βŠ• Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"),
296
+ ("FFN Up (GELU)\n(d=4,096)", 14, D_FF, D_MODEL*D_FF, "#8BC34A"),
297
+ ("FFN Down\n(d=1,024)", 8, D_MODEL, D_FF*D_MODEL, "#8BC34A"),
298
+ ("βŠ• Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"),
299
+ ("Layer Output\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"),
300
+ ]
301
+
302
+ n = len(sub_layers)
303
+ y_positions = np.linspace(0.9, 0.08, n)
304
+ x_center = 0.5
305
+ max_spread = 0.32
306
+
307
+ all_pos = []
308
+
309
+ for i, (name, n_nodes, actual, params, color) in enumerate(sub_layers):
310
+ y = y_positions[i]
311
+ xs = np.linspace(x_center - max_spread, x_center + max_spread, n_nodes)
312
+ all_pos.append((xs, y))
313
+
314
+ # Connections
315
+ if i > 0:
316
+ prev_xs, prev_y = all_pos[i-1]
317
+ step_c = max(1, len(xs) // 10)
318
+ step_p = max(1, len(prev_xs) // 10)
319
+ for px in prev_xs[::step_p]:
320
+ for cx in xs[::step_c]:
321
+ ax.plot([px, cx], [prev_y, y],
322
+ color=color, alpha=0.2, linewidth=0.7,
323
+ transform=ax.transAxes, zorder=1)
324
+
325
+ # Nodes
326
+ r = 0.011 if n_nodes <= 10 else 0.009
327
+ for x in xs:
328
+ c = plt.Circle((x, y), r, facecolor=color, edgecolor="white",
329
+ linewidth=0.6, alpha=0.95,
330
+ transform=ax.transAxes, zorder=3)
331
+ ax.add_patch(c)
332
+
333
+ # Overflow indicator
334
+ if actual > n_nodes:
335
+ ax.text(xs[-1] + 0.025, y, f"(+{actual - n_nodes:,})",
336
+ transform=ax.transAxes, fontsize=7, color="#8B949E",
337
+ ha="left", va="center", fontfamily="monospace")
338
+
339
+ # Label
340
+ ax.text(0.03, y, name, transform=ax.transAxes,
341
+ fontsize=9, fontweight="bold", color="#E6EDF3",
342
+ ha="left", va="center", fontfamily="monospace")
343
+
344
+ # Params
345
+ if params > 0:
346
+ ax.text(0.97, y, f"{params:,}", transform=ax.transAxes,
347
+ fontsize=8, color=color, ha="right", va="center",
348
+ fontfamily="monospace", fontweight="bold")
349
+
350
+ # Title
351
+ ax.text(0.5, 0.96, "Single Transformer Layer β€” Detailed View",
352
+ transform=ax.transAxes, fontsize=18, fontweight="bold",
353
+ color="#E6EDF3", ha="center", fontfamily="monospace")
354
+ ax.text(0.5, 0.935,
355
+ f"Parameters per layer: {layer_total:,} β€’ Γ—{N_LAYERS} layers = {all_layers_total:,} total",
356
+ transform=ax.transAxes, fontsize=9, color="#8B949E",
357
+ ha="center", fontfamily="monospace")
358
+
359
+ ax.set_xlim(0, 1)
360
+ ax.set_ylim(0, 1)
361
+ ax.axis("off")
362
+
363
+ plt.savefig(save_path, dpi=200, bbox_inches="tight",
364
+ facecolor="#0D1117", edgecolor="none")
365
+ print(f"Saved: {save_path}")
366
+ plt.close()
367
+
368
+
369
+ if __name__ == "__main__":
370
+ import os
371
+ os.makedirs("viz", exist_ok=True)
372
+
373
+ print("=" * 50)
374
+ print(" GPT-300M Parameter Verification")
375
+ print("=" * 50)
376
+ print(f" Token Embedding: {embed_params:>13,}")
377
+ print(f" Per-layer Attention: {attn_total:>13,}")
378
+ print(f" Per-layer FFN: {ffn_total:>13,}")
379
+ print(f" Per-layer Norm: {rmsnorm_params:>13,}")
380
+ print(f" Per-layer Total: {layer_total:>13,}")
381
+ print(f" All {N_LAYERS} layers: {all_layers_total:>13,}")
382
+ print(f" Final Norm: {final_norm_params:>13,}")
383
+ print(f" LM Head (tied): {'0 (shared)':>13}")
384
+ print(f" ─────────────────────────────────")
385
+ print(f" TOTAL: {TOTAL_PARAMS:>13,}")
386
+ print(f" β‰ˆ {TOTAL_PARAMS / 1e6:.1f}M parameters")
387
+ print("=" * 50)
388
+
389
+ print("\nGenerating full network diagram...")
390
+ draw_neural_network("viz/neural_network_full.png")
391
+
392
+ print("Generating single-layer detail...")
393
+ draw_single_layer_detail("viz/neural_network_layer.png")
394
+
395
+ print("\nDone!")
visualize_nn.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-300M Neural Network Visualizer
3
+ ====================================
4
+ Generates detailed architectural diagrams of the GPT-300M model
5
+ using matplotlib, showing:
6
+ - Full model architecture flow
7
+ - Detailed transformer block internals
8
+ - Attention head visualization
9
+ - Parameter distribution charts
10
+
11
+ Usage:
12
+ python visualize_nn.py
13
+ python visualize_nn.py --output architecture.png
14
+ """
15
+
16
+ import argparse
17
+ import matplotlib
18
+ matplotlib.use("Agg")
19
+
20
+ import matplotlib.pyplot as plt
21
+ import matplotlib.patches as patches
22
+ from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
23
+ import numpy as np
24
+
25
+ from config import GPT300MConfig, gpt_300m
26
+
27
+
28
+ # ═══════════════════════════════════════════════════════════════════════
29
+ # COLOR SCHEME
30
+ # ═══════════════════════════════════════════════════════════════════════
31
+
32
+ COLORS = {
33
+ "bg": "#0D1117",
34
+ "text": "#E6EDF3",
35
+ "text_dim": "#8B949E",
36
+ "embed": "#58A6FF", # Blue
37
+ "attn": "#F78166", # Orange
38
+ "ffn": "#7EE787", # Green
39
+ "norm": "#D2A8FF", # Purple
40
+ "residual": "#FFA657", # Yellow-orange
41
+ "output": "#FF7B72", # Red
42
+ "arrow": "#484F58",
43
+ "highlight": "#1F6FEB",
44
+ "border": "#30363D",
45
+ "card_bg": "#161B22",
46
+ "accent1": "#79C0FF",
47
+ "accent2": "#BB9AF7",
48
+ }
49
+
50
+
51
+ def draw_rounded_box(ax, x, y, w, h, color, label, fontsize=10,
52
+ text_color=None, alpha=0.9, sublabel=None):
53
+ """Draw a rounded rectangle with label."""
54
+ box = FancyBboxPatch(
55
+ (x - w/2, y - h/2), w, h,
56
+ boxstyle="round,pad=0.1",
57
+ facecolor=color,
58
+ edgecolor="white",
59
+ linewidth=0.5,
60
+ alpha=alpha,
61
+ zorder=3,
62
+ )
63
+ ax.add_patch(box)
64
+ ax.text(
65
+ x, y + (0.15 if sublabel else 0),
66
+ label,
67
+ ha="center", va="center",
68
+ fontsize=fontsize,
69
+ fontweight="bold",
70
+ color=text_color or COLORS["text"],
71
+ zorder=4,
72
+ )
73
+ if sublabel:
74
+ ax.text(
75
+ x, y - 0.25,
76
+ sublabel,
77
+ ha="center", va="center",
78
+ fontsize=fontsize - 2,
79
+ color=COLORS["text_dim"],
80
+ zorder=4,
81
+ )
82
+
83
+
84
+ def draw_arrow(ax, x1, y1, x2, y2, color=None):
85
+ """Draw an arrow between two points."""
86
+ ax.annotate(
87
+ "",
88
+ xy=(x2, y2), xytext=(x1, y1),
89
+ arrowprops=dict(
90
+ arrowstyle="->",
91
+ color=color or COLORS["arrow"],
92
+ lw=1.5,
93
+ connectionstyle="arc3,rad=0",
94
+ ),
95
+ zorder=2,
96
+ )
97
+
98
+
99
+ def draw_residual_connection(ax, x_start, y_start, x_end, y_end, offset=1.8):
100
+ """Draw a residual/skip connection arc."""
101
+ ax.annotate(
102
+ "",
103
+ xy=(x_end, y_end), xytext=(x_start, y_start),
104
+ arrowprops=dict(
105
+ arrowstyle="->",
106
+ color=COLORS["residual"],
107
+ lw=1.2,
108
+ linestyle="--",
109
+ connectionstyle=f"arc3,rad=0.3",
110
+ ),
111
+ zorder=1,
112
+ )
113
+
114
+
115
+ # ═══════════════════════════════════════════════════════════════════════
116
+ # FULL ARCHITECTURE DIAGRAM
117
+ # ═══════════════════════════════════════════════════════════════════════
118
+
119
+ def draw_full_architecture(config: GPT300MConfig, save_path: str = None):
120
+ """Draw the complete GPT-300M architecture."""
121
+ fig, ax = plt.subplots(1, 1, figsize=(14, 24), facecolor=COLORS["bg"])
122
+ ax.set_facecolor(COLORS["bg"])
123
+ ax.set_xlim(-4, 4)
124
+ ax.set_ylim(-1, 22)
125
+ ax.axis("off")
126
+
127
+ # Title
128
+ ax.text(0, 21.5, "GPT-300M Architecture", ha="center", va="center",
129
+ fontsize=22, fontweight="bold", color=COLORS["text"],
130
+ fontfamily="monospace")
131
+ ax.text(0, 21.0,
132
+ f"{config.total_params_estimate:,} parameters β€’ "
133
+ f"{config.n_layers} layers β€’ "
134
+ f"{config.n_heads} heads β€’ "
135
+ f"d={config.d_model}",
136
+ ha="center", va="center", fontsize=10, color=COLORS["text_dim"],
137
+ fontfamily="monospace")
138
+
139
+ y = 19.5 # Starting y position
140
+
141
+ # ── Input ──────────────────────────────────────────────────────
142
+ draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["card_bg"], "Input Token IDs",
143
+ sublabel=f"[batch, seq_len]", fontsize=11)
144
+ y -= 1.1
145
+ draw_arrow(ax, 0, y + 0.8, 0, y + 0.4)
146
+
147
+ # ── Token Embedding ────────────────────────────────────────────
148
+ draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["embed"],
149
+ "Token Embedding", text_color="#000",
150
+ sublabel=f"{config.vocab_size:,} Γ— {config.d_model}")
151
+ y -= 1.1
152
+ draw_arrow(ax, 0, y + 0.8, 0, y + 0.4)
153
+
154
+ # ── RoPE ───────────────────────────────────────────────────────
155
+ draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["accent2"],
156
+ "Rotary Position Embeddings (RoPE)",
157
+ text_color="#000", fontsize=9,
158
+ sublabel=f"ΞΈ = {config.rope_theta:.0f}")
159
+ y -= 1.0
160
+ draw_arrow(ax, 0, y + 0.7, 0, y + 0.4)
161
+
162
+ # ── Dropout ────────────────────────────────────────────────────
163
+ draw_rounded_box(ax, 0, y, 2.5, 0.5, COLORS["border"],
164
+ f"Dropout (p={config.dropout})", fontsize=9)
165
+ y -= 1.0
166
+ draw_arrow(ax, 0, y + 0.7, 0, y + 0.35)
167
+
168
+ # ── Transformer Blocks ─────────────────────────────────────────
169
+ block_height = 3.2
170
+
171
+ # Draw detailed first block
172
+ block_y_start = y
173
+ block_y_end = y - block_height
174
+
175
+ # Block container
176
+ block_box = FancyBboxPatch(
177
+ (-3.3, block_y_end - 0.1), 6.6, block_height + 0.2,
178
+ boxstyle="round,pad=0.15",
179
+ facecolor=COLORS["card_bg"],
180
+ edgecolor=COLORS["highlight"],
181
+ linewidth=1.5,
182
+ alpha=0.8,
183
+ zorder=1,
184
+ )
185
+ ax.add_patch(block_box)
186
+ ax.text(-3.0, block_y_start + 0.05,
187
+ f"Transformer Block Γ— {config.n_layers}",
188
+ fontsize=10, fontweight="bold", color=COLORS["highlight"],
189
+ fontfamily="monospace", zorder=5)
190
+
191
+ # Inside the block
192
+ by = block_y_start - 0.4
193
+
194
+ # RMSNorm 1
195
+ draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"],
196
+ "RMSNorm", text_color="#000", fontsize=9)
197
+ by -= 0.7
198
+ draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
199
+
200
+ # Multi-Head Attention
201
+ draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["attn"],
202
+ "Multi-Head Attention", text_color="#000", fontsize=10,
203
+ sublabel=f"{config.n_heads} heads Γ— {config.head_dim}d")
204
+ # Residual connection
205
+ draw_residual_connection(ax, -1.6, block_y_start - 0.2, -1.6, by)
206
+ ax.text(-2.5, by + 0.3, "βŠ• residual", fontsize=7,
207
+ color=COLORS["residual"], ha="center")
208
+
209
+ by -= 0.8
210
+ draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
211
+
212
+ # RMSNorm 2
213
+ draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"],
214
+ "RMSNorm", text_color="#000", fontsize=9)
215
+ by -= 0.7
216
+ draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
217
+
218
+ # Feed-Forward Network
219
+ draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["ffn"],
220
+ "Feed-Forward Network", text_color="#000", fontsize=10,
221
+ sublabel=f"{config.d_model} β†’ {config.d_ff} β†’ {config.d_model}")
222
+ # Residual connection
223
+ draw_residual_connection(ax, 1.6, by + 1.5, 1.6, by)
224
+ ax.text(2.5, by + 0.7, "βŠ• residual", fontsize=7,
225
+ color=COLORS["residual"], ha="center")
226
+
227
+ y = block_y_end - 0.4
228
+
229
+ # ── Repeated blocks indicator ──────────────────────────────────
230
+ draw_arrow(ax, 0, y + 0.2, 0, y - 0.1)
231
+ ax.text(0, y - 0.3, f"Γ— {config.n_layers} layers", ha="center",
232
+ fontsize=11, fontweight="bold", color=COLORS["text_dim"],
233
+ fontfamily="monospace",
234
+ bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS["card_bg"],
235
+ edgecolor=COLORS["border"]))
236
+ y -= 0.9
237
+ draw_arrow(ax, 0, y + 0.3, 0, y + 0.05)
238
+
239
+ # ── Final RMSNorm ──────────────────────────────────────────────
240
+ draw_rounded_box(ax, 0, y - 0.2, 3.5, 0.5, COLORS["norm"],
241
+ "Final RMSNorm", text_color="#000", fontsize=10)
242
+ y -= 1.0
243
+ draw_arrow(ax, 0, y + 0.5, 0, y + 0.2)
244
+
245
+ # ── LM Head ────────────────────────────────────────────────────
246
+ draw_rounded_box(ax, 0, y - 0.1, 3.5, 0.7, COLORS["output"],
247
+ "Linear (LM Head)", text_color="#000", fontsize=11,
248
+ sublabel=f"{config.d_model} β†’ {config.vocab_size:,} (weight-tied)")
249
+ y -= 1.1
250
+ draw_arrow(ax, 0, y + 0.7, 0, y + 0.35)
251
+
252
+ # ── Softmax / Output ───────────────────────────���───────────────
253
+ draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["card_bg"],
254
+ "Softmax β†’ Next Token Probabilities", fontsize=10,
255
+ sublabel=f"[batch, seq_len, {config.vocab_size:,}]")
256
+
257
+ plt.tight_layout()
258
+
259
+ if save_path:
260
+ fig.savefig(save_path, dpi=200, bbox_inches="tight",
261
+ facecolor=COLORS["bg"], edgecolor="none")
262
+ print(f"Saved architecture diagram: {save_path}")
263
+
264
+ return fig
265
+
266
+
267
+ # ═══════════════════════════════════════════════════════════════════════
268
+ # PARAMETER DISTRIBUTION CHART
269
+ # ═══════════════════════════════════════════════════════════════════════
270
+
271
+ def draw_parameter_chart(config: GPT300MConfig, save_path: str = None):
272
+ """Draw a parameter distribution breakdown."""
273
+ fig, axes = plt.subplots(1, 2, figsize=(16, 7), facecolor=COLORS["bg"])
274
+
275
+ # Calculate parameter counts per component
276
+ emb_params = config.vocab_size * config.d_model
277
+ attn_params = 4 * config.d_model * config.d_model * config.n_layers
278
+ ffn_params = 2 * config.d_model * config.d_ff * config.n_layers
279
+ norm_params = 2 * config.d_model * config.n_layers + config.d_model
280
+ total = emb_params + attn_params + ffn_params + norm_params
281
+
282
+ # ── Pie Chart ──────────────────────────────────────────────────
283
+ ax = axes[0]
284
+ ax.set_facecolor(COLORS["bg"])
285
+ labels = ["Token\nEmbedding", "Attention\nLayers", "Feed-Forward\nLayers", "LayerNorm"]
286
+ sizes = [emb_params, attn_params, ffn_params, norm_params]
287
+ colors = [COLORS["embed"], COLORS["attn"], COLORS["ffn"], COLORS["norm"]]
288
+
289
+ wedges, texts, autotexts = ax.pie(
290
+ sizes, labels=None, autopct=lambda p: f"{p:.1f}%",
291
+ colors=colors, startangle=90, pctdistance=0.7,
292
+ wedgeprops=dict(width=0.5, edgecolor=COLORS["bg"], linewidth=2),
293
+ textprops=dict(color=COLORS["text"], fontsize=10),
294
+ )
295
+ for at in autotexts:
296
+ at.set_fontweight("bold")
297
+ at.set_color("#000")
298
+
299
+ # Legend
300
+ legend_labels = [
301
+ f"{l}\n({s/1e6:.1f}M)" for l, s in zip(
302
+ ["Token Embedding", "Attention", "Feed-Forward", "LayerNorm"],
303
+ sizes
304
+ )
305
+ ]
306
+ ax.legend(
307
+ wedges, legend_labels, loc="center left", bbox_to_anchor=(1.05, 0.5),
308
+ fontsize=9, frameon=False, labelcolor=COLORS["text"],
309
+ )
310
+ ax.set_title("Parameter Distribution", fontsize=14, fontweight="bold",
311
+ color=COLORS["text"], pad=15)
312
+
313
+ # ── Per-Layer Breakdown Bar Chart ──────────────────────────────
314
+ ax = axes[1]
315
+ ax.set_facecolor(COLORS["bg"])
316
+
317
+ layer_attn = 4 * config.d_model * config.d_model
318
+ layer_ffn = 2 * config.d_model * config.d_ff
319
+ layer_norm = 2 * config.d_model
320
+
321
+ layers = range(1, config.n_layers + 1)
322
+ bar_width = 0.8
323
+
324
+ ax.bar(layers, [layer_attn / 1e6] * config.n_layers, bar_width,
325
+ label="Attention", color=COLORS["attn"], alpha=0.9)
326
+ ax.bar(layers, [layer_ffn / 1e6] * config.n_layers, bar_width,
327
+ bottom=[layer_attn / 1e6] * config.n_layers,
328
+ label="Feed-Forward", color=COLORS["ffn"], alpha=0.9)
329
+ ax.bar(layers, [layer_norm / 1e6] * config.n_layers, bar_width,
330
+ bottom=[(layer_attn + layer_ffn) / 1e6] * config.n_layers,
331
+ label="Norm", color=COLORS["norm"], alpha=0.9)
332
+
333
+ ax.set_xlabel("Layer", fontsize=11, color=COLORS["text"])
334
+ ax.set_ylabel("Parameters (M)", fontsize=11, color=COLORS["text"])
335
+ ax.set_title("Parameters Per Layer", fontsize=14, fontweight="bold",
336
+ color=COLORS["text"], pad=15)
337
+ ax.legend(fontsize=9, frameon=False, labelcolor=COLORS["text"])
338
+ ax.tick_params(colors=COLORS["text_dim"])
339
+ ax.spines["bottom"].set_color(COLORS["border"])
340
+ ax.spines["left"].set_color(COLORS["border"])
341
+ ax.spines["top"].set_visible(False)
342
+ ax.spines["right"].set_visible(False)
343
+
344
+ # Overall title
345
+ fig.suptitle(
346
+ f"GPT-300M β€’ {total:,} Total Parameters",
347
+ fontsize=16, fontweight="bold", color=COLORS["text"],
348
+ fontfamily="monospace", y=1.02,
349
+ )
350
+
351
+ plt.tight_layout()
352
+
353
+ if save_path:
354
+ fig.savefig(save_path, dpi=200, bbox_inches="tight",
355
+ facecolor=COLORS["bg"], edgecolor="none")
356
+ print(f"Saved parameter chart: {save_path}")
357
+
358
+ return fig
359
+
360
+
361
+ # ═══════════════════════════════════════════════════════════════════════
362
+ # ATTENTION HEAD VISUALIZATION
363
+ # ══���════════════════════════════════════════════════════════════════════
364
+
365
+ def draw_attention_heads(config: GPT300MConfig, save_path: str = None):
366
+ """Visualize the multi-head attention mechanism."""
367
+ fig, ax = plt.subplots(1, 1, figsize=(14, 10), facecolor=COLORS["bg"])
368
+ ax.set_facecolor(COLORS["bg"])
369
+ ax.set_xlim(-1, 11)
370
+ ax.set_ylim(-1, 8)
371
+ ax.axis("off")
372
+
373
+ ax.text(5, 7.5, "Multi-Head Self-Attention", ha="center",
374
+ fontsize=18, fontweight="bold", color=COLORS["text"],
375
+ fontfamily="monospace")
376
+ ax.text(5, 7.0,
377
+ f"{config.n_heads} heads Γ— {config.head_dim}d per head = {config.d_model}d total",
378
+ ha="center", fontsize=10, color=COLORS["text_dim"])
379
+
380
+ # Input
381
+ draw_rounded_box(ax, 5, 6.2, 4, 0.5, COLORS["embed"],
382
+ f"Input: [B, T, {config.d_model}]", text_color="#000", fontsize=9)
383
+
384
+ # Q, K, V projections
385
+ for i, (name, color) in enumerate(zip(["Q", "K", "V"],
386
+ ["#FF6B6B", "#4ECDC4", "#45B7D1"])):
387
+ x = 2 + i * 3
388
+ draw_arrow(ax, 5, 5.9, x, 5.4)
389
+ draw_rounded_box(ax, x, 5.1, 1.8, 0.5, color,
390
+ f"W_{name}", text_color="#000", fontsize=10,
391
+ sublabel=f"{config.d_model}Γ—{config.d_model}")
392
+
393
+ # Heads
394
+ head_y = 3.8
395
+ n_show = min(config.n_heads, 8)
396
+ head_spacing = 9.0 / n_show
397
+
398
+ for h in range(n_show):
399
+ hx = 1 + h * head_spacing
400
+ # Head box
401
+ box = FancyBboxPatch(
402
+ (hx - 0.4, head_y - 0.3), 0.8, 0.6,
403
+ boxstyle="round,pad=0.05",
404
+ facecolor=COLORS["attn"],
405
+ edgecolor="white",
406
+ linewidth=0.5,
407
+ alpha=0.8,
408
+ zorder=3,
409
+ )
410
+ ax.add_patch(box)
411
+ ax.text(hx, head_y, f"H{h+1}", ha="center", va="center",
412
+ fontsize=8, fontweight="bold", color="#000", zorder=4)
413
+
414
+ # Arrows from Q,K,V to heads
415
+ for qi, qx in enumerate([2, 5, 8]):
416
+ ax.annotate("", xy=(hx, head_y + 0.3), xytext=(qx, 4.8),
417
+ arrowprops=dict(arrowstyle="-", color=COLORS["arrow"],
418
+ lw=0.3, alpha=0.3), zorder=1)
419
+
420
+ if config.n_heads > 8:
421
+ ax.text(5, head_y - 0.6, f"... ({config.n_heads} heads total)",
422
+ ha="center", fontsize=9, color=COLORS["text_dim"])
423
+
424
+ # Attention computation
425
+ draw_rounded_box(ax, 5, 2.5, 6, 0.6, COLORS["card_bg"],
426
+ "Scaled Dot-Product: softmax(QK^T / √d_k) Γ— V",
427
+ fontsize=10)
428
+ for h in range(n_show):
429
+ hx = 1 + h * head_spacing
430
+ draw_arrow(ax, hx, head_y - 0.3, 5, 2.85)
431
+
432
+ # Concatenate
433
+ draw_arrow(ax, 5, 2.15, 5, 1.75)
434
+ draw_rounded_box(ax, 5, 1.5, 4, 0.5, COLORS["accent1"],
435
+ "Concat β†’ W_O projection", text_color="#000", fontsize=10)
436
+
437
+ # Output
438
+ draw_arrow(ax, 5, 1.2, 5, 0.8)
439
+ draw_rounded_box(ax, 5, 0.5, 4, 0.5, COLORS["ffn"],
440
+ f"Output: [B, T, {config.d_model}]", text_color="#000", fontsize=9)
441
+
442
+ plt.tight_layout()
443
+
444
+ if save_path:
445
+ fig.savefig(save_path, dpi=200, bbox_inches="tight",
446
+ facecolor=COLORS["bg"], edgecolor="none")
447
+ print(f"Saved attention diagram: {save_path}")
448
+
449
+ return fig
450
+
451
+
452
+ # ═══════════════════════════════════════════════════════════════════════
453
+ # MAIN
454
+ # ═══════════════════════════════════════════════════════════════════════
455
+
456
+ if __name__ == "__main__":
457
+ parser = argparse.ArgumentParser(description="Visualize GPT-300M Architecture")
458
+ parser.add_argument("--output", type=str, default="./viz",
459
+ help="Output directory for images")
460
+ args = parser.parse_args()
461
+
462
+ import os
463
+ os.makedirs(args.output, exist_ok=True)
464
+
465
+ config = gpt_300m()
466
+ print(f"Generating visualizations for GPT-300M ({config.total_params_estimate:,} params)...")
467
+
468
+ draw_full_architecture(config, os.path.join(args.output, "architecture.png"))
469
+ draw_parameter_chart(config, os.path.join(args.output, "parameters.png"))
470
+ draw_attention_heads(config, os.path.join(args.output, "attention.png"))
471
+
472
+ print("Done! All visualizations saved.")