Upload 16 files
Browse files- .gitattributes +1 -32
- README.md +193 -0
- WEIGHTS_GO_HERE.txt +3 -0
- chat.py +339 -0
- config.json +20 -0
- config.py +157 -0
- dataset.py +269 -0
- model.py +513 -0
- requirements.txt +3 -0
- special_tokens_map.json +12 -0
- tokenizer.py +344 -0
- tokenizer_config.json +11 -0
- train.py +456 -0
- visual_nn_3d.py +387 -0
- visual_nn_nodes.py +395 -0
- visualize_nn.py +472 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,4 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
| 1 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- text-generation
|
| 7 |
+
- from-scratch
|
| 8 |
+
- transformer
|
| 9 |
+
- gpt
|
| 10 |
+
- pytorch
|
| 11 |
+
- chatbot
|
| 12 |
+
pipeline_tag: text-generation
|
| 13 |
+
model-index:
|
| 14 |
+
- name: GPT-300M
|
| 15 |
+
results: []
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# GPT-300M
|
| 19 |
+
|
| 20 |
+
A **334,808,064 parameter** autoregressive transformer language model built **entirely from scratch** in PyTorch. No pretrained weights. No fine-tuning. Everything from zero.
|
| 21 |
+
|
| 22 |
+
## Architecture
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
Input Token IDs
|
| 26 |
+
β
|
| 27 |
+
Token Embedding (32,000 Γ 1,024) β 32.8M params
|
| 28 |
+
β
|
| 29 |
+
Rotary Position Embeddings (RoPE) β 0 learned params
|
| 30 |
+
β
|
| 31 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
β Transformer Block Γ 24 layers (12.6M each) β
|
| 33 |
+
β β
|
| 34 |
+
β RMSNorm β Multi-Head Attention β β Residual β
|
| 35 |
+
β 16 heads Γ 64d β
|
| 36 |
+
β 4,194,304 params β
|
| 37 |
+
β β
|
| 38 |
+
β RMSNorm β FFN (GELU) β β Residual β
|
| 39 |
+
β 1,024 β 4,096 β 1,024 β
|
| 40 |
+
β 8,388,608 params β
|
| 41 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
β
|
| 43 |
+
Final RMSNorm
|
| 44 |
+
β
|
| 45 |
+
LM Head (weight-tied with embedding) β 0 extra params
|
| 46 |
+
β
|
| 47 |
+
Softmax β Next Token Probabilities
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Parameter Breakdown
|
| 51 |
+
|
| 52 |
+
| Component | Parameters | Percentage |
|
| 53 |
+
|---|---:|---:|
|
| 54 |
+
| Token Embedding | 32,768,000 | 9.8% |
|
| 55 |
+
| Attention Layers (Γ24) | 100,663,296 | 30.1% |
|
| 56 |
+
| Feed-Forward Layers (Γ24) | 201,326,592 | 60.1% |
|
| 57 |
+
| RMSNorm (Γ24 + final) | 50,176 | 0.0% |
|
| 58 |
+
| LM Head | 0 (tied) | β |
|
| 59 |
+
| **TOTAL** | **334,808,064** | **100%** |
|
| 60 |
+
|
| 61 |
+
## Model Details
|
| 62 |
+
|
| 63 |
+
| Hyperparameter | Value |
|
| 64 |
+
|---|---|
|
| 65 |
+
| Hidden dimension (d_model) | 1,024 |
|
| 66 |
+
| Attention heads | 16 |
|
| 67 |
+
| Head dimension | 64 |
|
| 68 |
+
| Transformer layers | 24 |
|
| 69 |
+
| FFN dimension (d_ff) | 4,096 |
|
| 70 |
+
| Vocabulary size | 32,000 |
|
| 71 |
+
| Max sequence length | 2,048 |
|
| 72 |
+
| Position encoding | RoPE (ΞΈ=10,000) |
|
| 73 |
+
| Activation | GELU |
|
| 74 |
+
| Normalization | RMSNorm (Ξ΅=1e-5) |
|
| 75 |
+
| Weight tying | Yes (Embed β LM Head) |
|
| 76 |
+
| Bias | None |
|
| 77 |
+
|
| 78 |
+
## Training Configuration
|
| 79 |
+
|
| 80 |
+
| Setting | Value |
|
| 81 |
+
|---|---|
|
| 82 |
+
| Optimizer | AdamW (Ξ²β=0.9, Ξ²β=0.95) |
|
| 83 |
+
| Peak learning rate | 3e-4 |
|
| 84 |
+
| Min learning rate | 3e-5 |
|
| 85 |
+
| Schedule | Cosine decay + linear warmup |
|
| 86 |
+
| Warmup steps | 2,000 |
|
| 87 |
+
| Weight decay | 0.1 |
|
| 88 |
+
| Batch size | 32 Γ 8 gradient accumulation |
|
| 89 |
+
| Max training steps | 600,000 |
|
| 90 |
+
| Precision | bfloat16 |
|
| 91 |
+
| Gradient clipping | 1.0 |
|
| 92 |
+
|
| 93 |
+
## Usage
|
| 94 |
+
|
| 95 |
+
### Loading the Model
|
| 96 |
+
|
| 97 |
+
```python
|
| 98 |
+
from model import GPT300M
|
| 99 |
+
from config import GPT300MConfig
|
| 100 |
+
from tokenizer import BPETokenizer
|
| 101 |
+
import torch
|
| 102 |
+
|
| 103 |
+
# Load config, model, and tokenizer
|
| 104 |
+
config = GPT300MConfig()
|
| 105 |
+
model = GPT300M(config)
|
| 106 |
+
|
| 107 |
+
# Load trained weights
|
| 108 |
+
checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
|
| 109 |
+
model.load_state_dict(checkpoint)
|
| 110 |
+
model.eval()
|
| 111 |
+
|
| 112 |
+
# Load tokenizer
|
| 113 |
+
tokenizer = BPETokenizer.load("tokenizer.json")
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Chat with the Model
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
from chat import ChatBot
|
| 120 |
+
|
| 121 |
+
chatbot = ChatBot(model, tokenizer, config)
|
| 122 |
+
response = chatbot.chat("Hello! What is machine learning?")
|
| 123 |
+
print(response)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Interactive Chat
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
python chat.py --checkpoint pytorch_model.bin
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Training from Scratch
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
# Quick test (tiny model)
|
| 136 |
+
python train.py --tiny
|
| 137 |
+
|
| 138 |
+
# Full 300M model
|
| 139 |
+
python train.py --data your_training_data.txt
|
| 140 |
+
|
| 141 |
+
# Multi-GPU
|
| 142 |
+
torchrun --nproc_per_node=4 train.py --data your_data.txt
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Files
|
| 146 |
+
|
| 147 |
+
| File | Description |
|
| 148 |
+
|---|---|
|
| 149 |
+
| `config.json` | Model configuration (HuggingFace format) |
|
| 150 |
+
| `config.py` | Python config class with all hyperparameters |
|
| 151 |
+
| `model.py` | Full transformer architecture (RoPE, MHA, FFN, KV-cache) |
|
| 152 |
+
| `tokenizer.py` | BPE tokenizer built from scratch |
|
| 153 |
+
| `tokenizer_config.json` | Tokenizer settings |
|
| 154 |
+
| `special_tokens_map.json` | Special token definitions |
|
| 155 |
+
| `dataset.py` | Dataset classes and data loading |
|
| 156 |
+
| `train.py` | Training loop (DDP, mixed precision, scheduling) |
|
| 157 |
+
| `chat.py` | Interactive chatbot with streaming generation |
|
| 158 |
+
| `visual_nn_3d.py` | 3D matplotlib architecture visualization |
|
| 159 |
+
| `requirements.txt` | Python dependencies |
|
| 160 |
+
| `pytorch_model.bin` | Trained weights *(upload after training)* |
|
| 161 |
+
| `tokenizer.json` | Trained tokenizer *(upload after training)* |
|
| 162 |
+
|
| 163 |
+
## Hardware Requirements
|
| 164 |
+
|
| 165 |
+
| Config | GPU Memory | Est. Training Time |
|
| 166 |
+
|---|---|---|
|
| 167 |
+
| Tiny (debug) | ~1 GB | Minutes |
|
| 168 |
+
| Full 300M | ~24 GB | ~3-5 days (4ΓA100) |
|
| 169 |
+
|
| 170 |
+
## Key Features
|
| 171 |
+
|
| 172 |
+
- **100% from scratch** β no pretrained weights, no HuggingFace Transformers dependency
|
| 173 |
+
- **Rotary Position Embeddings** β better length generalization than learned positions
|
| 174 |
+
- **RMSNorm** β faster than LayerNorm, equally effective
|
| 175 |
+
- **Flash Attention** β via PyTorch 2.0 SDPA
|
| 176 |
+
- **KV-Cache** β efficient autoregressive generation
|
| 177 |
+
- **Weight tying** β saves ~33M parameters
|
| 178 |
+
- **Chat template** β built-in support for multi-turn conversations
|
| 179 |
+
- **torch.compile** β ready for PyTorch 2.0+ compilation
|
| 180 |
+
|
| 181 |
+
## Citation
|
| 182 |
+
|
| 183 |
+
```bibtex
|
| 184 |
+
@misc{gpt300m,
|
| 185 |
+
title={GPT-300M: A 300-Million Parameter Language Model From Scratch},
|
| 186 |
+
year={2025},
|
| 187 |
+
url={https://huggingface.co/YOUR_USERNAME/gpt-300m}
|
| 188 |
+
}
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
## License
|
| 192 |
+
|
| 193 |
+
MIT
|
WEIGHTS_GO_HERE.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PLACEHOLDER - Replace this file with your trained model weights after training.
|
| 2 |
+
Run: python train.py --data your_data.txt
|
| 3 |
+
Then: torch.save(checkpoint['model_state_dict'], 'pytorch_model.bin')
|
chat.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M Chatbot Interface
|
| 3 |
+
============================
|
| 4 |
+
Interactive terminal chatbot using a trained GPT-300M model.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python chat.py --checkpoint ./checkpoints/best_model.pt
|
| 8 |
+
|
| 9 |
+
# Or with custom generation parameters:
|
| 10 |
+
python chat.py --checkpoint ./checkpoints/best_model.pt \
|
| 11 |
+
--temperature 0.8 --top_k 40 --max_tokens 256
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
from typing import List, Dict, Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from config import GPT300MConfig
|
| 22 |
+
from model import GPT300M
|
| 23 |
+
from tokenizer import BPETokenizer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ChatBot:
|
| 27 |
+
"""
|
| 28 |
+
Interactive chatbot powered by GPT-300M.
|
| 29 |
+
|
| 30 |
+
Maintains conversation history, handles tokenization/detokenization,
|
| 31 |
+
and performs autoregressive generation with KV-caching.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
model: GPT300M,
|
| 37 |
+
tokenizer: BPETokenizer,
|
| 38 |
+
config: GPT300MConfig,
|
| 39 |
+
device: str = "auto",
|
| 40 |
+
):
|
| 41 |
+
self.config = config
|
| 42 |
+
self.tokenizer = tokenizer
|
| 43 |
+
|
| 44 |
+
# Device
|
| 45 |
+
if device == "auto":
|
| 46 |
+
if torch.cuda.is_available():
|
| 47 |
+
self.device = "cuda"
|
| 48 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 49 |
+
self.device = "mps"
|
| 50 |
+
else:
|
| 51 |
+
self.device = "cpu"
|
| 52 |
+
else:
|
| 53 |
+
self.device = device
|
| 54 |
+
|
| 55 |
+
self.model = model.to(self.device)
|
| 56 |
+
self.model.eval()
|
| 57 |
+
|
| 58 |
+
# Conversation state
|
| 59 |
+
self.history: List[Dict[str, str]] = []
|
| 60 |
+
self.system_prompt = config.system_prompt
|
| 61 |
+
|
| 62 |
+
def set_system_prompt(self, prompt: str):
|
| 63 |
+
"""Set the system prompt for the conversation."""
|
| 64 |
+
self.system_prompt = prompt
|
| 65 |
+
|
| 66 |
+
def reset(self):
|
| 67 |
+
"""Clear conversation history."""
|
| 68 |
+
self.history = []
|
| 69 |
+
print("\n⦠Conversation reset.\n")
|
| 70 |
+
|
| 71 |
+
def chat(
|
| 72 |
+
self,
|
| 73 |
+
user_message: str,
|
| 74 |
+
temperature: Optional[float] = None,
|
| 75 |
+
top_k: Optional[int] = None,
|
| 76 |
+
top_p: Optional[float] = None,
|
| 77 |
+
max_new_tokens: Optional[int] = None,
|
| 78 |
+
stream: bool = True,
|
| 79 |
+
) -> str:
|
| 80 |
+
"""
|
| 81 |
+
Send a message and get a response.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
user_message: The user's input
|
| 85 |
+
temperature: Override sampling temperature
|
| 86 |
+
top_k: Override top-k
|
| 87 |
+
top_p: Override top-p
|
| 88 |
+
max_new_tokens: Override max generation length
|
| 89 |
+
stream: Whether to stream tokens to stdout
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
The assistant's response text
|
| 93 |
+
"""
|
| 94 |
+
temp = temperature or self.config.temperature
|
| 95 |
+
k = top_k or self.config.top_k
|
| 96 |
+
p = top_p or self.config.top_p
|
| 97 |
+
max_tokens = max_new_tokens or self.config.max_new_tokens
|
| 98 |
+
|
| 99 |
+
# Build conversation messages
|
| 100 |
+
messages = []
|
| 101 |
+
if self.system_prompt:
|
| 102 |
+
messages.append({"role": "system", "content": self.system_prompt})
|
| 103 |
+
messages.extend(self.history)
|
| 104 |
+
messages.append({"role": "user", "content": user_message})
|
| 105 |
+
|
| 106 |
+
# Tokenize
|
| 107 |
+
input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True)
|
| 108 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 109 |
+
|
| 110 |
+
# Check sequence length
|
| 111 |
+
if input_tensor.size(1) > self.config.max_seq_len - max_tokens:
|
| 112 |
+
# Truncate history if needed
|
| 113 |
+
while (
|
| 114 |
+
len(self.history) > 0
|
| 115 |
+
and input_tensor.size(1) > self.config.max_seq_len - max_tokens
|
| 116 |
+
):
|
| 117 |
+
self.history.pop(0)
|
| 118 |
+
messages = []
|
| 119 |
+
if self.system_prompt:
|
| 120 |
+
messages.append({"role": "system", "content": self.system_prompt})
|
| 121 |
+
messages.extend(self.history)
|
| 122 |
+
messages.append({"role": "user", "content": user_message})
|
| 123 |
+
input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True)
|
| 124 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
|
| 125 |
+
|
| 126 |
+
# Generate
|
| 127 |
+
t0 = time.time()
|
| 128 |
+
|
| 129 |
+
if stream:
|
| 130 |
+
response_text = self._generate_streaming(
|
| 131 |
+
input_tensor, max_tokens, temp, k, p
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
output_ids = self.model.generate(
|
| 136 |
+
input_tensor,
|
| 137 |
+
max_new_tokens=max_tokens,
|
| 138 |
+
temperature=temp,
|
| 139 |
+
top_k=k,
|
| 140 |
+
top_p=p,
|
| 141 |
+
repetition_penalty=self.config.repetition_penalty,
|
| 142 |
+
eos_token_id=self.tokenizer.special_tokens.get("<|end|>"),
|
| 143 |
+
)
|
| 144 |
+
# Decode only the new tokens
|
| 145 |
+
new_ids = output_ids[0, input_tensor.size(1):].tolist()
|
| 146 |
+
response_text = self.tokenizer.decode(new_ids, skip_special=True)
|
| 147 |
+
|
| 148 |
+
dt = time.time() - t0
|
| 149 |
+
n_tokens = len(self.tokenizer.encode(response_text))
|
| 150 |
+
|
| 151 |
+
# Update history
|
| 152 |
+
self.history.append({"role": "user", "content": user_message})
|
| 153 |
+
self.history.append({"role": "assistant", "content": response_text.strip()})
|
| 154 |
+
|
| 155 |
+
if stream:
|
| 156 |
+
print(f"\n [{n_tokens} tokens, {dt:.1f}s, {n_tokens/dt:.1f} tok/s]")
|
| 157 |
+
|
| 158 |
+
return response_text.strip()
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def _generate_streaming(
|
| 162 |
+
self,
|
| 163 |
+
input_ids: torch.Tensor,
|
| 164 |
+
max_new_tokens: int,
|
| 165 |
+
temperature: float,
|
| 166 |
+
top_k: int,
|
| 167 |
+
top_p: float,
|
| 168 |
+
) -> str:
|
| 169 |
+
"""Generate tokens one at a time, printing as we go."""
|
| 170 |
+
import torch.nn.functional as F
|
| 171 |
+
|
| 172 |
+
model = self.model
|
| 173 |
+
model.eval()
|
| 174 |
+
|
| 175 |
+
eos_id = self.tokenizer.special_tokens.get("<|end|>")
|
| 176 |
+
end_id = self.tokenizer.special_tokens.get("<eos>")
|
| 177 |
+
|
| 178 |
+
# Initial forward pass
|
| 179 |
+
logits, _, kv_caches = model(input_ids, use_cache=True)
|
| 180 |
+
|
| 181 |
+
generated_ids = []
|
| 182 |
+
buffer = b""
|
| 183 |
+
|
| 184 |
+
for step in range(max_new_tokens):
|
| 185 |
+
next_logits = logits[:, -1, :]
|
| 186 |
+
|
| 187 |
+
# Repetition penalty
|
| 188 |
+
if self.config.repetition_penalty != 1.0:
|
| 189 |
+
for tid in set(generated_ids):
|
| 190 |
+
if next_logits[0, tid] > 0:
|
| 191 |
+
next_logits[0, tid] /= self.config.repetition_penalty
|
| 192 |
+
else:
|
| 193 |
+
next_logits[0, tid] *= self.config.repetition_penalty
|
| 194 |
+
|
| 195 |
+
# Temperature + sampling
|
| 196 |
+
if temperature > 0:
|
| 197 |
+
next_logits = next_logits / temperature
|
| 198 |
+
if top_k > 0:
|
| 199 |
+
topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| 200 |
+
next_logits[next_logits < topk_vals[:, -1:]] = float("-inf")
|
| 201 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 202 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 203 |
+
else:
|
| 204 |
+
next_token = next_logits.argmax(dim=-1, keepdim=True)
|
| 205 |
+
|
| 206 |
+
token_id = next_token.item()
|
| 207 |
+
|
| 208 |
+
# Check for stop tokens
|
| 209 |
+
if token_id in (eos_id, end_id):
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
generated_ids.append(token_id)
|
| 213 |
+
|
| 214 |
+
# Decode and print the new token
|
| 215 |
+
token_bytes = self.tokenizer.vocab.get(token_id, b"")
|
| 216 |
+
buffer += token_bytes
|
| 217 |
+
try:
|
| 218 |
+
text = buffer.decode("utf-8")
|
| 219 |
+
sys.stdout.write(text)
|
| 220 |
+
sys.stdout.flush()
|
| 221 |
+
buffer = b""
|
| 222 |
+
except UnicodeDecodeError:
|
| 223 |
+
pass # Wait for more bytes
|
| 224 |
+
|
| 225 |
+
# Forward with KV-cache
|
| 226 |
+
position_offset = input_ids.size(1) + step
|
| 227 |
+
logits, _, kv_caches = model(
|
| 228 |
+
next_token,
|
| 229 |
+
kv_caches=kv_caches,
|
| 230 |
+
use_cache=True,
|
| 231 |
+
position_offset=position_offset,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Flush remaining buffer
|
| 235 |
+
if buffer:
|
| 236 |
+
text = buffer.decode("utf-8", errors="replace")
|
| 237 |
+
sys.stdout.write(text)
|
| 238 |
+
sys.stdout.flush()
|
| 239 |
+
|
| 240 |
+
return self.tokenizer.decode(generated_ids, skip_special=True)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def interactive_chat(chatbot: ChatBot):
|
| 244 |
+
"""Run an interactive chat session in the terminal."""
|
| 245 |
+
print("=" * 60)
|
| 246 |
+
print(" GPT-300M Chatbot")
|
| 247 |
+
print(" Type 'quit' to exit, 'reset' to clear history")
|
| 248 |
+
print(" Type 'system: <prompt>' to set system prompt")
|
| 249 |
+
print("=" * 60)
|
| 250 |
+
print()
|
| 251 |
+
|
| 252 |
+
while True:
|
| 253 |
+
try:
|
| 254 |
+
user_input = input("You: ").strip()
|
| 255 |
+
except (KeyboardInterrupt, EOFError):
|
| 256 |
+
print("\n\nGoodbye!")
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
if not user_input:
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
if user_input.lower() == "quit":
|
| 263 |
+
print("Goodbye!")
|
| 264 |
+
break
|
| 265 |
+
|
| 266 |
+
if user_input.lower() == "reset":
|
| 267 |
+
chatbot.reset()
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
if user_input.lower().startswith("system:"):
|
| 271 |
+
prompt = user_input[7:].strip()
|
| 272 |
+
chatbot.set_system_prompt(prompt)
|
| 273 |
+
print(f"β¦ System prompt set: {prompt}\n")
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
print("\nAssistant: ", end="", flush=True)
|
| 277 |
+
chatbot.chat(user_input, stream=True)
|
| 278 |
+
print()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def load_model(checkpoint_path: str, device: str = "auto"):
|
| 282 |
+
"""Load a trained model from checkpoint."""
|
| 283 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 284 |
+
|
| 285 |
+
# Reconstruct config
|
| 286 |
+
config = GPT300MConfig(**checkpoint["config"])
|
| 287 |
+
|
| 288 |
+
# Load model
|
| 289 |
+
model = GPT300M(config)
|
| 290 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 291 |
+
|
| 292 |
+
# Load tokenizer
|
| 293 |
+
tokenizer_path = os.path.join(
|
| 294 |
+
os.path.dirname(checkpoint_path), "tokenizer.json"
|
| 295 |
+
)
|
| 296 |
+
if os.path.exists(tokenizer_path):
|
| 297 |
+
tokenizer = BPETokenizer.load(tokenizer_path)
|
| 298 |
+
else:
|
| 299 |
+
tokenizer = BPETokenizer(vocab_size=config.vocab_size)
|
| 300 |
+
print("Warning: Tokenizer not found, using untrained tokenizer")
|
| 301 |
+
|
| 302 |
+
return model, tokenizer, config
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββ
|
| 306 |
+
# MAIN
|
| 307 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
import os
|
| 311 |
+
|
| 312 |
+
parser = argparse.ArgumentParser(description="GPT-300M Chatbot")
|
| 313 |
+
parser.add_argument("--checkpoint", type=str, default=None,
|
| 314 |
+
help="Path to model checkpoint")
|
| 315 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
| 316 |
+
parser.add_argument("--top_k", type=int, default=50)
|
| 317 |
+
parser.add_argument("--top_p", type=float, default=0.9)
|
| 318 |
+
parser.add_argument("--max_tokens", type=int, default=512)
|
| 319 |
+
parser.add_argument("--device", type=str, default="auto")
|
| 320 |
+
args = parser.parse_args()
|
| 321 |
+
|
| 322 |
+
if args.checkpoint and os.path.exists(args.checkpoint):
|
| 323 |
+
model, tokenizer, config = load_model(args.checkpoint, args.device)
|
| 324 |
+
else:
|
| 325 |
+
print("No checkpoint provided. Initializing random model for demo...")
|
| 326 |
+
from config import gpt_tiny
|
| 327 |
+
config = gpt_tiny()
|
| 328 |
+
model = GPT300M(config)
|
| 329 |
+
tokenizer = BPETokenizer(vocab_size=config.vocab_size)
|
| 330 |
+
# Quick train on minimal data
|
| 331 |
+
tokenizer.train("Hello! How are you? I am fine. " * 100)
|
| 332 |
+
|
| 333 |
+
config.temperature = args.temperature
|
| 334 |
+
config.top_k = args.top_k
|
| 335 |
+
config.top_p = args.top_p
|
| 336 |
+
config.max_new_tokens = args.max_tokens
|
| 337 |
+
|
| 338 |
+
chatbot = ChatBot(model, tokenizer, config, device=args.device)
|
| 339 |
+
interactive_chat(chatbot)
|
config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": ["GPT300M"],
|
| 3 |
+
"model_type": "gpt-300m",
|
| 4 |
+
"vocab_size": 32000,
|
| 5 |
+
"max_position_embeddings": 2048,
|
| 6 |
+
"hidden_size": 1024,
|
| 7 |
+
"num_attention_heads": 16,
|
| 8 |
+
"num_hidden_layers": 24,
|
| 9 |
+
"intermediate_size": 4096,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"dropout": 0.1,
|
| 12 |
+
"attention_dropout": 0.1,
|
| 13 |
+
"use_bias": false,
|
| 14 |
+
"tie_word_embeddings": true,
|
| 15 |
+
"rope_theta": 10000.0,
|
| 16 |
+
"rms_norm_eps": 1e-5,
|
| 17 |
+
"torch_dtype": "bfloat16",
|
| 18 |
+
"total_params": 334808064,
|
| 19 |
+
"total_params_trainable": 334808064
|
| 20 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M Configuration
|
| 3 |
+
======================
|
| 4 |
+
A ~300 million parameter autoregressive transformer language model.
|
| 5 |
+
Built entirely from scratch β no pretrained weights, no fine-tuning.
|
| 6 |
+
|
| 7 |
+
Parameter budget breakdown:
|
| 8 |
+
- Token Embeddings: vocab_size Γ d_model = 32,000 Γ 1,024 = 32.8M
|
| 9 |
+
- Position Embeddings: max_seq_len Γ d_model = 2,048 Γ 1,024 = 2.1M
|
| 10 |
+
- Transformer Layers (Γ24):
|
| 11 |
+
- Multi-Head Attention (Q/K/V/O): 4 Γ d_modelΒ² = 4 Γ 1,048,576 = 4.2M each
|
| 12 |
+
- Feed-Forward Network: 2 Γ d_model Γ d_ff = 2 Γ 1,024 Γ 4,096 = 8.4M each
|
| 13 |
+
- LayerNorms: negligible
|
| 14 |
+
- Per layer total: ~12.6M
|
| 15 |
+
- All 24 layers: ~302M
|
| 16 |
+
- Final LayerNorm + LM Head (tied with embeddings): ~0
|
| 17 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
TOTAL: ~337M parameters (LM head weight-tied β ~304M unique)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from typing import Optional
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class GPT300MConfig:
|
| 29 |
+
"""Configuration for a ~300M parameter GPT model."""
|
| 30 |
+
|
| 31 |
+
# ββ Model Architecture ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
vocab_size: int = 32_000 # BPE vocabulary size
|
| 33 |
+
max_seq_len: int = 2_048 # Maximum sequence length (context window)
|
| 34 |
+
d_model: int = 1_024 # Hidden dimension / embedding size
|
| 35 |
+
n_heads: int = 16 # Number of attention heads
|
| 36 |
+
n_layers: int = 24 # Number of transformer blocks
|
| 37 |
+
d_ff: int = 4_096 # Feed-forward intermediate dimension
|
| 38 |
+
dropout: float = 0.1 # Dropout probability
|
| 39 |
+
bias: bool = False # Use bias in linear layers (modern GPTs skip this)
|
| 40 |
+
tie_weights: bool = True # Tie token embedding and LM head weights
|
| 41 |
+
activation: str = "gelu" # Activation function: "gelu" or "swiglu"
|
| 42 |
+
norm_eps: float = 1e-5 # LayerNorm epsilon
|
| 43 |
+
rope: bool = True # Use Rotary Position Embeddings (RoPE)
|
| 44 |
+
rope_theta: float = 10_000.0 # RoPE base frequency
|
| 45 |
+
|
| 46 |
+
# ββ Training Hyperparameters ββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
batch_size: int = 32 # Micro-batch size per GPU
|
| 48 |
+
gradient_accumulation_steps: int = 8 # Effective batch = batch_size Γ grad_accum Γ n_gpus
|
| 49 |
+
learning_rate: float = 3e-4 # Peak learning rate
|
| 50 |
+
min_learning_rate: float = 3e-5 # Minimum LR after cosine decay
|
| 51 |
+
weight_decay: float = 0.1 # AdamW weight decay
|
| 52 |
+
beta1: float = 0.9 # Adam beta1
|
| 53 |
+
beta2: float = 0.95 # Adam beta2
|
| 54 |
+
max_grad_norm: float = 1.0 # Gradient clipping norm
|
| 55 |
+
warmup_steps: int = 2_000 # Linear warmup steps
|
| 56 |
+
max_steps: int = 600_000 # Total training steps
|
| 57 |
+
eval_interval: int = 1_000 # Evaluate every N steps
|
| 58 |
+
save_interval: int = 5_000 # Save checkpoint every N steps
|
| 59 |
+
log_interval: int = 10 # Log metrics every N steps
|
| 60 |
+
|
| 61 |
+
# ββ Data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
data_dir: str = "./data" # Directory containing tokenized .bin shards
|
| 63 |
+
train_split: float = 0.98 # Train/val split ratio
|
| 64 |
+
num_workers: int = 4 # DataLoader workers
|
| 65 |
+
|
| 66 |
+
# ββ System ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
device: str = "auto" # "auto", "cuda", "cpu", "mps"
|
| 68 |
+
dtype: str = "bfloat16" # "float32", "float16", "bfloat16"
|
| 69 |
+
compile_model: bool = True # Use torch.compile (PyTorch 2.0+)
|
| 70 |
+
output_dir: str = "./checkpoints" # Where to save checkpoints
|
| 71 |
+
wandb_project: str = "gpt-300m" # Weights & Biases project name
|
| 72 |
+
wandb_run_name: Optional[str] = None
|
| 73 |
+
seed: int = 42
|
| 74 |
+
|
| 75 |
+
# ββ Chat / Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
+
temperature: float = 0.7 # Sampling temperature
|
| 77 |
+
top_k: int = 50 # Top-k sampling
|
| 78 |
+
top_p: float = 0.9 # Nucleus sampling threshold
|
| 79 |
+
max_new_tokens: int = 512 # Max tokens to generate per turn
|
| 80 |
+
repetition_penalty: float = 1.1 # Penalize repeated tokens
|
| 81 |
+
chat_template: str = (
|
| 82 |
+
"<|system|>{system}<|end|>"
|
| 83 |
+
"<|user|>{user}<|end|>"
|
| 84 |
+
"<|assistant|>"
|
| 85 |
+
)
|
| 86 |
+
system_prompt: str = (
|
| 87 |
+
"You are a helpful, harmless, and honest AI assistant. "
|
| 88 |
+
"Respond naturally and conversationally."
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# ββ Special Token IDs (set during tokenizer init) βββββββββββββββββββ
|
| 92 |
+
pad_token_id: int = 0
|
| 93 |
+
bos_token_id: int = 1
|
| 94 |
+
eos_token_id: int = 2
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def head_dim(self) -> int:
|
| 98 |
+
assert self.d_model % self.n_heads == 0
|
| 99 |
+
return self.d_model // self.n_heads
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def total_params_estimate(self) -> int:
|
| 103 |
+
emb = self.vocab_size * self.d_model
|
| 104 |
+
pos = self.max_seq_len * self.d_model if not self.rope else 0
|
| 105 |
+
attn = 4 * self.d_model * self.d_model * self.n_layers
|
| 106 |
+
ffn = 2 * self.d_model * self.d_ff * self.n_layers
|
| 107 |
+
ln = 2 * self.d_model * self.n_layers + self.d_model
|
| 108 |
+
tied = 0 if self.tie_weights else self.vocab_size * self.d_model
|
| 109 |
+
return emb + pos + attn + ffn + ln + tied
|
| 110 |
+
|
| 111 |
+
def save(self, path: str):
|
| 112 |
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
| 113 |
+
with open(path, "w") as f:
|
| 114 |
+
json.dump(self.__dict__, f, indent=2)
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def load(cls, path: str) -> "GPT300MConfig":
|
| 118 |
+
with open(path) as f:
|
| 119 |
+
return cls(**json.load(f))
|
| 120 |
+
|
| 121 |
+
def __post_init__(self):
|
| 122 |
+
assert self.d_model % self.n_heads == 0, (
|
| 123 |
+
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ββ Preset Configs ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
def gpt_300m() -> GPT300MConfig:
|
| 130 |
+
"""Default 300M config."""
|
| 131 |
+
return GPT300MConfig()
|
| 132 |
+
|
| 133 |
+
def gpt_125m() -> GPT300MConfig:
|
| 134 |
+
"""Smaller 125M config for testing."""
|
| 135 |
+
return GPT300MConfig(
|
| 136 |
+
d_model=768, n_heads=12, n_layers=12, d_ff=3072,
|
| 137 |
+
max_seq_len=1024, batch_size=64
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def gpt_tiny() -> GPT300MConfig:
|
| 141 |
+
"""Tiny config for debugging."""
|
| 142 |
+
return GPT300MConfig(
|
| 143 |
+
d_model=128, n_heads=4, n_layers=4, d_ff=512,
|
| 144 |
+
vocab_size=1000, max_seq_len=256, batch_size=8
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
cfg = gpt_300m()
|
| 150 |
+
print(f"GPT-300M Configuration")
|
| 151 |
+
print(f" Estimated parameters: {cfg.total_params_estimate:,}")
|
| 152 |
+
print(f" d_model: {cfg.d_model}")
|
| 153 |
+
print(f" n_heads: {cfg.n_heads}")
|
| 154 |
+
print(f" n_layers: {cfg.n_layers}")
|
| 155 |
+
print(f" d_ff: {cfg.d_ff}")
|
| 156 |
+
print(f" vocab_size: {cfg.vocab_size}")
|
| 157 |
+
print(f" max_seq_len: {cfg.max_seq_len}")
|
dataset.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset & DataLoader for GPT-300M
|
| 3 |
+
==================================
|
| 4 |
+
Handles loading, tokenizing, and batching text data for training.
|
| 5 |
+
|
| 6 |
+
Supports two modes:
|
| 7 |
+
1. Pre-tokenized binary shards (fast, for large-scale training)
|
| 8 |
+
2. Raw text files (convenient, for small datasets)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import glob
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
from typing import List, Optional
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 19 |
+
|
| 20 |
+
from config import GPT300MConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TextDataset(Dataset):
|
| 24 |
+
"""
|
| 25 |
+
Simple dataset that loads raw text, tokenizes it, and creates
|
| 26 |
+
fixed-length training sequences.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
text: str,
|
| 32 |
+
tokenizer,
|
| 33 |
+
seq_len: int,
|
| 34 |
+
stride: Optional[int] = None,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
text: Raw text data
|
| 39 |
+
tokenizer: BPETokenizer instance
|
| 40 |
+
seq_len: Sequence length for training
|
| 41 |
+
stride: Sliding window stride (default: seq_len // 2)
|
| 42 |
+
"""
|
| 43 |
+
self.seq_len = seq_len
|
| 44 |
+
self.stride = stride or seq_len // 2
|
| 45 |
+
|
| 46 |
+
# Tokenize the entire text
|
| 47 |
+
self.token_ids = tokenizer.encode(text, add_special_tokens=False)
|
| 48 |
+
self.token_ids = torch.tensor(self.token_ids, dtype=torch.long)
|
| 49 |
+
|
| 50 |
+
# Calculate number of sequences
|
| 51 |
+
self.n_sequences = max(0, (len(self.token_ids) - seq_len - 1) // self.stride + 1)
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return self.n_sequences
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
start = idx * self.stride
|
| 58 |
+
end = start + self.seq_len + 1 # +1 for target offset
|
| 59 |
+
chunk = self.token_ids[start:end]
|
| 60 |
+
|
| 61 |
+
x = chunk[:-1] # Input: tokens[0..seq_len-1]
|
| 62 |
+
y = chunk[1:] # Target: tokens[1..seq_len]
|
| 63 |
+
return x, y
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ChatDataset(Dataset):
|
| 67 |
+
"""
|
| 68 |
+
Dataset for chat/conversation data.
|
| 69 |
+
Each sample is a multi-turn conversation formatted with special tokens.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
conversations: List[List[dict]],
|
| 75 |
+
tokenizer,
|
| 76 |
+
max_seq_len: int,
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
conversations: List of conversations, each a list of
|
| 81 |
+
{"role": "user"|"assistant"|"system", "content": "..."}
|
| 82 |
+
tokenizer: BPETokenizer instance
|
| 83 |
+
max_seq_len: Maximum sequence length
|
| 84 |
+
"""
|
| 85 |
+
self.max_seq_len = max_seq_len
|
| 86 |
+
self.samples = []
|
| 87 |
+
|
| 88 |
+
for conv in conversations:
|
| 89 |
+
ids = tokenizer.encode_chat(conv, add_generation_prompt=False)
|
| 90 |
+
ids.append(tokenizer.special_tokens["<eos>"])
|
| 91 |
+
|
| 92 |
+
# Truncate if needed
|
| 93 |
+
if len(ids) > max_seq_len + 1:
|
| 94 |
+
ids = ids[:max_seq_len + 1]
|
| 95 |
+
|
| 96 |
+
if len(ids) >= 4: # Minimum meaningful length
|
| 97 |
+
self.samples.append(torch.tensor(ids, dtype=torch.long))
|
| 98 |
+
|
| 99 |
+
def __len__(self):
|
| 100 |
+
return len(self.samples)
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
tokens = self.samples[idx]
|
| 104 |
+
x = tokens[:-1]
|
| 105 |
+
y = tokens[1:]
|
| 106 |
+
return x, y
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ShardedDataset(IterableDataset):
|
| 110 |
+
"""
|
| 111 |
+
Efficient iterable dataset that streams from pre-tokenized binary shards.
|
| 112 |
+
Used for large-scale training where data doesn't fit in memory.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
data_dir: str,
|
| 118 |
+
seq_len: int,
|
| 119 |
+
split: str = "train",
|
| 120 |
+
seed: int = 42,
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.seq_len = seq_len
|
| 124 |
+
self.seed = seed
|
| 125 |
+
|
| 126 |
+
# Find shard files
|
| 127 |
+
pattern = os.path.join(data_dir, f"{split}_*.bin")
|
| 128 |
+
self.shards = sorted(glob.glob(pattern))
|
| 129 |
+
if not self.shards:
|
| 130 |
+
raise FileNotFoundError(f"No shards found matching: {pattern}")
|
| 131 |
+
|
| 132 |
+
print(f"Found {len(self.shards)} {split} shards")
|
| 133 |
+
|
| 134 |
+
def __iter__(self):
|
| 135 |
+
rng = random.Random(self.seed)
|
| 136 |
+
shards = list(self.shards)
|
| 137 |
+
rng.shuffle(shards)
|
| 138 |
+
|
| 139 |
+
for shard_path in shards:
|
| 140 |
+
# Memory-map the shard for efficiency
|
| 141 |
+
data = np.memmap(shard_path, dtype=np.uint16, mode="r")
|
| 142 |
+
n_tokens = len(data)
|
| 143 |
+
n_chunks = n_tokens // (self.seq_len + 1)
|
| 144 |
+
|
| 145 |
+
# Random order within shard
|
| 146 |
+
indices = list(range(n_chunks))
|
| 147 |
+
rng.shuffle(indices)
|
| 148 |
+
|
| 149 |
+
for idx in indices:
|
| 150 |
+
start = idx * (self.seq_len + 1)
|
| 151 |
+
chunk = torch.from_numpy(
|
| 152 |
+
data[start : start + self.seq_len + 1].astype(np.int64)
|
| 153 |
+
)
|
| 154 |
+
x = chunk[:-1]
|
| 155 |
+
y = chunk[1:]
|
| 156 |
+
yield x, y
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def collate_fn(batch, pad_id: int = 0):
|
| 160 |
+
"""
|
| 161 |
+
Collate function that pads sequences to the same length within a batch.
|
| 162 |
+
"""
|
| 163 |
+
xs, ys = zip(*batch)
|
| 164 |
+
max_len = max(x.size(0) for x in xs)
|
| 165 |
+
|
| 166 |
+
padded_x = torch.full((len(xs), max_len), pad_id, dtype=torch.long)
|
| 167 |
+
padded_y = torch.full((len(ys), max_len), pad_id, dtype=torch.long)
|
| 168 |
+
|
| 169 |
+
for i, (x, y) in enumerate(zip(xs, ys)):
|
| 170 |
+
padded_x[i, :x.size(0)] = x
|
| 171 |
+
padded_y[i, :y.size(0)] = y
|
| 172 |
+
|
| 173 |
+
return padded_x, padded_y
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def create_dataloaders(
|
| 177 |
+
config: GPT300MConfig,
|
| 178 |
+
tokenizer,
|
| 179 |
+
text: Optional[str] = None,
|
| 180 |
+
conversations: Optional[List[List[dict]]] = None,
|
| 181 |
+
) -> tuple:
|
| 182 |
+
"""
|
| 183 |
+
Create train and validation DataLoaders.
|
| 184 |
+
|
| 185 |
+
Supply either `text` for raw text training or `conversations` for chat training.
|
| 186 |
+
"""
|
| 187 |
+
if text is not None:
|
| 188 |
+
# Split into train/val
|
| 189 |
+
split = int(len(text) * config.train_split)
|
| 190 |
+
train_text = text[:split]
|
| 191 |
+
val_text = text[split:]
|
| 192 |
+
|
| 193 |
+
train_ds = TextDataset(train_text, tokenizer, config.max_seq_len)
|
| 194 |
+
val_ds = TextDataset(val_text, tokenizer, config.max_seq_len)
|
| 195 |
+
|
| 196 |
+
elif conversations is not None:
|
| 197 |
+
split = int(len(conversations) * config.train_split)
|
| 198 |
+
train_convs = conversations[:split]
|
| 199 |
+
val_convs = conversations[split:]
|
| 200 |
+
|
| 201 |
+
train_ds = ChatDataset(train_convs, tokenizer, config.max_seq_len)
|
| 202 |
+
val_ds = ChatDataset(val_convs, tokenizer, config.max_seq_len)
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError("Provide either `text` or `conversations`")
|
| 205 |
+
|
| 206 |
+
train_dl = DataLoader(
|
| 207 |
+
train_ds,
|
| 208 |
+
batch_size=config.batch_size,
|
| 209 |
+
shuffle=True,
|
| 210 |
+
collate_fn=lambda b: collate_fn(b, config.pad_token_id),
|
| 211 |
+
num_workers=config.num_workers,
|
| 212 |
+
pin_memory=True,
|
| 213 |
+
drop_last=True,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
val_dl = DataLoader(
|
| 217 |
+
val_ds,
|
| 218 |
+
batch_size=config.batch_size,
|
| 219 |
+
shuffle=False,
|
| 220 |
+
collate_fn=lambda b: collate_fn(b, config.pad_token_id),
|
| 221 |
+
num_workers=config.num_workers,
|
| 222 |
+
pin_memory=True,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return train_dl, val_dl
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 229 |
+
# UTILITIES: Tokenize and save to binary shards
|
| 230 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 231 |
+
|
| 232 |
+
def tokenize_to_shards(
|
| 233 |
+
text: str,
|
| 234 |
+
tokenizer,
|
| 235 |
+
output_dir: str,
|
| 236 |
+
shard_size: int = 100_000_000, # ~100M tokens per shard
|
| 237 |
+
split: str = "train",
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Tokenize text and save to binary shards for efficient loading.
|
| 241 |
+
"""
|
| 242 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 243 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 244 |
+
|
| 245 |
+
shard_idx = 0
|
| 246 |
+
for start in range(0, len(tokens), shard_size):
|
| 247 |
+
end = min(start + shard_size, len(tokens))
|
| 248 |
+
chunk = np.array(tokens[start:end], dtype=np.uint16)
|
| 249 |
+
path = os.path.join(output_dir, f"{split}_{shard_idx:04d}.bin")
|
| 250 |
+
chunk.tofile(path)
|
| 251 |
+
shard_idx += 1
|
| 252 |
+
|
| 253 |
+
print(f"Saved {shard_idx} shards ({len(tokens):,} tokens) to {output_dir}")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
from tokenizer import BPETokenizer
|
| 258 |
+
|
| 259 |
+
# Quick test with synthetic data
|
| 260 |
+
tok = BPETokenizer(vocab_size=500)
|
| 261 |
+
sample_text = "Hello world! " * 1000
|
| 262 |
+
tok.train(sample_text)
|
| 263 |
+
|
| 264 |
+
ds = TextDataset(sample_text, tok, seq_len=64)
|
| 265 |
+
print(f"Dataset: {len(ds)} sequences of length 64")
|
| 266 |
+
|
| 267 |
+
x, y = ds[0]
|
| 268 |
+
print(f"Sample x: {x[:10]}")
|
| 269 |
+
print(f"Sample y: {y[:10]}")
|
model.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M Model Architecture
|
| 3 |
+
============================
|
| 4 |
+
A decoder-only transformer built entirely from scratch in PyTorch.
|
| 5 |
+
|
| 6 |
+
Architecture features:
|
| 7 |
+
- Pre-LayerNorm transformer blocks
|
| 8 |
+
- Rotary Position Embeddings (RoPE)
|
| 9 |
+
- Multi-Head Self-Attention with causal masking
|
| 10 |
+
- GELU activation in feed-forward layers
|
| 11 |
+
- Optional weight tying (token embeddings β LM head)
|
| 12 |
+
- KV-Cache for efficient autoregressive generation
|
| 13 |
+
- Flash Attention support (PyTorch 2.0+)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from typing import Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from config import GPT300MConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
# ROTARY POSITION EMBEDDINGS (RoPE)
|
| 28 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
|
| 30 |
+
class RotaryEmbedding(nn.Module):
|
| 31 |
+
"""Rotary Position Embedding (Su et al., 2021)."""
|
| 32 |
+
|
| 33 |
+
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
| 34 |
+
super().__init__()
|
| 35 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 36 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 37 |
+
|
| 38 |
+
# Precompute cos/sin tables
|
| 39 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 40 |
+
freqs = torch.outer(t, inv_freq)
|
| 41 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 42 |
+
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| 43 |
+
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, seq_len: int, offset: int = 0):
|
| 46 |
+
return (
|
| 47 |
+
self.cos_cached[offset : offset + seq_len],
|
| 48 |
+
self.sin_cached[offset : offset + seq_len],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Rotate the second half of the last dimension."""
|
| 54 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 55 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def apply_rotary_emb(
|
| 59 |
+
q: torch.Tensor, k: torch.Tensor,
|
| 60 |
+
cos: torch.Tensor, sin: torch.Tensor
|
| 61 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 62 |
+
"""Apply rotary embeddings to query and key tensors."""
|
| 63 |
+
# cos/sin shape: [seq_len, head_dim] β [1, 1, seq_len, head_dim]
|
| 64 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 65 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 66 |
+
q_rot = q * cos + rotate_half(q) * sin
|
| 67 |
+
k_rot = k * cos + rotate_half(k) * sin
|
| 68 |
+
return q_rot, k_rot
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
# RMSNORM (faster alternative to LayerNorm)
|
| 73 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
|
| 75 |
+
class RMSNorm(nn.Module):
|
| 76 |
+
"""Root Mean Square Layer Normalization."""
|
| 77 |
+
|
| 78 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.eps = eps
|
| 81 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 82 |
+
|
| 83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
| 85 |
+
return (x.float() * norm).type_as(x) * self.weight
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
# MULTI-HEAD SELF-ATTENTION
|
| 90 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 91 |
+
|
| 92 |
+
class MultiHeadAttention(nn.Module):
|
| 93 |
+
"""Multi-Head Self-Attention with causal masking and optional KV-cache."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, config: GPT300MConfig):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.n_heads = config.n_heads
|
| 98 |
+
self.head_dim = config.head_dim
|
| 99 |
+
self.d_model = config.d_model
|
| 100 |
+
self.dropout = config.dropout
|
| 101 |
+
|
| 102 |
+
# Q, K, V projections (fused for efficiency)
|
| 103 |
+
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
|
| 104 |
+
# Output projection
|
| 105 |
+
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
|
| 106 |
+
|
| 107 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 108 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 109 |
+
|
| 110 |
+
# Check for Flash Attention support
|
| 111 |
+
self.flash_attn = hasattr(F, "scaled_dot_product_attention")
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
x: torch.Tensor,
|
| 116 |
+
cos: Optional[torch.Tensor] = None,
|
| 117 |
+
sin: Optional[torch.Tensor] = None,
|
| 118 |
+
mask: Optional[torch.Tensor] = None,
|
| 119 |
+
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 120 |
+
use_cache: bool = False,
|
| 121 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 122 |
+
B, T, C = x.shape
|
| 123 |
+
|
| 124 |
+
# Project to Q, K, V
|
| 125 |
+
qkv = self.qkv_proj(x)
|
| 126 |
+
q, k, v = qkv.split(self.d_model, dim=-1)
|
| 127 |
+
|
| 128 |
+
# Reshape: [B, T, n_heads, head_dim] β [B, n_heads, T, head_dim]
|
| 129 |
+
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 130 |
+
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 131 |
+
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
# Apply RoPE
|
| 134 |
+
if cos is not None and sin is not None:
|
| 135 |
+
q, k = apply_rotary_emb(q, k, cos, sin)
|
| 136 |
+
|
| 137 |
+
# KV-Cache for generation
|
| 138 |
+
if kv_cache is not None:
|
| 139 |
+
k_prev, v_prev = kv_cache
|
| 140 |
+
k = torch.cat([k_prev, k], dim=2)
|
| 141 |
+
v = torch.cat([v_prev, v], dim=2)
|
| 142 |
+
|
| 143 |
+
new_cache = (k, v) if use_cache else None
|
| 144 |
+
|
| 145 |
+
# Attention
|
| 146 |
+
if self.flash_attn and not use_cache:
|
| 147 |
+
# Use PyTorch's efficient SDPA
|
| 148 |
+
attn_out = F.scaled_dot_product_attention(
|
| 149 |
+
q, k, v,
|
| 150 |
+
attn_mask=mask,
|
| 151 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 152 |
+
is_causal=True if mask is None else False,
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
# Manual attention for compatibility / KV-cache
|
| 156 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 157 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 158 |
+
|
| 159 |
+
if mask is not None:
|
| 160 |
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
| 161 |
+
else:
|
| 162 |
+
# Causal mask
|
| 163 |
+
T_q, T_k = q.size(2), k.size(2)
|
| 164 |
+
causal = torch.tril(torch.ones(T_q, T_k, device=x.device, dtype=torch.bool))
|
| 165 |
+
# For KV-cache, the causal mask must align with key length
|
| 166 |
+
causal = causal[-T:, :] # last T rows
|
| 167 |
+
scores = scores.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 168 |
+
|
| 169 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 170 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 171 |
+
attn_out = torch.matmul(attn_weights, v)
|
| 172 |
+
|
| 173 |
+
# Reshape back and project
|
| 174 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
|
| 175 |
+
out = self.resid_dropout(self.out_proj(attn_out))
|
| 176 |
+
|
| 177 |
+
return out, new_cache
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 181 |
+
# FEED-FORWARD NETWORK
|
| 182 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 183 |
+
|
| 184 |
+
class FeedForward(nn.Module):
|
| 185 |
+
"""Position-wise Feed-Forward Network with GELU activation."""
|
| 186 |
+
|
| 187 |
+
def __init__(self, config: GPT300MConfig):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
|
| 190 |
+
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
|
| 191 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 192 |
+
|
| 193 |
+
if config.activation == "gelu":
|
| 194 |
+
self.act = nn.GELU()
|
| 195 |
+
elif config.activation == "swiglu":
|
| 196 |
+
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
|
| 197 |
+
self.act = nn.SiLU()
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"Unknown activation: {config.activation}")
|
| 200 |
+
|
| 201 |
+
self.use_swiglu = config.activation == "swiglu"
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
if self.use_swiglu:
|
| 205 |
+
return self.dropout(self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)))
|
| 206 |
+
else:
|
| 207 |
+
return self.dropout(self.down_proj(self.act(self.up_proj(x))))
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 211 |
+
# TRANSFORMER BLOCK
|
| 212 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 213 |
+
|
| 214 |
+
class TransformerBlock(nn.Module):
|
| 215 |
+
"""Pre-norm Transformer block: LayerNorm β Attention β Residual β LayerNorm β FFN β Residual."""
|
| 216 |
+
|
| 217 |
+
def __init__(self, config: GPT300MConfig, layer_idx: int):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.layer_idx = layer_idx
|
| 220 |
+
self.ln1 = RMSNorm(config.d_model, eps=config.norm_eps)
|
| 221 |
+
self.attn = MultiHeadAttention(config)
|
| 222 |
+
self.ln2 = RMSNorm(config.d_model, eps=config.norm_eps)
|
| 223 |
+
self.ffn = FeedForward(config)
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
x: torch.Tensor,
|
| 228 |
+
cos: Optional[torch.Tensor] = None,
|
| 229 |
+
sin: Optional[torch.Tensor] = None,
|
| 230 |
+
mask: Optional[torch.Tensor] = None,
|
| 231 |
+
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 232 |
+
use_cache: bool = False,
|
| 233 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 234 |
+
# Pre-norm attention with residual
|
| 235 |
+
residual = x
|
| 236 |
+
x = self.ln1(x)
|
| 237 |
+
attn_out, new_cache = self.attn(x, cos, sin, mask, kv_cache, use_cache)
|
| 238 |
+
x = residual + attn_out
|
| 239 |
+
|
| 240 |
+
# Pre-norm FFN with residual
|
| 241 |
+
x = x + self.ffn(self.ln2(x))
|
| 242 |
+
|
| 243 |
+
return x, new_cache
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 247 |
+
# GPT-300M: THE FULL MODEL
|
| 248 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 249 |
+
|
| 250 |
+
class GPT300M(nn.Module):
|
| 251 |
+
"""
|
| 252 |
+
GPT-300M: A 300-million parameter autoregressive language model.
|
| 253 |
+
|
| 254 |
+
Architecture:
|
| 255 |
+
Token Embedding β [Transformer Block Γ 24] β RMSNorm β LM Head
|
| 256 |
+
|
| 257 |
+
Each Transformer Block:
|
| 258 |
+
RMSNorm β Multi-Head Attention (+ RoPE) β Residual
|
| 259 |
+
β RMSNorm β Feed-Forward (GELU) β Residual
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, config: GPT300MConfig):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.config = config
|
| 265 |
+
|
| 266 |
+
# ββ Embeddings βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 267 |
+
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
|
| 268 |
+
self.drop = nn.Dropout(config.dropout)
|
| 269 |
+
|
| 270 |
+
# Rotary embeddings
|
| 271 |
+
if config.rope:
|
| 272 |
+
self.rotary = RotaryEmbedding(
|
| 273 |
+
config.head_dim, config.max_seq_len, config.rope_theta
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
|
| 277 |
+
|
| 278 |
+
# ββ Transformer Blocks βββββββββββββββββββββββββββββββββββββββ
|
| 279 |
+
self.layers = nn.ModuleList([
|
| 280 |
+
TransformerBlock(config, layer_idx=i)
|
| 281 |
+
for i in range(config.n_layers)
|
| 282 |
+
])
|
| 283 |
+
|
| 284 |
+
# ββ Output βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 285 |
+
self.ln_f = RMSNorm(config.d_model, eps=config.norm_eps)
|
| 286 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 287 |
+
|
| 288 |
+
# Weight tying
|
| 289 |
+
if config.tie_weights:
|
| 290 |
+
self.lm_head.weight = self.token_emb.weight
|
| 291 |
+
|
| 292 |
+
# Initialize weights
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
# Scale residual projections
|
| 295 |
+
for pn, p in self.named_parameters():
|
| 296 |
+
if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"):
|
| 297 |
+
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers))
|
| 298 |
+
|
| 299 |
+
def _init_weights(self, module: nn.Module):
|
| 300 |
+
if isinstance(module, nn.Linear):
|
| 301 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 302 |
+
if module.bias is not None:
|
| 303 |
+
nn.init.zeros_(module.bias)
|
| 304 |
+
elif isinstance(module, nn.Embedding):
|
| 305 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 306 |
+
|
| 307 |
+
def forward(
|
| 308 |
+
self,
|
| 309 |
+
input_ids: torch.Tensor,
|
| 310 |
+
targets: Optional[torch.Tensor] = None,
|
| 311 |
+
kv_caches: Optional[list] = None,
|
| 312 |
+
use_cache: bool = False,
|
| 313 |
+
position_offset: int = 0,
|
| 314 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list]]:
|
| 315 |
+
"""
|
| 316 |
+
Forward pass.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
input_ids: [B, T] token indices
|
| 320 |
+
targets: [B, T] target token indices for loss computation
|
| 321 |
+
kv_caches: List of KV-cache tuples, one per layer
|
| 322 |
+
use_cache: Whether to return updated KV-caches
|
| 323 |
+
position_offset: Offset for position embeddings (for KV-cache generation)
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
logits: [B, T, vocab_size]
|
| 327 |
+
loss: scalar loss if targets provided, else None
|
| 328 |
+
new_caches: Updated KV-caches if use_cache=True
|
| 329 |
+
"""
|
| 330 |
+
B, T = input_ids.shape
|
| 331 |
+
assert T <= self.config.max_seq_len, (
|
| 332 |
+
f"Sequence length {T} exceeds max {self.config.max_seq_len}"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Token embeddings
|
| 336 |
+
x = self.token_emb(input_ids) # [B, T, d_model]
|
| 337 |
+
|
| 338 |
+
# Position information
|
| 339 |
+
if self.config.rope:
|
| 340 |
+
cos, sin = self.rotary(T, offset=position_offset)
|
| 341 |
+
else:
|
| 342 |
+
positions = torch.arange(position_offset, position_offset + T, device=input_ids.device)
|
| 343 |
+
x = x + self.pos_emb(positions)
|
| 344 |
+
cos, sin = None, None
|
| 345 |
+
|
| 346 |
+
x = self.drop(x)
|
| 347 |
+
|
| 348 |
+
# Transformer blocks
|
| 349 |
+
new_caches = [] if use_cache else None
|
| 350 |
+
for i, layer in enumerate(self.layers):
|
| 351 |
+
cache_i = kv_caches[i] if kv_caches is not None else None
|
| 352 |
+
x, new_cache = layer(x, cos, sin, kv_cache=cache_i, use_cache=use_cache)
|
| 353 |
+
if use_cache:
|
| 354 |
+
new_caches.append(new_cache)
|
| 355 |
+
|
| 356 |
+
# Final norm and LM head
|
| 357 |
+
x = self.ln_f(x)
|
| 358 |
+
logits = self.lm_head(x) # [B, T, vocab_size]
|
| 359 |
+
|
| 360 |
+
# Loss
|
| 361 |
+
loss = None
|
| 362 |
+
if targets is not None:
|
| 363 |
+
loss = F.cross_entropy(
|
| 364 |
+
logits.view(-1, self.config.vocab_size),
|
| 365 |
+
targets.view(-1),
|
| 366 |
+
ignore_index=self.config.pad_token_id,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return logits, loss, new_caches
|
| 370 |
+
|
| 371 |
+
@torch.no_grad()
|
| 372 |
+
def generate(
|
| 373 |
+
self,
|
| 374 |
+
input_ids: torch.Tensor,
|
| 375 |
+
max_new_tokens: int = 256,
|
| 376 |
+
temperature: float = 0.7,
|
| 377 |
+
top_k: int = 50,
|
| 378 |
+
top_p: float = 0.9,
|
| 379 |
+
repetition_penalty: float = 1.1,
|
| 380 |
+
eos_token_id: Optional[int] = None,
|
| 381 |
+
) -> torch.Tensor:
|
| 382 |
+
"""
|
| 383 |
+
Autoregressive generation with KV-cache.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
input_ids: [B, T] prompt token IDs
|
| 387 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 388 |
+
temperature: Sampling temperature
|
| 389 |
+
top_k: Top-k sampling
|
| 390 |
+
top_p: Nucleus sampling threshold
|
| 391 |
+
repetition_penalty: Penalty for repeated tokens
|
| 392 |
+
eos_token_id: Stop generation when this token is produced
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
[B, T + max_new_tokens] generated token IDs
|
| 396 |
+
"""
|
| 397 |
+
self.eval()
|
| 398 |
+
B, T = input_ids.shape
|
| 399 |
+
device = input_ids.device
|
| 400 |
+
|
| 401 |
+
# Initial forward pass to populate KV-cache
|
| 402 |
+
logits, _, kv_caches = self.forward(input_ids, use_cache=True)
|
| 403 |
+
|
| 404 |
+
generated = input_ids
|
| 405 |
+
all_token_ids = input_ids.tolist()[0] if B == 1 else []
|
| 406 |
+
|
| 407 |
+
for step in range(max_new_tokens):
|
| 408 |
+
# Get logits for the last token
|
| 409 |
+
next_logits = logits[:, -1, :] # [B, vocab_size]
|
| 410 |
+
|
| 411 |
+
# Repetition penalty
|
| 412 |
+
if repetition_penalty != 1.0 and B == 1:
|
| 413 |
+
for token_id in set(all_token_ids):
|
| 414 |
+
if next_logits[0, token_id] > 0:
|
| 415 |
+
next_logits[0, token_id] /= repetition_penalty
|
| 416 |
+
else:
|
| 417 |
+
next_logits[0, token_id] *= repetition_penalty
|
| 418 |
+
|
| 419 |
+
# Temperature
|
| 420 |
+
if temperature > 0:
|
| 421 |
+
next_logits = next_logits / temperature
|
| 422 |
+
|
| 423 |
+
# Top-k filtering
|
| 424 |
+
if top_k > 0:
|
| 425 |
+
topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| 426 |
+
next_logits[next_logits < topk_vals[:, -1:]] = float("-inf")
|
| 427 |
+
|
| 428 |
+
# Top-p (nucleus) filtering
|
| 429 |
+
if top_p < 1.0:
|
| 430 |
+
sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
|
| 431 |
+
cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 432 |
+
sorted_mask = cumprobs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| 433 |
+
sorted_logits[sorted_mask] = float("-inf")
|
| 434 |
+
next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
| 435 |
+
|
| 436 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 437 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 438 |
+
else:
|
| 439 |
+
# Greedy
|
| 440 |
+
next_token = next_logits.argmax(dim=-1, keepdim=True)
|
| 441 |
+
|
| 442 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 443 |
+
|
| 444 |
+
if B == 1:
|
| 445 |
+
all_token_ids.append(next_token.item())
|
| 446 |
+
|
| 447 |
+
# Stop on EOS
|
| 448 |
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
| 449 |
+
break
|
| 450 |
+
|
| 451 |
+
# Forward pass with KV-cache (only the new token)
|
| 452 |
+
position_offset = generated.size(1) - 1
|
| 453 |
+
logits, _, kv_caches = self.forward(
|
| 454 |
+
next_token,
|
| 455 |
+
kv_caches=kv_caches,
|
| 456 |
+
use_cache=True,
|
| 457 |
+
position_offset=position_offset,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
return generated
|
| 461 |
+
|
| 462 |
+
def count_parameters(self, trainable_only: bool = True) -> int:
|
| 463 |
+
"""Count model parameters."""
|
| 464 |
+
if trainable_only:
|
| 465 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 466 |
+
return sum(p.numel() for p in self.parameters())
|
| 467 |
+
|
| 468 |
+
def model_summary(self) -> str:
|
| 469 |
+
"""Print a human-readable model summary."""
|
| 470 |
+
total = self.count_parameters(trainable_only=False)
|
| 471 |
+
trainable = self.count_parameters(trainable_only=True)
|
| 472 |
+
lines = [
|
| 473 |
+
"=" * 60,
|
| 474 |
+
" GPT-300M Model Summary",
|
| 475 |
+
"=" * 60,
|
| 476 |
+
f" Total parameters: {total:>15,}",
|
| 477 |
+
f" Trainable parameters: {trainable:>15,}",
|
| 478 |
+
f" d_model: {self.config.d_model:>15}",
|
| 479 |
+
f" n_heads: {self.config.n_heads:>15}",
|
| 480 |
+
f" n_layers: {self.config.n_layers:>15}",
|
| 481 |
+
f" d_ff: {self.config.d_ff:>15}",
|
| 482 |
+
f" vocab_size: {self.config.vocab_size:>15}",
|
| 483 |
+
f" max_seq_len: {self.config.max_seq_len:>15}",
|
| 484 |
+
f" RoPE: {'Yes':>15}",
|
| 485 |
+
f" Weight tying: {'Yes' if self.config.tie_weights else 'No':>15}",
|
| 486 |
+
f" Flash Attention: {'Yes' if self.layers[0].attn.flash_attn else 'No':>15}",
|
| 487 |
+
"=" * 60,
|
| 488 |
+
]
|
| 489 |
+
return "\n".join(lines)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 493 |
+
# QUICK TEST
|
| 494 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 495 |
+
|
| 496 |
+
if __name__ == "__main__":
|
| 497 |
+
from config import gpt_tiny
|
| 498 |
+
|
| 499 |
+
# Use tiny config for testing
|
| 500 |
+
cfg = gpt_tiny()
|
| 501 |
+
model = GPT300M(cfg)
|
| 502 |
+
print(model.model_summary())
|
| 503 |
+
|
| 504 |
+
# Test forward pass
|
| 505 |
+
x = torch.randint(0, cfg.vocab_size, (2, 32))
|
| 506 |
+
targets = torch.randint(0, cfg.vocab_size, (2, 32))
|
| 507 |
+
logits, loss, _ = model(x, targets=targets)
|
| 508 |
+
print(f"\nForward pass OK: logits={logits.shape}, loss={loss.item():.4f}")
|
| 509 |
+
|
| 510 |
+
# Test generation
|
| 511 |
+
prompt = torch.randint(0, cfg.vocab_size, (1, 8))
|
| 512 |
+
gen = model.generate(prompt, max_new_tokens=16, temperature=0.8)
|
| 513 |
+
print(f"Generation OK: {gen.shape}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
matplotlib>=3.7.0
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "<pad>",
|
| 3 |
+
"bos_token": "<bos>",
|
| 4 |
+
"eos_token": "<eos>",
|
| 5 |
+
"unk_token": "<unk>",
|
| 6 |
+
"additional_special_tokens": [
|
| 7 |
+
"<|system|>",
|
| 8 |
+
"<|user|>",
|
| 9 |
+
"<|assistant|>",
|
| 10 |
+
"<|end|>"
|
| 11 |
+
]
|
| 12 |
+
}
|
tokenizer.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte-Pair Encoding (BPE) Tokenizer β Built From Scratch
|
| 3 |
+
========================================================
|
| 4 |
+
A minimal but complete BPE tokenizer implementation.
|
| 5 |
+
Supports training from raw text, encoding/decoding, and special chat tokens.
|
| 6 |
+
|
| 7 |
+
For production use, you'd typically use SentencePiece or tiktoken,
|
| 8 |
+
but this demonstrates the full tokenizer pipeline.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
from collections import Counter
|
| 15 |
+
from typing import Dict, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BPETokenizer:
|
| 19 |
+
"""
|
| 20 |
+
Byte-Pair Encoding tokenizer with special token support.
|
| 21 |
+
|
| 22 |
+
Special tokens:
|
| 23 |
+
<pad> = 0 Padding token
|
| 24 |
+
<bos> = 1 Beginning of sequence
|
| 25 |
+
<eos> = 2 End of sequence
|
| 26 |
+
<unk> = 3 Unknown token
|
| 27 |
+
<|system|> = 4 System prompt delimiter
|
| 28 |
+
<|user|> = 5 User turn delimiter
|
| 29 |
+
<|assistant|> = 6 Assistant turn delimiter
|
| 30 |
+
<|end|> = 7 End of turn
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
SPECIAL_TOKENS = {
|
| 34 |
+
"<pad>": 0,
|
| 35 |
+
"<bos>": 1,
|
| 36 |
+
"<eos>": 2,
|
| 37 |
+
"<unk>": 3,
|
| 38 |
+
"<|system|>": 4,
|
| 39 |
+
"<|user|>": 5,
|
| 40 |
+
"<|assistant|>": 6,
|
| 41 |
+
"<|end|>": 7,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# Pre-tokenization regex (GPT-2 style)
|
| 45 |
+
PAT = re.compile(
|
| 46 |
+
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""",
|
| 47 |
+
re.UNICODE,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def __init__(self, vocab_size: int = 32_000):
|
| 51 |
+
self.target_vocab_size = vocab_size
|
| 52 |
+
self.special_tokens = dict(self.SPECIAL_TOKENS)
|
| 53 |
+
self.num_special = len(self.special_tokens)
|
| 54 |
+
|
| 55 |
+
# Byte-level base vocab: map each byte (0-255) to a token ID
|
| 56 |
+
self.byte_to_id: Dict[int, int] = {
|
| 57 |
+
b: b + self.num_special for b in range(256)
|
| 58 |
+
}
|
| 59 |
+
self.id_to_byte: Dict[int, int] = {v: k for k, v in self.byte_to_id.items()}
|
| 60 |
+
|
| 61 |
+
# Merge rules learned during training
|
| 62 |
+
self.merges: List[Tuple[int, int]] = []
|
| 63 |
+
self.merge_to_id: Dict[Tuple[int, int], int] = {}
|
| 64 |
+
|
| 65 |
+
# Full vocab (built after training)
|
| 66 |
+
self.vocab: Dict[int, bytes] = {}
|
| 67 |
+
self._build_vocab()
|
| 68 |
+
|
| 69 |
+
def _build_vocab(self):
|
| 70 |
+
"""Reconstruct the full vocabulary from merges."""
|
| 71 |
+
self.vocab = {}
|
| 72 |
+
# Special tokens
|
| 73 |
+
for tok, idx in self.special_tokens.items():
|
| 74 |
+
self.vocab[idx] = tok.encode("utf-8")
|
| 75 |
+
# Byte-level tokens
|
| 76 |
+
for b in range(256):
|
| 77 |
+
self.vocab[self.num_special + b] = bytes([b])
|
| 78 |
+
# Merged tokens
|
| 79 |
+
for (a, b), idx in self.merge_to_id.items():
|
| 80 |
+
self.vocab[idx] = self.vocab[a] + self.vocab[b]
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def vocab_size(self) -> int:
|
| 84 |
+
return len(self.vocab)
|
| 85 |
+
|
| 86 |
+
# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
+
|
| 88 |
+
def train(self, text: str, verbose: bool = True):
|
| 89 |
+
"""
|
| 90 |
+
Train BPE merges from raw text.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
text: Raw training text
|
| 94 |
+
verbose: Print progress
|
| 95 |
+
"""
|
| 96 |
+
if verbose:
|
| 97 |
+
print(f"Training BPE tokenizer (target vocab: {self.target_vocab_size:,})...")
|
| 98 |
+
|
| 99 |
+
# Pre-tokenize into words
|
| 100 |
+
words = re.findall(self.PAT, text)
|
| 101 |
+
|
| 102 |
+
# Convert each word to a tuple of byte token IDs
|
| 103 |
+
word_freqs: Counter = Counter()
|
| 104 |
+
for word in words:
|
| 105 |
+
byte_ids = tuple(self.byte_to_id[b] for b in word.encode("utf-8"))
|
| 106 |
+
word_freqs[byte_ids] += 1
|
| 107 |
+
|
| 108 |
+
current_vocab_size = self.num_special + 256
|
| 109 |
+
num_merges = self.target_vocab_size - current_vocab_size
|
| 110 |
+
|
| 111 |
+
for i in range(num_merges):
|
| 112 |
+
# Count adjacent pairs
|
| 113 |
+
pair_counts: Counter = Counter()
|
| 114 |
+
for word, freq in word_freqs.items():
|
| 115 |
+
for j in range(len(word) - 1):
|
| 116 |
+
pair_counts[(word[j], word[j + 1])] += freq
|
| 117 |
+
|
| 118 |
+
if not pair_counts:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
# Find most frequent pair
|
| 122 |
+
best_pair = pair_counts.most_common(1)[0][0]
|
| 123 |
+
new_id = current_vocab_size
|
| 124 |
+
|
| 125 |
+
# Register merge
|
| 126 |
+
self.merges.append(best_pair)
|
| 127 |
+
self.merge_to_id[best_pair] = new_id
|
| 128 |
+
|
| 129 |
+
# Apply merge to all words
|
| 130 |
+
new_word_freqs: Counter = Counter()
|
| 131 |
+
for word, freq in word_freqs.items():
|
| 132 |
+
new_word = self._apply_merge(word, best_pair, new_id)
|
| 133 |
+
new_word_freqs[new_word] += freq
|
| 134 |
+
word_freqs = new_word_freqs
|
| 135 |
+
|
| 136 |
+
current_vocab_size += 1
|
| 137 |
+
|
| 138 |
+
if verbose and (i + 1) % 1000 == 0:
|
| 139 |
+
print(f" Merge {i + 1}/{num_merges}: "
|
| 140 |
+
f"({best_pair[0]}, {best_pair[1]}) β {new_id}, "
|
| 141 |
+
f"freq={pair_counts[best_pair]}")
|
| 142 |
+
|
| 143 |
+
self._build_vocab()
|
| 144 |
+
if verbose:
|
| 145 |
+
print(f"Done! Final vocab size: {self.vocab_size:,}")
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _apply_merge(
|
| 149 |
+
word: Tuple[int, ...], pair: Tuple[int, int], new_id: int
|
| 150 |
+
) -> Tuple[int, ...]:
|
| 151 |
+
"""Apply a single merge rule to a word."""
|
| 152 |
+
result = []
|
| 153 |
+
i = 0
|
| 154 |
+
while i < len(word):
|
| 155 |
+
if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
|
| 156 |
+
result.append(new_id)
|
| 157 |
+
i += 2
|
| 158 |
+
else:
|
| 159 |
+
result.append(word[i])
|
| 160 |
+
i += 1
|
| 161 |
+
return tuple(result)
|
| 162 |
+
|
| 163 |
+
# ββ Encoding ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 164 |
+
|
| 165 |
+
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
|
| 166 |
+
"""
|
| 167 |
+
Encode text to token IDs.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
text: Input text
|
| 171 |
+
add_special_tokens: Whether to wrap with <bos>/<eos>
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
List of token IDs
|
| 175 |
+
"""
|
| 176 |
+
tokens = []
|
| 177 |
+
|
| 178 |
+
# Check for special tokens in the text
|
| 179 |
+
parts = self._split_special_tokens(text)
|
| 180 |
+
|
| 181 |
+
for part, is_special in parts:
|
| 182 |
+
if is_special:
|
| 183 |
+
tokens.append(self.special_tokens[part])
|
| 184 |
+
else:
|
| 185 |
+
# Pre-tokenize
|
| 186 |
+
words = re.findall(self.PAT, part)
|
| 187 |
+
for word in words:
|
| 188 |
+
# Convert to byte IDs
|
| 189 |
+
byte_ids = list(self.byte_to_id[b] for b in word.encode("utf-8"))
|
| 190 |
+
# Apply merges in order
|
| 191 |
+
for pair, new_id in zip(self.merges, range(self.num_special + 256, self.vocab_size)):
|
| 192 |
+
i = 0
|
| 193 |
+
while i < len(byte_ids) - 1:
|
| 194 |
+
if (byte_ids[i], byte_ids[i + 1]) == pair:
|
| 195 |
+
byte_ids[i] = new_id
|
| 196 |
+
del byte_ids[i + 1]
|
| 197 |
+
else:
|
| 198 |
+
i += 1
|
| 199 |
+
tokens.extend(byte_ids)
|
| 200 |
+
|
| 201 |
+
if add_special_tokens:
|
| 202 |
+
tokens = [self.special_tokens["<bos>"]] + tokens + [self.special_tokens["<eos>"]]
|
| 203 |
+
|
| 204 |
+
return tokens
|
| 205 |
+
|
| 206 |
+
def _split_special_tokens(self, text: str) -> List[Tuple[str, bool]]:
|
| 207 |
+
"""Split text on special token boundaries."""
|
| 208 |
+
# Build regex to match special tokens
|
| 209 |
+
pattern = "|".join(re.escape(tok) for tok in self.special_tokens.keys())
|
| 210 |
+
if not pattern:
|
| 211 |
+
return [(text, False)]
|
| 212 |
+
|
| 213 |
+
parts = []
|
| 214 |
+
last_end = 0
|
| 215 |
+
for match in re.finditer(pattern, text):
|
| 216 |
+
if match.start() > last_end:
|
| 217 |
+
parts.append((text[last_end:match.start()], False))
|
| 218 |
+
parts.append((match.group(), True))
|
| 219 |
+
last_end = match.end()
|
| 220 |
+
if last_end < len(text):
|
| 221 |
+
parts.append((text[last_end:], False))
|
| 222 |
+
return parts
|
| 223 |
+
|
| 224 |
+
# ββ Decoding ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
|
| 226 |
+
def decode(self, ids: List[int], skip_special: bool = True) -> str:
|
| 227 |
+
"""
|
| 228 |
+
Decode token IDs to text.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
ids: List of token IDs
|
| 232 |
+
skip_special: Whether to skip special tokens
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Decoded text string
|
| 236 |
+
"""
|
| 237 |
+
byte_chunks = []
|
| 238 |
+
for idx in ids:
|
| 239 |
+
if idx in self.special_tokens.values():
|
| 240 |
+
if not skip_special:
|
| 241 |
+
# Find the special token string
|
| 242 |
+
for tok, tid in self.special_tokens.items():
|
| 243 |
+
if tid == idx:
|
| 244 |
+
byte_chunks.append(tok.encode("utf-8"))
|
| 245 |
+
break
|
| 246 |
+
elif idx in self.vocab:
|
| 247 |
+
byte_chunks.append(self.vocab[idx])
|
| 248 |
+
return b"".join(byte_chunks).decode("utf-8", errors="replace")
|
| 249 |
+
|
| 250 |
+
# ββ Chat Formatting βββββββββββββββββββββββββββββββββββββββββββββ
|
| 251 |
+
|
| 252 |
+
def encode_chat(
|
| 253 |
+
self,
|
| 254 |
+
messages: List[Dict[str, str]],
|
| 255 |
+
add_generation_prompt: bool = True,
|
| 256 |
+
) -> List[int]:
|
| 257 |
+
"""
|
| 258 |
+
Encode a chat conversation into token IDs.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
messages: List of {"role": "system"|"user"|"assistant", "content": "..."}
|
| 262 |
+
add_generation_prompt: Add the assistant turn start token at the end
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
List of token IDs
|
| 266 |
+
"""
|
| 267 |
+
tokens = [self.special_tokens["<bos>"]]
|
| 268 |
+
|
| 269 |
+
for msg in messages:
|
| 270 |
+
role = msg["role"]
|
| 271 |
+
content = msg["content"]
|
| 272 |
+
|
| 273 |
+
if role == "system":
|
| 274 |
+
tokens.append(self.special_tokens["<|system|>"])
|
| 275 |
+
elif role == "user":
|
| 276 |
+
tokens.append(self.special_tokens["<|user|>"])
|
| 277 |
+
elif role == "assistant":
|
| 278 |
+
tokens.append(self.special_tokens["<|assistant|>"])
|
| 279 |
+
|
| 280 |
+
tokens.extend(self.encode(content))
|
| 281 |
+
tokens.append(self.special_tokens["<|end|>"])
|
| 282 |
+
|
| 283 |
+
if add_generation_prompt:
|
| 284 |
+
tokens.append(self.special_tokens["<|assistant|>"])
|
| 285 |
+
|
| 286 |
+
return tokens
|
| 287 |
+
|
| 288 |
+
# ββ Save / Load βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 289 |
+
|
| 290 |
+
def save(self, path: str):
|
| 291 |
+
"""Save tokenizer to JSON."""
|
| 292 |
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
| 293 |
+
data = {
|
| 294 |
+
"target_vocab_size": self.target_vocab_size,
|
| 295 |
+
"merges": self.merges,
|
| 296 |
+
}
|
| 297 |
+
with open(path, "w") as f:
|
| 298 |
+
json.dump(data, f)
|
| 299 |
+
|
| 300 |
+
@classmethod
|
| 301 |
+
def load(cls, path: str) -> "BPETokenizer":
|
| 302 |
+
"""Load tokenizer from JSON."""
|
| 303 |
+
with open(path) as f:
|
| 304 |
+
data = json.load(f)
|
| 305 |
+
tok = cls(vocab_size=data["target_vocab_size"])
|
| 306 |
+
tok.merges = [tuple(m) for m in data["merges"]]
|
| 307 |
+
tok.merge_to_id = {
|
| 308 |
+
tuple(pair): idx
|
| 309 |
+
for idx, pair in enumerate(tok.merges, start=tok.num_special + 256)
|
| 310 |
+
}
|
| 311 |
+
tok._build_vocab()
|
| 312 |
+
return tok
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 316 |
+
# QUICK TEST
|
| 317 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
tok = BPETokenizer(vocab_size=500)
|
| 321 |
+
|
| 322 |
+
sample = (
|
| 323 |
+
"Hello, world! This is a test of the BPE tokenizer. "
|
| 324 |
+
"The quick brown fox jumps over the lazy dog. "
|
| 325 |
+
"Machine learning is fascinating and powerful. " * 20
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
tok.train(sample, verbose=True)
|
| 329 |
+
|
| 330 |
+
text = "Hello, world! Machine learning is great."
|
| 331 |
+
ids = tok.encode(text)
|
| 332 |
+
decoded = tok.decode(ids)
|
| 333 |
+
print(f"\nOriginal: {text}")
|
| 334 |
+
print(f"Token IDs: {ids[:20]}...")
|
| 335 |
+
print(f"Decoded: {decoded}")
|
| 336 |
+
|
| 337 |
+
# Test chat encoding
|
| 338 |
+
chat = [
|
| 339 |
+
{"role": "system", "content": "You are helpful."},
|
| 340 |
+
{"role": "user", "content": "Hello!"},
|
| 341 |
+
]
|
| 342 |
+
chat_ids = tok.encode_chat(chat)
|
| 343 |
+
print(f"\nChat IDs: {chat_ids[:20]}...")
|
| 344 |
+
print(f"Chat decoded: {tok.decode(chat_ids, skip_special=False)}")
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "BPETokenizer",
|
| 3 |
+
"model_type": "gpt-300m",
|
| 4 |
+
"vocab_size": 32000,
|
| 5 |
+
"model_max_length": 2048,
|
| 6 |
+
"padding_side": "right",
|
| 7 |
+
"bos_token": "<bos>",
|
| 8 |
+
"eos_token": "<eos>",
|
| 9 |
+
"pad_token": "<pad>",
|
| 10 |
+
"unk_token": "<unk>"
|
| 11 |
+
}
|
train.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M Training Script
|
| 3 |
+
=========================
|
| 4 |
+
Full training pipeline with:
|
| 5 |
+
- Mixed-precision training (bf16/fp16)
|
| 6 |
+
- Gradient accumulation
|
| 7 |
+
- Cosine learning rate schedule with warmup
|
| 8 |
+
- Gradient clipping
|
| 9 |
+
- Periodic evaluation & checkpointing
|
| 10 |
+
- Distributed Data Parallel (DDP) support
|
| 11 |
+
- Weights & Biases logging
|
| 12 |
+
- torch.compile support
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Single GPU
|
| 16 |
+
python train.py
|
| 17 |
+
|
| 18 |
+
# Multi-GPU with DDP
|
| 19 |
+
torchrun --nproc_per_node=4 train.py
|
| 20 |
+
|
| 21 |
+
# With custom config
|
| 22 |
+
python train.py --d_model 768 --n_layers 12 --batch_size 64
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import math
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import time
|
| 30 |
+
from contextlib import nullcontext
|
| 31 |
+
from typing import Optional
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.distributed as dist
|
| 36 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 37 |
+
|
| 38 |
+
from config import GPT300MConfig, gpt_300m, gpt_tiny
|
| 39 |
+
from model import GPT300M
|
| 40 |
+
from tokenizer import BPETokenizer
|
| 41 |
+
from dataset import TextDataset, ChatDataset, create_dataloaders, collate_fn
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
# LEARNING RATE SCHEDULER
|
| 46 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
|
| 48 |
+
def get_lr(step: int, config: GPT300MConfig) -> float:
|
| 49 |
+
"""Cosine decay with linear warmup."""
|
| 50 |
+
# Linear warmup
|
| 51 |
+
if step < config.warmup_steps:
|
| 52 |
+
return config.learning_rate * step / config.warmup_steps
|
| 53 |
+
|
| 54 |
+
# Cosine decay
|
| 55 |
+
if step > config.max_steps:
|
| 56 |
+
return config.min_learning_rate
|
| 57 |
+
|
| 58 |
+
decay_ratio = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
|
| 59 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
| 60 |
+
return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
# TRAINING LOOP
|
| 65 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
|
| 67 |
+
class Trainer:
|
| 68 |
+
"""
|
| 69 |
+
Full-featured training loop for GPT-300M.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: GPT300MConfig, resume_from: Optional[str] = None):
|
| 73 |
+
self.config = config
|
| 74 |
+
self.setup_distributed()
|
| 75 |
+
self.setup_device()
|
| 76 |
+
self.setup_model()
|
| 77 |
+
self.setup_optimizer()
|
| 78 |
+
self.global_step = 0
|
| 79 |
+
self.best_val_loss = float("inf")
|
| 80 |
+
|
| 81 |
+
if resume_from:
|
| 82 |
+
self.load_checkpoint(resume_from)
|
| 83 |
+
|
| 84 |
+
def setup_distributed(self):
|
| 85 |
+
"""Setup DDP if running in distributed mode."""
|
| 86 |
+
self.ddp = int(os.environ.get("RANK", -1)) != -1
|
| 87 |
+
if self.ddp:
|
| 88 |
+
dist.init_process_group(backend="nccl")
|
| 89 |
+
self.ddp_rank = int(os.environ["RANK"])
|
| 90 |
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
| 91 |
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
|
| 92 |
+
self.master_process = self.ddp_rank == 0
|
| 93 |
+
else:
|
| 94 |
+
self.ddp_rank = 0
|
| 95 |
+
self.ddp_local_rank = 0
|
| 96 |
+
self.ddp_world_size = 1
|
| 97 |
+
self.master_process = True
|
| 98 |
+
|
| 99 |
+
def setup_device(self):
|
| 100 |
+
"""Configure device and mixed precision."""
|
| 101 |
+
cfg = self.config
|
| 102 |
+
|
| 103 |
+
if cfg.device == "auto":
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
self.device = f"cuda:{self.ddp_local_rank}" if self.ddp else "cuda"
|
| 106 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 107 |
+
self.device = "mps"
|
| 108 |
+
else:
|
| 109 |
+
self.device = "cpu"
|
| 110 |
+
else:
|
| 111 |
+
self.device = cfg.device
|
| 112 |
+
|
| 113 |
+
# Mixed precision context
|
| 114 |
+
if "cuda" in self.device:
|
| 115 |
+
if cfg.dtype == "bfloat16" and torch.cuda.is_bf16_supported():
|
| 116 |
+
self.dtype = torch.bfloat16
|
| 117 |
+
self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 118 |
+
elif cfg.dtype == "float16":
|
| 119 |
+
self.dtype = torch.float16
|
| 120 |
+
self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16)
|
| 121 |
+
else:
|
| 122 |
+
self.dtype = torch.float32
|
| 123 |
+
self.amp_ctx = nullcontext()
|
| 124 |
+
self.scaler = torch.amp.GradScaler("cuda", enabled=(cfg.dtype == "float16"))
|
| 125 |
+
else:
|
| 126 |
+
self.dtype = torch.float32
|
| 127 |
+
self.amp_ctx = nullcontext()
|
| 128 |
+
self.scaler = torch.amp.GradScaler(enabled=False)
|
| 129 |
+
|
| 130 |
+
if self.master_process:
|
| 131 |
+
print(f"Device: {self.device}, dtype: {cfg.dtype}")
|
| 132 |
+
|
| 133 |
+
def setup_model(self):
|
| 134 |
+
"""Initialize or load model."""
|
| 135 |
+
self.model = GPT300M(self.config).to(self.device)
|
| 136 |
+
|
| 137 |
+
if self.master_process:
|
| 138 |
+
print(self.model.model_summary())
|
| 139 |
+
|
| 140 |
+
# Compile model (PyTorch 2.0+)
|
| 141 |
+
if self.config.compile_model and hasattr(torch, "compile"):
|
| 142 |
+
if self.master_process:
|
| 143 |
+
print("Compiling model with torch.compile...")
|
| 144 |
+
self.model = torch.compile(self.model)
|
| 145 |
+
|
| 146 |
+
# Wrap in DDP
|
| 147 |
+
if self.ddp:
|
| 148 |
+
self.model = DDP(self.model, device_ids=[self.ddp_local_rank])
|
| 149 |
+
|
| 150 |
+
self.raw_model = self.model.module if self.ddp else self.model
|
| 151 |
+
|
| 152 |
+
def setup_optimizer(self):
|
| 153 |
+
"""Configure AdamW optimizer with weight decay."""
|
| 154 |
+
cfg = self.config
|
| 155 |
+
|
| 156 |
+
# Separate parameters: decay vs no-decay
|
| 157 |
+
decay_params = []
|
| 158 |
+
nodecay_params = []
|
| 159 |
+
for name, param in self.raw_model.named_parameters():
|
| 160 |
+
if not param.requires_grad:
|
| 161 |
+
continue
|
| 162 |
+
if param.dim() >= 2:
|
| 163 |
+
decay_params.append(param)
|
| 164 |
+
else:
|
| 165 |
+
nodecay_params.append(param)
|
| 166 |
+
|
| 167 |
+
optim_groups = [
|
| 168 |
+
{"params": decay_params, "weight_decay": cfg.weight_decay},
|
| 169 |
+
{"params": nodecay_params, "weight_decay": 0.0},
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
# Use fused AdamW if available (faster on CUDA)
|
| 173 |
+
use_fused = "cuda" in self.device and hasattr(torch.optim, "_multi_tensor")
|
| 174 |
+
self.optimizer = torch.optim.AdamW(
|
| 175 |
+
optim_groups,
|
| 176 |
+
lr=cfg.learning_rate,
|
| 177 |
+
betas=(cfg.beta1, cfg.beta2),
|
| 178 |
+
fused="cuda" in self.device,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if self.master_process:
|
| 182 |
+
n_decay = sum(p.numel() for p in decay_params)
|
| 183 |
+
n_nodecay = sum(p.numel() for p in nodecay_params)
|
| 184 |
+
print(f"Optimizer: {n_decay:,} decay params, {n_nodecay:,} no-decay params")
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def evaluate(self, val_loader) -> float:
|
| 188 |
+
"""Run evaluation and return average loss."""
|
| 189 |
+
self.model.eval()
|
| 190 |
+
total_loss = 0.0
|
| 191 |
+
n_batches = 0
|
| 192 |
+
|
| 193 |
+
for x, y in val_loader:
|
| 194 |
+
x, y = x.to(self.device), y.to(self.device)
|
| 195 |
+
with self.amp_ctx:
|
| 196 |
+
_, loss, _ = self.model(x, targets=y)
|
| 197 |
+
total_loss += loss.item()
|
| 198 |
+
n_batches += 1
|
| 199 |
+
|
| 200 |
+
if n_batches >= 50: # Limit eval batches
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
self.model.train()
|
| 204 |
+
return total_loss / max(n_batches, 1)
|
| 205 |
+
|
| 206 |
+
def save_checkpoint(self, path: Optional[str] = None):
|
| 207 |
+
"""Save model checkpoint."""
|
| 208 |
+
if not self.master_process:
|
| 209 |
+
return
|
| 210 |
+
|
| 211 |
+
if path is None:
|
| 212 |
+
path = os.path.join(
|
| 213 |
+
self.config.output_dir,
|
| 214 |
+
f"checkpoint_step_{self.global_step}.pt",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 218 |
+
checkpoint = {
|
| 219 |
+
"model_state_dict": self.raw_model.state_dict(),
|
| 220 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 221 |
+
"config": self.config.__dict__,
|
| 222 |
+
"global_step": self.global_step,
|
| 223 |
+
"best_val_loss": self.best_val_loss,
|
| 224 |
+
}
|
| 225 |
+
torch.save(checkpoint, path)
|
| 226 |
+
print(f" Saved checkpoint: {path}")
|
| 227 |
+
|
| 228 |
+
def load_checkpoint(self, path: str):
|
| 229 |
+
"""Load model checkpoint."""
|
| 230 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 231 |
+
self.raw_model.load_state_dict(checkpoint["model_state_dict"])
|
| 232 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 233 |
+
self.global_step = checkpoint.get("global_step", 0)
|
| 234 |
+
self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
|
| 235 |
+
if self.master_process:
|
| 236 |
+
print(f"Resumed from step {self.global_step}")
|
| 237 |
+
|
| 238 |
+
def train(self, train_loader, val_loader):
|
| 239 |
+
"""
|
| 240 |
+
Main training loop.
|
| 241 |
+
"""
|
| 242 |
+
cfg = self.config
|
| 243 |
+
model = self.model
|
| 244 |
+
optimizer = self.optimizer
|
| 245 |
+
|
| 246 |
+
model.train()
|
| 247 |
+
train_iter = iter(train_loader)
|
| 248 |
+
|
| 249 |
+
if self.master_process:
|
| 250 |
+
print(f"\n{'='*60}")
|
| 251 |
+
print(f" Starting training")
|
| 252 |
+
print(f" Effective batch size: {cfg.batch_size * cfg.gradient_accumulation_steps * self.ddp_world_size}")
|
| 253 |
+
print(f" Max steps: {cfg.max_steps:,}")
|
| 254 |
+
print(f"{'='*60}\n")
|
| 255 |
+
|
| 256 |
+
t0 = time.time()
|
| 257 |
+
|
| 258 |
+
for step in range(self.global_step, cfg.max_steps):
|
| 259 |
+
self.global_step = step
|
| 260 |
+
|
| 261 |
+
# Update learning rate
|
| 262 |
+
lr = get_lr(step, cfg)
|
| 263 |
+
for param_group in optimizer.param_groups:
|
| 264 |
+
param_group["lr"] = lr
|
| 265 |
+
|
| 266 |
+
# ββ Gradient Accumulation Loop ββββββββββββββββββββββββββ
|
| 267 |
+
optimizer.zero_grad(set_to_none=True)
|
| 268 |
+
accumulated_loss = 0.0
|
| 269 |
+
|
| 270 |
+
for micro_step in range(cfg.gradient_accumulation_steps):
|
| 271 |
+
# Get next batch (cycle through data)
|
| 272 |
+
try:
|
| 273 |
+
x, y = next(train_iter)
|
| 274 |
+
except StopIteration:
|
| 275 |
+
train_iter = iter(train_loader)
|
| 276 |
+
x, y = next(train_iter)
|
| 277 |
+
|
| 278 |
+
x, y = x.to(self.device), y.to(self.device)
|
| 279 |
+
|
| 280 |
+
# DDP sync only on last micro-step
|
| 281 |
+
if self.ddp:
|
| 282 |
+
model.require_backward_grad_sync = (
|
| 283 |
+
micro_step == cfg.gradient_accumulation_steps - 1
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Forward pass with mixed precision
|
| 287 |
+
with self.amp_ctx:
|
| 288 |
+
_, loss, _ = model(x, targets=y)
|
| 289 |
+
loss = loss / cfg.gradient_accumulation_steps
|
| 290 |
+
|
| 291 |
+
accumulated_loss += loss.item()
|
| 292 |
+
|
| 293 |
+
# Backward pass
|
| 294 |
+
self.scaler.scale(loss).backward()
|
| 295 |
+
|
| 296 |
+
# Gradient clipping
|
| 297 |
+
if cfg.max_grad_norm > 0:
|
| 298 |
+
self.scaler.unscale_(optimizer)
|
| 299 |
+
grad_norm = nn.utils.clip_grad_norm_(
|
| 300 |
+
model.parameters(), cfg.max_grad_norm
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
grad_norm = 0.0
|
| 304 |
+
|
| 305 |
+
# Optimizer step
|
| 306 |
+
self.scaler.step(optimizer)
|
| 307 |
+
self.scaler.update()
|
| 308 |
+
|
| 309 |
+
# ββ Logging βββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
+
if step % cfg.log_interval == 0 and self.master_process:
|
| 311 |
+
dt = time.time() - t0
|
| 312 |
+
tokens_per_sec = (
|
| 313 |
+
cfg.batch_size * cfg.max_seq_len
|
| 314 |
+
* cfg.gradient_accumulation_steps
|
| 315 |
+
* self.ddp_world_size
|
| 316 |
+
/ dt
|
| 317 |
+
)
|
| 318 |
+
print(
|
| 319 |
+
f"step {step:>6d} | "
|
| 320 |
+
f"loss {accumulated_loss:.4f} | "
|
| 321 |
+
f"lr {lr:.2e} | "
|
| 322 |
+
f"grad_norm {grad_norm:.2f} | "
|
| 323 |
+
f"tok/s {tokens_per_sec:.0f} | "
|
| 324 |
+
f"dt {dt:.2f}s"
|
| 325 |
+
)
|
| 326 |
+
t0 = time.time()
|
| 327 |
+
|
| 328 |
+
# ββ Evaluation ββββββββββββββββββββββββββββββββββββββββββ
|
| 329 |
+
if step > 0 and step % cfg.eval_interval == 0 and self.master_process:
|
| 330 |
+
val_loss = self.evaluate(val_loader)
|
| 331 |
+
print(f" β¦ Validation loss: {val_loss:.4f}")
|
| 332 |
+
|
| 333 |
+
if val_loss < self.best_val_loss:
|
| 334 |
+
self.best_val_loss = val_loss
|
| 335 |
+
self.save_checkpoint(
|
| 336 |
+
os.path.join(cfg.output_dir, "best_model.pt")
|
| 337 |
+
)
|
| 338 |
+
print(f" β¦ New best! Saved best_model.pt")
|
| 339 |
+
|
| 340 |
+
# ββ Checkpointing βββββββββββββββββββββββββββββββββββββββ
|
| 341 |
+
if step > 0 and step % cfg.save_interval == 0 and self.master_process:
|
| 342 |
+
self.save_checkpoint()
|
| 343 |
+
|
| 344 |
+
# Final save
|
| 345 |
+
if self.master_process:
|
| 346 |
+
self.save_checkpoint(
|
| 347 |
+
os.path.join(cfg.output_dir, "final_model.pt")
|
| 348 |
+
)
|
| 349 |
+
print("\n⦠Training complete!")
|
| 350 |
+
|
| 351 |
+
# Cleanup DDP
|
| 352 |
+
if self.ddp:
|
| 353 |
+
dist.destroy_process_group()
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 357 |
+
# MAIN
|
| 358 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 359 |
+
|
| 360 |
+
def main():
|
| 361 |
+
parser = argparse.ArgumentParser(description="Train GPT-300M")
|
| 362 |
+
parser.add_argument("--tiny", action="store_true", help="Use tiny config for debugging")
|
| 363 |
+
parser.add_argument("--data", type=str, default=None, help="Path to training text file")
|
| 364 |
+
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
|
| 365 |
+
parser.add_argument("--d_model", type=int, default=None)
|
| 366 |
+
parser.add_argument("--n_layers", type=int, default=None)
|
| 367 |
+
parser.add_argument("--n_heads", type=int, default=None)
|
| 368 |
+
parser.add_argument("--batch_size", type=int, default=None)
|
| 369 |
+
parser.add_argument("--learning_rate", type=float, default=None)
|
| 370 |
+
parser.add_argument("--max_steps", type=int, default=None)
|
| 371 |
+
args = parser.parse_args()
|
| 372 |
+
|
| 373 |
+
# Config
|
| 374 |
+
config = gpt_tiny() if args.tiny else gpt_300m()
|
| 375 |
+
|
| 376 |
+
# Override config from CLI
|
| 377 |
+
for key in ["d_model", "n_layers", "n_heads", "batch_size", "learning_rate", "max_steps"]:
|
| 378 |
+
val = getattr(args, key, None)
|
| 379 |
+
if val is not None:
|
| 380 |
+
setattr(config, key, val)
|
| 381 |
+
|
| 382 |
+
# Seed
|
| 383 |
+
torch.manual_seed(config.seed)
|
| 384 |
+
if torch.cuda.is_available():
|
| 385 |
+
torch.cuda.manual_seed_all(config.seed)
|
| 386 |
+
|
| 387 |
+
# Tokenizer
|
| 388 |
+
tokenizer = BPETokenizer(vocab_size=config.vocab_size)
|
| 389 |
+
|
| 390 |
+
# Load data
|
| 391 |
+
if args.data and os.path.exists(args.data):
|
| 392 |
+
print(f"Loading data from {args.data}...")
|
| 393 |
+
with open(args.data, "r") as f:
|
| 394 |
+
text = f.read()
|
| 395 |
+
else:
|
| 396 |
+
# Generate synthetic data for demonstration
|
| 397 |
+
print("No data file provided. Generating synthetic training data...")
|
| 398 |
+
text = generate_synthetic_data()
|
| 399 |
+
|
| 400 |
+
# Train tokenizer on data
|
| 401 |
+
print("Training tokenizer...")
|
| 402 |
+
tokenizer.train(text, verbose=True)
|
| 403 |
+
tokenizer.save(os.path.join(config.output_dir, "tokenizer.json"))
|
| 404 |
+
|
| 405 |
+
# Create dataloaders
|
| 406 |
+
train_loader, val_loader = create_dataloaders(config, tokenizer, text=text)
|
| 407 |
+
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
|
| 408 |
+
|
| 409 |
+
# Train!
|
| 410 |
+
trainer = Trainer(config, resume_from=args.resume)
|
| 411 |
+
trainer.train(train_loader, val_loader)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def generate_synthetic_data(n_samples: int = 10_000) -> str:
|
| 415 |
+
"""Generate synthetic conversational data for demonstration."""
|
| 416 |
+
import random
|
| 417 |
+
random.seed(42)
|
| 418 |
+
|
| 419 |
+
greetings = ["Hello!", "Hi there!", "Hey!", "Good morning!", "Greetings!"]
|
| 420 |
+
questions = [
|
| 421 |
+
"What is machine learning?",
|
| 422 |
+
"How does gravity work?",
|
| 423 |
+
"What is the meaning of life?",
|
| 424 |
+
"Can you explain photosynthesis?",
|
| 425 |
+
"What are neural networks?",
|
| 426 |
+
"How do computers work?",
|
| 427 |
+
"What is quantum physics?",
|
| 428 |
+
"Tell me about the solar system.",
|
| 429 |
+
"How does the internet work?",
|
| 430 |
+
"What is artificial intelligence?",
|
| 431 |
+
]
|
| 432 |
+
answers = [
|
| 433 |
+
"That's a great question! Machine learning is a subset of AI that enables systems to learn from data.",
|
| 434 |
+
"Gravity is a fundamental force that attracts objects with mass toward each other.",
|
| 435 |
+
"The meaning of life is a deeply philosophical question that has been debated for centuries.",
|
| 436 |
+
"Photosynthesis is the process by which plants convert sunlight into chemical energy.",
|
| 437 |
+
"Neural networks are computing systems inspired by biological neural networks in the brain.",
|
| 438 |
+
"Computers work by processing binary data through electronic circuits called transistors.",
|
| 439 |
+
"Quantum physics describes the behavior of matter and energy at the atomic scale.",
|
| 440 |
+
"The solar system consists of the Sun and everything that orbits around it.",
|
| 441 |
+
"The internet is a global network of interconnected computers that communicate using protocols.",
|
| 442 |
+
"Artificial intelligence is the simulation of human intelligence by computer systems.",
|
| 443 |
+
]
|
| 444 |
+
|
| 445 |
+
lines = []
|
| 446 |
+
for _ in range(n_samples):
|
| 447 |
+
g = random.choice(greetings)
|
| 448 |
+
q = random.choice(questions)
|
| 449 |
+
a = random.choice(answers)
|
| 450 |
+
lines.append(f"User: {g} {q}\nAssistant: {a}\n")
|
| 451 |
+
|
| 452 |
+
return "\n".join(lines)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
if __name__ == "__main__":
|
| 456 |
+
main()
|
visual_nn_3d.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M 3D Neural Network Visualization
|
| 3 |
+
==========================================
|
| 4 |
+
A 3D node-and-connection neural network diagram with depth,
|
| 5 |
+
perspective, and accurate parameter counts.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 13 |
+
from mpl_toolkits.mplot3d.art3d import Line3DCollection
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
# ACCURATE GPT-300M PARAMETERS
|
| 18 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
|
| 20 |
+
VOCAB = 32_000
|
| 21 |
+
D = 1_024
|
| 22 |
+
HEADS = 16
|
| 23 |
+
HEAD_D = 64
|
| 24 |
+
D_FF = 4_096
|
| 25 |
+
N_LAYERS = 24
|
| 26 |
+
|
| 27 |
+
embed_p = VOCAB * D # 32,768,000
|
| 28 |
+
attn_p = 4 * D * D # 4,194,304 per layer
|
| 29 |
+
ffn_p = 2 * D * D_FF # 8,388,608 per layer
|
| 30 |
+
norm_p = 2 * D # 2,048 per layer
|
| 31 |
+
layer_p = attn_p + ffn_p + norm_p # 12,584,960 per layer
|
| 32 |
+
all_layers_p = layer_p * N_LAYERS # 302,039,040
|
| 33 |
+
final_norm_p = D # 1,024
|
| 34 |
+
TOTAL = embed_p + all_layers_p + final_norm_p # 334,808,064
|
| 35 |
+
|
| 36 |
+
# Layer definitions: (name, num_display_nodes, actual_neurons, params, color_hex)
|
| 37 |
+
LAYERS = [
|
| 38 |
+
("Input Tokens", 10, VOCAB, 0, "#4CAF50"),
|
| 39 |
+
("Token Embedding", 12, D, embed_p, "#2196F3"),
|
| 40 |
+
("RoPE Positions", 12, D, 0, "#00BCD4"),
|
| 41 |
+
("Layer 1: Attention QKV", 14, D, attn_p * 3 // 4, "#FF9800"),
|
| 42 |
+
("Layer 1: Attention Out", 12, D, attn_p * 1 // 4, "#FF9800"),
|
| 43 |
+
("Layer 1: FFN Up (GELU)", 16, D_FF, ffn_p // 2, "#8BC34A"),
|
| 44 |
+
("Layer 1: FFN Down", 12, D, ffn_p // 2, "#8BC34A"),
|
| 45 |
+
("Layers 2β23 (Γ22)", 14, D, layer_p * 22, "#9C27B0"),
|
| 46 |
+
("Layer 24: Attention", 14, D, attn_p, "#FF5722"),
|
| 47 |
+
("Layer 24: FFN", 16, D_FF, ffn_p, "#009688"),
|
| 48 |
+
("Layer 24: Norm + Out", 12, D, norm_p + final_norm_p, "#E91E63"),
|
| 49 |
+
("LM Head (weight-tied)", 12, VOCAB, 0, "#F44336"),
|
| 50 |
+
("Output Probabilities", 1, VOCAB, 0, "#FF1744"),
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def hex_to_rgb(h):
|
| 55 |
+
h = h.lstrip("#")
|
| 56 |
+
return tuple(int(h[i:i+2], 16) / 255.0 for i in (0, 2, 4))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def generate_3d_network(save_path="neural_network_3d.png", elev=22, azim=-65):
|
| 60 |
+
"""Generate a 3D neural network with nodes, connections, and parameter labels."""
|
| 61 |
+
|
| 62 |
+
fig = plt.figure(figsize=(28, 28), facecolor="#0a0e17")
|
| 63 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 64 |
+
|
| 65 |
+
# Dark theme for 3D axes
|
| 66 |
+
ax.set_facecolor("#0a0e17")
|
| 67 |
+
ax.xaxis.pane.fill = False
|
| 68 |
+
ax.yaxis.pane.fill = False
|
| 69 |
+
ax.zaxis.pane.fill = False
|
| 70 |
+
ax.xaxis.pane.set_edgecolor("#0a0e17")
|
| 71 |
+
ax.yaxis.pane.set_edgecolor("#0a0e17")
|
| 72 |
+
ax.zaxis.pane.set_edgecolor("#0a0e17")
|
| 73 |
+
ax.grid(False)
|
| 74 |
+
ax.set_axis_off()
|
| 75 |
+
|
| 76 |
+
ax.view_init(elev=elev, azim=azim)
|
| 77 |
+
|
| 78 |
+
n_layers = len(LAYERS)
|
| 79 |
+
y_positions = np.linspace(0, n_layers * 4.0, n_layers) # depth (layer position)
|
| 80 |
+
|
| 81 |
+
all_positions = [] # list of (xs, ys_unused, zs, y_layer)
|
| 82 |
+
running_params = 0
|
| 83 |
+
|
| 84 |
+
for i, (name, n_nodes, actual, params, color_hex) in enumerate(LAYERS):
|
| 85 |
+
y = y_positions[i]
|
| 86 |
+
running_params += params
|
| 87 |
+
|
| 88 |
+
rgb = hex_to_rgb(color_hex)
|
| 89 |
+
|
| 90 |
+
# Arrange nodes in a circle/arc for 3D effect
|
| 91 |
+
if n_nodes == 1:
|
| 92 |
+
xs = np.array([0.0])
|
| 93 |
+
zs = np.array([0.0])
|
| 94 |
+
else:
|
| 95 |
+
# Spread nodes along x
|
| 96 |
+
spread = min(n_nodes * 0.5, 7.0)
|
| 97 |
+
xs = np.linspace(-spread, spread, n_nodes)
|
| 98 |
+
# Slight arc for 3D depth perception
|
| 99 |
+
zs = -0.1 * (xs ** 2)
|
| 100 |
+
|
| 101 |
+
ys = np.full_like(xs, y)
|
| 102 |
+
all_positions.append((xs, ys, zs))
|
| 103 |
+
|
| 104 |
+
# ββ Draw connections to previous layer ββββββββββββββββββ
|
| 105 |
+
if i > 0:
|
| 106 |
+
prev_xs, prev_ys, prev_zs = all_positions[i - 1]
|
| 107 |
+
|
| 108 |
+
# Sample connections to avoid clutter
|
| 109 |
+
n_prev = len(prev_xs)
|
| 110 |
+
n_curr = len(xs)
|
| 111 |
+
step_p = max(1, n_prev // 8)
|
| 112 |
+
step_c = max(1, n_curr // 8)
|
| 113 |
+
|
| 114 |
+
lines = []
|
| 115 |
+
colors_lines = []
|
| 116 |
+
for pi in range(0, n_prev, step_p):
|
| 117 |
+
for ci in range(0, n_curr, step_c):
|
| 118 |
+
lines.append([
|
| 119 |
+
(prev_xs[pi], prev_ys[pi], prev_zs[pi]),
|
| 120 |
+
(xs[ci], ys[ci], zs[ci]),
|
| 121 |
+
])
|
| 122 |
+
colors_lines.append((*rgb, 0.18))
|
| 123 |
+
|
| 124 |
+
if lines:
|
| 125 |
+
lc = Line3DCollection(lines, colors=colors_lines, linewidths=0.7)
|
| 126 |
+
ax.add_collection3d(lc)
|
| 127 |
+
|
| 128 |
+
# ββ Draw nodes (spheres) ββββββββββββββββββββββββββββββββ
|
| 129 |
+
node_size = 200 if n_nodes > 12 else 280
|
| 130 |
+
if n_nodes == 1:
|
| 131 |
+
node_size = 600
|
| 132 |
+
|
| 133 |
+
ax.scatter(
|
| 134 |
+
xs, ys, zs,
|
| 135 |
+
c=[color_hex], s=node_size,
|
| 136 |
+
alpha=0.95, edgecolors="white", linewidths=0.5,
|
| 137 |
+
depthshade=True, zorder=5,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# ββ Glow effect (larger transparent scatter behind) βββββ
|
| 141 |
+
ax.scatter(
|
| 142 |
+
xs, ys, zs,
|
| 143 |
+
c=[color_hex], s=node_size * 3,
|
| 144 |
+
alpha=0.08, edgecolors="none",
|
| 145 |
+
depthshade=True, zorder=4,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# ββ Labels ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
+
label_x = xs[-1] + 1.8 if n_nodes > 1 else 2.0
|
| 150 |
+
ax.text(
|
| 151 |
+
label_x, y, 0,
|
| 152 |
+
name,
|
| 153 |
+
fontsize=9.5, fontweight="bold",
|
| 154 |
+
color="#E6EDF3", fontfamily="monospace",
|
| 155 |
+
zorder=10,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Param count
|
| 159 |
+
if params > 0:
|
| 160 |
+
if params >= 1_000_000:
|
| 161 |
+
ptxt = f"{params/1e6:.1f}M params"
|
| 162 |
+
else:
|
| 163 |
+
ptxt = f"{params:,} params"
|
| 164 |
+
ax.text(
|
| 165 |
+
label_x, y, -1.0,
|
| 166 |
+
ptxt,
|
| 167 |
+
fontsize=8, color=color_hex,
|
| 168 |
+
fontfamily="monospace", fontweight="bold",
|
| 169 |
+
zorder=10,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Running total
|
| 173 |
+
if running_params > 0:
|
| 174 |
+
ax.text(
|
| 175 |
+
label_x, y, -1.8,
|
| 176 |
+
f"Ξ£ {running_params/1e6:.1f}M",
|
| 177 |
+
fontsize=6, color="#8B949E",
|
| 178 |
+
fontfamily="monospace",
|
| 179 |
+
zorder=10,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Overflow indicator
|
| 183 |
+
if actual > n_nodes and n_nodes > 1:
|
| 184 |
+
ax.text(
|
| 185 |
+
xs[-1] + 0.5, y, zs[-1],
|
| 186 |
+
f"(+{actual - n_nodes:,})",
|
| 187 |
+
fontsize=6, color="#8B949E",
|
| 188 |
+
fontfamily="monospace",
|
| 189 |
+
zorder=10,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# ββ Title ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 193 |
+
ax.text2D(
|
| 194 |
+
0.5, 0.96,
|
| 195 |
+
"GPT-300M β’ 3D Neural Network Architecture",
|
| 196 |
+
transform=fig.transFigure,
|
| 197 |
+
fontsize=22, fontweight="bold", color="#E6EDF3",
|
| 198 |
+
ha="center", fontfamily="monospace",
|
| 199 |
+
)
|
| 200 |
+
ax.text2D(
|
| 201 |
+
0.5, 0.94,
|
| 202 |
+
f"{TOTAL:,} parameters | {N_LAYERS} layers | {HEADS} heads | d_model={D} | d_ff={D_FF}",
|
| 203 |
+
transform=fig.transFigure,
|
| 204 |
+
fontsize=10, color="#8B949E",
|
| 205 |
+
ha="center", fontfamily="monospace",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# ββ Parameter summary ββββββββββββββββββββββββββββββββββββββββββ
|
| 209 |
+
summary = (
|
| 210 |
+
f"Parameter Breakdown:\n"
|
| 211 |
+
f" Embedding: {embed_p/1e6:>7.1f}M ({embed_p/TOTAL*100:.1f}%)\n"
|
| 212 |
+
f" Attention Γ24: {attn_p*N_LAYERS/1e6:>7.1f}M ({attn_p*N_LAYERS/TOTAL*100:.1f}%)\n"
|
| 213 |
+
f" FFN Γ24: {ffn_p*N_LAYERS/1e6:>7.1f}M ({ffn_p*N_LAYERS/TOTAL*100:.1f}%)\n"
|
| 214 |
+
f" Norms: {(norm_p*N_LAYERS+final_norm_p)/1e6:>7.3f}M ({(norm_p*N_LAYERS+final_norm_p)/TOTAL*100:.1f}%)\n"
|
| 215 |
+
f" LM Head: tied (0 extra)\n"
|
| 216 |
+
f" βββββββββββββββββββββββ\n"
|
| 217 |
+
f" TOTAL: {TOTAL/1e6:>7.1f}M"
|
| 218 |
+
)
|
| 219 |
+
ax.text2D(
|
| 220 |
+
0.02, 0.06, summary,
|
| 221 |
+
transform=fig.transFigure,
|
| 222 |
+
fontsize=8, color="#58A6FF",
|
| 223 |
+
fontfamily="monospace", verticalalignment="bottom",
|
| 224 |
+
bbox=dict(boxstyle="round,pad=0.6", facecolor="#161B22",
|
| 225 |
+
edgecolor="#30363D", linewidth=1),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# ββ Legend ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 229 |
+
legend_items = [
|
| 230 |
+
("#4CAF50", "Input"), ("#2196F3", "Embeddings"), ("#FF9800", "Attention"),
|
| 231 |
+
("#8BC34A", "FFN"), ("#9C27B0", "Γ22 Layers"), ("#E91E63", "Norm"),
|
| 232 |
+
("#F44336", "Output"),
|
| 233 |
+
]
|
| 234 |
+
for j, (c, l) in enumerate(legend_items):
|
| 235 |
+
ax.text2D(
|
| 236 |
+
0.92, 0.30 - j * 0.025, f"β {l}",
|
| 237 |
+
transform=fig.transFigure,
|
| 238 |
+
fontsize=8, color=c, fontfamily="monospace",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Set axis limits
|
| 242 |
+
all_x = np.concatenate([p[0] for p in all_positions])
|
| 243 |
+
all_y = np.concatenate([p[1] for p in all_positions])
|
| 244 |
+
all_z = np.concatenate([p[2] for p in all_positions])
|
| 245 |
+
margin = 4
|
| 246 |
+
ax.set_xlim(all_x.min() - margin, all_x.max() + margin + 8)
|
| 247 |
+
ax.set_ylim(all_y.min() - margin, all_y.max() + margin)
|
| 248 |
+
ax.set_zlim(all_z.min() - margin, all_z.max() + margin)
|
| 249 |
+
|
| 250 |
+
plt.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 251 |
+
facecolor="#0a0e17", edgecolor="none")
|
| 252 |
+
print(f"Saved: {save_path}")
|
| 253 |
+
plt.close()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def generate_3d_single_layer(save_path="layer_3d.png", elev=18, azim=-55):
|
| 257 |
+
"""3D view of a single transformer layer internals."""
|
| 258 |
+
|
| 259 |
+
fig = plt.figure(figsize=(22, 18), facecolor="#0a0e17")
|
| 260 |
+
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
|
| 261 |
+
|
| 262 |
+
ax.set_facecolor("#0a0e17")
|
| 263 |
+
ax.xaxis.pane.fill = False
|
| 264 |
+
ax.yaxis.pane.fill = False
|
| 265 |
+
ax.zaxis.pane.fill = False
|
| 266 |
+
ax.xaxis.pane.set_edgecolor("#0a0e17")
|
| 267 |
+
ax.yaxis.pane.set_edgecolor("#0a0e17")
|
| 268 |
+
ax.zaxis.pane.set_edgecolor("#0a0e17")
|
| 269 |
+
ax.grid(False)
|
| 270 |
+
ax.set_axis_off()
|
| 271 |
+
ax.view_init(elev=elev, azim=azim)
|
| 272 |
+
|
| 273 |
+
sub_layers = [
|
| 274 |
+
("Input (d=1024)", 10, D, 0, "#2196F3"),
|
| 275 |
+
("Query (d=1024)", 10, D, D*D, "#FF6B6B"),
|
| 276 |
+
("Key (d=1024)", 10, D, D*D, "#4ECDC4"),
|
| 277 |
+
("Value (d=1024)", 10, D, D*D, "#45B7D1"),
|
| 278 |
+
("16 Attention Heads", 16, D, 0, "#FF9800"),
|
| 279 |
+
("Attn Output (d=1024)", 10, D, D*D, "#FFA726"),
|
| 280 |
+
("β Residual + RMSNorm", 10, D, D, "#E91E63"),
|
| 281 |
+
("FFN Up β GELU (d=4096)", 16, D_FF, D*D_FF, "#8BC34A"),
|
| 282 |
+
("FFN Down (d=1024)", 10, D, D_FF*D, "#7CB342"),
|
| 283 |
+
("β Residual + RMSNorm", 10, D, D, "#E91E63"),
|
| 284 |
+
("Layer Output (d=1024)", 10, D, 0, "#2196F3"),
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
n = len(sub_layers)
|
| 288 |
+
y_positions = np.linspace(0, n * 3, n)
|
| 289 |
+
all_pos = []
|
| 290 |
+
|
| 291 |
+
for i, (name, n_nodes, actual, params, chex) in enumerate(sub_layers):
|
| 292 |
+
y = y_positions[i]
|
| 293 |
+
rgb = hex_to_rgb(chex)
|
| 294 |
+
|
| 295 |
+
spread = min(n_nodes * 0.45, 5.5)
|
| 296 |
+
xs = np.linspace(-spread, spread, n_nodes)
|
| 297 |
+
zs = -0.12 * (xs ** 2)
|
| 298 |
+
ys = np.full_like(xs, y)
|
| 299 |
+
all_pos.append((xs, ys, zs))
|
| 300 |
+
|
| 301 |
+
# Connections
|
| 302 |
+
if i > 0:
|
| 303 |
+
pxs, pys, pzs = all_pos[i - 1]
|
| 304 |
+
sp = max(1, len(pxs) // 8)
|
| 305 |
+
sc = max(1, len(xs) // 8)
|
| 306 |
+
lines = []
|
| 307 |
+
cols = []
|
| 308 |
+
for pi in range(0, len(pxs), sp):
|
| 309 |
+
for ci in range(0, len(xs), sc):
|
| 310 |
+
lines.append([(pxs[pi], pys[pi], pzs[pi]), (xs[ci], ys[ci], zs[ci])])
|
| 311 |
+
cols.append((*rgb, 0.15))
|
| 312 |
+
if lines:
|
| 313 |
+
ax.add_collection3d(Line3DCollection(lines, colors=cols, linewidths=0.6))
|
| 314 |
+
|
| 315 |
+
# Nodes
|
| 316 |
+
sz = 130 if n_nodes > 12 else 180
|
| 317 |
+
ax.scatter(xs, ys, zs, c=[chex], s=sz, alpha=0.95,
|
| 318 |
+
edgecolors="white", linewidths=0.5, depthshade=True, zorder=5)
|
| 319 |
+
ax.scatter(xs, ys, zs, c=[chex], s=sz * 3, alpha=0.07,
|
| 320 |
+
edgecolors="none", depthshade=True, zorder=4)
|
| 321 |
+
|
| 322 |
+
# Labels
|
| 323 |
+
lx = xs[-1] + 1.0
|
| 324 |
+
ax.text(lx, y, 0, name, fontsize=9, fontweight="bold",
|
| 325 |
+
color="#E6EDF3", fontfamily="monospace", zorder=10)
|
| 326 |
+
if params > 0:
|
| 327 |
+
ax.text(lx, y, -0.8, f"{params:,} params",
|
| 328 |
+
fontsize=7, color=chex, fontfamily="monospace",
|
| 329 |
+
fontweight="bold", zorder=10)
|
| 330 |
+
|
| 331 |
+
if actual > n_nodes:
|
| 332 |
+
ax.text(xs[-1] + 0.4, y, zs[-1], f"(+{actual-n_nodes:,})",
|
| 333 |
+
fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10)
|
| 334 |
+
|
| 335 |
+
ax.text2D(0.5, 0.96, "Single Transformer Layer β 3D View",
|
| 336 |
+
transform=fig.transFigure, fontsize=20, fontweight="bold",
|
| 337 |
+
color="#E6EDF3", ha="center", fontfamily="monospace")
|
| 338 |
+
ax.text2D(0.5, 0.935,
|
| 339 |
+
f"12,584,960 params/layer Γ 24 layers = 302,039,040 total",
|
| 340 |
+
transform=fig.transFigure, fontsize=10, color="#8B949E",
|
| 341 |
+
ha="center", fontfamily="monospace")
|
| 342 |
+
|
| 343 |
+
all_x = np.concatenate([p[0] for p in all_pos])
|
| 344 |
+
all_y = np.concatenate([p[1] for p in all_pos])
|
| 345 |
+
all_z = np.concatenate([p[2] for p in all_pos])
|
| 346 |
+
ax.set_xlim(all_x.min() - 2, all_x.max() + 8)
|
| 347 |
+
ax.set_ylim(all_y.min() - 2, all_y.max() + 2)
|
| 348 |
+
ax.set_zlim(all_z.min() - 2, all_z.max() + 2)
|
| 349 |
+
|
| 350 |
+
plt.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 351 |
+
facecolor="#0a0e17", edgecolor="none")
|
| 352 |
+
print(f"Saved: {save_path}")
|
| 353 |
+
plt.close()
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def generate_3d_rotating_views(base_path="viz"):
|
| 357 |
+
"""Generate multiple angle views."""
|
| 358 |
+
import os
|
| 359 |
+
os.makedirs(base_path, exist_ok=True)
|
| 360 |
+
|
| 361 |
+
# Main dramatic angle β more front-facing
|
| 362 |
+
generate_3d_network(f"{base_path}/nn_3d_main.png", elev=12, azim=-15)
|
| 363 |
+
|
| 364 |
+
# Angled view
|
| 365 |
+
generate_3d_network(f"{base_path}/nn_3d_top.png", elev=35, azim=-25)
|
| 366 |
+
|
| 367 |
+
# Side angle
|
| 368 |
+
generate_3d_network(f"{base_path}/nn_3d_side.png", elev=8, azim=-45)
|
| 369 |
+
|
| 370 |
+
# Single layer detail
|
| 371 |
+
generate_3d_single_layer(f"{base_path}/nn_3d_layer.png", elev=18, azim=-55)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
import os
|
| 376 |
+
os.makedirs("viz", exist_ok=True)
|
| 377 |
+
|
| 378 |
+
print("=" * 55)
|
| 379 |
+
print(" GPT-300M β’ 3D Visualization Generator")
|
| 380 |
+
print("=" * 55)
|
| 381 |
+
print(f" Total parameters: {TOTAL:,}")
|
| 382 |
+
print(f" Per layer: {layer_p:,}")
|
| 383 |
+
print(f" Layers: {N_LAYERS}")
|
| 384 |
+
print("=" * 55)
|
| 385 |
+
|
| 386 |
+
generate_3d_rotating_views("viz")
|
| 387 |
+
print("\nAll 3D views generated!")
|
visual_nn_nodes.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M Visual Neural Network β Node & Connection Style
|
| 3 |
+
==========================================================
|
| 4 |
+
Generates a classic neural network diagram (like the user's reference)
|
| 5 |
+
with nodes and connection lines, accurately showing the GPT-300M architecture
|
| 6 |
+
with correct parameter calculations at each layer.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use("Agg")
|
| 11 |
+
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib.patches as mpatches
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
# GPT-300M ARCHITECTURE β ACCURATE PARAMETER COUNTS
|
| 18 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
|
| 20 |
+
# All layer definitions with EXACT parameter counts
|
| 21 |
+
# Format: (layer_name, display_nodes, actual_neurons, params_in_layer, color)
|
| 22 |
+
|
| 23 |
+
VOCAB_SIZE = 32_000
|
| 24 |
+
D_MODEL = 1_024
|
| 25 |
+
N_HEADS = 16
|
| 26 |
+
HEAD_DIM = 64
|
| 27 |
+
D_FF = 4_096
|
| 28 |
+
N_LAYERS = 24
|
| 29 |
+
|
| 30 |
+
# Parameter calculations per component:
|
| 31 |
+
embed_params = VOCAB_SIZE * D_MODEL # 32,768,000
|
| 32 |
+
# RoPE has no learned parameters (precomputed sin/cos)
|
| 33 |
+
rope_params = 0
|
| 34 |
+
|
| 35 |
+
# Per transformer layer:
|
| 36 |
+
qkv_params = 3 * D_MODEL * D_MODEL # 3,145,728 (Q, K, V projections)
|
| 37 |
+
out_proj_params = D_MODEL * D_MODEL # 1,048,576 (output projection)
|
| 38 |
+
attn_total = qkv_params + out_proj_params # 4,194,304
|
| 39 |
+
|
| 40 |
+
ffn_up_params = D_MODEL * D_FF # 4,194,304 (up projection)
|
| 41 |
+
ffn_down_params = D_FF * D_MODEL # 4,194,304 (down projection)
|
| 42 |
+
ffn_total = ffn_up_params + ffn_down_params # 8,388,608
|
| 43 |
+
|
| 44 |
+
rmsnorm_params = D_MODEL * 2 # 2,048 (2 norms per layer)
|
| 45 |
+
layer_total = attn_total + ffn_total + rmsnorm_params # 12,584,960
|
| 46 |
+
|
| 47 |
+
all_layers_total = layer_total * N_LAYERS # 302,039,040
|
| 48 |
+
|
| 49 |
+
final_norm_params = D_MODEL # 1,024
|
| 50 |
+
# LM Head is weight-tied with embedding, so 0 extra params
|
| 51 |
+
lm_head_params = 0 # (tied)
|
| 52 |
+
|
| 53 |
+
TOTAL_PARAMS = embed_params + all_layers_total + final_norm_params + lm_head_params
|
| 54 |
+
# = 32,768,000 + 302,039,040 + 1,024 = 334,808,064
|
| 55 |
+
# With weight tying, unique params β 334,808,064
|
| 56 |
+
|
| 57 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
# LAYER DEFINITIONS FOR VISUALIZATION
|
| 59 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
|
| 61 |
+
# (name, nodes_to_display, actual_size, params_to_this_layer, color)
|
| 62 |
+
LAYERS = [
|
| 63 |
+
("Input Tokens", 10, VOCAB_SIZE, 0, "#4CAF50"), # Green
|
| 64 |
+
("Token Embedding", 10, D_MODEL, embed_params, "#2196F3"), # Blue
|
| 65 |
+
("RoPE Positions", 10, D_MODEL, 0, "#00BCD4"), # Cyan
|
| 66 |
+
|
| 67 |
+
# Show 3 representative transformer layers (of 24)
|
| 68 |
+
("Layer 1: Attention Q,K,V", 12, D_MODEL, qkv_params, "#FF9800"), # Orange
|
| 69 |
+
("Layer 1: Attention Out", 10, D_MODEL, out_proj_params, "#FF9800"),
|
| 70 |
+
("Layer 1: FFN Up", 14, D_FF, ffn_up_params, "#8BC34A"), # Light green
|
| 71 |
+
("Layer 1: FFN Down", 10, D_MODEL, ffn_down_params, "#8BC34A"),
|
| 72 |
+
|
| 73 |
+
("Layer 2β23: Γ22 Blocks", 12, D_MODEL, layer_total * 22, "#9C27B0"), # Purple
|
| 74 |
+
|
| 75 |
+
("Layer 24: Attention", 12, D_MODEL, attn_total, "#FF5722"), # Deep orange
|
| 76 |
+
("Layer 24: FFN", 14, D_FF, ffn_total, "#009688"), # Teal
|
| 77 |
+
("Layer 24: Output", 10, D_MODEL, rmsnorm_params, "#009688"),
|
| 78 |
+
|
| 79 |
+
("Final RMSNorm", 10, D_MODEL, final_norm_params, "#E91E63"), # Pink
|
| 80 |
+
("LM Head (tied)", 10, VOCAB_SIZE, lm_head_params, "#F44336"), # Red
|
| 81 |
+
("Output Probabilities", 1, VOCAB_SIZE, 0, "#F44336"), # Red
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def draw_neural_network(save_path="neural_network.png"):
|
| 86 |
+
fig, ax = plt.subplots(figsize=(22, 30), facecolor="#0D1117")
|
| 87 |
+
ax.set_facecolor("#0D1117")
|
| 88 |
+
|
| 89 |
+
n_layers = len(LAYERS)
|
| 90 |
+
y_positions = np.linspace(0.92, 0.04, n_layers)
|
| 91 |
+
|
| 92 |
+
# Spacing
|
| 93 |
+
x_center = 0.5
|
| 94 |
+
max_spread = 0.38
|
| 95 |
+
|
| 96 |
+
all_node_positions = [] # Store (x_list, y) for connections
|
| 97 |
+
|
| 98 |
+
running_params = 0
|
| 99 |
+
|
| 100 |
+
for i, (name, n_display, actual_size, params, color) in enumerate(LAYERS):
|
| 101 |
+
y = y_positions[i]
|
| 102 |
+
running_params += params
|
| 103 |
+
|
| 104 |
+
# Calculate x positions for nodes
|
| 105 |
+
if n_display == 1:
|
| 106 |
+
xs = [x_center]
|
| 107 |
+
else:
|
| 108 |
+
xs = np.linspace(x_center - max_spread, x_center + max_spread, n_display)
|
| 109 |
+
|
| 110 |
+
all_node_positions.append((xs, y))
|
| 111 |
+
|
| 112 |
+
# Draw connections to previous layer
|
| 113 |
+
if i > 0:
|
| 114 |
+
prev_xs, prev_y = all_node_positions[i - 1]
|
| 115 |
+
|
| 116 |
+
# Limit connections for readability
|
| 117 |
+
max_connections = 200
|
| 118 |
+
step_curr = max(1, len(xs) // 12)
|
| 119 |
+
step_prev = max(1, len(prev_xs) // 12)
|
| 120 |
+
|
| 121 |
+
conn_count = 0
|
| 122 |
+
for px in prev_xs[::step_prev]:
|
| 123 |
+
for cx in xs[::step_curr]:
|
| 124 |
+
if conn_count > max_connections:
|
| 125 |
+
break
|
| 126 |
+
ax.plot(
|
| 127 |
+
[px, cx], [prev_y, y],
|
| 128 |
+
color=color, alpha=0.22, linewidth=0.6,
|
| 129 |
+
transform=ax.transAxes, zorder=1,
|
| 130 |
+
)
|
| 131 |
+
conn_count += 1
|
| 132 |
+
|
| 133 |
+
# Draw nodes
|
| 134 |
+
node_radius = 0.01 if n_display <= 12 else 0.008
|
| 135 |
+
if n_display == 1:
|
| 136 |
+
node_radius = 0.016
|
| 137 |
+
|
| 138 |
+
for x in xs:
|
| 139 |
+
circle = plt.Circle(
|
| 140 |
+
(x, y), node_radius,
|
| 141 |
+
facecolor=color, edgecolor="white",
|
| 142 |
+
linewidth=0.6, alpha=0.95,
|
| 143 |
+
transform=ax.transAxes, zorder=3,
|
| 144 |
+
)
|
| 145 |
+
ax.add_patch(circle)
|
| 146 |
+
|
| 147 |
+
# Draw "+N" indicator if actual size > displayed
|
| 148 |
+
if actual_size > n_display and n_display > 1:
|
| 149 |
+
extra = actual_size - n_display
|
| 150 |
+
if extra > 0:
|
| 151 |
+
ax.text(
|
| 152 |
+
xs[-1] + 0.03, y,
|
| 153 |
+
f"(+{extra:,})",
|
| 154 |
+
transform=ax.transAxes,
|
| 155 |
+
fontsize=7, color="#8B949E",
|
| 156 |
+
ha="left", va="center",
|
| 157 |
+
fontfamily="monospace",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Layer label (left side)
|
| 161 |
+
ax.text(
|
| 162 |
+
0.02, y,
|
| 163 |
+
name,
|
| 164 |
+
transform=ax.transAxes,
|
| 165 |
+
fontsize=9, fontweight="bold",
|
| 166 |
+
color="#E6EDF3",
|
| 167 |
+
ha="left", va="center",
|
| 168 |
+
fontfamily="monospace",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Parameter count (right side)
|
| 172 |
+
if params > 0:
|
| 173 |
+
param_text = f"{params:,} params"
|
| 174 |
+
ax.text(
|
| 175 |
+
0.98, y,
|
| 176 |
+
param_text,
|
| 177 |
+
transform=ax.transAxes,
|
| 178 |
+
fontsize=8,
|
| 179 |
+
color=color,
|
| 180 |
+
ha="right", va="center",
|
| 181 |
+
fontfamily="monospace",
|
| 182 |
+
fontweight="bold",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Running total (far right, smaller)
|
| 186 |
+
if running_params > 0:
|
| 187 |
+
ax.text(
|
| 188 |
+
0.98, y - 0.012,
|
| 189 |
+
f"Ξ£ {running_params / 1e6:.1f}M",
|
| 190 |
+
transform=ax.transAxes,
|
| 191 |
+
fontsize=6.5,
|
| 192 |
+
color="#8B949E",
|
| 193 |
+
ha="right", va="center",
|
| 194 |
+
fontfamily="monospace",
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# ββ Title ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 198 |
+
ax.text(
|
| 199 |
+
0.5, 0.97,
|
| 200 |
+
"GPT-300M Neural Network",
|
| 201 |
+
transform=ax.transAxes,
|
| 202 |
+
fontsize=24, fontweight="bold",
|
| 203 |
+
color="#E6EDF3", ha="center", va="center",
|
| 204 |
+
fontfamily="monospace",
|
| 205 |
+
)
|
| 206 |
+
ax.text(
|
| 207 |
+
0.5, 0.955,
|
| 208 |
+
f"Total: {TOTAL_PARAMS:,} parameters β’ {N_LAYERS} transformer layers β’ "
|
| 209 |
+
f"{N_HEADS} attention heads β’ d_model={D_MODEL}",
|
| 210 |
+
transform=ax.transAxes,
|
| 211 |
+
fontsize=9, color="#8B949E", ha="center", va="center",
|
| 212 |
+
fontfamily="monospace",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# ββ Parameter Summary Box ββββββββββββββββββββββββββββββββββββββ
|
| 216 |
+
summary_y = 0.005
|
| 217 |
+
summary_text = (
|
| 218 |
+
f"ββββββββββββββββ Parameter Summary ββββββββββββββββ\n"
|
| 219 |
+
f"β Token Embedding: {embed_params:>13,} ({embed_params/TOTAL_PARAMS*100:4.1f}%) β\n"
|
| 220 |
+
f"β Attention (Γ{N_LAYERS}): {attn_total*N_LAYERS:>13,} ({attn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) β\n"
|
| 221 |
+
f"β Feed-Forward (Γ{N_LAYERS}): {ffn_total*N_LAYERS:>13,} ({ffn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) β\n"
|
| 222 |
+
f"β RMSNorm (Γ{N_LAYERS}+1): {rmsnorm_params*N_LAYERS+final_norm_params:>13,} ({(rmsnorm_params*N_LAYERS+final_norm_params)/TOTAL_PARAMS*100:4.1f}%) β\n"
|
| 223 |
+
f"β LM Head (tied): {'0 (shared)':>13} β\n"
|
| 224 |
+
f"βββββββββββββββββββββββββββββββββββββββββββββββββββ€\n"
|
| 225 |
+
f"β TOTAL: {TOTAL_PARAMS:>13,} (100%) β\n"
|
| 226 |
+
f"βββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| 227 |
+
)
|
| 228 |
+
ax.text(
|
| 229 |
+
0.5, summary_y,
|
| 230 |
+
summary_text,
|
| 231 |
+
transform=ax.transAxes,
|
| 232 |
+
fontsize=8, color="#58A6FF",
|
| 233 |
+
ha="center", va="bottom",
|
| 234 |
+
fontfamily="monospace",
|
| 235 |
+
bbox=dict(boxstyle="round,pad=0.8", facecolor="#161B22",
|
| 236 |
+
edgecolor="#30363D", linewidth=1),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# ββ Legend ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
+
legend_items = [
|
| 241 |
+
("#4CAF50", "Input / Tokenization"),
|
| 242 |
+
("#2196F3", "Embeddings"),
|
| 243 |
+
("#FF9800", "Self-Attention"),
|
| 244 |
+
("#8BC34A", "Feed-Forward (GELU)"),
|
| 245 |
+
("#9C27B0", "Collapsed Layers (Γ22)"),
|
| 246 |
+
("#E91E63", "Normalization"),
|
| 247 |
+
("#F44336", "Output / LM Head"),
|
| 248 |
+
]
|
| 249 |
+
for j, (c, label) in enumerate(legend_items):
|
| 250 |
+
lx = 0.02
|
| 251 |
+
ly = 0.035 - j * 0.015
|
| 252 |
+
circle = plt.Circle(
|
| 253 |
+
(lx, ly), 0.004,
|
| 254 |
+
facecolor=c, edgecolor="white", linewidth=0.3,
|
| 255 |
+
transform=ax.transAxes, zorder=5,
|
| 256 |
+
)
|
| 257 |
+
ax.add_patch(circle)
|
| 258 |
+
ax.text(
|
| 259 |
+
lx + 0.012, ly, label,
|
| 260 |
+
transform=ax.transAxes,
|
| 261 |
+
fontsize=7, color="#C9D1D9", va="center",
|
| 262 |
+
fontfamily="monospace",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
ax.set_xlim(0, 1)
|
| 266 |
+
ax.set_ylim(0, 1)
|
| 267 |
+
ax.axis("off")
|
| 268 |
+
|
| 269 |
+
plt.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 270 |
+
facecolor="#0D1117", edgecolor="none")
|
| 271 |
+
print(f"Saved: {save_path}")
|
| 272 |
+
plt.close()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
+
# ALSO: A cleaner "zoomed in" single-layer view
|
| 277 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 278 |
+
|
| 279 |
+
def draw_single_layer_detail(save_path="layer_detail.png"):
|
| 280 |
+
"""Draw a detailed view of one transformer layer with node connections."""
|
| 281 |
+
fig, ax = plt.subplots(figsize=(20, 14), facecolor="#0D1117")
|
| 282 |
+
ax.set_facecolor("#0D1117")
|
| 283 |
+
|
| 284 |
+
# One transformer layer breakdown:
|
| 285 |
+
# Input (1024) β Q,K,V (3Γ1024) β Attention Heads (16Γ64) β Output Proj (1024)
|
| 286 |
+
# β RMSNorm (1024) β FFN Up (4096) β GELU β FFN Down (1024) β Output (1024)
|
| 287 |
+
|
| 288 |
+
sub_layers = [
|
| 289 |
+
("Input\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"),
|
| 290 |
+
("Query\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF6B6B"),
|
| 291 |
+
("Key\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#4ECDC4"),
|
| 292 |
+
("Value\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#45B7D1"),
|
| 293 |
+
("Attention Heads\n(16Γ64)", 16, D_MODEL, 0, "#FF9800"),
|
| 294 |
+
("Attn Output\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF9800"),
|
| 295 |
+
("β Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"),
|
| 296 |
+
("FFN Up (GELU)\n(d=4,096)", 14, D_FF, D_MODEL*D_FF, "#8BC34A"),
|
| 297 |
+
("FFN Down\n(d=1,024)", 8, D_MODEL, D_FF*D_MODEL, "#8BC34A"),
|
| 298 |
+
("β Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"),
|
| 299 |
+
("Layer Output\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"),
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
n = len(sub_layers)
|
| 303 |
+
y_positions = np.linspace(0.9, 0.08, n)
|
| 304 |
+
x_center = 0.5
|
| 305 |
+
max_spread = 0.32
|
| 306 |
+
|
| 307 |
+
all_pos = []
|
| 308 |
+
|
| 309 |
+
for i, (name, n_nodes, actual, params, color) in enumerate(sub_layers):
|
| 310 |
+
y = y_positions[i]
|
| 311 |
+
xs = np.linspace(x_center - max_spread, x_center + max_spread, n_nodes)
|
| 312 |
+
all_pos.append((xs, y))
|
| 313 |
+
|
| 314 |
+
# Connections
|
| 315 |
+
if i > 0:
|
| 316 |
+
prev_xs, prev_y = all_pos[i-1]
|
| 317 |
+
step_c = max(1, len(xs) // 10)
|
| 318 |
+
step_p = max(1, len(prev_xs) // 10)
|
| 319 |
+
for px in prev_xs[::step_p]:
|
| 320 |
+
for cx in xs[::step_c]:
|
| 321 |
+
ax.plot([px, cx], [prev_y, y],
|
| 322 |
+
color=color, alpha=0.2, linewidth=0.7,
|
| 323 |
+
transform=ax.transAxes, zorder=1)
|
| 324 |
+
|
| 325 |
+
# Nodes
|
| 326 |
+
r = 0.011 if n_nodes <= 10 else 0.009
|
| 327 |
+
for x in xs:
|
| 328 |
+
c = plt.Circle((x, y), r, facecolor=color, edgecolor="white",
|
| 329 |
+
linewidth=0.6, alpha=0.95,
|
| 330 |
+
transform=ax.transAxes, zorder=3)
|
| 331 |
+
ax.add_patch(c)
|
| 332 |
+
|
| 333 |
+
# Overflow indicator
|
| 334 |
+
if actual > n_nodes:
|
| 335 |
+
ax.text(xs[-1] + 0.025, y, f"(+{actual - n_nodes:,})",
|
| 336 |
+
transform=ax.transAxes, fontsize=7, color="#8B949E",
|
| 337 |
+
ha="left", va="center", fontfamily="monospace")
|
| 338 |
+
|
| 339 |
+
# Label
|
| 340 |
+
ax.text(0.03, y, name, transform=ax.transAxes,
|
| 341 |
+
fontsize=9, fontweight="bold", color="#E6EDF3",
|
| 342 |
+
ha="left", va="center", fontfamily="monospace")
|
| 343 |
+
|
| 344 |
+
# Params
|
| 345 |
+
if params > 0:
|
| 346 |
+
ax.text(0.97, y, f"{params:,}", transform=ax.transAxes,
|
| 347 |
+
fontsize=8, color=color, ha="right", va="center",
|
| 348 |
+
fontfamily="monospace", fontweight="bold")
|
| 349 |
+
|
| 350 |
+
# Title
|
| 351 |
+
ax.text(0.5, 0.96, "Single Transformer Layer β Detailed View",
|
| 352 |
+
transform=ax.transAxes, fontsize=18, fontweight="bold",
|
| 353 |
+
color="#E6EDF3", ha="center", fontfamily="monospace")
|
| 354 |
+
ax.text(0.5, 0.935,
|
| 355 |
+
f"Parameters per layer: {layer_total:,} β’ Γ{N_LAYERS} layers = {all_layers_total:,} total",
|
| 356 |
+
transform=ax.transAxes, fontsize=9, color="#8B949E",
|
| 357 |
+
ha="center", fontfamily="monospace")
|
| 358 |
+
|
| 359 |
+
ax.set_xlim(0, 1)
|
| 360 |
+
ax.set_ylim(0, 1)
|
| 361 |
+
ax.axis("off")
|
| 362 |
+
|
| 363 |
+
plt.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 364 |
+
facecolor="#0D1117", edgecolor="none")
|
| 365 |
+
print(f"Saved: {save_path}")
|
| 366 |
+
plt.close()
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
import os
|
| 371 |
+
os.makedirs("viz", exist_ok=True)
|
| 372 |
+
|
| 373 |
+
print("=" * 50)
|
| 374 |
+
print(" GPT-300M Parameter Verification")
|
| 375 |
+
print("=" * 50)
|
| 376 |
+
print(f" Token Embedding: {embed_params:>13,}")
|
| 377 |
+
print(f" Per-layer Attention: {attn_total:>13,}")
|
| 378 |
+
print(f" Per-layer FFN: {ffn_total:>13,}")
|
| 379 |
+
print(f" Per-layer Norm: {rmsnorm_params:>13,}")
|
| 380 |
+
print(f" Per-layer Total: {layer_total:>13,}")
|
| 381 |
+
print(f" All {N_LAYERS} layers: {all_layers_total:>13,}")
|
| 382 |
+
print(f" Final Norm: {final_norm_params:>13,}")
|
| 383 |
+
print(f" LM Head (tied): {'0 (shared)':>13}")
|
| 384 |
+
print(f" βββββββββββββββββββββββββββββββββ")
|
| 385 |
+
print(f" TOTAL: {TOTAL_PARAMS:>13,}")
|
| 386 |
+
print(f" β {TOTAL_PARAMS / 1e6:.1f}M parameters")
|
| 387 |
+
print("=" * 50)
|
| 388 |
+
|
| 389 |
+
print("\nGenerating full network diagram...")
|
| 390 |
+
draw_neural_network("viz/neural_network_full.png")
|
| 391 |
+
|
| 392 |
+
print("Generating single-layer detail...")
|
| 393 |
+
draw_single_layer_detail("viz/neural_network_layer.png")
|
| 394 |
+
|
| 395 |
+
print("\nDone!")
|
visualize_nn.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-300M Neural Network Visualizer
|
| 3 |
+
====================================
|
| 4 |
+
Generates detailed architectural diagrams of the GPT-300M model
|
| 5 |
+
using matplotlib, showing:
|
| 6 |
+
- Full model architecture flow
|
| 7 |
+
- Detailed transformer block internals
|
| 8 |
+
- Attention head visualization
|
| 9 |
+
- Parameter distribution charts
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python visualize_nn.py
|
| 13 |
+
python visualize_nn.py --output architecture.png
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import matplotlib
|
| 18 |
+
matplotlib.use("Agg")
|
| 19 |
+
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import matplotlib.patches as patches
|
| 22 |
+
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from config import GPT300MConfig, gpt_300m
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
+
# COLOR SCHEME
|
| 30 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
COLORS = {
|
| 33 |
+
"bg": "#0D1117",
|
| 34 |
+
"text": "#E6EDF3",
|
| 35 |
+
"text_dim": "#8B949E",
|
| 36 |
+
"embed": "#58A6FF", # Blue
|
| 37 |
+
"attn": "#F78166", # Orange
|
| 38 |
+
"ffn": "#7EE787", # Green
|
| 39 |
+
"norm": "#D2A8FF", # Purple
|
| 40 |
+
"residual": "#FFA657", # Yellow-orange
|
| 41 |
+
"output": "#FF7B72", # Red
|
| 42 |
+
"arrow": "#484F58",
|
| 43 |
+
"highlight": "#1F6FEB",
|
| 44 |
+
"border": "#30363D",
|
| 45 |
+
"card_bg": "#161B22",
|
| 46 |
+
"accent1": "#79C0FF",
|
| 47 |
+
"accent2": "#BB9AF7",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def draw_rounded_box(ax, x, y, w, h, color, label, fontsize=10,
|
| 52 |
+
text_color=None, alpha=0.9, sublabel=None):
|
| 53 |
+
"""Draw a rounded rectangle with label."""
|
| 54 |
+
box = FancyBboxPatch(
|
| 55 |
+
(x - w/2, y - h/2), w, h,
|
| 56 |
+
boxstyle="round,pad=0.1",
|
| 57 |
+
facecolor=color,
|
| 58 |
+
edgecolor="white",
|
| 59 |
+
linewidth=0.5,
|
| 60 |
+
alpha=alpha,
|
| 61 |
+
zorder=3,
|
| 62 |
+
)
|
| 63 |
+
ax.add_patch(box)
|
| 64 |
+
ax.text(
|
| 65 |
+
x, y + (0.15 if sublabel else 0),
|
| 66 |
+
label,
|
| 67 |
+
ha="center", va="center",
|
| 68 |
+
fontsize=fontsize,
|
| 69 |
+
fontweight="bold",
|
| 70 |
+
color=text_color or COLORS["text"],
|
| 71 |
+
zorder=4,
|
| 72 |
+
)
|
| 73 |
+
if sublabel:
|
| 74 |
+
ax.text(
|
| 75 |
+
x, y - 0.25,
|
| 76 |
+
sublabel,
|
| 77 |
+
ha="center", va="center",
|
| 78 |
+
fontsize=fontsize - 2,
|
| 79 |
+
color=COLORS["text_dim"],
|
| 80 |
+
zorder=4,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def draw_arrow(ax, x1, y1, x2, y2, color=None):
|
| 85 |
+
"""Draw an arrow between two points."""
|
| 86 |
+
ax.annotate(
|
| 87 |
+
"",
|
| 88 |
+
xy=(x2, y2), xytext=(x1, y1),
|
| 89 |
+
arrowprops=dict(
|
| 90 |
+
arrowstyle="->",
|
| 91 |
+
color=color or COLORS["arrow"],
|
| 92 |
+
lw=1.5,
|
| 93 |
+
connectionstyle="arc3,rad=0",
|
| 94 |
+
),
|
| 95 |
+
zorder=2,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def draw_residual_connection(ax, x_start, y_start, x_end, y_end, offset=1.8):
|
| 100 |
+
"""Draw a residual/skip connection arc."""
|
| 101 |
+
ax.annotate(
|
| 102 |
+
"",
|
| 103 |
+
xy=(x_end, y_end), xytext=(x_start, y_start),
|
| 104 |
+
arrowprops=dict(
|
| 105 |
+
arrowstyle="->",
|
| 106 |
+
color=COLORS["residual"],
|
| 107 |
+
lw=1.2,
|
| 108 |
+
linestyle="--",
|
| 109 |
+
connectionstyle=f"arc3,rad=0.3",
|
| 110 |
+
),
|
| 111 |
+
zorder=1,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
# FULL ARCHITECTURE DIAGRAM
|
| 117 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
def draw_full_architecture(config: GPT300MConfig, save_path: str = None):
|
| 120 |
+
"""Draw the complete GPT-300M architecture."""
|
| 121 |
+
fig, ax = plt.subplots(1, 1, figsize=(14, 24), facecolor=COLORS["bg"])
|
| 122 |
+
ax.set_facecolor(COLORS["bg"])
|
| 123 |
+
ax.set_xlim(-4, 4)
|
| 124 |
+
ax.set_ylim(-1, 22)
|
| 125 |
+
ax.axis("off")
|
| 126 |
+
|
| 127 |
+
# Title
|
| 128 |
+
ax.text(0, 21.5, "GPT-300M Architecture", ha="center", va="center",
|
| 129 |
+
fontsize=22, fontweight="bold", color=COLORS["text"],
|
| 130 |
+
fontfamily="monospace")
|
| 131 |
+
ax.text(0, 21.0,
|
| 132 |
+
f"{config.total_params_estimate:,} parameters β’ "
|
| 133 |
+
f"{config.n_layers} layers β’ "
|
| 134 |
+
f"{config.n_heads} heads β’ "
|
| 135 |
+
f"d={config.d_model}",
|
| 136 |
+
ha="center", va="center", fontsize=10, color=COLORS["text_dim"],
|
| 137 |
+
fontfamily="monospace")
|
| 138 |
+
|
| 139 |
+
y = 19.5 # Starting y position
|
| 140 |
+
|
| 141 |
+
# ββ Input ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["card_bg"], "Input Token IDs",
|
| 143 |
+
sublabel=f"[batch, seq_len]", fontsize=11)
|
| 144 |
+
y -= 1.1
|
| 145 |
+
draw_arrow(ax, 0, y + 0.8, 0, y + 0.4)
|
| 146 |
+
|
| 147 |
+
# ββ Token Embedding ββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["embed"],
|
| 149 |
+
"Token Embedding", text_color="#000",
|
| 150 |
+
sublabel=f"{config.vocab_size:,} Γ {config.d_model}")
|
| 151 |
+
y -= 1.1
|
| 152 |
+
draw_arrow(ax, 0, y + 0.8, 0, y + 0.4)
|
| 153 |
+
|
| 154 |
+
# ββ RoPE βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["accent2"],
|
| 156 |
+
"Rotary Position Embeddings (RoPE)",
|
| 157 |
+
text_color="#000", fontsize=9,
|
| 158 |
+
sublabel=f"ΞΈ = {config.rope_theta:.0f}")
|
| 159 |
+
y -= 1.0
|
| 160 |
+
draw_arrow(ax, 0, y + 0.7, 0, y + 0.4)
|
| 161 |
+
|
| 162 |
+
# ββ Dropout ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
draw_rounded_box(ax, 0, y, 2.5, 0.5, COLORS["border"],
|
| 164 |
+
f"Dropout (p={config.dropout})", fontsize=9)
|
| 165 |
+
y -= 1.0
|
| 166 |
+
draw_arrow(ax, 0, y + 0.7, 0, y + 0.35)
|
| 167 |
+
|
| 168 |
+
# ββ Transformer Blocks βββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
block_height = 3.2
|
| 170 |
+
|
| 171 |
+
# Draw detailed first block
|
| 172 |
+
block_y_start = y
|
| 173 |
+
block_y_end = y - block_height
|
| 174 |
+
|
| 175 |
+
# Block container
|
| 176 |
+
block_box = FancyBboxPatch(
|
| 177 |
+
(-3.3, block_y_end - 0.1), 6.6, block_height + 0.2,
|
| 178 |
+
boxstyle="round,pad=0.15",
|
| 179 |
+
facecolor=COLORS["card_bg"],
|
| 180 |
+
edgecolor=COLORS["highlight"],
|
| 181 |
+
linewidth=1.5,
|
| 182 |
+
alpha=0.8,
|
| 183 |
+
zorder=1,
|
| 184 |
+
)
|
| 185 |
+
ax.add_patch(block_box)
|
| 186 |
+
ax.text(-3.0, block_y_start + 0.05,
|
| 187 |
+
f"Transformer Block Γ {config.n_layers}",
|
| 188 |
+
fontsize=10, fontweight="bold", color=COLORS["highlight"],
|
| 189 |
+
fontfamily="monospace", zorder=5)
|
| 190 |
+
|
| 191 |
+
# Inside the block
|
| 192 |
+
by = block_y_start - 0.4
|
| 193 |
+
|
| 194 |
+
# RMSNorm 1
|
| 195 |
+
draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"],
|
| 196 |
+
"RMSNorm", text_color="#000", fontsize=9)
|
| 197 |
+
by -= 0.7
|
| 198 |
+
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
|
| 199 |
+
|
| 200 |
+
# Multi-Head Attention
|
| 201 |
+
draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["attn"],
|
| 202 |
+
"Multi-Head Attention", text_color="#000", fontsize=10,
|
| 203 |
+
sublabel=f"{config.n_heads} heads Γ {config.head_dim}d")
|
| 204 |
+
# Residual connection
|
| 205 |
+
draw_residual_connection(ax, -1.6, block_y_start - 0.2, -1.6, by)
|
| 206 |
+
ax.text(-2.5, by + 0.3, "β residual", fontsize=7,
|
| 207 |
+
color=COLORS["residual"], ha="center")
|
| 208 |
+
|
| 209 |
+
by -= 0.8
|
| 210 |
+
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
|
| 211 |
+
|
| 212 |
+
# RMSNorm 2
|
| 213 |
+
draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"],
|
| 214 |
+
"RMSNorm", text_color="#000", fontsize=9)
|
| 215 |
+
by -= 0.7
|
| 216 |
+
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
|
| 217 |
+
|
| 218 |
+
# Feed-Forward Network
|
| 219 |
+
draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["ffn"],
|
| 220 |
+
"Feed-Forward Network", text_color="#000", fontsize=10,
|
| 221 |
+
sublabel=f"{config.d_model} β {config.d_ff} β {config.d_model}")
|
| 222 |
+
# Residual connection
|
| 223 |
+
draw_residual_connection(ax, 1.6, by + 1.5, 1.6, by)
|
| 224 |
+
ax.text(2.5, by + 0.7, "β residual", fontsize=7,
|
| 225 |
+
color=COLORS["residual"], ha="center")
|
| 226 |
+
|
| 227 |
+
y = block_y_end - 0.4
|
| 228 |
+
|
| 229 |
+
# ββ Repeated blocks indicator ββββββββββββββββββββββββββββββββββ
|
| 230 |
+
draw_arrow(ax, 0, y + 0.2, 0, y - 0.1)
|
| 231 |
+
ax.text(0, y - 0.3, f"Γ {config.n_layers} layers", ha="center",
|
| 232 |
+
fontsize=11, fontweight="bold", color=COLORS["text_dim"],
|
| 233 |
+
fontfamily="monospace",
|
| 234 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS["card_bg"],
|
| 235 |
+
edgecolor=COLORS["border"]))
|
| 236 |
+
y -= 0.9
|
| 237 |
+
draw_arrow(ax, 0, y + 0.3, 0, y + 0.05)
|
| 238 |
+
|
| 239 |
+
# ββ Final RMSNorm ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
+
draw_rounded_box(ax, 0, y - 0.2, 3.5, 0.5, COLORS["norm"],
|
| 241 |
+
"Final RMSNorm", text_color="#000", fontsize=10)
|
| 242 |
+
y -= 1.0
|
| 243 |
+
draw_arrow(ax, 0, y + 0.5, 0, y + 0.2)
|
| 244 |
+
|
| 245 |
+
# ββ LM Head ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
+
draw_rounded_box(ax, 0, y - 0.1, 3.5, 0.7, COLORS["output"],
|
| 247 |
+
"Linear (LM Head)", text_color="#000", fontsize=11,
|
| 248 |
+
sublabel=f"{config.d_model} β {config.vocab_size:,} (weight-tied)")
|
| 249 |
+
y -= 1.1
|
| 250 |
+
draw_arrow(ax, 0, y + 0.7, 0, y + 0.35)
|
| 251 |
+
|
| 252 |
+
# ββ Softmax / Output βββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββ
|
| 253 |
+
draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["card_bg"],
|
| 254 |
+
"Softmax β Next Token Probabilities", fontsize=10,
|
| 255 |
+
sublabel=f"[batch, seq_len, {config.vocab_size:,}]")
|
| 256 |
+
|
| 257 |
+
plt.tight_layout()
|
| 258 |
+
|
| 259 |
+
if save_path:
|
| 260 |
+
fig.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 261 |
+
facecolor=COLORS["bg"], edgecolor="none")
|
| 262 |
+
print(f"Saved architecture diagram: {save_path}")
|
| 263 |
+
|
| 264 |
+
return fig
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
+
# PARAMETER DISTRIBUTION CHART
|
| 269 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 270 |
+
|
| 271 |
+
def draw_parameter_chart(config: GPT300MConfig, save_path: str = None):
|
| 272 |
+
"""Draw a parameter distribution breakdown."""
|
| 273 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 7), facecolor=COLORS["bg"])
|
| 274 |
+
|
| 275 |
+
# Calculate parameter counts per component
|
| 276 |
+
emb_params = config.vocab_size * config.d_model
|
| 277 |
+
attn_params = 4 * config.d_model * config.d_model * config.n_layers
|
| 278 |
+
ffn_params = 2 * config.d_model * config.d_ff * config.n_layers
|
| 279 |
+
norm_params = 2 * config.d_model * config.n_layers + config.d_model
|
| 280 |
+
total = emb_params + attn_params + ffn_params + norm_params
|
| 281 |
+
|
| 282 |
+
# ββ Pie Chart ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
+
ax = axes[0]
|
| 284 |
+
ax.set_facecolor(COLORS["bg"])
|
| 285 |
+
labels = ["Token\nEmbedding", "Attention\nLayers", "Feed-Forward\nLayers", "LayerNorm"]
|
| 286 |
+
sizes = [emb_params, attn_params, ffn_params, norm_params]
|
| 287 |
+
colors = [COLORS["embed"], COLORS["attn"], COLORS["ffn"], COLORS["norm"]]
|
| 288 |
+
|
| 289 |
+
wedges, texts, autotexts = ax.pie(
|
| 290 |
+
sizes, labels=None, autopct=lambda p: f"{p:.1f}%",
|
| 291 |
+
colors=colors, startangle=90, pctdistance=0.7,
|
| 292 |
+
wedgeprops=dict(width=0.5, edgecolor=COLORS["bg"], linewidth=2),
|
| 293 |
+
textprops=dict(color=COLORS["text"], fontsize=10),
|
| 294 |
+
)
|
| 295 |
+
for at in autotexts:
|
| 296 |
+
at.set_fontweight("bold")
|
| 297 |
+
at.set_color("#000")
|
| 298 |
+
|
| 299 |
+
# Legend
|
| 300 |
+
legend_labels = [
|
| 301 |
+
f"{l}\n({s/1e6:.1f}M)" for l, s in zip(
|
| 302 |
+
["Token Embedding", "Attention", "Feed-Forward", "LayerNorm"],
|
| 303 |
+
sizes
|
| 304 |
+
)
|
| 305 |
+
]
|
| 306 |
+
ax.legend(
|
| 307 |
+
wedges, legend_labels, loc="center left", bbox_to_anchor=(1.05, 0.5),
|
| 308 |
+
fontsize=9, frameon=False, labelcolor=COLORS["text"],
|
| 309 |
+
)
|
| 310 |
+
ax.set_title("Parameter Distribution", fontsize=14, fontweight="bold",
|
| 311 |
+
color=COLORS["text"], pad=15)
|
| 312 |
+
|
| 313 |
+
# ββ Per-Layer Breakdown Bar Chart ββββββββββββββββββββββββββββββ
|
| 314 |
+
ax = axes[1]
|
| 315 |
+
ax.set_facecolor(COLORS["bg"])
|
| 316 |
+
|
| 317 |
+
layer_attn = 4 * config.d_model * config.d_model
|
| 318 |
+
layer_ffn = 2 * config.d_model * config.d_ff
|
| 319 |
+
layer_norm = 2 * config.d_model
|
| 320 |
+
|
| 321 |
+
layers = range(1, config.n_layers + 1)
|
| 322 |
+
bar_width = 0.8
|
| 323 |
+
|
| 324 |
+
ax.bar(layers, [layer_attn / 1e6] * config.n_layers, bar_width,
|
| 325 |
+
label="Attention", color=COLORS["attn"], alpha=0.9)
|
| 326 |
+
ax.bar(layers, [layer_ffn / 1e6] * config.n_layers, bar_width,
|
| 327 |
+
bottom=[layer_attn / 1e6] * config.n_layers,
|
| 328 |
+
label="Feed-Forward", color=COLORS["ffn"], alpha=0.9)
|
| 329 |
+
ax.bar(layers, [layer_norm / 1e6] * config.n_layers, bar_width,
|
| 330 |
+
bottom=[(layer_attn + layer_ffn) / 1e6] * config.n_layers,
|
| 331 |
+
label="Norm", color=COLORS["norm"], alpha=0.9)
|
| 332 |
+
|
| 333 |
+
ax.set_xlabel("Layer", fontsize=11, color=COLORS["text"])
|
| 334 |
+
ax.set_ylabel("Parameters (M)", fontsize=11, color=COLORS["text"])
|
| 335 |
+
ax.set_title("Parameters Per Layer", fontsize=14, fontweight="bold",
|
| 336 |
+
color=COLORS["text"], pad=15)
|
| 337 |
+
ax.legend(fontsize=9, frameon=False, labelcolor=COLORS["text"])
|
| 338 |
+
ax.tick_params(colors=COLORS["text_dim"])
|
| 339 |
+
ax.spines["bottom"].set_color(COLORS["border"])
|
| 340 |
+
ax.spines["left"].set_color(COLORS["border"])
|
| 341 |
+
ax.spines["top"].set_visible(False)
|
| 342 |
+
ax.spines["right"].set_visible(False)
|
| 343 |
+
|
| 344 |
+
# Overall title
|
| 345 |
+
fig.suptitle(
|
| 346 |
+
f"GPT-300M β’ {total:,} Total Parameters",
|
| 347 |
+
fontsize=16, fontweight="bold", color=COLORS["text"],
|
| 348 |
+
fontfamily="monospace", y=1.02,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
plt.tight_layout()
|
| 352 |
+
|
| 353 |
+
if save_path:
|
| 354 |
+
fig.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 355 |
+
facecolor=COLORS["bg"], edgecolor="none")
|
| 356 |
+
print(f"Saved parameter chart: {save_path}")
|
| 357 |
+
|
| 358 |
+
return fig
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
+
# ATTENTION HEAD VISUALIZATION
|
| 363 |
+
# ββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 364 |
+
|
| 365 |
+
def draw_attention_heads(config: GPT300MConfig, save_path: str = None):
|
| 366 |
+
"""Visualize the multi-head attention mechanism."""
|
| 367 |
+
fig, ax = plt.subplots(1, 1, figsize=(14, 10), facecolor=COLORS["bg"])
|
| 368 |
+
ax.set_facecolor(COLORS["bg"])
|
| 369 |
+
ax.set_xlim(-1, 11)
|
| 370 |
+
ax.set_ylim(-1, 8)
|
| 371 |
+
ax.axis("off")
|
| 372 |
+
|
| 373 |
+
ax.text(5, 7.5, "Multi-Head Self-Attention", ha="center",
|
| 374 |
+
fontsize=18, fontweight="bold", color=COLORS["text"],
|
| 375 |
+
fontfamily="monospace")
|
| 376 |
+
ax.text(5, 7.0,
|
| 377 |
+
f"{config.n_heads} heads Γ {config.head_dim}d per head = {config.d_model}d total",
|
| 378 |
+
ha="center", fontsize=10, color=COLORS["text_dim"])
|
| 379 |
+
|
| 380 |
+
# Input
|
| 381 |
+
draw_rounded_box(ax, 5, 6.2, 4, 0.5, COLORS["embed"],
|
| 382 |
+
f"Input: [B, T, {config.d_model}]", text_color="#000", fontsize=9)
|
| 383 |
+
|
| 384 |
+
# Q, K, V projections
|
| 385 |
+
for i, (name, color) in enumerate(zip(["Q", "K", "V"],
|
| 386 |
+
["#FF6B6B", "#4ECDC4", "#45B7D1"])):
|
| 387 |
+
x = 2 + i * 3
|
| 388 |
+
draw_arrow(ax, 5, 5.9, x, 5.4)
|
| 389 |
+
draw_rounded_box(ax, x, 5.1, 1.8, 0.5, color,
|
| 390 |
+
f"W_{name}", text_color="#000", fontsize=10,
|
| 391 |
+
sublabel=f"{config.d_model}Γ{config.d_model}")
|
| 392 |
+
|
| 393 |
+
# Heads
|
| 394 |
+
head_y = 3.8
|
| 395 |
+
n_show = min(config.n_heads, 8)
|
| 396 |
+
head_spacing = 9.0 / n_show
|
| 397 |
+
|
| 398 |
+
for h in range(n_show):
|
| 399 |
+
hx = 1 + h * head_spacing
|
| 400 |
+
# Head box
|
| 401 |
+
box = FancyBboxPatch(
|
| 402 |
+
(hx - 0.4, head_y - 0.3), 0.8, 0.6,
|
| 403 |
+
boxstyle="round,pad=0.05",
|
| 404 |
+
facecolor=COLORS["attn"],
|
| 405 |
+
edgecolor="white",
|
| 406 |
+
linewidth=0.5,
|
| 407 |
+
alpha=0.8,
|
| 408 |
+
zorder=3,
|
| 409 |
+
)
|
| 410 |
+
ax.add_patch(box)
|
| 411 |
+
ax.text(hx, head_y, f"H{h+1}", ha="center", va="center",
|
| 412 |
+
fontsize=8, fontweight="bold", color="#000", zorder=4)
|
| 413 |
+
|
| 414 |
+
# Arrows from Q,K,V to heads
|
| 415 |
+
for qi, qx in enumerate([2, 5, 8]):
|
| 416 |
+
ax.annotate("", xy=(hx, head_y + 0.3), xytext=(qx, 4.8),
|
| 417 |
+
arrowprops=dict(arrowstyle="-", color=COLORS["arrow"],
|
| 418 |
+
lw=0.3, alpha=0.3), zorder=1)
|
| 419 |
+
|
| 420 |
+
if config.n_heads > 8:
|
| 421 |
+
ax.text(5, head_y - 0.6, f"... ({config.n_heads} heads total)",
|
| 422 |
+
ha="center", fontsize=9, color=COLORS["text_dim"])
|
| 423 |
+
|
| 424 |
+
# Attention computation
|
| 425 |
+
draw_rounded_box(ax, 5, 2.5, 6, 0.6, COLORS["card_bg"],
|
| 426 |
+
"Scaled Dot-Product: softmax(QK^T / βd_k) Γ V",
|
| 427 |
+
fontsize=10)
|
| 428 |
+
for h in range(n_show):
|
| 429 |
+
hx = 1 + h * head_spacing
|
| 430 |
+
draw_arrow(ax, hx, head_y - 0.3, 5, 2.85)
|
| 431 |
+
|
| 432 |
+
# Concatenate
|
| 433 |
+
draw_arrow(ax, 5, 2.15, 5, 1.75)
|
| 434 |
+
draw_rounded_box(ax, 5, 1.5, 4, 0.5, COLORS["accent1"],
|
| 435 |
+
"Concat β W_O projection", text_color="#000", fontsize=10)
|
| 436 |
+
|
| 437 |
+
# Output
|
| 438 |
+
draw_arrow(ax, 5, 1.2, 5, 0.8)
|
| 439 |
+
draw_rounded_box(ax, 5, 0.5, 4, 0.5, COLORS["ffn"],
|
| 440 |
+
f"Output: [B, T, {config.d_model}]", text_color="#000", fontsize=9)
|
| 441 |
+
|
| 442 |
+
plt.tight_layout()
|
| 443 |
+
|
| 444 |
+
if save_path:
|
| 445 |
+
fig.savefig(save_path, dpi=200, bbox_inches="tight",
|
| 446 |
+
facecolor=COLORS["bg"], edgecolor="none")
|
| 447 |
+
print(f"Saved attention diagram: {save_path}")
|
| 448 |
+
|
| 449 |
+
return fig
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 453 |
+
# MAIN
|
| 454 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 455 |
+
|
| 456 |
+
if __name__ == "__main__":
|
| 457 |
+
parser = argparse.ArgumentParser(description="Visualize GPT-300M Architecture")
|
| 458 |
+
parser.add_argument("--output", type=str, default="./viz",
|
| 459 |
+
help="Output directory for images")
|
| 460 |
+
args = parser.parse_args()
|
| 461 |
+
|
| 462 |
+
import os
|
| 463 |
+
os.makedirs(args.output, exist_ok=True)
|
| 464 |
+
|
| 465 |
+
config = gpt_300m()
|
| 466 |
+
print(f"Generating visualizations for GPT-300M ({config.total_params_estimate:,} params)...")
|
| 467 |
+
|
| 468 |
+
draw_full_architecture(config, os.path.join(args.output, "architecture.png"))
|
| 469 |
+
draw_parameter_chart(config, os.path.join(args.output, "parameters.png"))
|
| 470 |
+
draw_attention_heads(config, os.path.join(args.output, "attention.png"))
|
| 471 |
+
|
| 472 |
+
print("Done! All visualizations saved.")
|