Initial commit
Browse files- .gitignore +44 -0
- README.md +242 -0
- data/dataloader.py +223 -0
- finetune/README.md +137 -0
- finetune/__init__.py +1 -0
- finetune/chat.py +296 -0
- finetune/check_data.py +269 -0
- finetune/data/meta.json +13 -0
- finetune/data/tokenizer.json +0 -0
- finetune/data/tokenizer_config.json +17 -0
- finetune/prepare_data.py +303 -0
- finetune/sft_dataset.py +103 -0
- finetune/sft_train.py +563 -0
- model/__init__.py +5 -0
- model/attention.py +114 -0
- model/block.py +84 -0
- model/config.py +110 -0
- model/mlp.py +83 -0
- model/model.py +245 -0
- model/norm.py +65 -0
- model/rope.py +172 -0
- model_explained.md +376 -0
- plot_training.py +370 -0
- requirements.txt +20 -0
- run.md +34 -0
- test_chatmodel.py +366 -0
- test_checkpoint.py +290 -0
- tokenizer/bpe.py +134 -0
- tokenizer/fineweb_edu_tokenizer.json +0 -0
- tokenizer/fineweb_edu_tokenizer/special_tokens_map.json +5 -0
- tokenizer/fineweb_edu_tokenizer/tokenizer.json +0 -0
- tokenizer/fineweb_edu_tokenizer/tokenizer_config.json +11 -0
- tokenizer/normalizer.py +42 -0
- tokenizer/post_processor.py +152 -0
- tokenizer/pretokenizer.py +159 -0
- tokenizer/tempCodeRunnerFile.py +5 -0
- tokenizer/tokenize_dataset.py +389 -0
- tokenizer/traintokenizer.py +207 -0
- tokenizer/wrap_tokenizer.py +232 -0
- tokenizer_walkthrough.md +105 -0
- train.py +485 -0
.gitignore
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ── Checkpoints & training runs ──────────────────────────────────────
|
| 2 |
+
runs/
|
| 3 |
+
|
| 4 |
+
# ── Python ───────────────────────────────────────────────────────────
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
.Python
|
| 10 |
+
*.egg-info/
|
| 11 |
+
dist/
|
| 12 |
+
build/
|
| 13 |
+
*.egg
|
| 14 |
+
|
| 15 |
+
# ── Virtual environments ──────────────────────────────────────────────
|
| 16 |
+
.env
|
| 17 |
+
.venv
|
| 18 |
+
env/
|
| 19 |
+
venv/
|
| 20 |
+
|
| 21 |
+
# ── Jupyter ───────────────────────────────────────────────────────────
|
| 22 |
+
.ipynb_checkpoints/
|
| 23 |
+
*.ipynb
|
| 24 |
+
|
| 25 |
+
# ── Data / binaries ──────────────────────────────────────────────────
|
| 26 |
+
*.bin
|
| 27 |
+
*.pt
|
| 28 |
+
*.pth
|
| 29 |
+
*.safetensors
|
| 30 |
+
*.npy
|
| 31 |
+
*.npz
|
| 32 |
+
|
| 33 |
+
# ── Logs ─────────────────────────────────────────────────────────────
|
| 34 |
+
*.log
|
| 35 |
+
*.jsonl
|
| 36 |
+
|
| 37 |
+
# ── OS ───────────────────────────────────────────────────────────────
|
| 38 |
+
.DS_Store
|
| 39 |
+
Thumbs.db
|
| 40 |
+
|
| 41 |
+
# ── IDE ───────────────────────────────────────────────────────────────
|
| 42 |
+
.vscode/
|
| 43 |
+
.idea/
|
| 44 |
+
*.swp
|
README.md
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SLLM — Small Language Model from Scratch
|
| 2 |
+
|
| 3 |
+
A GPT-style decoder-only transformer built and trained from scratch in PyTorch. Two model sizes are available (100M and 150M parameters), designed to fit on consumer GPUs as small as a 4 GB VRAM card (e.g. RTX 3050).
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## ✨ Features
|
| 8 |
+
|
| 9 |
+
- **Architecture**: Decoder-only transformer (GPT-style) with modern improvements
|
| 10 |
+
- RMSNorm instead of LayerNorm (faster, no bias)
|
| 11 |
+
- RoPE (Rotary Position Embeddings) — used in LLaMA, Mistral, Gemma
|
| 12 |
+
- SwiGLU feed-forward network — outperforms GELU at the same parameter count
|
| 13 |
+
- Flash Attention via `F.scaled_dot_product_attention` (O(T²) memory avoided)
|
| 14 |
+
- Weight-tied token embeddings + LM head (saves ~32M parameters)
|
| 15 |
+
- **Training**
|
| 16 |
+
- bf16 mixed-precision with gradient accumulation
|
| 17 |
+
- Gradient checkpointing for low-VRAM GPUs
|
| 18 |
+
- Cosine LR schedule with linear warmup
|
| 19 |
+
- Resumable checkpointing (`--resume`, `--extra_steps`)
|
| 20 |
+
- JSONL metric logging + live training dashboard
|
| 21 |
+
- **Custom BPE Tokenizer** — trained on FineWeb-Edu with byte fallback (zero OOV)
|
| 22 |
+
- **Supervised Fine-Tuning (SFT)** — chat model pipeline included in `finetune/`
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## 🏗️ Project Structure
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
sllm/
|
| 30 |
+
├── model/ # Model architecture
|
| 31 |
+
│ ├── config.py # ModelConfig dataclass (SLLM_100M, SLLM_150M presets)
|
| 32 |
+
│ ├── model.py # SLLM — full model assembly, weight init, gradient checkpointing
|
| 33 |
+
│ ├── block.py # TransformerBlock (pre-norm, residual)
|
| 34 |
+
│ ├── attention.py # Causal multi-head self-attention + RoPE
|
| 35 |
+
│ ├── mlp.py # SwiGLU feed-forward network
|
| 36 |
+
│ ├── norm.py # RMSNorm
|
| 37 |
+
│ └── rope.py # Rotary Position Embeddings
|
| 38 |
+
│
|
| 39 |
+
├── tokenizer/ # Custom BPE tokenizer
|
| 40 |
+
│ ├── normalizer.py # HTML stripping, unicode NFC, whitespace cleanup
|
| 41 |
+
│ ├── pretokenizer.py # Regex pre-tokenizer (code-aware, contraction-aware)
|
| 42 |
+
│ ├── bpe.py # BPE model config with byte fallback (32k vocab)
|
| 43 |
+
│ ├── traintokenizer.py # Train on FineWeb-Edu stream
|
| 44 |
+
│ ├── post_processor.py # Append <|endoftext|> to every sequence
|
| 45 |
+
│ ├── wrap_tokenizer.py # Wrap into PreTrainedTokenizerFast
|
| 46 |
+
│ └── tokenize_dataset.py # Pack tokens into flat binary .bin shards
|
| 47 |
+
│
|
| 48 |
+
├── data/
|
| 49 |
+
│ └── dataloader.py # Memory-mapped shard dataloader
|
| 50 |
+
│
|
| 51 |
+
├── finetune/ # Supervised fine-tuning (SFT) pipeline
|
| 52 |
+
│ ├── prepare_data.py # Prepare chat data
|
| 53 |
+
│ ├── sft_train.py # SFT training loop
|
| 54 |
+
│ ├── sft_dataset.py # Chat dataset
|
| 55 |
+
│ └── chat.py # Interactive chat with the fine-tuned model
|
| 56 |
+
│
|
| 57 |
+
├── train.py # Pre-training loop
|
| 58 |
+
├── plot_training.py # Training dashboard (static + live mode)
|
| 59 |
+
├── requirements.txt
|
| 60 |
+
├── model_explained.md # Deep-dive into every model component
|
| 61 |
+
└── tokenizer_walkthrough.md # Tokenizer design and pipeline walkthrough
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## 📐 Model Configs
|
| 67 |
+
|
| 68 |
+
| Config | d_model | Heads | Layers | Parameters |
|
| 69 |
+
|------------|---------|-------|--------|------------|
|
| 70 |
+
| `SLLM_100M` | 768 | 12 | 12 | ~109.5M |
|
| 71 |
+
| `SLLM_150M` | 1024 | 16 | 9 | ~148.4M |
|
| 72 |
+
|
| 73 |
+
Both configs use:
|
| 74 |
+
- Context length: **1024 tokens**
|
| 75 |
+
- Vocab size: **32,000** (custom BPE)
|
| 76 |
+
- SwiGLU d_ff: computed as `round_up_256(⌊2/3 × 4 × d_model⌋)`
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## ⚙️ Installation
|
| 81 |
+
|
| 82 |
+
**Requires:** Python 3.10+, PyTorch 2.3+, CUDA-capable GPU (bf16 recommended)
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
# Create and activate a conda environment
|
| 86 |
+
conda create -n pytorch python=3.11
|
| 87 |
+
conda activate pytorch
|
| 88 |
+
|
| 89 |
+
# Install dependencies
|
| 90 |
+
pip install -r requirements.txt
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## 🚀 Training
|
| 96 |
+
|
| 97 |
+
### Start a new run (RTX 3050 4GB recommended settings)
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
python train.py \
|
| 101 |
+
--config 150M \
|
| 102 |
+
--data_dir tokenizer/data \
|
| 103 |
+
--batch_size 2 \
|
| 104 |
+
--grad_accum 16 \
|
| 105 |
+
--grad_checkpoint \
|
| 106 |
+
--dtype bf16 \
|
| 107 |
+
--max_steps 5000 \
|
| 108 |
+
--run_dir runs/sllm_150m \
|
| 109 |
+
--log_every 10 \
|
| 110 |
+
--save_every 500 \
|
| 111 |
+
--val_every 500 \
|
| 112 |
+
--warmup_steps 200
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### Resume from a checkpoint
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
python train.py \
|
| 119 |
+
--resume \
|
| 120 |
+
--run_dir runs/sllm_150m \
|
| 121 |
+
--extra_steps 5000 \
|
| 122 |
+
--data_dir tokenizer/data \
|
| 123 |
+
--batch_size 2 \
|
| 124 |
+
--grad_accum 16 \
|
| 125 |
+
--grad_checkpoint \
|
| 126 |
+
--dtype bf16
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Key training flags
|
| 130 |
+
|
| 131 |
+
| Flag | Default | Description |
|
| 132 |
+
|------|---------|-------------|
|
| 133 |
+
| `--config` | `100M` | Model size (`100M` or `150M`) |
|
| 134 |
+
| `--batch_size` | `4` | Per-device micro-batch size |
|
| 135 |
+
| `--grad_accum` | `8` | Gradient accumulation steps |
|
| 136 |
+
| `--max_steps` | unlimited | Absolute step target |
|
| 137 |
+
| `--extra_steps` | — | Run N more steps from current checkpoint |
|
| 138 |
+
| `--resume` | — | Resume from latest checkpoint in `--run_dir` |
|
| 139 |
+
| `--grad_checkpoint` | — | Enable gradient checkpointing (saves VRAM) |
|
| 140 |
+
| `--dtype` | `bf16` | Mixed precision dtype (`fp32`, `fp16`, `bf16`) |
|
| 141 |
+
| `--synthetic` | — | Use random data (for testing without real shards) |
|
| 142 |
+
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
## 📊 Training Dashboard
|
| 146 |
+
|
| 147 |
+
Visualize training metrics in a dark-mode 6-panel dashboard:
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
# Static plot
|
| 151 |
+
python plot_training.py --run_dir runs/sllm_150m
|
| 152 |
+
|
| 153 |
+
# Live mode — refresh every 30 seconds while training
|
| 154 |
+
python plot_training.py --run_dir runs/sllm_150m --live --interval 30
|
| 155 |
+
|
| 156 |
+
# Compare two runs
|
| 157 |
+
python plot_training.py --run_dir runs/run_a runs/run_b
|
| 158 |
+
|
| 159 |
+
# Save to file
|
| 160 |
+
python plot_training.py --run_dir runs/sllm_150m --save dashboard.png
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
**Dashboard panels:** Training Loss (raw + EMA) · Validation Loss · Learning Rate · Tokens/sec · VRAM usage · Gradient norm
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
|
| 167 |
+
## 💬 Fine-Tuning (Chat Model)
|
| 168 |
+
|
| 169 |
+
After pre-training, you can fine-tune with supervised instruction data:
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
# 1. Prepare chat data
|
| 173 |
+
python finetune/prepare_data.py
|
| 174 |
+
|
| 175 |
+
# 2. Fine-tune
|
| 176 |
+
python finetune/sft_train.py \
|
| 177 |
+
--base_ckpt runs/sllm_150m/ckpt_0011500.pt \
|
| 178 |
+
--run_dir runs/sllm_150m_chat \
|
| 179 |
+
--max_steps 2500 \
|
| 180 |
+
--batch_size 4 \
|
| 181 |
+
--grad_accum 8 \
|
| 182 |
+
--grad_checkpoint
|
| 183 |
+
|
| 184 |
+
# 3. Chat interactively
|
| 185 |
+
python finetune/chat.py --run_dir runs/sllm_150m_chat
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## 🔡 Tokenizer
|
| 191 |
+
|
| 192 |
+
A custom BPE tokenizer trained on the educational subset of [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu):
|
| 193 |
+
|
| 194 |
+
- **32,000 token vocabulary**
|
| 195 |
+
- **Byte fallback** — zero out-of-vocabulary tokens (even math symbols and emojis work)
|
| 196 |
+
- **Code-aware** — preserves `snake_case`, operators (`==`, `->`, `**`), and indentation
|
| 197 |
+
- **Contraction-aware** — `don't`, `I've`, `they're` are split correctly
|
| 198 |
+
- Packaged as a `PreTrainedTokenizerFast` (HuggingFace-compatible)
|
| 199 |
+
|
| 200 |
+
Training data is packed into flat binary `.bin` shards (`np.uint16`, 100M tokens each) for fast memory-mapped loading.
|
| 201 |
+
|
| 202 |
+
See [`tokenizer_walkthrough.md`](tokenizer_walkthrough.md) for a full pipeline deep-dive.
|
| 203 |
+
|
| 204 |
+
---
|
| 205 |
+
|
| 206 |
+
## 🧠 Architecture Deep-Dive
|
| 207 |
+
|
| 208 |
+
See [`model_explained.md`](model_explained.md) for a plain-language walkthrough of every model component, including:
|
| 209 |
+
- Why RMSNorm is faster than LayerNorm
|
| 210 |
+
- How RoPE encodes relative position without extra parameters
|
| 211 |
+
- Why SwiGLU outperforms GELU
|
| 212 |
+
- How weight tying saves 32M parameters
|
| 213 |
+
- Flash Attention and gradient checkpointing explained
|
| 214 |
+
|
| 215 |
+
---
|
| 216 |
+
|
| 217 |
+
## 📋 Checkpoints & Logging
|
| 218 |
+
|
| 219 |
+
- Checkpoints are saved to `<run_dir>/ckpt_NNNNNNN.pt` every `--save_every` steps and on clean exit (Ctrl+C)
|
| 220 |
+
- Metrics are appended to `<run_dir>/train_log.jsonl` (one JSON line per log step)
|
| 221 |
+
- Each checkpoint stores: model weights, optimizer state, step number, loss, and config name
|
| 222 |
+
- Resuming auto-detects the correct model config from the checkpoint
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## 📦 Requirements
|
| 227 |
+
|
| 228 |
+
```
|
| 229 |
+
torch>=2.3.0
|
| 230 |
+
datasets>=2.14.0 # HuggingFace datasets (streaming)
|
| 231 |
+
tokenizers>=0.15.0 # Fast BPE tokenizer
|
| 232 |
+
transformers>=4.40.0 # PreTrainedTokenizerFast
|
| 233 |
+
numpy>=1.26.0
|
| 234 |
+
tqdm
|
| 235 |
+
matplotlib
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
---
|
| 239 |
+
|
| 240 |
+
## 📄 License
|
| 241 |
+
|
| 242 |
+
This project is released for educational purposes.
|
data/dataloader.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data/dataloader.py
|
| 3 |
+
|
| 4 |
+
Streaming dataloader for the pre-tokenized binary shards produced by
|
| 5 |
+
tokenizer/tokenize_dataset.py.
|
| 6 |
+
|
| 7 |
+
Each shard is a flat binary file of np.uint16 token IDs.
|
| 8 |
+
100M tokens * 2 bytes = ~200MB per shard.
|
| 9 |
+
|
| 10 |
+
Strategy:
|
| 11 |
+
1. Discover all shards matching split name (train/val).
|
| 12 |
+
2. Shuffle shard order at start of each epoch.
|
| 13 |
+
3. For each shard, load it (memmap or full) and yield non-overlapping
|
| 14 |
+
chunks of (context_length + 1) tokens.
|
| 15 |
+
4. Inputs = chunk[:-1] (length context_length)
|
| 16 |
+
Targets = chunk[1:] (length context_length, shifted right by 1)
|
| 17 |
+
|
| 18 |
+
When no data shards exist yet (tokenization not done), a SyntheticShard
|
| 19 |
+
can be used for architecture testing.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import glob
|
| 24 |
+
import random
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ------------------------------------------------------------------ #
|
| 31 |
+
# SHARD DISCOVERY
|
| 32 |
+
# ------------------------------------------------------------------ #
|
| 33 |
+
|
| 34 |
+
def find_shards(data_dir: str, split: str) -> list[str]:
|
| 35 |
+
"""
|
| 36 |
+
Returns sorted list of shard paths for the given split.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
data_dir : directory containing .bin shard files
|
| 40 |
+
split : 'train' or 'val'
|
| 41 |
+
"""
|
| 42 |
+
pattern = os.path.join(data_dir, f"{split}_*.bin")
|
| 43 |
+
shards = sorted(glob.glob(pattern))
|
| 44 |
+
return shards
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ------------------------------------------------------------------ #
|
| 48 |
+
# ITERABLE DATASET
|
| 49 |
+
# ------------------------------------------------------------------ #
|
| 50 |
+
|
| 51 |
+
class ShardedTokenDataset(IterableDataset):
|
| 52 |
+
"""
|
| 53 |
+
IterableDataset that streams token chunks from binary shards.
|
| 54 |
+
|
| 55 |
+
Each worker processes a disjoint subset of shards so we get
|
| 56 |
+
proper parallelism with DataLoader(num_workers=N).
|
| 57 |
+
|
| 58 |
+
Usage:
|
| 59 |
+
dataset = ShardedTokenDataset(data_dir, split='train', context_length=1024)
|
| 60 |
+
loader = DataLoader(dataset, batch_size=4)
|
| 61 |
+
for input_ids, targets in loader:
|
| 62 |
+
...
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
data_dir: str,
|
| 68 |
+
split: str,
|
| 69 |
+
context_length: int,
|
| 70 |
+
shuffle_shards: bool = True,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
data_dir : path to directory with .bin shard files
|
| 75 |
+
split : 'train' or 'val'
|
| 76 |
+
context_length : sequence length (model context length)
|
| 77 |
+
shuffle_shards : shuffle shard order each epoch (train only)
|
| 78 |
+
"""
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.context_length = context_length
|
| 81 |
+
self.shuffle_shards = shuffle_shards
|
| 82 |
+
|
| 83 |
+
self.shards = find_shards(data_dir, split)
|
| 84 |
+
if not self.shards:
|
| 85 |
+
raise FileNotFoundError(
|
| 86 |
+
f"No {split} shards found in {data_dir}.\n"
|
| 87 |
+
f"Run tokenizer/tokenize_dataset.py first to generate data."
|
| 88 |
+
)
|
| 89 |
+
print(f"[DataLoader] Found {len(self.shards)} {split} shards in {data_dir}")
|
| 90 |
+
|
| 91 |
+
def __iter__(self):
|
| 92 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 93 |
+
|
| 94 |
+
shards = self.shards.copy()
|
| 95 |
+
if self.shuffle_shards:
|
| 96 |
+
random.shuffle(shards)
|
| 97 |
+
|
| 98 |
+
# Split shards across workers
|
| 99 |
+
if worker_info is not None:
|
| 100 |
+
shards = shards[worker_info.id :: worker_info.num_workers]
|
| 101 |
+
|
| 102 |
+
chunk = self.context_length + 1 # +1 so we can shift for targets
|
| 103 |
+
|
| 104 |
+
for shard_path in shards:
|
| 105 |
+
# Load shard as uint16 array
|
| 106 |
+
tokens = np.fromfile(shard_path, dtype=np.uint16).astype(np.int32)
|
| 107 |
+
|
| 108 |
+
# Yield non-overlapping chunks
|
| 109 |
+
n_chunks = len(tokens) // chunk
|
| 110 |
+
for i in range(n_chunks):
|
| 111 |
+
start = i * chunk
|
| 112 |
+
seq = torch.from_numpy(tokens[start : start + chunk].copy())
|
| 113 |
+
input_ids = seq[:-1].long() # (context_length,)
|
| 114 |
+
targets = seq[1:].long() # (context_length,)
|
| 115 |
+
yield input_ids, targets
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ------------------------------------------------------------------ #
|
| 119 |
+
# SYNTHETIC DATASET (for testing without real data)
|
| 120 |
+
# ------------------------------------------------------------------ #
|
| 121 |
+
|
| 122 |
+
class SyntheticDataset(IterableDataset):
|
| 123 |
+
"""
|
| 124 |
+
Generates random token sequences for architecture testing.
|
| 125 |
+
Use when real shards are not yet available.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, vocab_size: int, context_length: int, n_batches: int = 1000):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.vocab_size = vocab_size
|
| 131 |
+
self.context_length = context_length
|
| 132 |
+
self.n_batches = n_batches
|
| 133 |
+
|
| 134 |
+
def __iter__(self):
|
| 135 |
+
for _ in range(self.n_batches):
|
| 136 |
+
seq = torch.randint(0, self.vocab_size, (self.context_length + 1,))
|
| 137 |
+
input_ids = seq[:-1]
|
| 138 |
+
targets = seq[1:]
|
| 139 |
+
yield input_ids, targets
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ------------------------------------------------------------------ #
|
| 143 |
+
# FACTORY FUNCTION
|
| 144 |
+
# ------------------------------------------------------------------ #
|
| 145 |
+
|
| 146 |
+
def build_dataloader(
|
| 147 |
+
data_dir: str,
|
| 148 |
+
split: str,
|
| 149 |
+
context_length: int,
|
| 150 |
+
batch_size: int,
|
| 151 |
+
num_workers: int = 2,
|
| 152 |
+
use_synthetic: bool = False,
|
| 153 |
+
vocab_size: int = 32_000,
|
| 154 |
+
) -> DataLoader:
|
| 155 |
+
"""
|
| 156 |
+
Builds and returns a DataLoader for the given split.
|
| 157 |
+
|
| 158 |
+
Falls back to SyntheticDataset if use_synthetic=True or no shards found.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
data_dir : directory with .bin shards
|
| 162 |
+
split : 'train' or 'val'
|
| 163 |
+
context_length : model context length (1024)
|
| 164 |
+
batch_size : number of sequences per batch
|
| 165 |
+
num_workers : DataLoader workers (0 = main process)
|
| 166 |
+
use_synthetic : force synthetic data (for testing)
|
| 167 |
+
vocab_size : needed for synthetic fallback
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
DataLoader yielding (input_ids, targets) each of shape (B, T)
|
| 171 |
+
"""
|
| 172 |
+
if use_synthetic:
|
| 173 |
+
dataset = SyntheticDataset(vocab_size, context_length)
|
| 174 |
+
print(f"[DataLoader] Using synthetic data (use_synthetic=True)")
|
| 175 |
+
else:
|
| 176 |
+
try:
|
| 177 |
+
dataset = ShardedTokenDataset(
|
| 178 |
+
data_dir = data_dir,
|
| 179 |
+
split = split,
|
| 180 |
+
context_length = context_length,
|
| 181 |
+
shuffle_shards = (split == "train"),
|
| 182 |
+
)
|
| 183 |
+
except FileNotFoundError as e:
|
| 184 |
+
print(f"[DataLoader] WARNING: {e}")
|
| 185 |
+
print(f"[DataLoader] Falling back to synthetic data for testing.")
|
| 186 |
+
dataset = SyntheticDataset(vocab_size, context_length)
|
| 187 |
+
|
| 188 |
+
return DataLoader(
|
| 189 |
+
dataset,
|
| 190 |
+
batch_size = batch_size,
|
| 191 |
+
num_workers = num_workers,
|
| 192 |
+
pin_memory = True, # faster CPU->GPU transfer
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ------------------------------------------------------------------ #
|
| 197 |
+
# QUICK CHECK
|
| 198 |
+
# ------------------------------------------------------------------ #
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
import sys
|
| 202 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 203 |
+
from model.config import SLLM_100M
|
| 204 |
+
|
| 205 |
+
cfg = SLLM_100M
|
| 206 |
+
|
| 207 |
+
print("Testing with synthetic data...")
|
| 208 |
+
loader = build_dataloader(
|
| 209 |
+
data_dir = "tokenizer/data",
|
| 210 |
+
split = "train",
|
| 211 |
+
context_length = cfg.context_length,
|
| 212 |
+
batch_size = 4,
|
| 213 |
+
num_workers = 0,
|
| 214 |
+
use_synthetic = True,
|
| 215 |
+
vocab_size = cfg.vocab_size,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
for i, (x, y) in enumerate(loader):
|
| 219 |
+
print(f"Batch {i}: input_ids={x.shape}, targets={y.shape}, dtype={x.dtype}")
|
| 220 |
+
if i == 3:
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
print("DataLoader OK")
|
finetune/README.md
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SLLM-150M → Chat Model (SFT)
|
| 2 |
+
|
| 3 |
+
Supervised Fine-Tuning pipeline to turn the pretrained **SLLM-150M** base model into
|
| 4 |
+
an instruction-following chat model using **OpenHermes-2.5**.
|
| 5 |
+
|
| 6 |
+
## Pipeline
|
| 7 |
+
|
| 8 |
+
```
|
| 9 |
+
Base model (runs/sllm_150m/ckpt_0011500.pt)
|
| 10 |
+
│
|
| 11 |
+
▼
|
| 12 |
+
prepare_data.py ─── download & tokenize OpenHermes-2.5 (80k convs)
|
| 13 |
+
│
|
| 14 |
+
▼
|
| 15 |
+
sft_train.py ─── SFT with ChatML loss masking
|
| 16 |
+
│
|
| 17 |
+
▼
|
| 18 |
+
chat.py ─── interactive CLI chat
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Step 1 — Install dependency
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
pip install datasets
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Step 2 — Prepare data
|
| 28 |
+
|
| 29 |
+
Downloads 80k conversations, formats as ChatML, tokenizes, saves shards.
|
| 30 |
+
Also saves the extended tokenizer (vocab 32,002) to `finetune/data/`.
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
python finetune/prepare_data.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Options:
|
| 37 |
+
|
| 38 |
+
| Flag | Default | Description |
|
| 39 |
+
|------|---------|-------------|
|
| 40 |
+
| `--n_samples` | `80000` | Conversations to sample |
|
| 41 |
+
| `--val_ratio` | `0.05` | Validation fraction |
|
| 42 |
+
| `--output_dir` | `finetune/data` | Output directory |
|
| 43 |
+
| `--seed` | `42` | Random seed |
|
| 44 |
+
|
| 45 |
+
Expected output:
|
| 46 |
+
```
|
| 47 |
+
finetune/data/
|
| 48 |
+
tokenizer.json ← extended tokenizer (32,002 vocab)
|
| 49 |
+
tokenizer_config.json
|
| 50 |
+
special_tokens_map.json
|
| 51 |
+
train_sft.pt ← ~76,000 examples
|
| 52 |
+
val_sft.pt ← ~4,000 examples
|
| 53 |
+
meta.json ← stats
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Step 3 — Fine-tune
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
python finetune/sft_train.py \
|
| 60 |
+
--base_ckpt runs/sllm_150m/ckpt_0011500.pt \
|
| 61 |
+
--run_dir runs/sllm_150m_chat \
|
| 62 |
+
--max_steps 2000 \
|
| 63 |
+
--batch_size 4 --grad_accum 8 \
|
| 64 |
+
--grad_checkpoint
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
For an RTX 3050 4 GB, these settings use ~3.5 GB VRAM and take **~5–8 minutes**.
|
| 68 |
+
|
| 69 |
+
**Resume training:**
|
| 70 |
+
```bash
|
| 71 |
+
python finetune/sft_train.py \
|
| 72 |
+
--resume --run_dir runs/sllm_150m_chat \
|
| 73 |
+
--extra_steps 1000
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Key options:
|
| 77 |
+
|
| 78 |
+
| Flag | Default | Description |
|
| 79 |
+
|------|---------|-------------|
|
| 80 |
+
| `--base_ckpt` | `runs/sllm_150m/ckpt_0011500.pt` | Base pretrained checkpoint |
|
| 81 |
+
| `--max_lr` | `1e-5` | Peak LR (10× lower than pretraining) |
|
| 82 |
+
| `--dropout` | `0.1` | SFT dropout (0 in pretraining) |
|
| 83 |
+
| `--max_steps` | `2000` | Total training steps |
|
| 84 |
+
| `--grad_checkpoint` | off | Enable for lower VRAM |
|
| 85 |
+
|
| 86 |
+
Checkpoints are saved to `runs/sllm_150m_chat/ckpt_sft_XXXXXXX.pt`.
|
| 87 |
+
Training log: `runs/sllm_150m_chat/sft_log.jsonl`.
|
| 88 |
+
|
| 89 |
+
## Step 4 — Chat
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
python finetune/chat.py
|
| 93 |
+
python finetune/chat.py --run_dir runs/sllm_150m_chat --temperature 0.7
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
In-chat commands:
|
| 97 |
+
|
| 98 |
+
| Command | Effect |
|
| 99 |
+
|---------|--------|
|
| 100 |
+
| `/reset` | Clear conversation history |
|
| 101 |
+
| `/system <text>` | Change system prompt |
|
| 102 |
+
| `/quit` | Exit |
|
| 103 |
+
|
| 104 |
+
## What changes vs pretraining
|
| 105 |
+
|
| 106 |
+
| | Pretraining (`train.py`) | SFT (`sft_train.py`) |
|
| 107 |
+
|---|---|---|
|
| 108 |
+
| Data | Raw text shards (`.bin`) | ChatML conversations (`.pt`) |
|
| 109 |
+
| Loss | Every token | **Assistant tokens only** (`ignore_index=-100`) |
|
| 110 |
+
| Learning rate | `3e-4` | **`1e-5`** |
|
| 111 |
+
| Warmup | 100 steps | 30 steps |
|
| 112 |
+
| Vocab | 32,000 | **32,002** (`<\|im_start\|>` + `<\|im_end\|>`) |
|
| 113 |
+
| Dropout | 0.0 | **0.1** |
|
| 114 |
+
| Checkpoint prefix | `ckpt_` | `ckpt_sft_` |
|
| 115 |
+
|
| 116 |
+
## Expected loss curve
|
| 117 |
+
|
| 118 |
+
| Stage | Expected loss |
|
| 119 |
+
|-------|--------------|
|
| 120 |
+
| Start (step 0) | 1.5 – 2.5 |
|
| 121 |
+
| Step 500 | 1.0 – 1.5 |
|
| 122 |
+
| Step 2000 | 0.8 – 1.2 |
|
| 123 |
+
|
| 124 |
+
> **If loss starts above 4.0 or goes NaN** → reduce `--max_lr` to `5e-6`.
|
| 125 |
+
|
| 126 |
+
## Prompt format (ChatML)
|
| 127 |
+
|
| 128 |
+
```
|
| 129 |
+
<|im_start|>system
|
| 130 |
+
You are a helpful, concise assistant.<|im_end|>
|
| 131 |
+
<|im_start|>user
|
| 132 |
+
What is the capital of France?<|im_end|>
|
| 133 |
+
<|im_start|>assistant
|
| 134 |
+
The capital of France is Paris.<|im_end|>
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Generation stops automatically when the model produces `<|im_end|>`.
|
finetune/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# finetune package
|
finetune/chat.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune/chat.py
|
| 3 |
+
|
| 4 |
+
Interactive CLI chat with the fine-tuned SLLM-150M chat model.
|
| 5 |
+
|
| 6 |
+
Loads the latest SFT checkpoint from --run_dir, formats your input
|
| 7 |
+
as a ChatML prompt, generates a response token-by-token, and stops
|
| 8 |
+
at the <|im_end|> token.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python finetune/chat.py
|
| 12 |
+
python finetune/chat.py --run_dir runs/sllm_150m_chat
|
| 13 |
+
python finetune/chat.py --temperature 0.7 --top_k 40
|
| 14 |
+
|
| 15 |
+
In-chat commands:
|
| 16 |
+
/reset clear conversation history (start fresh)
|
| 17 |
+
/system <text> change the system prompt
|
| 18 |
+
/quit exit
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import argparse
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from transformers import PreTrainedTokenizerFast
|
| 29 |
+
|
| 30 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 31 |
+
PROJECT_ROOT = SCRIPT_DIR.parent
|
| 32 |
+
DATA_DIR = SCRIPT_DIR / "data"
|
| 33 |
+
|
| 34 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 35 |
+
|
| 36 |
+
from model.config import SLLM_150M
|
| 37 |
+
from model.model import SLLM
|
| 38 |
+
|
| 39 |
+
DEFAULT_SYSTEM = "You are a helpful, concise assistant."
|
| 40 |
+
DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ------------------------------------------------------------------ #
|
| 44 |
+
# HELPERS
|
| 45 |
+
# ------------------------------------------------------------------ #
|
| 46 |
+
|
| 47 |
+
def find_latest_ckpt(run_dir: str) -> str:
|
| 48 |
+
"""Returns path to the most recent ckpt_sft_*.pt in run_dir."""
|
| 49 |
+
ckpts = sorted([
|
| 50 |
+
f for f in os.listdir(run_dir)
|
| 51 |
+
if f.startswith("ckpt_sft_") and f.endswith(".pt")
|
| 52 |
+
])
|
| 53 |
+
if not ckpts:
|
| 54 |
+
raise FileNotFoundError(
|
| 55 |
+
f"No SFT checkpoints found in '{run_dir}'.\n"
|
| 56 |
+
f"Run sft_train.py first."
|
| 57 |
+
)
|
| 58 |
+
return os.path.join(run_dir, ckpts[-1])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def resize_token_embeddings(model: SLLM, new_vocab_size: int):
|
| 62 |
+
"""Same resize logic as sft_train.py — kept local to avoid circular imports."""
|
| 63 |
+
old_size = model.config.vocab_size
|
| 64 |
+
if new_vocab_size == old_size:
|
| 65 |
+
return
|
| 66 |
+
d_model = model.config.d_model
|
| 67 |
+
device = model.token_emb.weight.device
|
| 68 |
+
dtype = model.token_emb.weight.dtype
|
| 69 |
+
old_weight = model.token_emb.weight.data.clone()
|
| 70 |
+
mean_vec = old_weight.mean(dim=0)
|
| 71 |
+
new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
|
| 72 |
+
new_weight[:old_size] = old_weight
|
| 73 |
+
new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
|
| 74 |
+
new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
|
| 75 |
+
new_emb.weight.data = new_weight
|
| 76 |
+
model.token_emb = new_emb
|
| 77 |
+
model.lm_head.weight = model.token_emb.weight
|
| 78 |
+
model.config.vocab_size = new_vocab_size
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_model_and_tokenizer(run_dir: str, device: torch.device):
|
| 82 |
+
"""Loads tokenizer (from data dir) and fine-tuned model (from run_dir)."""
|
| 83 |
+
|
| 84 |
+
# ---- Tokenizer ------------------------------------------------- #
|
| 85 |
+
tok_path = str(DATA_DIR)
|
| 86 |
+
if os.path.exists(os.path.join(tok_path, "tokenizer.json")):
|
| 87 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
|
| 88 |
+
else:
|
| 89 |
+
# Fallback: base tokenizer + manual special token add
|
| 90 |
+
base_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer")
|
| 91 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_dir)
|
| 92 |
+
tokenizer.add_special_tokens({
|
| 93 |
+
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
# ---- Checkpoint ------------------------------------------------ #
|
| 97 |
+
ckpt_path = find_latest_ckpt(run_dir)
|
| 98 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 99 |
+
|
| 100 |
+
# ---- Model ----------------------------------------------------- #
|
| 101 |
+
model = SLLM(SLLM_150M).to(device)
|
| 102 |
+
saved_vocab = ckpt.get("vocab_size", len(tokenizer))
|
| 103 |
+
resize_token_embeddings(model, saved_vocab)
|
| 104 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 105 |
+
model.eval()
|
| 106 |
+
|
| 107 |
+
return model, tokenizer, ckpt_path, ckpt.get("step", "?"), ckpt.get("loss", float("nan"))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------------ #
|
| 111 |
+
# PROMPT BUILDING
|
| 112 |
+
# ------------------------------------------------------------------ #
|
| 113 |
+
|
| 114 |
+
def build_prompt(history: list[dict], system_prompt: str,
|
| 115 |
+
tokenizer: PreTrainedTokenizerFast) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Formats conversation history as ChatML and tokenises it.
|
| 118 |
+
|
| 119 |
+
Template:
|
| 120 |
+
<|im_start|>system
|
| 121 |
+
{system}<|im_end|>
|
| 122 |
+
<|im_start|>user
|
| 123 |
+
{user}<|im_end|>
|
| 124 |
+
<|im_start|>assistant
|
| 125 |
+
{assistant}<|im_end|>
|
| 126 |
+
...
|
| 127 |
+
<|im_start|>assistant\\n ← left open for the model to complete
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
input_ids : (1, T) LongTensor
|
| 131 |
+
"""
|
| 132 |
+
text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
| 133 |
+
for turn in history:
|
| 134 |
+
text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
|
| 135 |
+
# Prime the model to generate as assistant
|
| 136 |
+
text += "<|im_start|>assistant\n"
|
| 137 |
+
|
| 138 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 139 |
+
return torch.tensor([ids], dtype=torch.long)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ------------------------------------------------------------------ #
|
| 143 |
+
# GENERATION
|
| 144 |
+
# ------------------------------------------------------------------ #
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def generate_response(
|
| 148 |
+
model: SLLM,
|
| 149 |
+
input_ids: torch.Tensor,
|
| 150 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 151 |
+
max_new_tokens: int = 300,
|
| 152 |
+
temperature: float = 0.8,
|
| 153 |
+
top_k: int = 50,
|
| 154 |
+
device: torch.device = None,
|
| 155 |
+
) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Autoregressively generates tokens until:
|
| 158 |
+
- <|im_end|> is produced (clean stop), or
|
| 159 |
+
- eos_token_id is produced, or
|
| 160 |
+
- max_new_tokens is reached
|
| 161 |
+
|
| 162 |
+
Returns the decoded response string (special tokens stripped).
|
| 163 |
+
"""
|
| 164 |
+
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 165 |
+
eos_id = tokenizer.eos_token_id
|
| 166 |
+
|
| 167 |
+
ids = input_ids.to(device)
|
| 168 |
+
generated = []
|
| 169 |
+
|
| 170 |
+
for _ in range(max_new_tokens):
|
| 171 |
+
# Crop to context window
|
| 172 |
+
ctx = ids if ids.shape[1] <= model.config.context_length \
|
| 173 |
+
else ids[:, -model.config.context_length:]
|
| 174 |
+
|
| 175 |
+
logits, _ = model(ctx) # (1, T, V)
|
| 176 |
+
logits = logits[:, -1, :] / max(temperature, 1e-8)
|
| 177 |
+
|
| 178 |
+
# Top-k filtering
|
| 179 |
+
if top_k and top_k > 0:
|
| 180 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 181 |
+
logits[logits < v[:, [-1]]] = float("-inf")
|
| 182 |
+
|
| 183 |
+
probs = torch.softmax(logits, dim=-1)
|
| 184 |
+
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
|
| 185 |
+
tok_id = next_token.item()
|
| 186 |
+
|
| 187 |
+
# Stop conditions
|
| 188 |
+
if tok_id == im_end_id or tok_id == eos_id:
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
generated.append(tok_id)
|
| 192 |
+
ids = torch.cat([ids, next_token], dim=1)
|
| 193 |
+
|
| 194 |
+
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# ------------------------------------------------------------------ #
|
| 198 |
+
# MAIN
|
| 199 |
+
# ------------------------------------------------------------------ #
|
| 200 |
+
|
| 201 |
+
def parse_args():
|
| 202 |
+
p = argparse.ArgumentParser(description="SLLM-150M Chat")
|
| 203 |
+
p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR)
|
| 204 |
+
p.add_argument("--temperature", type=float, default=0.8,
|
| 205 |
+
help="Sampling temperature (lower = more focused)")
|
| 206 |
+
p.add_argument("--top_k", type=int, default=50,
|
| 207 |
+
help="Top-k sampling (0 = disabled)")
|
| 208 |
+
p.add_argument("--max_new_tokens", type=int, default=300,
|
| 209 |
+
help="Max tokens per assistant response")
|
| 210 |
+
p.add_argument("--system", type=str, default=DEFAULT_SYSTEM,
|
| 211 |
+
help="System prompt")
|
| 212 |
+
return p.parse_args()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main():
|
| 216 |
+
args = parse_args()
|
| 217 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 218 |
+
|
| 219 |
+
print("\n" + "=" * 60)
|
| 220 |
+
print(" SLLM-150M Chat")
|
| 221 |
+
print("=" * 60)
|
| 222 |
+
print(f" Device : {device}")
|
| 223 |
+
if device.type == "cuda":
|
| 224 |
+
print(f" GPU : {torch.cuda.get_device_name(0)}")
|
| 225 |
+
|
| 226 |
+
# ---- Load ------------------------------------------------------ #
|
| 227 |
+
print("\nLoading model...")
|
| 228 |
+
model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device)
|
| 229 |
+
print(f" Checkpoint : {ckpt_path}")
|
| 230 |
+
print(f" Step : {step} Loss: {loss:.4f}")
|
| 231 |
+
print(f" Vocab size : {len(tokenizer):,}")
|
| 232 |
+
|
| 233 |
+
# ---- Chat loop ------------------------------------------------- #
|
| 234 |
+
system_prompt = args.system
|
| 235 |
+
history: list[dict] = []
|
| 236 |
+
|
| 237 |
+
print(f"\n System : {system_prompt}")
|
| 238 |
+
print(" Commands: /reset | /system <new prompt> | /quit")
|
| 239 |
+
print("─" * 60 + "\n")
|
| 240 |
+
|
| 241 |
+
while True:
|
| 242 |
+
try:
|
| 243 |
+
user_input = input("You: ").strip()
|
| 244 |
+
except (EOFError, KeyboardInterrupt):
|
| 245 |
+
print("\nBye!")
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
if not user_input:
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
# ---- Commands ---------------------------------------------- #
|
| 252 |
+
if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
|
| 253 |
+
print("Bye!")
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
if user_input.lower() == "/reset":
|
| 257 |
+
history = []
|
| 258 |
+
print(" [Conversation cleared]\n")
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
if user_input.lower().startswith("/system "):
|
| 262 |
+
new_sys = user_input[8:].strip()
|
| 263 |
+
if new_sys:
|
| 264 |
+
system_prompt = new_sys
|
| 265 |
+
history = []
|
| 266 |
+
print(f" [System prompt updated. Conversation cleared.]\n")
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
# ---- Build prompt ------------------------------------------ #
|
| 270 |
+
history.append({"role": "user", "content": user_input})
|
| 271 |
+
input_ids = build_prompt(history, system_prompt, tokenizer)
|
| 272 |
+
|
| 273 |
+
# Trim history if prompt is getting close to context limit
|
| 274 |
+
while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10:
|
| 275 |
+
if len(history) > 2:
|
| 276 |
+
history = history[2:] # drop oldest user+assistant pair
|
| 277 |
+
input_ids = build_prompt(history, system_prompt, tokenizer)
|
| 278 |
+
else:
|
| 279 |
+
break # can't trim further — just truncate in generation
|
| 280 |
+
|
| 281 |
+
# ---- Generate ---------------------------------------------- #
|
| 282 |
+
print("SLLM: ", end="", flush=True)
|
| 283 |
+
response = generate_response(
|
| 284 |
+
model, input_ids, tokenizer,
|
| 285 |
+
max_new_tokens = args.max_new_tokens,
|
| 286 |
+
temperature = args.temperature,
|
| 287 |
+
top_k = args.top_k,
|
| 288 |
+
device = device,
|
| 289 |
+
)
|
| 290 |
+
print(response + "\n")
|
| 291 |
+
|
| 292 |
+
history.append({"role": "assistant", "content": response})
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
finetune/check_data.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune/check_data.py
|
| 3 |
+
|
| 4 |
+
Smoke-test: loads 5 rows from OpenHermes-2.5, runs them through the
|
| 5 |
+
same format_and_tokenize() logic used by prepare_data.py, and prints
|
| 6 |
+
a full visual audit so you can confirm everything lines up.
|
| 7 |
+
|
| 8 |
+
Checks:
|
| 9 |
+
1. Raw conversation structure from the dataset
|
| 10 |
+
2. ChatML text that gets fed to the tokenizer
|
| 11 |
+
3. Token IDs and decoded tokens (side-by-side)
|
| 12 |
+
4. Label mask — ✓ (labeled) vs (masked -100) for every token
|
| 13 |
+
5. Label ratio (should be ~30-60% assistant tokens)
|
| 14 |
+
|
| 15 |
+
Run from project root:
|
| 16 |
+
python finetune/check_data.py
|
| 17 |
+
python finetune/check_data.py --row 3 # inspect a specific row index
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
import argparse
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# ------------------------------------------------------------------ #
|
| 25 |
+
# Paths
|
| 26 |
+
# ------------------------------------------------------------------ #
|
| 27 |
+
|
| 28 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 29 |
+
PROJECT_ROOT = SCRIPT_DIR.parent
|
| 30 |
+
TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
|
| 31 |
+
|
| 32 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 33 |
+
|
| 34 |
+
from transformers import PreTrainedTokenizerFast
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
|
| 37 |
+
SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"]
|
| 38 |
+
MAX_LENGTH = 1024
|
| 39 |
+
|
| 40 |
+
ROLE_MAP = {
|
| 41 |
+
"system": "system",
|
| 42 |
+
"human": "user",
|
| 43 |
+
"gpt": "assistant",
|
| 44 |
+
"user": "user",
|
| 45 |
+
"assistant": "assistant",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ------------------------------------------------------------------ #
|
| 50 |
+
# Replicated from prepare_data.py (no import to keep this self-contained)
|
| 51 |
+
# ------------------------------------------------------------------ #
|
| 52 |
+
|
| 53 |
+
def load_tokenizer() -> PreTrainedTokenizerFast:
|
| 54 |
+
tok = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 55 |
+
new = [t for t in SPECIAL_TOKENS if t not in tok.get_vocab()]
|
| 56 |
+
if new:
|
| 57 |
+
tok.add_special_tokens({"additional_special_tokens": new})
|
| 58 |
+
return tok
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def format_and_tokenize(conversations, tokenizer):
|
| 62 |
+
"""Identical logic to prepare_data.py — returns (input_ids, labels) or None."""
|
| 63 |
+
input_ids, labels = [], []
|
| 64 |
+
|
| 65 |
+
for turn in conversations:
|
| 66 |
+
role_raw = turn.get("from", turn.get("role", "")).strip().lower()
|
| 67 |
+
content = turn.get("value", turn.get("content", "")).strip()
|
| 68 |
+
role = ROLE_MAP.get(role_raw, role_raw)
|
| 69 |
+
|
| 70 |
+
if not content or not role:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
header_text = f"<|im_start|>{role}\n"
|
| 74 |
+
header_ids = tokenizer.encode(header_text, add_special_tokens=False)
|
| 75 |
+
|
| 76 |
+
body_text = f"{content}<|im_end|>\n"
|
| 77 |
+
body_ids = tokenizer.encode(body_text, add_special_tokens=False)
|
| 78 |
+
|
| 79 |
+
turn_input = header_ids + body_ids
|
| 80 |
+
|
| 81 |
+
if role == "assistant":
|
| 82 |
+
turn_labels = [-100] * len(header_ids) + body_ids
|
| 83 |
+
else:
|
| 84 |
+
turn_labels = [-100] * len(turn_input)
|
| 85 |
+
|
| 86 |
+
input_ids.extend(turn_input)
|
| 87 |
+
labels.extend(turn_labels)
|
| 88 |
+
|
| 89 |
+
if not any(l != -100 for l in labels):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
input_ids = input_ids[:MAX_LENGTH]
|
| 93 |
+
labels = labels[:MAX_LENGTH]
|
| 94 |
+
|
| 95 |
+
if len(input_ids) < 8:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
return input_ids, labels
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ------------------------------------------------------------------ #
|
| 102 |
+
# Pretty-print helpers
|
| 103 |
+
# ------------------------------------------------------------------ #
|
| 104 |
+
|
| 105 |
+
def print_section(title: str):
|
| 106 |
+
print(f"\n{'─'*60}")
|
| 107 |
+
print(f" {title}")
|
| 108 |
+
print(f"{'─'*60}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def print_token_table(input_ids, labels, tokenizer, max_rows: int = 80):
|
| 112 |
+
"""
|
| 113 |
+
Prints a table: idx | token_str | label (✓ or ✗)
|
| 114 |
+
Green ✓ = labeled (assistant) — model learns this
|
| 115 |
+
Red ✗ = masked -100 — model ignores this
|
| 116 |
+
"""
|
| 117 |
+
GREEN = "\033[92m"
|
| 118 |
+
RED = "\033[91m"
|
| 119 |
+
RESET = "\033[0m"
|
| 120 |
+
|
| 121 |
+
print(f"\n {'IDX':>5} {'TOKEN':<22} {'ID':>6} {'LABEL':>8} {'LEARN?'}")
|
| 122 |
+
print(f" {'─'*5} {'─'*22} {'─'*6} {'─'*8} {'─'*6}")
|
| 123 |
+
|
| 124 |
+
shown = 0
|
| 125 |
+
for i, (tok_id, lbl) in enumerate(zip(input_ids, labels)):
|
| 126 |
+
tok_str = repr(tokenizer.decode([tok_id]))[:22]
|
| 127 |
+
if lbl == -100:
|
| 128 |
+
learn_str = f"{RED}✗ masked{RESET}"
|
| 129 |
+
lbl_str = " -100"
|
| 130 |
+
else:
|
| 131 |
+
learn_str = f"{GREEN}✓ learn {RESET}"
|
| 132 |
+
lbl_str = f"{lbl:>8}"
|
| 133 |
+
|
| 134 |
+
print(f" {i:>5} {tok_str:<22} {tok_id:>6} {lbl_str} {learn_str}")
|
| 135 |
+
shown += 1
|
| 136 |
+
if shown >= max_rows:
|
| 137 |
+
remaining = len(input_ids) - max_rows
|
| 138 |
+
print(f" ... ({remaining} more tokens not shown)")
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
# Summary
|
| 142 |
+
n_labeled = sum(1 for l in labels if l != -100)
|
| 143 |
+
n_total = len(labels)
|
| 144 |
+
print(f"\n Total tokens : {n_total}")
|
| 145 |
+
print(f" Labeled : {n_labeled} ({n_labeled/n_total:.1%}) ← assistant tokens")
|
| 146 |
+
print(f" Masked : {n_total - n_labeled} ({(n_total-n_labeled)/n_total:.1%}) ← user/system tokens")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ------------------------------------------------------------------ #
|
| 150 |
+
# MAIN
|
| 151 |
+
# ------------------------------------------------------------------ #
|
| 152 |
+
|
| 153 |
+
def parse_args():
|
| 154 |
+
p = argparse.ArgumentParser(description="Check one OpenHermes row through the SFT pipeline")
|
| 155 |
+
p.add_argument("--row", type=int, default=0,
|
| 156 |
+
help="Which row to inspect in detail (0-indexed, from the first 20 fetched)")
|
| 157 |
+
p.add_argument("--n_fetch", type=int, default=20,
|
| 158 |
+
help="How many rows to fetch from HuggingFace (default: 20)")
|
| 159 |
+
return p.parse_args()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def main():
|
| 163 |
+
args = parse_args()
|
| 164 |
+
|
| 165 |
+
print("\n" + "=" * 60)
|
| 166 |
+
print(" SFT Pipeline — Data Alignment Check")
|
| 167 |
+
print("=" * 60)
|
| 168 |
+
|
| 169 |
+
# ---- 1. Tokenizer ---------------------------------------------- #
|
| 170 |
+
print_section("1. Tokenizer")
|
| 171 |
+
tokenizer = load_tokenizer()
|
| 172 |
+
im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 173 |
+
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 174 |
+
print(f" Vocab size : {len(tokenizer):,}")
|
| 175 |
+
print(f" <|im_start|> : token ID {im_start_id}")
|
| 176 |
+
print(f" <|im_end|> : token ID {im_end_id}")
|
| 177 |
+
assert im_start_id != tokenizer.unk_token_id, "ERROR: <|im_start|> not in vocab!"
|
| 178 |
+
assert im_end_id != tokenizer.unk_token_id, "ERROR: <|im_end|> not in vocab!"
|
| 179 |
+
print(" ✓ Special tokens present in vocab")
|
| 180 |
+
|
| 181 |
+
# ---- 2. Load one row ------------------------------------------- #
|
| 182 |
+
print_section(f"2. Loading row {args.row} from OpenHermes-2.5")
|
| 183 |
+
print(f" Loading first {args.n_fetch} rows from local cache (Arrow format)...")
|
| 184 |
+
ds = load_dataset("teknium/OpenHermes-2.5", split="train")
|
| 185 |
+
row = ds[args.row]
|
| 186 |
+
convs = row.get("conversations", [])
|
| 187 |
+
|
| 188 |
+
print(f" Row index : {args.row}")
|
| 189 |
+
print(f" Turns in conv : {len(convs)}")
|
| 190 |
+
|
| 191 |
+
# ---- 3. Raw conversation --------------------------------------- #
|
| 192 |
+
print_section("3. Raw conversation (from dataset)")
|
| 193 |
+
for i, turn in enumerate(convs):
|
| 194 |
+
role = turn.get("from", "?")
|
| 195 |
+
content = turn.get("value", "").strip()
|
| 196 |
+
preview = content[:120].replace("\n", "↵")
|
| 197 |
+
print(f" [{i}] from={role!r:12s} | {preview!r}")
|
| 198 |
+
|
| 199 |
+
# ---- 4. ChatML formatted text ---------------------------------- #
|
| 200 |
+
print_section("4. ChatML text (what tokenizer sees)")
|
| 201 |
+
chatml = ""
|
| 202 |
+
for turn in convs:
|
| 203 |
+
role_raw = turn.get("from", "").strip().lower()
|
| 204 |
+
content = turn.get("value", "").strip()
|
| 205 |
+
role = ROLE_MAP.get(role_raw, role_raw)
|
| 206 |
+
if content and role:
|
| 207 |
+
chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
| 208 |
+
print(chatml[:800])
|
| 209 |
+
if len(chatml) > 800:
|
| 210 |
+
print(f" ... ({len(chatml) - 800} more chars)")
|
| 211 |
+
|
| 212 |
+
# ---- 5. Run through format_and_tokenize ----------------------- #
|
| 213 |
+
print_section("5. format_and_tokenize() output")
|
| 214 |
+
result = format_and_tokenize(convs, tokenizer)
|
| 215 |
+
|
| 216 |
+
if result is None:
|
| 217 |
+
print(" ✗ RETURNED None — no assistant turn or too short.")
|
| 218 |
+
print(" Try a different --row index.")
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
input_ids, labels = result
|
| 222 |
+
print(f" input_ids length : {len(input_ids)}")
|
| 223 |
+
print(f" labels length : {len(labels)}")
|
| 224 |
+
assert len(input_ids) == len(labels), "MISMATCH: input_ids and labels have different lengths!"
|
| 225 |
+
print(" ✓ Lengths match")
|
| 226 |
+
|
| 227 |
+
# ---- 6. Verify label alignment --------------------------------- #
|
| 228 |
+
print_section("6. Label alignment sanity checks")
|
| 229 |
+
|
| 230 |
+
# Every im_start should be masked
|
| 231 |
+
im_start_positions = [i for i, t in enumerate(input_ids) if t == im_start_id]
|
| 232 |
+
im_end_positions = [i for i, t in enumerate(input_ids) if t == im_end_id]
|
| 233 |
+
|
| 234 |
+
print(f" <|im_start|> positions : {im_start_positions}")
|
| 235 |
+
print(f" <|im_end|> positions : {im_end_positions}")
|
| 236 |
+
|
| 237 |
+
im_start_masked = all(labels[i] == -100 for i in im_start_positions)
|
| 238 |
+
print(f" All <|im_start|> tokens are masked (-100) : {'✓' if im_start_masked else '✗ FAIL'}")
|
| 239 |
+
|
| 240 |
+
# Decode the labeled span to confirm it's the assistant content
|
| 241 |
+
labeled_ids = [t for t, l in zip(input_ids, labels) if l != -100]
|
| 242 |
+
labeled_text = tokenizer.decode(labeled_ids, skip_special_tokens=False)
|
| 243 |
+
print(f"\n Labeled (assistant) text preview:")
|
| 244 |
+
print(f" {labeled_text[:300].replace(chr(10), '↵')!r}")
|
| 245 |
+
|
| 246 |
+
# Check that labeled text doesn't contain user/system markers
|
| 247 |
+
if "user\n" in labeled_text or "system\n" in labeled_text:
|
| 248 |
+
print(" ✗ WARNING: user/system content found in labeled tokens!")
|
| 249 |
+
else:
|
| 250 |
+
print(" ✓ Labeled tokens contain only assistant content")
|
| 251 |
+
|
| 252 |
+
# ---- 7. Token-by-token table ----------------------------------- #
|
| 253 |
+
print_section("7. Token-by-token table (first 80 tokens)")
|
| 254 |
+
print_token_table(input_ids, labels, tokenizer, max_rows=80)
|
| 255 |
+
|
| 256 |
+
# ---- 8. Decode round-trip ------------------------------------- #
|
| 257 |
+
print_section("8. Full decode round-trip (skip_special_tokens=False)")
|
| 258 |
+
decoded = tokenizer.decode(input_ids, skip_special_tokens=False)
|
| 259 |
+
print(decoded[:600])
|
| 260 |
+
|
| 261 |
+
print("\n" + "=" * 60)
|
| 262 |
+
print(" CHECK COMPLETE — pipeline looks aligned ✓")
|
| 263 |
+
print("=" * 60)
|
| 264 |
+
print(f"\nWhen ready, run the full data prep:")
|
| 265 |
+
print(f" python finetune/prepare_data.py")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
main()
|
finetune/data/meta.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dataset": "teknium/OpenHermes-2.5",
|
| 3 |
+
"n_sampled": 80000,
|
| 4 |
+
"n_train": 76000,
|
| 5 |
+
"n_val": 4000,
|
| 6 |
+
"vocab_size": 32002,
|
| 7 |
+
"special_tokens": [
|
| 8 |
+
"<|im_start|>",
|
| 9 |
+
"<|im_end|>"
|
| 10 |
+
],
|
| 11 |
+
"max_length": 1024,
|
| 12 |
+
"seed": 42
|
| 13 |
+
}
|
finetune/data/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
finetune/data/tokenizer_config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<|endoftext|>",
|
| 4 |
+
"eos_token": "<|endoftext|>",
|
| 5 |
+
"extra_special_tokens": [
|
| 6 |
+
"<|im_start|>",
|
| 7 |
+
"<|im_end|>"
|
| 8 |
+
],
|
| 9 |
+
"is_local": true,
|
| 10 |
+
"local_files_only": false,
|
| 11 |
+
"model_max_length": 1024,
|
| 12 |
+
"pad_token": "<|endoftext|>",
|
| 13 |
+
"padding_side": "right",
|
| 14 |
+
"tokenizer_class": "TokenizersBackend",
|
| 15 |
+
"truncation_side": "right",
|
| 16 |
+
"unk_token": null
|
| 17 |
+
}
|
finetune/prepare_data.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune/prepare_data.py
|
| 3 |
+
|
| 4 |
+
Downloads teknium/OpenHermes-2.5 from HuggingFace, formats conversations
|
| 5 |
+
as ChatML, tokenizes with our custom tokenizer + 2 new special tokens,
|
| 6 |
+
and saves train_sft.pt / val_sft.pt to finetune/data/.
|
| 7 |
+
|
| 8 |
+
Also saves the tokenizer (with special tokens baked in) to finetune/data/
|
| 9 |
+
so sft_train.py and chat.py can load it without re-adding tokens.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python finetune/prepare_data.py
|
| 13 |
+
python finetune/prepare_data.py --n_samples 50000
|
| 14 |
+
|
| 15 |
+
Dataset structure (OpenHermes-2.5):
|
| 16 |
+
Each row has a "conversations" key:
|
| 17 |
+
[
|
| 18 |
+
{"from": "system", "value": "..."}, # optional
|
| 19 |
+
{"from": "human", "value": "..."},
|
| 20 |
+
{"from": "gpt", "value": "..."},
|
| 21 |
+
... # may have more turns
|
| 22 |
+
]
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import json
|
| 28 |
+
import random
|
| 29 |
+
import argparse
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
from transformers import PreTrainedTokenizerFast
|
| 34 |
+
from datasets import load_dataset
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
# ------------------------------------------------------------------ #
|
| 38 |
+
# Paths (relative to project root, not this script)
|
| 39 |
+
# ------------------------------------------------------------------ #
|
| 40 |
+
|
| 41 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 42 |
+
PROJECT_ROOT = SCRIPT_DIR.parent
|
| 43 |
+
|
| 44 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 45 |
+
|
| 46 |
+
TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
|
| 47 |
+
|
| 48 |
+
# The two new tokens that define ChatML structure
|
| 49 |
+
SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"]
|
| 50 |
+
|
| 51 |
+
MAX_LENGTH = 1024 # model context_length — truncate anything longer
|
| 52 |
+
|
| 53 |
+
# Map OpenHermes role names → ChatML role names
|
| 54 |
+
ROLE_MAP = {
|
| 55 |
+
"system": "system",
|
| 56 |
+
"human": "user",
|
| 57 |
+
"gpt": "assistant",
|
| 58 |
+
"user": "user",
|
| 59 |
+
"assistant": "assistant",
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ------------------------------------------------------------------ #
|
| 64 |
+
# TOKENIZER
|
| 65 |
+
# ------------------------------------------------------------------ #
|
| 66 |
+
|
| 67 |
+
def load_and_extend_tokenizer() -> PreTrainedTokenizerFast:
|
| 68 |
+
"""
|
| 69 |
+
Loads our pretrained BPE tokenizer and adds the two ChatML tokens.
|
| 70 |
+
Returns the extended tokenizer (vocab 32,000 → 32,002).
|
| 71 |
+
"""
|
| 72 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 73 |
+
|
| 74 |
+
new_tokens = [t for t in SPECIAL_TOKENS if t not in tokenizer.get_vocab()]
|
| 75 |
+
if new_tokens:
|
| 76 |
+
added = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
|
| 77 |
+
print(f" Added {added} special token(s): {new_tokens}")
|
| 78 |
+
else:
|
| 79 |
+
print(" Special tokens already present — skipping add.")
|
| 80 |
+
|
| 81 |
+
print(f" Final vocab size: {len(tokenizer):,}")
|
| 82 |
+
return tokenizer
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ------------------------------------------------------------------ #
|
| 86 |
+
# FORMAT + TOKENIZE ONE CONVERSATION
|
| 87 |
+
# ------------------------------------------------------------------ #
|
| 88 |
+
|
| 89 |
+
def format_and_tokenize(
|
| 90 |
+
conversations: list[dict],
|
| 91 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 92 |
+
) -> tuple[list[int], list[int]] | None:
|
| 93 |
+
"""
|
| 94 |
+
Converts a list of chat turns into (input_ids, labels).
|
| 95 |
+
|
| 96 |
+
ChatML format per turn:
|
| 97 |
+
<|im_start|>{role}\\n{content}<|im_end|>\\n
|
| 98 |
+
|
| 99 |
+
Labels:
|
| 100 |
+
- User / system turns → all -100 (not learned)
|
| 101 |
+
- Assistant turns → header (-100) + content (actual token ids)
|
| 102 |
+
i.e. we learn the response but not the "<|im_start|>assistant\\n" prefix
|
| 103 |
+
|
| 104 |
+
Returns None for:
|
| 105 |
+
- Conversations with no assistant turns (nothing to learn)
|
| 106 |
+
- Conversations that tokenize to fewer than 8 tokens
|
| 107 |
+
"""
|
| 108 |
+
input_ids: list[int] = []
|
| 109 |
+
labels: list[int] = []
|
| 110 |
+
|
| 111 |
+
for turn in conversations:
|
| 112 |
+
role_raw = turn.get("from", turn.get("role", "")).strip().lower()
|
| 113 |
+
content = turn.get("value", turn.get("content", "")).strip()
|
| 114 |
+
role = ROLE_MAP.get(role_raw, role_raw)
|
| 115 |
+
|
| 116 |
+
if not content or not role:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
# ---- header: <|im_start|>role\n — never labeled ----------- #
|
| 120 |
+
header_text = f"<|im_start|>{role}\n"
|
| 121 |
+
header_ids = tokenizer.encode(header_text, add_special_tokens=False)
|
| 122 |
+
|
| 123 |
+
# ---- body: content<|im_end|>\n ------------------------------ #
|
| 124 |
+
body_text = f"{content}<|im_end|>\n"
|
| 125 |
+
body_ids = tokenizer.encode(body_text, add_special_tokens=False)
|
| 126 |
+
|
| 127 |
+
turn_input = header_ids + body_ids
|
| 128 |
+
|
| 129 |
+
if role == "assistant":
|
| 130 |
+
# Teach the model the body (response + im_end), not the header
|
| 131 |
+
turn_labels = [-100] * len(header_ids) + body_ids
|
| 132 |
+
else:
|
| 133 |
+
# User / system: no learning signal
|
| 134 |
+
turn_labels = [-100] * len(turn_input)
|
| 135 |
+
|
| 136 |
+
input_ids.extend(turn_input)
|
| 137 |
+
labels.extend(turn_labels)
|
| 138 |
+
|
| 139 |
+
# Must have at least one labeled token to be a valid training example
|
| 140 |
+
if not any(l != -100 for l in labels):
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
# Truncate to context window
|
| 144 |
+
input_ids = input_ids[:MAX_LENGTH]
|
| 145 |
+
labels = labels[:MAX_LENGTH]
|
| 146 |
+
|
| 147 |
+
# Skip micro-sequences (likely malformed)
|
| 148 |
+
if len(input_ids) < 8:
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
return input_ids, labels
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ------------------------------------------------------------------ #
|
| 155 |
+
# ARG PARSING
|
| 156 |
+
# ------------------------------------------------------------------ #
|
| 157 |
+
|
| 158 |
+
def parse_args():
|
| 159 |
+
p = argparse.ArgumentParser(description="Prepare SFT data from OpenHermes-2.5")
|
| 160 |
+
p.add_argument("--n_samples", type=int, default=80_000,
|
| 161 |
+
help="Number of conversations to sample (default: 80000)")
|
| 162 |
+
p.add_argument("--val_ratio", type=float, default=0.05,
|
| 163 |
+
help="Fraction held out for validation (default: 0.05)")
|
| 164 |
+
p.add_argument("--output_dir", type=str, default=str(SCRIPT_DIR / "data"),
|
| 165 |
+
help="Where to save train_sft.pt, val_sft.pt, and tokenizer")
|
| 166 |
+
p.add_argument("--seed", type=int, default=42)
|
| 167 |
+
return p.parse_args()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ------------------------------------------------------------------ #
|
| 171 |
+
# MAIN
|
| 172 |
+
# ------------------------------------------------------------------ #
|
| 173 |
+
|
| 174 |
+
def main():
|
| 175 |
+
args = parse_args()
|
| 176 |
+
random.seed(args.seed)
|
| 177 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 178 |
+
|
| 179 |
+
print("\n" + "=" * 60)
|
| 180 |
+
print(" SLLM-150M SFT — Data Preparation")
|
| 181 |
+
print("=" * 60)
|
| 182 |
+
|
| 183 |
+
# ---------------------------------------------------------------- #
|
| 184 |
+
# 1. Tokenizer
|
| 185 |
+
# ---------------------------------------------------------------- #
|
| 186 |
+
print("\n[1/4] Loading tokenizer + adding ChatML special tokens...")
|
| 187 |
+
tokenizer = load_and_extend_tokenizer()
|
| 188 |
+
|
| 189 |
+
# Save the extended tokenizer to data dir so training/chat can load it
|
| 190 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 191 |
+
print(f" Extended tokenizer saved → {args.output_dir}/")
|
| 192 |
+
|
| 193 |
+
# ---------------------------------------------------------------- #
|
| 194 |
+
# 2. Dataset download
|
| 195 |
+
# ---------------------------------------------------------------- #
|
| 196 |
+
print(f"\n[2/4] Loading teknium/OpenHermes-2.5 from HuggingFace...")
|
| 197 |
+
ds = load_dataset("teknium/OpenHermes-2.5")
|
| 198 |
+
full = ds["train"] # only split in this dataset
|
| 199 |
+
print(f" Full dataset size: {len(full):,} examples")
|
| 200 |
+
|
| 201 |
+
# Sample a subset
|
| 202 |
+
n = min(args.n_samples, len(full))
|
| 203 |
+
indices = random.sample(range(len(full)), n)
|
| 204 |
+
subset = full.select(indices)
|
| 205 |
+
print(f" Sampled: {n:,} examples (seed={args.seed})")
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------- #
|
| 208 |
+
# 3. Tokenize
|
| 209 |
+
# ---------------------------------------------------------------- #
|
| 210 |
+
print(f"\n[3/4] Formatting and tokenizing conversations...")
|
| 211 |
+
|
| 212 |
+
all_input_ids: list[torch.Tensor] = []
|
| 213 |
+
all_labels: list[torch.Tensor] = []
|
| 214 |
+
skipped = 0
|
| 215 |
+
|
| 216 |
+
for example in tqdm(subset, desc="Tokenizing", unit="conv"):
|
| 217 |
+
conversations = example.get("conversations", [])
|
| 218 |
+
result = format_and_tokenize(conversations, tokenizer)
|
| 219 |
+
|
| 220 |
+
if result is None:
|
| 221 |
+
skipped += 1
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
ids, lbls = result
|
| 225 |
+
all_input_ids.append(torch.tensor(ids, dtype=torch.long))
|
| 226 |
+
all_labels.append( torch.tensor(lbls, dtype=torch.long))
|
| 227 |
+
|
| 228 |
+
total = len(all_input_ids)
|
| 229 |
+
print(f"\n Kept : {total:,}")
|
| 230 |
+
print(f" Skipped: {skipped:,} (no assistant turn or too short)")
|
| 231 |
+
|
| 232 |
+
if total == 0:
|
| 233 |
+
raise RuntimeError("No valid examples produced — check dataset structure.")
|
| 234 |
+
|
| 235 |
+
# Print a sample so we can visually verify
|
| 236 |
+
print("\n ── Sample (first conversation, first 400 chars) ──")
|
| 237 |
+
sample_decoded = tokenizer.decode(all_input_ids[0].tolist(), skip_special_tokens=False)
|
| 238 |
+
print(" " + sample_decoded[:400].replace("\n", "\n "))
|
| 239 |
+
print()
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------- #
|
| 242 |
+
# 4. Split + save
|
| 243 |
+
# ---------------------------------------------------------------- #
|
| 244 |
+
print(f"[4/4] Splitting and saving...")
|
| 245 |
+
|
| 246 |
+
perm = list(range(total))
|
| 247 |
+
random.shuffle(perm)
|
| 248 |
+
val_n = max(1, int(total * args.val_ratio))
|
| 249 |
+
train_n = total - val_n
|
| 250 |
+
|
| 251 |
+
train_ids = [all_input_ids[i] for i in perm[:train_n]]
|
| 252 |
+
train_lbl = [all_labels[i] for i in perm[:train_n]]
|
| 253 |
+
val_ids = [all_input_ids[i] for i in perm[train_n:]]
|
| 254 |
+
val_lbl = [all_labels[i] for i in perm[train_n:]]
|
| 255 |
+
|
| 256 |
+
train_path = os.path.join(args.output_dir, "train_sft.pt")
|
| 257 |
+
val_path = os.path.join(args.output_dir, "val_sft.pt")
|
| 258 |
+
|
| 259 |
+
torch.save({"input_ids": train_ids, "labels": train_lbl}, train_path)
|
| 260 |
+
torch.save({"input_ids": val_ids, "labels": val_lbl}, val_path)
|
| 261 |
+
|
| 262 |
+
# Stats
|
| 263 |
+
lengths = [len(x) for x in all_input_ids]
|
| 264 |
+
label_ratios = [(t != -100).float().mean().item() for t in all_labels]
|
| 265 |
+
avg_len = sum(lengths) / len(lengths)
|
| 266 |
+
avg_lbl_ratio = sum(label_ratios) / len(label_ratios)
|
| 267 |
+
|
| 268 |
+
print(f"\n train_sft.pt : {train_n:,} examples")
|
| 269 |
+
print(f" val_sft.pt : {val_n:,} examples")
|
| 270 |
+
print(f"\n Avg seq length : {avg_len:.0f} tokens (max={max(lengths)})")
|
| 271 |
+
print(f" Avg assistant ratio : {avg_lbl_ratio:.1%} of tokens are labeled")
|
| 272 |
+
|
| 273 |
+
# Save metadata for reference
|
| 274 |
+
meta = {
|
| 275 |
+
"dataset": "teknium/OpenHermes-2.5",
|
| 276 |
+
"n_sampled": n,
|
| 277 |
+
"n_train": train_n,
|
| 278 |
+
"n_val": val_n,
|
| 279 |
+
"vocab_size": len(tokenizer),
|
| 280 |
+
"special_tokens": SPECIAL_TOKENS,
|
| 281 |
+
"max_length": MAX_LENGTH,
|
| 282 |
+
"seed": args.seed,
|
| 283 |
+
}
|
| 284 |
+
with open(os.path.join(args.output_dir, "meta.json"), "w") as f:
|
| 285 |
+
json.dump(meta, f, indent=2)
|
| 286 |
+
print(f"\n meta.json saved → {args.output_dir}/meta.json")
|
| 287 |
+
|
| 288 |
+
print("\n" + "=" * 60)
|
| 289 |
+
print(" Data preparation complete!")
|
| 290 |
+
print("=" * 60)
|
| 291 |
+
print(f"""
|
| 292 |
+
Next step:
|
| 293 |
+
python finetune/sft_train.py \\
|
| 294 |
+
--base_ckpt runs/sllm_150m/ckpt_0011500.pt \\
|
| 295 |
+
--run_dir runs/sllm_150m_chat \\
|
| 296 |
+
--max_steps 2000 \\
|
| 297 |
+
--batch_size 4 --grad_accum 8 \\
|
| 298 |
+
--grad_checkpoint
|
| 299 |
+
""")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
main()
|
finetune/sft_dataset.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune/sft_dataset.py
|
| 3 |
+
|
| 4 |
+
SFT Dataset — loads pre-tokenized ChatML sequences from .pt shards
|
| 5 |
+
produced by prepare_data.py.
|
| 6 |
+
|
| 7 |
+
Each item returns (input_ids, labels) where labels has -100 for all
|
| 8 |
+
non-assistant tokens so CrossEntropy only trains on assistant responses.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SFTDataset(Dataset):
|
| 18 |
+
"""
|
| 19 |
+
Dataset for Supervised Fine-Tuning.
|
| 20 |
+
|
| 21 |
+
Loads a .pt shard containing:
|
| 22 |
+
{
|
| 23 |
+
"input_ids": list of LongTensors (variable length),
|
| 24 |
+
"labels": list of LongTensors (same shapes, -100 for masked)
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
Each __getitem__ returns:
|
| 28 |
+
input_ids : (seq_len,) LongTensor
|
| 29 |
+
labels : (seq_len,) LongTensor — -100 for user/system tokens
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, data_path: str, context_length: int = 1024):
|
| 33 |
+
data = torch.load(data_path, weights_only=False)
|
| 34 |
+
self.input_ids = data["input_ids"]
|
| 35 |
+
self.labels = data["labels"]
|
| 36 |
+
self.context_length = context_length
|
| 37 |
+
|
| 38 |
+
assert len(self.input_ids) == len(self.labels), "input_ids / labels length mismatch"
|
| 39 |
+
print(f"[SFTDataset] Loaded {len(self.input_ids):,} examples from {data_path}")
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return len(self.input_ids)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
ids = self.input_ids[idx]
|
| 46 |
+
lbl = self.labels[idx]
|
| 47 |
+
|
| 48 |
+
# Hard-truncate to model context length
|
| 49 |
+
if len(ids) > self.context_length:
|
| 50 |
+
ids = ids[: self.context_length]
|
| 51 |
+
lbl = lbl[: self.context_length]
|
| 52 |
+
|
| 53 |
+
return ids, lbl
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ------------------------------------------------------------------ #
|
| 57 |
+
# COLLATE
|
| 58 |
+
# ------------------------------------------------------------------ #
|
| 59 |
+
|
| 60 |
+
def sft_collate_fn(batch, pad_token_id: int):
|
| 61 |
+
"""
|
| 62 |
+
Pads a batch of variable-length sequences to the same length.
|
| 63 |
+
input_ids → padded with pad_token_id
|
| 64 |
+
labels → padded with -100 (ignored by CrossEntropy)
|
| 65 |
+
"""
|
| 66 |
+
input_ids_list, labels_list = zip(*batch)
|
| 67 |
+
|
| 68 |
+
max_len = max(x.size(0) for x in input_ids_list)
|
| 69 |
+
|
| 70 |
+
input_ids_padded = torch.full((len(batch), max_len), pad_token_id, dtype=torch.long)
|
| 71 |
+
labels_padded = torch.full((len(batch), max_len), -100, dtype=torch.long)
|
| 72 |
+
|
| 73 |
+
for i, (ids, lbl) in enumerate(zip(input_ids_list, labels_list)):
|
| 74 |
+
n = ids.size(0)
|
| 75 |
+
input_ids_padded[i, :n] = ids
|
| 76 |
+
labels_padded[i, :n] = lbl
|
| 77 |
+
|
| 78 |
+
return input_ids_padded, labels_padded
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ------------------------------------------------------------------ #
|
| 82 |
+
# FACTORY
|
| 83 |
+
# ------------------------------------------------------------------ #
|
| 84 |
+
|
| 85 |
+
def build_sft_dataloader(
|
| 86 |
+
data_path: str,
|
| 87 |
+
batch_size: int,
|
| 88 |
+
pad_token_id: int,
|
| 89 |
+
context_length: int = 1024,
|
| 90 |
+
num_workers: int = 0,
|
| 91 |
+
shuffle: bool = True,
|
| 92 |
+
) -> DataLoader:
|
| 93 |
+
dataset = SFTDataset(data_path, context_length=context_length)
|
| 94 |
+
collate_fn = partial(sft_collate_fn, pad_token_id=pad_token_id)
|
| 95 |
+
|
| 96 |
+
return DataLoader(
|
| 97 |
+
dataset,
|
| 98 |
+
batch_size = batch_size,
|
| 99 |
+
shuffle = shuffle,
|
| 100 |
+
num_workers = num_workers,
|
| 101 |
+
collate_fn = collate_fn,
|
| 102 |
+
pin_memory = True,
|
| 103 |
+
)
|
finetune/sft_train.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune/sft_train.py
|
| 3 |
+
|
| 4 |
+
Full Supervised Fine-Tuning (SFT) of SLLM-150M → Chat Model.
|
| 5 |
+
|
| 6 |
+
Starts from the pretrained base checkpoint, resizes the token embedding
|
| 7 |
+
for 2 new ChatML special tokens, then trains with masked CrossEntropy
|
| 8 |
+
so only assistant response tokens contribute to the loss.
|
| 9 |
+
|
| 10 |
+
Usage (first run):
|
| 11 |
+
python finetune/sft_train.py \\
|
| 12 |
+
--base_ckpt runs/sllm_150m/ckpt_0011500.pt \\
|
| 13 |
+
--run_dir runs/sllm_150m_chat \\
|
| 14 |
+
--max_steps 2000 \\
|
| 15 |
+
--batch_size 4 --grad_accum 8 \\
|
| 16 |
+
--grad_checkpoint
|
| 17 |
+
|
| 18 |
+
Resume:
|
| 19 |
+
python finetune/sft_train.py \\
|
| 20 |
+
--resume --run_dir runs/sllm_150m_chat \\
|
| 21 |
+
--extra_steps 1000
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import json
|
| 27 |
+
import math
|
| 28 |
+
import time
|
| 29 |
+
import signal
|
| 30 |
+
import argparse
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
from torch.amp import autocast, GradScaler
|
| 37 |
+
from transformers import PreTrainedTokenizerFast
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
|
| 40 |
+
# ------------------------------------------------------------------ #
|
| 41 |
+
# Resolve project root so model/ is importable
|
| 42 |
+
# ------------------------------------------------------------------ #
|
| 43 |
+
|
| 44 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 45 |
+
PROJECT_ROOT = SCRIPT_DIR.parent
|
| 46 |
+
DATA_DIR = SCRIPT_DIR / "data"
|
| 47 |
+
|
| 48 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 49 |
+
sys.path.insert(0, str(SCRIPT_DIR)) # so we can import sft_dataset
|
| 50 |
+
|
| 51 |
+
from model.config import SLLM_150M
|
| 52 |
+
from model.model import SLLM
|
| 53 |
+
from sft_dataset import build_sft_dataloader
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ------------------------------------------------------------------ #
|
| 57 |
+
# ARG PARSING
|
| 58 |
+
# ------------------------------------------------------------------ #
|
| 59 |
+
|
| 60 |
+
def parse_args():
|
| 61 |
+
p = argparse.ArgumentParser(description="SLLM-150M SFT Training")
|
| 62 |
+
|
| 63 |
+
# Checkpoints
|
| 64 |
+
p.add_argument("--base_ckpt", type=str,
|
| 65 |
+
default=str(PROJECT_ROOT / "runs" / "sllm_150m" / "ckpt_0011500.pt"),
|
| 66 |
+
help="Path to pretrained base checkpoint (.pt)")
|
| 67 |
+
p.add_argument("--run_dir", type=str, default="runs/sllm_150m_chat",
|
| 68 |
+
help="Output directory for SFT checkpoints and logs")
|
| 69 |
+
p.add_argument("--resume", action="store_true",
|
| 70 |
+
help="Resume from latest SFT checkpoint in --run_dir")
|
| 71 |
+
p.add_argument("--max_steps", type=int, default=2000,
|
| 72 |
+
help="Absolute step target for this run")
|
| 73 |
+
p.add_argument("--extra_steps", type=int, default=None,
|
| 74 |
+
help="Run N more steps from current checkpoint (relative)")
|
| 75 |
+
|
| 76 |
+
# Data
|
| 77 |
+
p.add_argument("--data_dir", type=str, default=str(DATA_DIR),
|
| 78 |
+
help="Directory with train_sft.pt, val_sft.pt, and tokenizer files")
|
| 79 |
+
p.add_argument("--num_workers", type=int, default=0)
|
| 80 |
+
|
| 81 |
+
# Optimisation — note: much lower LR than pretraining
|
| 82 |
+
p.add_argument("--batch_size", type=int, default=4)
|
| 83 |
+
p.add_argument("--grad_accum", type=int, default=8)
|
| 84 |
+
p.add_argument("--max_lr", type=float, default=1e-5,
|
| 85 |
+
help="Peak LR (10x lower than pretraining)")
|
| 86 |
+
p.add_argument("--min_lr", type=float, default=1e-6)
|
| 87 |
+
p.add_argument("--warmup_steps", type=int, default=30)
|
| 88 |
+
p.add_argument("--weight_decay", type=float, default=0.1)
|
| 89 |
+
p.add_argument("--grad_clip", type=float, default=1.0)
|
| 90 |
+
p.add_argument("--dropout", type=float, default=0.1,
|
| 91 |
+
help="Dropout rate during SFT (0.0 in pretraining)")
|
| 92 |
+
|
| 93 |
+
# Memory
|
| 94 |
+
p.add_argument("--grad_checkpoint", action="store_true",
|
| 95 |
+
help="Enable gradient checkpointing (saves VRAM)")
|
| 96 |
+
p.add_argument("--dtype", type=str, default="bf16",
|
| 97 |
+
choices=["fp32", "fp16", "bf16"])
|
| 98 |
+
|
| 99 |
+
# Logging
|
| 100 |
+
p.add_argument("--log_every", type=int, default=10)
|
| 101 |
+
p.add_argument("--save_every", type=int, default=500)
|
| 102 |
+
p.add_argument("--val_every", type=int, default=250)
|
| 103 |
+
p.add_argument("--val_steps", type=int, default=20)
|
| 104 |
+
|
| 105 |
+
return p.parse_args()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ------------------------------------------------------------------ #
|
| 109 |
+
# VOCAB RESIZE
|
| 110 |
+
# ------------------------------------------------------------------ #
|
| 111 |
+
|
| 112 |
+
def resize_token_embeddings(model: SLLM, new_vocab_size: int):
|
| 113 |
+
"""
|
| 114 |
+
Grows model.token_emb from old_vocab_size → new_vocab_size.
|
| 115 |
+
|
| 116 |
+
New rows are initialised to the mean of existing embeddings so
|
| 117 |
+
training starts from a stable point rather than random noise.
|
| 118 |
+
lm_head weight-tying is re-applied automatically.
|
| 119 |
+
"""
|
| 120 |
+
old_size = model.config.vocab_size
|
| 121 |
+
if new_vocab_size == old_size:
|
| 122 |
+
return
|
| 123 |
+
if new_vocab_size < old_size:
|
| 124 |
+
raise ValueError(f"Cannot shrink vocab ({old_size} → {new_vocab_size})")
|
| 125 |
+
|
| 126 |
+
d_model = model.config.d_model
|
| 127 |
+
device = model.token_emb.weight.device
|
| 128 |
+
dtype = model.token_emb.weight.dtype
|
| 129 |
+
old_weight = model.token_emb.weight.data.clone() # (old_size, d)
|
| 130 |
+
mean_vec = old_weight.mean(dim=0) # (d,)
|
| 131 |
+
|
| 132 |
+
new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
|
| 133 |
+
new_weight[:old_size] = old_weight
|
| 134 |
+
# Broadcast mean_vec into new rows
|
| 135 |
+
new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
|
| 136 |
+
|
| 137 |
+
# Replace the embedding module in-place
|
| 138 |
+
new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
|
| 139 |
+
new_emb.weight.data = new_weight
|
| 140 |
+
model.token_emb = new_emb
|
| 141 |
+
|
| 142 |
+
# Re-tie the LM head to the (now larger) embedding
|
| 143 |
+
model.lm_head.weight = model.token_emb.weight
|
| 144 |
+
|
| 145 |
+
# Keep config consistent
|
| 146 |
+
model.config.vocab_size = new_vocab_size
|
| 147 |
+
|
| 148 |
+
n_new = new_vocab_size - old_size
|
| 149 |
+
print(f" Vocab resized: {old_size:,} → {new_vocab_size:,} (+{n_new} tokens, init=mean)")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ------------------------------------------------------------------ #
|
| 153 |
+
# DROPOUT
|
| 154 |
+
# ------------------------------------------------------------------ #
|
| 155 |
+
|
| 156 |
+
def set_dropout(model: SLLM, rate: float):
|
| 157 |
+
"""Applies dropout rate to every nn.Dropout in the model."""
|
| 158 |
+
count = 0
|
| 159 |
+
for m in model.modules():
|
| 160 |
+
if isinstance(m, nn.Dropout):
|
| 161 |
+
m.p = rate
|
| 162 |
+
count += 1
|
| 163 |
+
if count:
|
| 164 |
+
print(f" Dropout set to {rate} on {count} layer(s)")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ------------------------------------------------------------------ #
|
| 168 |
+
# LR SCHEDULE (cosine with linear warmup, same shape as train.py)
|
| 169 |
+
# ------------------------------------------------------------------ #
|
| 170 |
+
|
| 171 |
+
def get_lr(step: int, warmup_steps: int, total_steps: int,
|
| 172 |
+
max_lr: float, min_lr: float) -> float:
|
| 173 |
+
if step < warmup_steps:
|
| 174 |
+
return max_lr * (step + 1) / warmup_steps
|
| 175 |
+
decay_steps = total_steps if total_steps else 5_000
|
| 176 |
+
if step >= decay_steps:
|
| 177 |
+
return min_lr
|
| 178 |
+
progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps)
|
| 179 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 180 |
+
return min_lr + coeff * (max_lr - min_lr)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ------------------------------------------------------------------ #
|
| 184 |
+
# OPTIMIZER (mirrors train.py — AdamW selective decay)
|
| 185 |
+
# ------------------------------------------------------------------ #
|
| 186 |
+
|
| 187 |
+
def build_optimizer(model: SLLM, lr: float, weight_decay: float):
|
| 188 |
+
decay, no_decay = [], []
|
| 189 |
+
for name, param in model.named_parameters():
|
| 190 |
+
if not param.requires_grad:
|
| 191 |
+
continue
|
| 192 |
+
if param.dim() >= 2:
|
| 193 |
+
decay.append(param)
|
| 194 |
+
else:
|
| 195 |
+
no_decay.append(param)
|
| 196 |
+
|
| 197 |
+
groups = [
|
| 198 |
+
{"params": decay, "weight_decay": weight_decay},
|
| 199 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 200 |
+
]
|
| 201 |
+
n_d = sum(p.numel() for p in decay)
|
| 202 |
+
n_nd = sum(p.numel() for p in no_decay)
|
| 203 |
+
print(f" Optimizer: {n_d/1e6:.1f}M decay | {n_nd/1e6:.1f}M no-decay | lr={lr:.2e}")
|
| 204 |
+
|
| 205 |
+
# Note: no fused=True here — new embedding rows need correct grad flow
|
| 206 |
+
return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ------------------------------------------------------------------ #
|
| 210 |
+
# CHECKPOINT SAVE / LOAD
|
| 211 |
+
# ------------------------------------------------------------------ #
|
| 212 |
+
|
| 213 |
+
def save_checkpoint(path: str, model: SLLM, optimizer, step: int,
|
| 214 |
+
loss: float, vocab_size: int):
|
| 215 |
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
| 216 |
+
torch.save({
|
| 217 |
+
"step": step,
|
| 218 |
+
"model_state_dict": model.state_dict(),
|
| 219 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 220 |
+
"loss": loss,
|
| 221 |
+
"vocab_size": vocab_size,
|
| 222 |
+
}, path)
|
| 223 |
+
print(f"\n [CKPT] Saved: {path} (step={step}, loss={loss:.4f})")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def load_sft_checkpoint(run_dir: str, model: SLLM, optimizer, device):
|
| 227 |
+
"""Loads the latest ckpt_sft_*.pt from run_dir. Returns (step, vocab_size)."""
|
| 228 |
+
ckpts = sorted([
|
| 229 |
+
f for f in os.listdir(run_dir)
|
| 230 |
+
if f.startswith("ckpt_sft_") and f.endswith(".pt")
|
| 231 |
+
])
|
| 232 |
+
if not ckpts:
|
| 233 |
+
raise FileNotFoundError(f"No SFT checkpoints found in {run_dir}")
|
| 234 |
+
|
| 235 |
+
path = os.path.join(run_dir, ckpts[-1])
|
| 236 |
+
ckpt = torch.load(path, map_location=device, weights_only=False)
|
| 237 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 238 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 239 |
+
step = ckpt["step"]
|
| 240 |
+
vocab_size = ckpt.get("vocab_size", model.config.vocab_size)
|
| 241 |
+
loss = ckpt.get("loss", float("nan"))
|
| 242 |
+
print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})")
|
| 243 |
+
return step, vocab_size
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ------------------------------------------------------------------ #
|
| 247 |
+
# VALIDATION (uses ignore_index=-100 like training)
|
| 248 |
+
# ------------------------------------------------------------------ #
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def estimate_val_loss(model: SLLM, val_loader, val_steps: int,
|
| 252 |
+
device, dtype_ctx) -> float:
|
| 253 |
+
model.eval()
|
| 254 |
+
losses = []
|
| 255 |
+
for i, (x, y) in enumerate(val_loader):
|
| 256 |
+
if i >= val_steps:
|
| 257 |
+
break
|
| 258 |
+
x, y = x.to(device), y.to(device)
|
| 259 |
+
with dtype_ctx:
|
| 260 |
+
logits, _ = model(x)
|
| 261 |
+
# Shift logits and labels by 1 to predict the next token
|
| 262 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 263 |
+
shift_labels = y[..., 1:].contiguous()
|
| 264 |
+
loss = F.cross_entropy(
|
| 265 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 266 |
+
shift_labels.view(-1),
|
| 267 |
+
ignore_index=-100,
|
| 268 |
+
)
|
| 269 |
+
losses.append(loss.item())
|
| 270 |
+
model.train()
|
| 271 |
+
return sum(losses) / len(losses) if losses else float("nan")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# ------------------------------------------------------------------ #
|
| 275 |
+
# METRIC LOGGER
|
| 276 |
+
# ------------------------------------------------------------------ #
|
| 277 |
+
|
| 278 |
+
class MetricLogger:
|
| 279 |
+
def __init__(self, log_path: str):
|
| 280 |
+
self.log_path = log_path
|
| 281 |
+
os.makedirs(os.path.dirname(os.path.abspath(log_path)), exist_ok=True)
|
| 282 |
+
print(f" [LOG] Logging to: {log_path}")
|
| 283 |
+
|
| 284 |
+
def log(self, **kwargs):
|
| 285 |
+
with open(self.log_path, "a") as f:
|
| 286 |
+
f.write(json.dumps(kwargs) + "\n")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# ------------------------------------------------------------------ #
|
| 290 |
+
# MAIN TRAINING LOOP
|
| 291 |
+
# ------------------------------------------------------------------ #
|
| 292 |
+
|
| 293 |
+
def train():
|
| 294 |
+
args = parse_args()
|
| 295 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 296 |
+
|
| 297 |
+
print(f"\n{'='*60}")
|
| 298 |
+
print(f" SLLM-150M → Chat Model (SFT)")
|
| 299 |
+
print(f"{'='*60}")
|
| 300 |
+
print(f"\nDevice : {device}")
|
| 301 |
+
if device.type == "cuda":
|
| 302 |
+
print(f"GPU : {torch.cuda.get_device_name(0)}")
|
| 303 |
+
print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
|
| 304 |
+
|
| 305 |
+
# ---- dtype ----------------------------------------------------- #
|
| 306 |
+
if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
|
| 307 |
+
dtype_torch, dtype_name = torch.bfloat16, "bf16"
|
| 308 |
+
elif args.dtype == "fp16" and device.type == "cuda":
|
| 309 |
+
dtype_torch, dtype_name = torch.float16, "fp16"
|
| 310 |
+
else:
|
| 311 |
+
dtype_torch, dtype_name = torch.float32, "fp32"
|
| 312 |
+
|
| 313 |
+
print(f"dtype : {dtype_name}")
|
| 314 |
+
use_amp = dtype_torch in (torch.float16, torch.bfloat16)
|
| 315 |
+
dtype_ctx = (autocast(device_type=device.type, dtype=dtype_torch)
|
| 316 |
+
if use_amp else torch.no_grad().__class__())
|
| 317 |
+
scaler = GradScaler(enabled=(dtype_torch == torch.float16))
|
| 318 |
+
|
| 319 |
+
# ---- Tokenizer ------------------------------------------------- #
|
| 320 |
+
print("\n[1/5] Loading tokenizer...")
|
| 321 |
+
tok_path = args.data_dir
|
| 322 |
+
if os.path.exists(os.path.join(tok_path, "tokenizer.json")):
|
| 323 |
+
# Prefer the saved tokenizer from prepare_data.py (has special tokens)
|
| 324 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
|
| 325 |
+
print(f" Loaded from data dir: {tok_path}")
|
| 326 |
+
else:
|
| 327 |
+
# Fallback: load base tokenizer and add special tokens manually
|
| 328 |
+
base_tok_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer")
|
| 329 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_tok_dir)
|
| 330 |
+
tokenizer.add_special_tokens({"additional_special_tokens":
|
| 331 |
+
["<|im_start|>", "<|im_end|>"]})
|
| 332 |
+
print(f" Loaded base tokenizer + added special tokens")
|
| 333 |
+
|
| 334 |
+
new_vocab_size = len(tokenizer)
|
| 335 |
+
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None \
|
| 336 |
+
else tokenizer.eos_token_id
|
| 337 |
+
print(f" Vocab size : {new_vocab_size:,}")
|
| 338 |
+
print(f" Pad token : {pad_id}")
|
| 339 |
+
|
| 340 |
+
# ---- Model ----------------------------------------------------- #
|
| 341 |
+
print("\n[2/5] Loading model...")
|
| 342 |
+
cfg = SLLM_150M
|
| 343 |
+
model = SLLM(cfg).to(device)
|
| 344 |
+
|
| 345 |
+
if not args.resume:
|
| 346 |
+
# Load pretrained base weights (step 11,500)
|
| 347 |
+
print(f" Loading base checkpoint: {args.base_ckpt}")
|
| 348 |
+
base_ckpt = torch.load(args.base_ckpt, map_location=device, weights_only=False)
|
| 349 |
+
model.load_state_dict(base_ckpt["model_state_dict"])
|
| 350 |
+
base_step = base_ckpt.get("step", "?")
|
| 351 |
+
base_loss = base_ckpt.get("loss", float("nan"))
|
| 352 |
+
print(f" Base model step={base_step} loss={base_loss:.4f}")
|
| 353 |
+
del base_ckpt
|
| 354 |
+
|
| 355 |
+
# Grow embedding for the 2 new special tokens
|
| 356 |
+
resize_token_embeddings(model, new_vocab_size)
|
| 357 |
+
|
| 358 |
+
# Apply SFT dropout (was 0.0 in pretraining)
|
| 359 |
+
set_dropout(model, args.dropout)
|
| 360 |
+
|
| 361 |
+
if args.grad_checkpoint:
|
| 362 |
+
model.enable_gradient_checkpointing()
|
| 363 |
+
print(" Gradient checkpointing: ON")
|
| 364 |
+
|
| 365 |
+
print(f" Model params: {model.count_params()/1e6:.1f}M")
|
| 366 |
+
|
| 367 |
+
# ---- Optimizer ------------------------------------------------- #
|
| 368 |
+
print("\n[3/5] Building optimizer...")
|
| 369 |
+
optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay)
|
| 370 |
+
|
| 371 |
+
# ---- Resume from SFT checkpoint -------------------------------- #
|
| 372 |
+
start_step = 0
|
| 373 |
+
if args.resume:
|
| 374 |
+
try:
|
| 375 |
+
start_step, _ = load_sft_checkpoint(args.run_dir, model, optimizer, device)
|
| 376 |
+
except FileNotFoundError as e:
|
| 377 |
+
print(f" [WARN] {e} — starting SFT from base checkpoint.")
|
| 378 |
+
|
| 379 |
+
# Resolve --extra_steps → --max_steps
|
| 380 |
+
if args.extra_steps is not None:
|
| 381 |
+
args.max_steps = start_step + args.extra_steps
|
| 382 |
+
print(f" --extra_steps {args.extra_steps} → max_steps={args.max_steps}")
|
| 383 |
+
|
| 384 |
+
if args.max_steps is not None and start_step >= args.max_steps:
|
| 385 |
+
print(f"\n [WARN] Already at step {start_step} >= max_steps {args.max_steps}.")
|
| 386 |
+
print(f" Use --extra_steps N to run N more steps.")
|
| 387 |
+
return
|
| 388 |
+
|
| 389 |
+
# ---- Data ------------------------------------------------------ #
|
| 390 |
+
print("\n[4/5] Loading SFT dataset...")
|
| 391 |
+
train_path = os.path.join(args.data_dir, "train_sft.pt")
|
| 392 |
+
val_path = os.path.join(args.data_dir, "val_sft.pt")
|
| 393 |
+
|
| 394 |
+
train_loader = build_sft_dataloader(
|
| 395 |
+
data_path=train_path, batch_size=args.batch_size,
|
| 396 |
+
pad_token_id=pad_id, context_length=cfg.context_length,
|
| 397 |
+
num_workers=args.num_workers, shuffle=True,
|
| 398 |
+
)
|
| 399 |
+
val_loader = build_sft_dataloader(
|
| 400 |
+
data_path=val_path, batch_size=args.batch_size,
|
| 401 |
+
pad_token_id=pad_id, context_length=cfg.context_length,
|
| 402 |
+
num_workers=0, shuffle=False,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# ---- Run dir + logger ------------------------------------------ #
|
| 406 |
+
os.makedirs(args.run_dir, exist_ok=True)
|
| 407 |
+
log_path = os.path.join(args.run_dir, "sft_log.jsonl")
|
| 408 |
+
logger = MetricLogger(log_path)
|
| 409 |
+
|
| 410 |
+
# ---- Training info --------------------------------------------- #
|
| 411 |
+
eff_batch = args.batch_size * args.grad_accum
|
| 412 |
+
print(f"\n[5/5] Training config:")
|
| 413 |
+
print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} → eff={eff_batch})")
|
| 414 |
+
print(f" max_steps : {args.max_steps}")
|
| 415 |
+
print(f" start_step : {start_step}")
|
| 416 |
+
print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else '∞'}")
|
| 417 |
+
print(f" max_lr / min_lr: {args.max_lr:.2e} / {args.min_lr:.2e}")
|
| 418 |
+
print(f" warmup_steps : {args.warmup_steps}")
|
| 419 |
+
print(f" save_every : {args.save_every}")
|
| 420 |
+
print(f" val_every : {args.val_every}")
|
| 421 |
+
|
| 422 |
+
# ---- Ctrl+C handler -------------------------------------------- #
|
| 423 |
+
stop_flag = {"stop": False}
|
| 424 |
+
def _signal_handler(sig, frame):
|
| 425 |
+
print("\n [SIGNAL] Ctrl+C — will save and exit after this step.")
|
| 426 |
+
stop_flag["stop"] = True
|
| 427 |
+
signal.signal(signal.SIGINT, _signal_handler)
|
| 428 |
+
|
| 429 |
+
# ================================================================ #
|
| 430 |
+
# TRAINING LOOP
|
| 431 |
+
# ================================================================ #
|
| 432 |
+
model.train()
|
| 433 |
+
step = start_step
|
| 434 |
+
running_loss = 0.0
|
| 435 |
+
t_start = time.time()
|
| 436 |
+
t_step_start = time.time()
|
| 437 |
+
data_iter = iter(train_loader)
|
| 438 |
+
|
| 439 |
+
print(f"\n{'='*60}")
|
| 440 |
+
print(f" SFT STARTED (step {step} → {args.max_steps})")
|
| 441 |
+
print(f"{'='*60}\n")
|
| 442 |
+
|
| 443 |
+
pbar = tqdm(
|
| 444 |
+
initial=step, total=args.max_steps,
|
| 445 |
+
desc="SFT", unit="step", dynamic_ncols=True,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
while True:
|
| 449 |
+
# ---- Stop conditions --------------------------------------- #
|
| 450 |
+
if stop_flag["stop"]:
|
| 451 |
+
break
|
| 452 |
+
if args.max_steps is not None and step >= args.max_steps:
|
| 453 |
+
print(f"\n [DONE] Reached max_steps={args.max_steps}")
|
| 454 |
+
break
|
| 455 |
+
|
| 456 |
+
optimizer.zero_grad(set_to_none=True)
|
| 457 |
+
accum_loss = 0.0
|
| 458 |
+
|
| 459 |
+
# ---- Gradient accumulation micro-steps --------------------- #
|
| 460 |
+
for _ in range(args.grad_accum):
|
| 461 |
+
try:
|
| 462 |
+
x, y = next(data_iter)
|
| 463 |
+
except StopIteration:
|
| 464 |
+
data_iter = iter(train_loader)
|
| 465 |
+
x, y = next(data_iter)
|
| 466 |
+
|
| 467 |
+
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
|
| 468 |
+
|
| 469 |
+
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
|
| 470 |
+
logits, _ = model(x) # (B, T, V) — don't use built-in loss
|
| 471 |
+
# Shift logits and labels by 1 to predict the next token
|
| 472 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 473 |
+
shift_labels = y[..., 1:].contiguous()
|
| 474 |
+
# Use ignore_index=-100 so only assistant tokens drive the loss
|
| 475 |
+
loss = F.cross_entropy(
|
| 476 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 477 |
+
shift_labels.view(-1),
|
| 478 |
+
ignore_index=-100,
|
| 479 |
+
) / args.grad_accum # scale for accumulation
|
| 480 |
+
|
| 481 |
+
scaler.scale(loss).backward()
|
| 482 |
+
accum_loss += loss.item()
|
| 483 |
+
|
| 484 |
+
# ---- Grad clip --------------------------------------------- #
|
| 485 |
+
if args.grad_clip > 0:
|
| 486 |
+
scaler.unscale_(optimizer)
|
| 487 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 488 |
+
else:
|
| 489 |
+
grad_norm = float("nan")
|
| 490 |
+
|
| 491 |
+
# ---- LR ---------------------------------------------------- #
|
| 492 |
+
lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr)
|
| 493 |
+
for pg in optimizer.param_groups:
|
| 494 |
+
pg["lr"] = lr
|
| 495 |
+
|
| 496 |
+
# ---- Optimizer step ---------------------------------------- #
|
| 497 |
+
scaler.step(optimizer)
|
| 498 |
+
scaler.update()
|
| 499 |
+
|
| 500 |
+
step += 1
|
| 501 |
+
running_loss = accum_loss
|
| 502 |
+
|
| 503 |
+
t_now = time.time()
|
| 504 |
+
elapsed_step = t_now - t_step_start
|
| 505 |
+
t_step_start = t_now
|
| 506 |
+
|
| 507 |
+
pbar.update(1)
|
| 508 |
+
pbar.set_postfix({"loss": f"{running_loss:.4f}", "lr": f"{lr:.1e}"})
|
| 509 |
+
|
| 510 |
+
# ---- Logging ----------------------------------------------- #
|
| 511 |
+
if step % args.log_every == 0:
|
| 512 |
+
entry = {
|
| 513 |
+
"step": step,
|
| 514 |
+
"loss": round(running_loss, 6),
|
| 515 |
+
"lr": lr,
|
| 516 |
+
"grad_norm": round(float(grad_norm), 4)
|
| 517 |
+
if not math.isnan(float(grad_norm)) else None,
|
| 518 |
+
"elapsed_s": round(t_now - t_start, 1),
|
| 519 |
+
}
|
| 520 |
+
if device.type == "cuda":
|
| 521 |
+
entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3)
|
| 522 |
+
logger.log(**entry)
|
| 523 |
+
|
| 524 |
+
# ---- Validation -------------------------------------------- #
|
| 525 |
+
if step % args.val_every == 0:
|
| 526 |
+
v_ctx = autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp)
|
| 527 |
+
val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, v_ctx)
|
| 528 |
+
tqdm.write(
|
| 529 |
+
f" [STEP {step:5d}] train={running_loss:.4f} "
|
| 530 |
+
f"val={val_loss:.4f} lr={lr:.1e}"
|
| 531 |
+
)
|
| 532 |
+
logger.log(step=step, val_loss=round(val_loss, 6))
|
| 533 |
+
|
| 534 |
+
# ---- Checkpoint -------------------------------------------- #
|
| 535 |
+
if step % args.save_every == 0:
|
| 536 |
+
ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt")
|
| 537 |
+
save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size)
|
| 538 |
+
|
| 539 |
+
# ================================================================ #
|
| 540 |
+
# FINAL SAVE
|
| 541 |
+
# ================================================================ #
|
| 542 |
+
pbar.close()
|
| 543 |
+
steps_done = step - start_step
|
| 544 |
+
if steps_done > 0:
|
| 545 |
+
ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt")
|
| 546 |
+
save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size)
|
| 547 |
+
else:
|
| 548 |
+
print("\n [SKIP] No steps taken — skipping checkpoint save.")
|
| 549 |
+
|
| 550 |
+
total_time = time.time() - t_start
|
| 551 |
+
print(f"\n{'='*60}")
|
| 552 |
+
print(f" SFT COMPLETE")
|
| 553 |
+
print(f"{'='*60}")
|
| 554 |
+
print(f" Steps done : {steps_done}")
|
| 555 |
+
print(f" Final loss : {running_loss:.4f}")
|
| 556 |
+
print(f" Total time : {total_time/60:.1f} min")
|
| 557 |
+
print(f" Run dir : {args.run_dir}")
|
| 558 |
+
print(f"\nStart chatting:")
|
| 559 |
+
print(f" python finetune/chat.py --run_dir {args.run_dir}")
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
if __name__ == "__main__":
|
| 563 |
+
train()
|
model/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model/__init__.py
|
| 2 |
+
from model.config import ModelConfig, SLLM_100M, SLLM_150M
|
| 3 |
+
from model.model import SLLM
|
| 4 |
+
|
| 5 |
+
__all__ = ["ModelConfig", "SLLM_100M", "SLLM_150M", "SLLM"]
|
model/attention.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/attention.py
|
| 3 |
+
|
| 4 |
+
Causal Multi-Head Self-Attention with RoPE.
|
| 5 |
+
|
| 6 |
+
Architecture:
|
| 7 |
+
Input x (B, T, d_model)
|
| 8 |
+
-> Linear projections Q, K, V (no bias)
|
| 9 |
+
-> Reshape to (B, n_heads, T, head_dim)
|
| 10 |
+
-> Apply RoPE to Q and K
|
| 11 |
+
-> Scaled dot-product attention with causal mask
|
| 12 |
+
-> Reshape back to (B, T, d_model)
|
| 13 |
+
-> Output projection O (no bias)
|
| 14 |
+
|
| 15 |
+
Uses torch.nn.functional.scaled_dot_product_attention (Flash Attention
|
| 16 |
+
when available via PyTorch 2.0+) for memory-efficient attention.
|
| 17 |
+
The causal mask is handled by is_causal=True — no need to materialize
|
| 18 |
+
an explicit O(T^2) mask tensor.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
from model.config import ModelConfig
|
| 26 |
+
from model.rope import RoPECache, apply_rope
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CausalSelfAttention(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ModelConfig):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.n_heads = config.n_heads
|
| 34 |
+
self.head_dim = config.head_dim
|
| 35 |
+
self.d_model = config.d_model
|
| 36 |
+
self.dropout = config.dropout
|
| 37 |
+
|
| 38 |
+
# Q, K, V projections fused into one matrix for efficiency
|
| 39 |
+
# Output: (B, T, 3 * d_model), then split
|
| 40 |
+
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
|
| 41 |
+
|
| 42 |
+
# Output projection
|
| 43 |
+
self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
|
| 44 |
+
|
| 45 |
+
# Attention dropout (applied inside sdpa)
|
| 46 |
+
self.attn_dropout = config.dropout
|
| 47 |
+
|
| 48 |
+
# RoPE cache — lives as a buffer (moves to GPU automatically)
|
| 49 |
+
self.rope = RoPECache(
|
| 50 |
+
head_dim = config.head_dim,
|
| 51 |
+
max_seq_len = config.context_length,
|
| 52 |
+
theta = config.rope_theta,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
Args:
|
| 58 |
+
x : (B, T, d_model)
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
out : (B, T, d_model)
|
| 62 |
+
"""
|
| 63 |
+
B, T, C = x.shape # C = d_model
|
| 64 |
+
|
| 65 |
+
# ---- QKV projection ---------------------------------------- #
|
| 66 |
+
qkv = self.qkv_proj(x) # (B, T, 3*C)
|
| 67 |
+
q, k, v = qkv.split(self.d_model, dim=-1) # each: (B, T, C)
|
| 68 |
+
|
| 69 |
+
# ---- Reshape to (B, n_heads, T, head_dim) ------------------ #
|
| 70 |
+
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 71 |
+
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 72 |
+
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 73 |
+
|
| 74 |
+
# ---- Apply RoPE to Q and K --------------------------------- #
|
| 75 |
+
cos, sin = self.rope.get(T) # (T, head_dim)
|
| 76 |
+
q, k = apply_rope(q, k, cos, sin)
|
| 77 |
+
|
| 78 |
+
# ---- Scaled dot-product attention (Flash Attention) -------- #
|
| 79 |
+
# is_causal=True handles the causal mask internally — no mask alloc.
|
| 80 |
+
# dropout_p only applies during training.
|
| 81 |
+
attn_out = F.scaled_dot_product_attention(
|
| 82 |
+
q, k, v,
|
| 83 |
+
attn_mask = None,
|
| 84 |
+
dropout_p = self.attn_dropout if self.training else 0.0,
|
| 85 |
+
is_causal = True,
|
| 86 |
+
) # (B, n_heads, T, head_dim)
|
| 87 |
+
|
| 88 |
+
# ---- Merge heads ------------------------------------------- #
|
| 89 |
+
# contiguous() needed before view after transpose
|
| 90 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)
|
| 91 |
+
|
| 92 |
+
# ---- Output projection ------------------------------------- #
|
| 93 |
+
return self.o_proj(attn_out) # (B, T, d_model)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ------------------------------------------------------------------ #
|
| 97 |
+
# QUICK CHECK
|
| 98 |
+
# ------------------------------------------------------------------ #
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
from model.config import SLLM_100M
|
| 102 |
+
|
| 103 |
+
cfg = SLLM_100M
|
| 104 |
+
attn = CausalSelfAttention(cfg)
|
| 105 |
+
print(f"Attention params : {sum(p.numel() for p in attn.parameters())/1e6:.2f}M")
|
| 106 |
+
|
| 107 |
+
B, T = 2, 64
|
| 108 |
+
x = torch.randn(B, T, cfg.d_model)
|
| 109 |
+
out = attn(x)
|
| 110 |
+
|
| 111 |
+
print(f"Input shape : {x.shape}")
|
| 112 |
+
print(f"Output shape : {out.shape}")
|
| 113 |
+
assert out.shape == (B, T, cfg.d_model), "Shape mismatch!"
|
| 114 |
+
print("PASS")
|
model/block.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/block.py
|
| 3 |
+
|
| 4 |
+
Single Transformer Block (pre-norm LLaMA-style).
|
| 5 |
+
|
| 6 |
+
Pre-Norm vs Post-Norm:
|
| 7 |
+
GPT-2 (post-norm): x = x + Attention(LayerNorm(x)) <- less stable
|
| 8 |
+
LLaMA (pre-norm): x = LayerNorm(x); x = x + Attention(x) <- more stable
|
| 9 |
+
|
| 10 |
+
We use PRE-NORM with RMSNorm for training stability at scale.
|
| 11 |
+
|
| 12 |
+
Block structure:
|
| 13 |
+
x -> RMSNorm -> CausalSelfAttention -> (+residual)
|
| 14 |
+
-> RMSNorm -> SwiGLU MLP -> (+residual)
|
| 15 |
+
-> output
|
| 16 |
+
|
| 17 |
+
Note: Residual connections bypass both norm and sublayer, which allows
|
| 18 |
+
gradients to flow directly to earlier layers during backprop.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
from model.config import ModelConfig
|
| 25 |
+
from model.norm import RMSNorm
|
| 26 |
+
from model.attention import CausalSelfAttention
|
| 27 |
+
from model.mlp import SwiGLU
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TransformerBlock(nn.Module):
|
| 31 |
+
|
| 32 |
+
def __init__(self, config: ModelConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
# Pre-attention norm
|
| 36 |
+
self.norm_attn = RMSNorm(config.d_model)
|
| 37 |
+
|
| 38 |
+
# Causal self-attention with RoPE
|
| 39 |
+
self.attn = CausalSelfAttention(config)
|
| 40 |
+
|
| 41 |
+
# Pre-FFN norm
|
| 42 |
+
self.norm_mlp = RMSNorm(config.d_model)
|
| 43 |
+
|
| 44 |
+
# SwiGLU feed-forward
|
| 45 |
+
self.mlp = SwiGLU(config)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
x : (B, T, d_model)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
x : (B, T, d_model)
|
| 54 |
+
"""
|
| 55 |
+
# Attention sub-layer with residual
|
| 56 |
+
x = x + self.attn(self.norm_attn(x))
|
| 57 |
+
|
| 58 |
+
# FFN sub-layer with residual
|
| 59 |
+
x = x + self.mlp(self.norm_mlp(x))
|
| 60 |
+
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ------------------------------------------------------------------ #
|
| 65 |
+
# QUICK CHECK
|
| 66 |
+
# ------------------------------------------------------------------ #
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
from model.config import SLLM_100M
|
| 70 |
+
|
| 71 |
+
cfg = SLLM_100M
|
| 72 |
+
block = TransformerBlock(cfg)
|
| 73 |
+
|
| 74 |
+
n = sum(p.numel() for p in block.parameters())
|
| 75 |
+
print(f"Block params : {n/1e6:.3f}M")
|
| 76 |
+
|
| 77 |
+
B, T = 2, 64
|
| 78 |
+
x = torch.randn(B, T, cfg.d_model)
|
| 79 |
+
out = block(x)
|
| 80 |
+
|
| 81 |
+
print(f"Input shape : {x.shape}")
|
| 82 |
+
print(f"Output shape : {out.shape}")
|
| 83 |
+
assert out.shape == x.shape, "Shape mismatch!"
|
| 84 |
+
print("PASS")
|
model/config.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/config.py
|
| 3 |
+
|
| 4 |
+
ModelConfig dataclass + preset configs for SLLM-100M and SLLM-150M.
|
| 5 |
+
All hyperparameters live here so every other module imports from one place.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _swiglu_d_ff(d_model: int) -> int:
|
| 12 |
+
"""
|
| 13 |
+
SwiGLU hidden dimension.
|
| 14 |
+
LLaMA formula: round_up_256( int(2/3 * 4 * d_model) )
|
| 15 |
+
"""
|
| 16 |
+
raw = int(2 / 3 * 4 * d_model)
|
| 17 |
+
return ((raw + 255) // 256) * 256 # round up to nearest 256
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class ModelConfig:
|
| 22 |
+
# ---- Vocabulary ------------------------------------------------- #
|
| 23 |
+
vocab_size: int = 32_000 # must match trained tokenizer
|
| 24 |
+
|
| 25 |
+
# ---- Sequence --------------------------------------------------- #
|
| 26 |
+
context_length: int = 1024 # max tokens per sequence
|
| 27 |
+
|
| 28 |
+
# ---- Transformer dimensions ------------------------------------- #
|
| 29 |
+
d_model: int = 768 # embedding / hidden dim
|
| 30 |
+
n_heads: int = 12 # number of attention heads
|
| 31 |
+
n_layers: int = 12 # number of transformer blocks
|
| 32 |
+
|
| 33 |
+
# ---- FFN -------------------------------------------------------- #
|
| 34 |
+
# SwiGLU d_ff is auto-computed from d_model if not set explicitly
|
| 35 |
+
d_ff: int = 0 # 0 = auto
|
| 36 |
+
|
| 37 |
+
# ---- Regularization --------------------------------------------- #
|
| 38 |
+
dropout: float = 0.0 # 0.0 for pre-training
|
| 39 |
+
|
| 40 |
+
# ---- Misc ------------------------------------------------------- #
|
| 41 |
+
bias: bool = False # no bias (cleaner, matches LLaMA)
|
| 42 |
+
rope_theta: float = 10_000.0 # RoPE base frequency
|
| 43 |
+
|
| 44 |
+
def __post_init__(self):
|
| 45 |
+
# Auto-compute d_ff if not set
|
| 46 |
+
if self.d_ff == 0:
|
| 47 |
+
self.d_ff = _swiglu_d_ff(self.d_model)
|
| 48 |
+
|
| 49 |
+
# Sanity checks
|
| 50 |
+
assert self.d_model % self.n_heads == 0, (
|
| 51 |
+
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def head_dim(self) -> int:
|
| 56 |
+
return self.d_model // self.n_heads
|
| 57 |
+
|
| 58 |
+
def count_params(self) -> int:
|
| 59 |
+
"""Returns total trainable parameter count (with tied embeddings)."""
|
| 60 |
+
embed = self.vocab_size * self.d_model
|
| 61 |
+
attn = 4 * self.d_model * self.d_model # Q, K, V, O
|
| 62 |
+
mlp = 3 * self.d_model * self.d_ff # gate, up, down
|
| 63 |
+
norms = 2 * self.d_model # pre-attn + pre-mlp
|
| 64 |
+
per_block = attn + mlp + norms
|
| 65 |
+
final_norm = self.d_model
|
| 66 |
+
return embed + self.n_layers * per_block + final_norm
|
| 67 |
+
|
| 68 |
+
def __repr__(self) -> str:
|
| 69 |
+
n = self.count_params()
|
| 70 |
+
return (
|
| 71 |
+
f"ModelConfig("
|
| 72 |
+
f"d={self.d_model}, h={self.n_heads}, l={self.n_layers}, "
|
| 73 |
+
f"ff={self.d_ff}, ctx={self.context_length}, "
|
| 74 |
+
f"params={n/1e6:.1f}M)"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ------------------------------------------------------------------ #
|
| 79 |
+
# PRESET CONFIGS
|
| 80 |
+
# ------------------------------------------------------------------ #
|
| 81 |
+
|
| 82 |
+
SLLM_100M = ModelConfig(
|
| 83 |
+
vocab_size = 32_000,
|
| 84 |
+
context_length = 1024,
|
| 85 |
+
d_model = 768,
|
| 86 |
+
n_heads = 12,
|
| 87 |
+
n_layers = 12,
|
| 88 |
+
# d_ff auto = 2048
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
SLLM_150M = ModelConfig(
|
| 92 |
+
vocab_size = 32_000,
|
| 93 |
+
context_length = 1024,
|
| 94 |
+
d_model = 1024,
|
| 95 |
+
n_heads = 16,
|
| 96 |
+
n_layers = 9,
|
| 97 |
+
# d_ff auto = 2816
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ------------------------------------------------------------------ #
|
| 102 |
+
# QUICK CHECK
|
| 103 |
+
# ------------------------------------------------------------------ #
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
for cfg in [SLLM_100M, SLLM_150M]:
|
| 107 |
+
print(cfg)
|
| 108 |
+
print(f" head_dim : {cfg.head_dim}")
|
| 109 |
+
print(f" d_ff : {cfg.d_ff}")
|
| 110 |
+
print()
|
model/mlp.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/mlp.py
|
| 3 |
+
|
| 4 |
+
SwiGLU Feed-Forward Network — used in LLaMA, PaLM, Mistral, etc.
|
| 5 |
+
|
| 6 |
+
Standard FFN (GPT-2):
|
| 7 |
+
out = dropout(W2 * GELU(W1 * x))
|
| 8 |
+
|
| 9 |
+
SwiGLU FFN (LLaMA):
|
| 10 |
+
gate = W_gate * x # linear gate
|
| 11 |
+
up = W_up * x # linear up-proj
|
| 12 |
+
hidden = SiLU(gate) * up # element-wise gating (learned)
|
| 13 |
+
out = W_down * hidden # down-proj back to d_model
|
| 14 |
+
|
| 15 |
+
SiLU (Sigmoid Linear Unit):
|
| 16 |
+
SiLU(x) = x * sigmoid(x)
|
| 17 |
+
|
| 18 |
+
Why SwiGLU is better:
|
| 19 |
+
- The gating mechanism (SiLU(gate) * up) gives the model a learned
|
| 20 |
+
way to activate or suppress each hidden dimension independently.
|
| 21 |
+
- Empirically outperforms GELU/ReLU FFNs at the same parameter count.
|
| 22 |
+
- d_ff is set to int(2/3 * 4 * d_model) rounded to nearest 256.
|
| 23 |
+
This compensates for having 3 matrices instead of 2, keeping
|
| 24 |
+
total parameter count comparable to a standard 4x FFN.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
|
| 31 |
+
from model.config import ModelConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SwiGLU(nn.Module):
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: ModelConfig):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
d_model = config.d_model
|
| 40 |
+
d_ff = config.d_ff
|
| 41 |
+
|
| 42 |
+
# Three weight matrices — no bias
|
| 43 |
+
self.gate = nn.Linear(d_model, d_ff, bias=config.bias) # gate projection
|
| 44 |
+
self.up = nn.Linear(d_model, d_ff, bias=config.bias) # up projection
|
| 45 |
+
self.down = nn.Linear(d_ff, d_model, bias=config.bias) # down projection
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
x : (B, T, d_model)
|
| 51 |
+
Returns:
|
| 52 |
+
out : (B, T, d_model)
|
| 53 |
+
"""
|
| 54 |
+
# SiLU = x * sigmoid(x) (also called swish)
|
| 55 |
+
# Element-wise gating: SiLU(gate) acts as a learned activation mask on up
|
| 56 |
+
return self.down(F.silu(self.gate(x)) * self.up(x))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ------------------------------------------------------------------ #
|
| 60 |
+
# QUICK CHECK
|
| 61 |
+
# ------------------------------------------------------------------ #
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
from model.config import SLLM_100M
|
| 65 |
+
|
| 66 |
+
cfg = SLLM_100M
|
| 67 |
+
mlp = SwiGLU(cfg)
|
| 68 |
+
|
| 69 |
+
n_params = sum(p.numel() for p in mlp.parameters())
|
| 70 |
+
print(f"SwiGLU d_model={cfg.d_model} d_ff={cfg.d_ff}")
|
| 71 |
+
print(f" gate : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
|
| 72 |
+
print(f" up : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
|
| 73 |
+
print(f" down : {cfg.d_ff} x {cfg.d_model} = {cfg.d_ff * cfg.d_model:,}")
|
| 74 |
+
print(f" total MLP params : {n_params/1e6:.3f}M")
|
| 75 |
+
|
| 76 |
+
B, T = 2, 64
|
| 77 |
+
x = torch.randn(B, T, cfg.d_model)
|
| 78 |
+
out = mlp(x)
|
| 79 |
+
|
| 80 |
+
print(f"Input shape : {x.shape}")
|
| 81 |
+
print(f"Output shape : {out.shape}")
|
| 82 |
+
assert out.shape == x.shape, "Shape mismatch!"
|
| 83 |
+
print("PASS")
|
model/model.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/model.py
|
| 3 |
+
|
| 4 |
+
SLLM — Small Language Model (decoder-only Transformer).
|
| 5 |
+
|
| 6 |
+
Full architecture:
|
| 7 |
+
tokens (B, T)
|
| 8 |
+
-> Embedding (vocab_size -> d_model)
|
| 9 |
+
-> N x TransformerBlock (attention + FFN)
|
| 10 |
+
-> Final RMSNorm
|
| 11 |
+
-> LM Head (Linear d_model -> vocab_size) <- weight-TIED to embedding
|
| 12 |
+
|
| 13 |
+
Weight tying:
|
| 14 |
+
The embedding matrix and the LM head output matrix share the same weights.
|
| 15 |
+
- Halves memory for the embedding/output layers.
|
| 16 |
+
- A standard practice since GPT-2 (Press & Wolf, 2016).
|
| 17 |
+
|
| 18 |
+
Weight initialization:
|
| 19 |
+
- Embeddings: std=0.02 (GPT-2 convention)
|
| 20 |
+
- Linear layers: std=0.02
|
| 21 |
+
- Output projections (attn.o_proj, mlp.down): std = 0.02/sqrt(2*n_layers)
|
| 22 |
+
- Scaled down per GPT-2/NanoGPT: at initialization, the residual
|
| 23 |
+
stream grows as sqrt(n_layers), so we scale residual contributions down.
|
| 24 |
+
|
| 25 |
+
Forward:
|
| 26 |
+
Returns logits (B, T, vocab_size).
|
| 27 |
+
Loss is computed externally in the training loop for flexibility.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import math
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
from torch.utils.checkpoint import checkpoint
|
| 34 |
+
|
| 35 |
+
from model.config import ModelConfig
|
| 36 |
+
from model.norm import RMSNorm
|
| 37 |
+
from model.block import TransformerBlock
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SLLM(nn.Module):
|
| 41 |
+
|
| 42 |
+
def __init__(self, config: ModelConfig):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.config = config
|
| 45 |
+
|
| 46 |
+
# ---- Token embedding --------------------------------------- #
|
| 47 |
+
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
|
| 48 |
+
|
| 49 |
+
# ---- Transformer blocks ------------------------------------ #
|
| 50 |
+
self.blocks = nn.ModuleList([
|
| 51 |
+
TransformerBlock(config) for _ in range(config.n_layers)
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
# ---- Final norm -------------------------------------------- #
|
| 55 |
+
self.norm = RMSNorm(config.d_model)
|
| 56 |
+
|
| 57 |
+
# ---- LM Head ----------------------------------------------- #
|
| 58 |
+
# Linear: d_model -> vocab_size, no bias
|
| 59 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 60 |
+
|
| 61 |
+
# ---- Weight tying ------------------------------------------ #
|
| 62 |
+
# Share embedding weights with lm_head
|
| 63 |
+
self.lm_head.weight = self.token_emb.weight
|
| 64 |
+
|
| 65 |
+
# ---- Gradient checkpointing flag --------------------------- #
|
| 66 |
+
# Enabled via enable_gradient_checkpointing() to save VRAM
|
| 67 |
+
self._gradient_checkpointing = False
|
| 68 |
+
|
| 69 |
+
# ---- Initialize weights ------------------------------------ #
|
| 70 |
+
self.apply(self._init_weights)
|
| 71 |
+
|
| 72 |
+
def _init_weights(self, module: nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Custom weight initialization.
|
| 75 |
+
- Normal(0, 0.02) for Linear and Embedding
|
| 76 |
+
- Scaled residual projections: std *= 1/sqrt(2 * n_layers)
|
| 77 |
+
"""
|
| 78 |
+
if isinstance(module, nn.Linear):
|
| 79 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 80 |
+
if module.bias is not None:
|
| 81 |
+
nn.init.zeros_(module.bias)
|
| 82 |
+
|
| 83 |
+
elif isinstance(module, nn.Embedding):
|
| 84 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 85 |
+
|
| 86 |
+
# Scale down residual projections (attn output + mlp down)
|
| 87 |
+
# Accessed by name: o_proj and down
|
| 88 |
+
if isinstance(module, nn.Linear):
|
| 89 |
+
if getattr(module, '_is_residual', False):
|
| 90 |
+
scale = 0.02 / math.sqrt(2 * self.config.n_layers)
|
| 91 |
+
nn.init.normal_(module.weight, mean=0.0, std=scale)
|
| 92 |
+
|
| 93 |
+
def _mark_residual_projections(self):
|
| 94 |
+
"""
|
| 95 |
+
Mark output projections so _init_weights can scale them.
|
| 96 |
+
Called after __init__ to tag the specific layers.
|
| 97 |
+
"""
|
| 98 |
+
for block in self.blocks:
|
| 99 |
+
block.attn.o_proj._is_residual = True
|
| 100 |
+
block.mlp.down._is_residual = True
|
| 101 |
+
self.apply(self._init_weights)
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
input_ids: torch.Tensor,
|
| 106 |
+
targets: torch.Tensor = None,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
Args:
|
| 110 |
+
input_ids : (B, T) — integer token IDs
|
| 111 |
+
targets : (B, T) — optional, for loss computation
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
logits : (B, T, vocab_size)
|
| 115 |
+
loss : scalar CrossEntropy loss if targets given, else None
|
| 116 |
+
"""
|
| 117 |
+
B, T = input_ids.shape
|
| 118 |
+
assert T <= self.config.context_length, (
|
| 119 |
+
f"Sequence length {T} exceeds context_length {self.config.context_length}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# ---- Embedding --------------------------------------------- #
|
| 123 |
+
x = self.token_emb(input_ids) # (B, T, d_model)
|
| 124 |
+
|
| 125 |
+
# ---- Transformer blocks ------------------------------------ #
|
| 126 |
+
for block in self.blocks:
|
| 127 |
+
if self._gradient_checkpointing and self.training:
|
| 128 |
+
# Recompute activations during backward to save VRAM
|
| 129 |
+
# use_reentrant=False is the modern recommended API
|
| 130 |
+
x = checkpoint(block, x, use_reentrant=False)
|
| 131 |
+
else:
|
| 132 |
+
x = block(x)
|
| 133 |
+
|
| 134 |
+
# ---- Final norm -------------------------------------------- #
|
| 135 |
+
x = self.norm(x) # (B, T, d_model)
|
| 136 |
+
|
| 137 |
+
# ---- LM Head ----------------------------------------------- #
|
| 138 |
+
logits = self.lm_head(x) # (B, T, vocab_size)
|
| 139 |
+
|
| 140 |
+
# ---- Loss -------------------------------------------------- #
|
| 141 |
+
loss = None
|
| 142 |
+
if targets is not None:
|
| 143 |
+
# Flatten for cross-entropy: (B*T, vocab_size) vs (B*T,)
|
| 144 |
+
loss = nn.functional.cross_entropy(
|
| 145 |
+
logits.view(-1, logits.size(-1)),
|
| 146 |
+
targets.view(-1),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return logits, loss
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def generate(
|
| 153 |
+
self,
|
| 154 |
+
input_ids: torch.Tensor,
|
| 155 |
+
max_new_tokens: int,
|
| 156 |
+
temperature: float = 1.0,
|
| 157 |
+
top_k: int = None,
|
| 158 |
+
) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Autoregressive text generation (greedy or top-k sampling).
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
input_ids : (B, T) prompt tokens
|
| 164 |
+
max_new_tokens : number of tokens to generate
|
| 165 |
+
temperature : softmax temperature (1.0 = neutral, <1 = sharper)
|
| 166 |
+
top_k : if set, sample from top-k tokens only
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
(B, T + max_new_tokens) token IDs
|
| 170 |
+
"""
|
| 171 |
+
self.eval()
|
| 172 |
+
for _ in range(max_new_tokens):
|
| 173 |
+
|
| 174 |
+
# Crop context if longer than max
|
| 175 |
+
ctx = input_ids
|
| 176 |
+
if ctx.shape[1] > self.config.context_length:
|
| 177 |
+
ctx = ctx[:, -self.config.context_length:]
|
| 178 |
+
|
| 179 |
+
# Forward pass — only need last logit
|
| 180 |
+
logits, _ = self(ctx)
|
| 181 |
+
logits = logits[:, -1, :] / temperature # (B, vocab_size)
|
| 182 |
+
|
| 183 |
+
# Optional top-k filtering
|
| 184 |
+
if top_k is not None:
|
| 185 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 186 |
+
logits[logits < v[:, [-1]]] = float('-inf')
|
| 187 |
+
|
| 188 |
+
# Sample from distribution
|
| 189 |
+
probs = torch.softmax(logits, dim=-1)
|
| 190 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 191 |
+
|
| 192 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 193 |
+
|
| 194 |
+
return input_ids
|
| 195 |
+
|
| 196 |
+
def enable_gradient_checkpointing(self):
|
| 197 |
+
"""
|
| 198 |
+
Enables gradient checkpointing to reduce VRAM usage.
|
| 199 |
+
Recomputes activations during the backward pass instead of
|
| 200 |
+
storing them — trades ~30% more compute for ~40% less memory.
|
| 201 |
+
Essential for fitting 100M+ models on 4GB VRAM.
|
| 202 |
+
"""
|
| 203 |
+
self._gradient_checkpointing = True
|
| 204 |
+
|
| 205 |
+
def count_params(self, non_embedding: bool = False) -> int:
|
| 206 |
+
"""
|
| 207 |
+
Returns parameter count.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
non_embedding: if True, exclude embedding parameters
|
| 211 |
+
(common in LLM reporting since embeddings scale
|
| 212 |
+
with vocab size and not model capacity)
|
| 213 |
+
"""
|
| 214 |
+
total = sum(p.numel() for p in self.parameters())
|
| 215 |
+
if non_embedding:
|
| 216 |
+
total -= self.token_emb.weight.numel()
|
| 217 |
+
return total
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ------------------------------------------------------------------ #
|
| 221 |
+
# QUICK CHECK
|
| 222 |
+
# ------------------------------------------------------------------ #
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
from model.config import SLLM_100M, SLLM_150M
|
| 226 |
+
|
| 227 |
+
for name, cfg in [("SLLM-100M", SLLM_100M), ("SLLM-150M", SLLM_150M)]:
|
| 228 |
+
model = SLLM(cfg)
|
| 229 |
+
|
| 230 |
+
total = model.count_params()
|
| 231 |
+
non_emb = model.count_params(non_embedding=True)
|
| 232 |
+
print(f"{name}")
|
| 233 |
+
print(f" total params : {total/1e6:.1f}M")
|
| 234 |
+
print(f" non-embedding params : {non_emb/1e6:.1f}M")
|
| 235 |
+
print(f" embedding params : {(total-non_emb)/1e6:.1f}M")
|
| 236 |
+
|
| 237 |
+
# Forward pass check
|
| 238 |
+
B, T = 2, 64
|
| 239 |
+
ids = torch.randint(0, cfg.vocab_size, (B, T))
|
| 240 |
+
targets = torch.randint(0, cfg.vocab_size, (B, T))
|
| 241 |
+
|
| 242 |
+
logits, loss = model(ids, targets)
|
| 243 |
+
print(f" logits shape : {logits.shape}")
|
| 244 |
+
print(f" loss : {loss.item():.4f}")
|
| 245 |
+
print()
|
model/norm.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/norm.py
|
| 3 |
+
|
| 4 |
+
RMSNorm — Root Mean Square Layer Normalization.
|
| 5 |
+
Used in LLaMA-style transformers instead of standard LayerNorm.
|
| 6 |
+
|
| 7 |
+
Key difference from LayerNorm:
|
| 8 |
+
- No mean subtraction (centering)
|
| 9 |
+
- No bias term
|
| 10 |
+
- Only re-scales with a single learned gain vector (weight)
|
| 11 |
+
- ~40% faster in practice (no mean computation)
|
| 12 |
+
|
| 13 |
+
Formula:
|
| 14 |
+
RMSNorm(x) = x / RMS(x) * weight
|
| 15 |
+
where RMS(x) = sqrt( mean(x^2) + eps )
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RMSNorm(nn.Module):
|
| 23 |
+
|
| 24 |
+
def __init__(self, d_model: int, eps: float = 1e-6):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
d_model : hidden dimension (size of last axis of input)
|
| 28 |
+
eps : small constant for numerical stability
|
| 29 |
+
"""
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.eps = eps
|
| 32 |
+
self.weight = nn.Parameter(torch.ones(d_model)) # learnable gain
|
| 33 |
+
|
| 34 |
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
# x: (..., d_model)
|
| 36 |
+
# compute RMS along last dimension, keepdim for broadcasting
|
| 37 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 38 |
+
|
| 39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
# cast to float32 for stable norm, then back to input dtype
|
| 41 |
+
output = self._norm(x.float()).type_as(x)
|
| 42 |
+
return output * self.weight
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ------------------------------------------------------------------ #
|
| 46 |
+
# QUICK CHECK
|
| 47 |
+
# ------------------------------------------------------------------ #
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
torch.manual_seed(0)
|
| 51 |
+
B, T, D = 2, 16, 768
|
| 52 |
+
x = torch.randn(B, T, D)
|
| 53 |
+
norm = RMSNorm(D)
|
| 54 |
+
|
| 55 |
+
out = norm(x)
|
| 56 |
+
print(f"Input shape : {x.shape}")
|
| 57 |
+
print(f"Output shape : {out.shape}")
|
| 58 |
+
print(f"Output dtype : {out.dtype}")
|
| 59 |
+
|
| 60 |
+
# Verify: each vector should be approximately unit RMS after norm (before weight)
|
| 61 |
+
rms_before = x.pow(2).mean(dim=-1).sqrt()
|
| 62 |
+
rms_after = out.pow(2).mean(dim=-1).sqrt()
|
| 63 |
+
print(f"RMS before norm : {rms_before.mean():.3f}")
|
| 64 |
+
print(f"RMS after norm : {rms_after.mean():.3f} (weight=1 so should be ~1.0)")
|
| 65 |
+
print("PASS" if torch.allclose(rms_after, torch.ones_like(rms_after), atol=1e-4) else "FAIL")
|
model/rope.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model/rope.py
|
| 3 |
+
|
| 4 |
+
Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer).
|
| 5 |
+
Used in LLaMA, Mistral, Gemma, etc.
|
| 6 |
+
|
| 7 |
+
Core idea:
|
| 8 |
+
Instead of adding position embeddings to token vectors, we ROTATE
|
| 9 |
+
the query and key vectors in attention using position-dependent angles.
|
| 10 |
+
|
| 11 |
+
- Relative positions are encoded implicitly via dot-product invariance.
|
| 12 |
+
- Works for any sequence length (extrapolates beyond training length).
|
| 13 |
+
- Only applied to Q and K, NOT V.
|
| 14 |
+
|
| 15 |
+
Implementation:
|
| 16 |
+
1. Precompute cos/sin tables for all positions up to max_seq_len.
|
| 17 |
+
Shape: (max_seq_len, head_dim)
|
| 18 |
+
|
| 19 |
+
2. At forward time, slice cos/sin to the current seq_len and
|
| 20 |
+
apply rotation to Q and K.
|
| 21 |
+
|
| 22 |
+
Rotation formula (pairs of dims):
|
| 23 |
+
Given a vector x with dims [x0, x1, x2, x3, ...]:
|
| 24 |
+
Pair each consecutive two dims: (x0,x1), (x2,x3), ...
|
| 25 |
+
Rotate each pair by angle theta_i * position:
|
| 26 |
+
[x0*cos - x1*sin, x0*sin + x1*cos, ...]
|
| 27 |
+
|
| 28 |
+
Equivalent implementation using rotate_half:
|
| 29 |
+
rotated = concat([-x_second_half, x_first_half]) # swapped halves
|
| 30 |
+
out = x * cos + rotated * sin
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
from typing import Tuple
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def precompute_rope_freqs(
|
| 39 |
+
head_dim: int,
|
| 40 |
+
max_seq_len: int,
|
| 41 |
+
theta: float = 10_000.0,
|
| 42 |
+
device: torch.device = None,
|
| 43 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
"""
|
| 45 |
+
Precompute RoPE cosine and sine tables.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
head_dim : dimension of each attention head (must be even)
|
| 49 |
+
max_seq_len : max sequence length to precompute
|
| 50 |
+
theta : RoPE base frequency (default 10_000, use 500_000 for long context)
|
| 51 |
+
device : torch device
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
cos : (max_seq_len, head_dim)
|
| 55 |
+
sin : (max_seq_len, head_dim)
|
| 56 |
+
"""
|
| 57 |
+
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
|
| 58 |
+
|
| 59 |
+
# Inverse frequencies: shape (head_dim // 2,)
|
| 60 |
+
# inv_freq[i] = 1 / theta^(2i / head_dim)
|
| 61 |
+
i = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 62 |
+
inv_freq = 1.0 / (theta ** (i / head_dim))
|
| 63 |
+
|
| 64 |
+
# Position indices: shape (max_seq_len,)
|
| 65 |
+
positions = torch.arange(max_seq_len, dtype=torch.float32, device=device)
|
| 66 |
+
|
| 67 |
+
# Outer product: (max_seq_len, head_dim // 2)
|
| 68 |
+
freqs = torch.outer(positions, inv_freq)
|
| 69 |
+
|
| 70 |
+
# Duplicate along last dim to match head_dim:
|
| 71 |
+
# (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
|
| 72 |
+
# cos/sin applied to [x0,x1,x2,x3,...] as [theta0,theta0, theta1,theta1, ...]
|
| 73 |
+
freqs = torch.cat([freqs, freqs], dim=-1)
|
| 74 |
+
|
| 75 |
+
return freqs.cos(), freqs.sin()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""
|
| 80 |
+
Rotates pairs of dimensions in the last axis.
|
| 81 |
+
Splits last dim in half, negates the second half, then swaps:
|
| 82 |
+
[x0..xN/2, xN/2..xN] -> [-xN/2..xN, x0..xN/2]
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
x: (..., head_dim)
|
| 86 |
+
Returns:
|
| 87 |
+
rotated: (..., head_dim)
|
| 88 |
+
"""
|
| 89 |
+
half = x.shape[-1] // 2
|
| 90 |
+
x1 = x[..., :half] # first half
|
| 91 |
+
x2 = x[..., half:] # second half
|
| 92 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def apply_rope(
|
| 96 |
+
q: torch.Tensor,
|
| 97 |
+
k: torch.Tensor,
|
| 98 |
+
cos: torch.Tensor,
|
| 99 |
+
sin: torch.Tensor,
|
| 100 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 101 |
+
"""
|
| 102 |
+
Apply RoPE rotation to query and key tensors.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
q : (B, n_heads, T, head_dim)
|
| 106 |
+
k : (B, n_heads, T, head_dim)
|
| 107 |
+
cos : (T, head_dim) - precomputed from precompute_rope_freqs
|
| 108 |
+
sin : (T, head_dim) - precomputed from precompute_rope_freqs
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
q_rot, k_rot : same shapes as inputs
|
| 112 |
+
"""
|
| 113 |
+
# Broadcast cos/sin from (T, head_dim) to (1, 1, T, head_dim)
|
| 114 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 115 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 116 |
+
|
| 117 |
+
q_rot = (q * cos) + (rotate_half(q) * sin)
|
| 118 |
+
k_rot = (k * cos) + (rotate_half(k) * sin)
|
| 119 |
+
return q_rot, k_rot
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class RoPECache(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
Module that holds the RoPE cos/sin cache as a buffer.
|
| 125 |
+
Not a learnable module — just stores precomputed freqs and moves them
|
| 126 |
+
to the right device automatically via register_buffer.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0):
|
| 130 |
+
super().__init__()
|
| 131 |
+
cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta)
|
| 132 |
+
# register_buffer: not a parameter, but moves with .to(device)
|
| 133 |
+
self.register_buffer("cos", cos, persistent=True)
|
| 134 |
+
self.register_buffer("sin", sin, persistent=True)
|
| 135 |
+
|
| 136 |
+
def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 137 |
+
"""Slice cos/sin to current sequence length."""
|
| 138 |
+
return self.cos[:seq_len], self.sin[:seq_len]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ------------------------------------------------------------------ #
|
| 142 |
+
# QUICK CHECK
|
| 143 |
+
# ------------------------------------------------------------------ #
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
torch.manual_seed(0)
|
| 147 |
+
|
| 148 |
+
B, n_heads, T, head_dim = 2, 12, 16, 64
|
| 149 |
+
|
| 150 |
+
cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024)
|
| 151 |
+
cos_T = cos[:T]
|
| 152 |
+
sin_T = sin[:T]
|
| 153 |
+
|
| 154 |
+
q = torch.randn(B, n_heads, T, head_dim)
|
| 155 |
+
k = torch.randn(B, n_heads, T, head_dim)
|
| 156 |
+
|
| 157 |
+
q_rot, k_rot = apply_rope(q, k, cos_T, sin_T)
|
| 158 |
+
|
| 159 |
+
print(f"q shape : {q.shape}")
|
| 160 |
+
print(f"q_rot shape : {q_rot.shape}")
|
| 161 |
+
print(f"k_rot shape : {k_rot.shape}")
|
| 162 |
+
|
| 163 |
+
# Verify: rotation should preserve norm (|x| = |Rx|)
|
| 164 |
+
q_norm = q.norm(dim=-1)
|
| 165 |
+
q_rot_norm = q_rot.norm(dim=-1)
|
| 166 |
+
print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}")
|
| 167 |
+
|
| 168 |
+
# Test RoPECache
|
| 169 |
+
cache = RoPECache(head_dim=64, max_seq_len=1024)
|
| 170 |
+
c, s = cache.get(T)
|
| 171 |
+
print(f"Cache cos shape: {c.shape}")
|
| 172 |
+
print("PASS")
|
model_explained.md
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Folder — Plain Language Explanation
|
| 2 |
+
|
| 3 |
+
The `model/` folder builds a **GPT-style decoder-only transformer** from scratch,
|
| 4 |
+
piece by piece. Each file is one component. Here's how they stack:
|
| 5 |
+
|
| 6 |
+
```
|
| 7 |
+
tokens (integers)
|
| 8 |
+
│
|
| 9 |
+
▼
|
| 10 |
+
┌─────────────┐
|
| 11 |
+
│ Embedding │ config.py defines the shape of everything
|
| 12 |
+
└──────┬──────┘
|
| 13 |
+
│
|
| 14 |
+
▼ ×N layers
|
| 15 |
+
┌──────────────────────────────────────┐
|
| 16 |
+
│ TransformerBlock │ block.py
|
| 17 |
+
│ │
|
| 18 |
+
│ ┌──────────┐ ┌──────────────┐ │
|
| 19 |
+
│ │ RMSNorm │ │ RMSNorm │ │ norm.py
|
| 20 |
+
│ └────┬─────┘ └──────┬───────┘ │
|
| 21 |
+
│ │ │ │
|
| 22 |
+
│ ┌────▼─────┐ ┌──────▼───────┐ │
|
| 23 |
+
│ │Attention │ │ SwiGLU MLP │ │ attention.py / mlp.py
|
| 24 |
+
│ │ + RoPE │ │ │ │ rope.py
|
| 25 |
+
│ └────┬─────┘ └──────┬───────┘ │
|
| 26 |
+
│ │ (+residual) │ (+residual)│
|
| 27 |
+
└────────┼─────────────────┼───────────┘
|
| 28 |
+
│ │
|
| 29 |
+
└────────┬────────┘
|
| 30 |
+
│
|
| 31 |
+
▼
|
| 32 |
+
┌──────────┐
|
| 33 |
+
│ RMSNorm │ final norm
|
| 34 |
+
└────┬─────┘
|
| 35 |
+
│
|
| 36 |
+
┌────▼─────┐
|
| 37 |
+
│ LM Head │ Linear → vocab_size logits
|
| 38 |
+
└──────────┘
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 1. `config.py` — The Blueprint
|
| 44 |
+
|
| 45 |
+
**What it does:** Stores all the numbers that define the model size.
|
| 46 |
+
Nothing computes anything here — it's just a settings object.
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
@dataclass
|
| 50 |
+
class ModelConfig:
|
| 51 |
+
vocab_size = 32_000 # how many tokens exist
|
| 52 |
+
context_length = 1024 # max sequence length
|
| 53 |
+
d_model = 1024 # width of every vector throughout the model
|
| 54 |
+
n_heads = 16 # how many attention heads
|
| 55 |
+
n_layers = 9 # how many transformer blocks stacked
|
| 56 |
+
d_ff = 2816 # width of the MLP hidden layer (auto-computed)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
**Why these numbers?**
|
| 60 |
+
- `d_model` is the "resolution" of the model — bigger = more expressive but more memory
|
| 61 |
+
- `n_heads` splits each attention layer into parallel sub-attentions
|
| 62 |
+
- `head_dim = d_model / n_heads = 64` — each head sees 64-dim slices
|
| 63 |
+
- `d_ff` for SwiGLU = `round_256( 2/3 × 4 × d_model )` — compensates for having 3 matrices instead of 2
|
| 64 |
+
|
| 65 |
+
**Presets defined here:**
|
| 66 |
+
```
|
| 67 |
+
SLLM_100M: d=768, h=12, l=12 → 109.5M params
|
| 68 |
+
SLLM_150M: d=1024, h=16, l=9 → 148.4M params
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## 2. `norm.py` — RMSNorm
|
| 74 |
+
|
| 75 |
+
**What it does:** Normalizes vectors so they don't explode or vanish during training.
|
| 76 |
+
Used before every attention and MLP layer.
|
| 77 |
+
|
| 78 |
+
**Standard LayerNorm (GPT-2):**
|
| 79 |
+
```
|
| 80 |
+
1. Compute mean of x
|
| 81 |
+
2. Subtract mean (centering)
|
| 82 |
+
3. Divide by std
|
| 83 |
+
4. Scale by learned weight
|
| 84 |
+
5. Add learned bias
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
**RMSNorm (LLaMA / our model):**
|
| 88 |
+
```
|
| 89 |
+
1. Compute RMS = sqrt( mean(x²) ) ← no mean subtraction!
|
| 90 |
+
2. Divide by RMS
|
| 91 |
+
3. Scale by learned weight ← no bias!
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
**Why simpler is better:**
|
| 95 |
+
- No mean subtraction → ~40% faster
|
| 96 |
+
- No bias → fewer parameters
|
| 97 |
+
- Works just as well in practice
|
| 98 |
+
- LLaMA, Mistral, Gemma all use it
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
# What it computes:
|
| 102 |
+
output = (x / sqrt(mean(x²) + 1e-6)) * weight
|
| 103 |
+
# ↑ normalize ↑ rescale with learned gain
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
The `weight` starts at all-ones (no change at init) and is learned during training.
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 3. `rope.py` — Rotary Position Embedding (RoPE)
|
| 111 |
+
|
| 112 |
+
**The problem it solves:** Transformers have no built-in sense of position.
|
| 113 |
+
Without position encoding, `"cat sat on mat"` and `"mat on sat cat"` look identical.
|
| 114 |
+
|
| 115 |
+
**How older models solved it (GPT-2):**
|
| 116 |
+
Added a fixed learned vector to each token: `token[i] += position_embedding[i]`
|
| 117 |
+
Problem: can't generalize beyond the training length.
|
| 118 |
+
|
| 119 |
+
**What RoPE does instead:**
|
| 120 |
+
Instead of adding position info to token vectors, it **rotates** the Query and Key
|
| 121 |
+
vectors in attention by an angle that depends on their position.
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
Token at position 3 → rotate Q and K by angle θ₃
|
| 125 |
+
Token at position 7 → rotate Q and K by angle θ₇
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
When you compute attention score `Q·K`, the rotation cancels out in a way that
|
| 129 |
+
encodes *relative distance* between tokens, not absolute positions.
|
| 130 |
+
|
| 131 |
+
**Why this is better:**
|
| 132 |
+
- No extra parameters (pure math, no learned table)
|
| 133 |
+
- Works beyond training length (extrapolates)
|
| 134 |
+
- Used in LLaMA, Mistral, GPT-4 (likely), Gemma
|
| 135 |
+
|
| 136 |
+
**How the code works:**
|
| 137 |
+
```python
|
| 138 |
+
# Step 1: precompute a table of cos/sin values for every position
|
| 139 |
+
cos, sin = precompute_rope_freqs(head_dim=64, max_seq_len=1024)
|
| 140 |
+
# cos/sin shape: (1024, 64)
|
| 141 |
+
|
| 142 |
+
# Step 2: at forward time, rotate Q and K
|
| 143 |
+
q_rotated = q * cos + rotate_half(q) * sin
|
| 144 |
+
k_rotated = k * cos + rotate_half(k) * sin
|
| 145 |
+
|
| 146 |
+
# rotate_half(x): splits x in half, negates second half, swaps
|
| 147 |
+
# [a, b, c, d] → [-c, -d, a, b]
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
V (values) are **not** rotated — only Q and K get position encoding.
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## 4. `attention.py` — Causal Self-Attention
|
| 155 |
+
|
| 156 |
+
**What it does:** Lets every token look at all *previous* tokens and decide
|
| 157 |
+
which ones are relevant to predict the next token.
|
| 158 |
+
|
| 159 |
+
**The full flow:**
|
| 160 |
+
|
| 161 |
+
```
|
| 162 |
+
Input x: (Batch, Tokens, d_model)
|
| 163 |
+
e.g. (2, 1024, 1024)
|
| 164 |
+
│
|
| 165 |
+
▼
|
| 166 |
+
QKV projection: one big Linear(d_model → 3×d_model)
|
| 167 |
+
│
|
| 168 |
+
├─── Q: (2, 1024, 1024) — "what am I looking for?"
|
| 169 |
+
├─── K: (2, 1024, 1024) — "what do I contain?"
|
| 170 |
+
└─── V: (2, 1024, 1024) — "what do I send if attended to?"
|
| 171 |
+
│
|
| 172 |
+
▼
|
| 173 |
+
Reshape to heads: (2, 16_heads, 1024, 64_head_dim)
|
| 174 |
+
│
|
| 175 |
+
▼
|
| 176 |
+
Apply RoPE to Q and K ← position encoding happens here
|
| 177 |
+
│
|
| 178 |
+
▼
|
| 179 |
+
Scaled Dot-Product Attention:
|
| 180 |
+
scores = Q @ K^T / sqrt(64) # how much does each token attend to each other
|
| 181 |
+
mask = causal mask # can only look LEFT (past), not right (future)
|
| 182 |
+
weights = softmax(scores + mask)
|
| 183 |
+
out = weights @ V # weighted sum of values
|
| 184 |
+
│
|
| 185 |
+
▼
|
| 186 |
+
Reshape back: (2, 1024, 1024)
|
| 187 |
+
│
|
| 188 |
+
▼
|
| 189 |
+
Output projection: Linear(d_model → d_model)
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
**Causal mask** — this is what makes it a *language model* (predicts next token):
|
| 193 |
+
```
|
| 194 |
+
Position: 0 1 2 3
|
| 195 |
+
Token 0: [✓ ✗ ✗ ✗] can only see itself
|
| 196 |
+
Token 1: [✓ ✓ ✗ ✗] can see 0,1
|
| 197 |
+
Token 2: [✓ ✓ ✓ ✗] can see 0,1,2
|
| 198 |
+
Token 3: [✓ ✓ ✓ ✓] can see all
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
**Flash Attention:** We use `F.scaled_dot_product_attention(..., is_causal=True)`
|
| 202 |
+
which is PyTorch 2.0's built-in Flash Attention — it never materializes the full
|
| 203 |
+
O(T²) attention matrix in memory. Much faster and uses far less VRAM.
|
| 204 |
+
|
| 205 |
+
---
|
| 206 |
+
|
| 207 |
+
## 5. `mlp.py` — SwiGLU Feed-Forward Network
|
| 208 |
+
|
| 209 |
+
**What it does:** After attention (which mixes *between* tokens), the MLP
|
| 210 |
+
transforms each token *independently* — it's where most of the model's
|
| 211 |
+
"knowledge" is stored.
|
| 212 |
+
|
| 213 |
+
**Standard MLP (GPT-2):**
|
| 214 |
+
```python
|
| 215 |
+
out = W2 @ GELU(W1 @ x) # 2 matrices
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
**SwiGLU (LLaMA / our model):**
|
| 219 |
+
```python
|
| 220 |
+
gate = W_gate @ x # linear
|
| 221 |
+
up = W_up @ x # linear
|
| 222 |
+
hidden = SiLU(gate) * up # element-wise gate ← the key difference
|
| 223 |
+
out = W_down @ hidden # 3 matrices total
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
**What is SiLU?**
|
| 227 |
+
```
|
| 228 |
+
SiLU(x) = x × sigmoid(x)
|
| 229 |
+
```
|
| 230 |
+
It's a smooth version of ReLU — never exactly zero, has a small negative region.
|
| 231 |
+
|
| 232 |
+
**Why gating matters:**
|
| 233 |
+
- `SiLU(gate)` acts as a learned on/off switch for each hidden dimension
|
| 234 |
+
- The model learns to activate only the neurons relevant to each input
|
| 235 |
+
- Empirically outperforms GELU at the same parameter count
|
| 236 |
+
- Used in LLaMA, PaLM, Mistral
|
| 237 |
+
|
| 238 |
+
**The d_ff formula:**
|
| 239 |
+
```
|
| 240 |
+
d_ff = round_up_256( int(2/3 × 4 × d_model) )
|
| 241 |
+
|
| 242 |
+
For 150M: round_up_256( int(2/3 × 4 × 1024) ) = round_up_256(2730) = 2816
|
| 243 |
+
```
|
| 244 |
+
The `2/3` factor compensates for having 3 matrices instead of 2 — keeps
|
| 245 |
+
total parameter count equal to a standard 4× FFN.
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## 6. `block.py` — TransformerBlock
|
| 250 |
+
|
| 251 |
+
**What it does:** Wraps attention + MLP into one reusable block.
|
| 252 |
+
The model is just N copies of this block stacked.
|
| 253 |
+
|
| 254 |
+
```python
|
| 255 |
+
def forward(x):
|
| 256 |
+
# Attention sub-layer
|
| 257 |
+
x = x + attention( rmsnorm(x) ) # pre-norm + residual
|
| 258 |
+
|
| 259 |
+
# MLP sub-layer
|
| 260 |
+
x = x + mlp( rmsnorm(x) ) # pre-norm + residual
|
| 261 |
+
|
| 262 |
+
return x
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
**Two key ideas:**
|
| 266 |
+
|
| 267 |
+
**1. Pre-norm (normalize BEFORE the sublayer):**
|
| 268 |
+
```
|
| 269 |
+
Pre-norm (LLaMA): x → norm → attention → + original x
|
| 270 |
+
Post-norm (GPT-2): x → attention → + original x → norm
|
| 271 |
+
```
|
| 272 |
+
Pre-norm is more stable at large scale — gradients flow more cleanly.
|
| 273 |
+
|
| 274 |
+
**2. Residual connections (`x + sublayer(x)`):**
|
| 275 |
+
The output of each sublayer is *added* back to the input, not replacing it.
|
| 276 |
+
This means:
|
| 277 |
+
- Gradients can skip directly to earlier layers during backprop
|
| 278 |
+
- The model learns *corrections* to the input, not transformations from scratch
|
| 279 |
+
- Allows stacking many layers without vanishing gradients
|
| 280 |
+
|
| 281 |
+
---
|
| 282 |
+
|
| 283 |
+
## 7. `model.py` — SLLM (The Full Model)
|
| 284 |
+
|
| 285 |
+
**What it does:** Assembles everything into the complete language model.
|
| 286 |
+
|
| 287 |
+
```
|
| 288 |
+
tokens: (B, T) ← integer IDs like [423, 1829, 55, ...]
|
| 289 |
+
│
|
| 290 |
+
▼
|
| 291 |
+
token_emb: Embedding(32000 → 1024)
|
| 292 |
+
│ converts each integer to a 1024-dim vector
|
| 293 |
+
▼
|
| 294 |
+
blocks[0]: TransformerBlock ─┐
|
| 295 |
+
blocks[1]: TransformerBlock │ 9 blocks for 150M
|
| 296 |
+
... │
|
| 297 |
+
blocks[8]: TransformerBlock ─┘
|
| 298 |
+
│
|
| 299 |
+
▼
|
| 300 |
+
norm: RMSNorm(1024) ← final stabilization
|
| 301 |
+
│
|
| 302 |
+
▼
|
| 303 |
+
lm_head: Linear(1024 → 32000)
|
| 304 |
+
│ produces a score for each possible next token
|
| 305 |
+
▼
|
| 306 |
+
logits: (B, T, 32000) ← unnormalized scores
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
**Weight tying:**
|
| 310 |
+
The `token_emb` matrix and `lm_head` matrix **share the same weights**.
|
| 311 |
+
```python
|
| 312 |
+
self.lm_head.weight = self.token_emb.weight
|
| 313 |
+
```
|
| 314 |
+
- Same matrix used for: embedding lookup (input) AND output projection
|
| 315 |
+
- Saves 32M parameters (32000 × 1024)
|
| 316 |
+
- Works because: if token X has a similar embedding to the current hidden state,
|
| 317 |
+
it should also score highly as the next token prediction
|
| 318 |
+
|
| 319 |
+
**Loss computation:**
|
| 320 |
+
```python
|
| 321 |
+
# Cross-entropy: at each position, predict the NEXT token
|
| 322 |
+
# Input: [The, cat, sat, on] → predicts [cat, sat, on, mat]
|
| 323 |
+
# targets = input shifted by 1
|
| 324 |
+
loss = cross_entropy(logits.view(-1, 32000), targets.view(-1))
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
**Gradient checkpointing** (`enable_gradient_checkpointing()`):
|
| 328 |
+
Normally PyTorch saves all intermediate activations during forward pass to use
|
| 329 |
+
in backprop. For 9 layers with batch_size=2 and seq_len=1024, that's ~1.5GB.
|
| 330 |
+
|
| 331 |
+
With gradient checkpointing:
|
| 332 |
+
- Activations are **NOT saved** during forward pass
|
| 333 |
+
- During backward pass, they are **recomputed on-the-fly**
|
| 334 |
+
- Result: ~40% less VRAM, ~30% slower training
|
| 335 |
+
- Essential for fitting 150M on a 4GB GPU
|
| 336 |
+
|
| 337 |
+
**Weight initialization:**
|
| 338 |
+
```python
|
| 339 |
+
# All Linear and Embedding weights: Normal(mean=0, std=0.02)
|
| 340 |
+
# Residual projections (o_proj, mlp.down): scaled down by 1/sqrt(2 × n_layers)
|
| 341 |
+
```
|
| 342 |
+
The residual scaling prevents the residual stream from growing too large
|
| 343 |
+
at initialization when many layers add to it.
|
| 344 |
+
|
| 345 |
+
---
|
| 346 |
+
|
| 347 |
+
## How it all fits together — One forward pass
|
| 348 |
+
|
| 349 |
+
```
|
| 350 |
+
"The cat sat" → tokenizer → [423, 1829, 55]
|
| 351 |
+
|
| 352 |
+
token_emb: [423]→[0.1,-0.3,...] (1024 floats)
|
| 353 |
+
[1829]→[0.8, 0.2,...] (1024 floats)
|
| 354 |
+
[55] →[-0.1,0.4,...] (1024 floats)
|
| 355 |
+
|
| 356 |
+
Block 0:
|
| 357 |
+
norm → Q,K,V projections → RoPE rotation → Flash Attention → output proj → + residual
|
| 358 |
+
norm → gate,up projections → SiLU(gate)*up → down proj → + residual
|
| 359 |
+
|
| 360 |
+
Block 1..8: same
|
| 361 |
+
|
| 362 |
+
Final norm → LM head → 32000 scores per position
|
| 363 |
+
|
| 364 |
+
softmax → probabilities → sample next token
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
**Total parameters (150M):**
|
| 368 |
+
```
|
| 369 |
+
Embedding: 32000 × 1024 = 32.8M
|
| 370 |
+
Per block: attn(4.2M) + mlp(8.6M) + norms(~0M) = 12.85M
|
| 371 |
+
9 blocks: 9 × 12.85M = 115.6M
|
| 372 |
+
Final norm: 1024 = ~0M
|
| 373 |
+
LM head: TIED to embedding = 0M (reuses same weights)
|
| 374 |
+
─────────────────────────────────────────
|
| 375 |
+
TOTAL: 148.4M params
|
| 376 |
+
```
|
plot_training.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
plot_training.py — Training Visualization Dashboard
|
| 3 |
+
|
| 4 |
+
Reads train_log.jsonl and renders a clean, dark-mode training dashboard.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# Static plot of completed/current run
|
| 8 |
+
python plot_training.py --run_dir runs/run_001
|
| 9 |
+
|
| 10 |
+
# Live mode: refresh every 5 seconds while training runs
|
| 11 |
+
python plot_training.py --run_dir runs/run_001 --live
|
| 12 |
+
|
| 13 |
+
# Compare multiple runs
|
| 14 |
+
python plot_training.py --run_dir runs/run_001 runs/run_002
|
| 15 |
+
|
| 16 |
+
Dashboard panels:
|
| 17 |
+
1. Training Loss (raw + EMA smoothed)
|
| 18 |
+
2. Validation Loss (if available)
|
| 19 |
+
3. Learning Rate schedule
|
| 20 |
+
4. Tokens / second (throughput)
|
| 21 |
+
5. VRAM usage (if logged)
|
| 22 |
+
6. Gradient norm (if logged)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import json
|
| 28 |
+
import time
|
| 29 |
+
import argparse
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import matplotlib
|
| 33 |
+
import matplotlib.pyplot as plt
|
| 34 |
+
import matplotlib.gridspec as gridspec
|
| 35 |
+
import matplotlib.ticker as ticker
|
| 36 |
+
import numpy as np
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------ #
|
| 40 |
+
# STYLE
|
| 41 |
+
# ------------------------------------------------------------------ #
|
| 42 |
+
|
| 43 |
+
DARK_BG = "#0d1117"
|
| 44 |
+
PANEL_BG = "#161b22"
|
| 45 |
+
GRID_COLOR = "#21262d"
|
| 46 |
+
TEXT_COLOR = "#c9d1d9"
|
| 47 |
+
MUTED_COLOR = "#6e7681"
|
| 48 |
+
ACCENT_BLUE = "#58a6ff"
|
| 49 |
+
ACCENT_GREEN = "#3fb950"
|
| 50 |
+
ACCENT_ORANGE= "#d29922"
|
| 51 |
+
ACCENT_RED = "#f85149"
|
| 52 |
+
ACCENT_PURPLE= "#bc8cff"
|
| 53 |
+
ACCENT_TEAL = "#39d353"
|
| 54 |
+
|
| 55 |
+
matplotlib.rcParams.update({
|
| 56 |
+
"figure.facecolor": DARK_BG,
|
| 57 |
+
"axes.facecolor": PANEL_BG,
|
| 58 |
+
"axes.edgecolor": GRID_COLOR,
|
| 59 |
+
"axes.labelcolor": TEXT_COLOR,
|
| 60 |
+
"axes.titlecolor": TEXT_COLOR,
|
| 61 |
+
"xtick.color": MUTED_COLOR,
|
| 62 |
+
"ytick.color": MUTED_COLOR,
|
| 63 |
+
"grid.color": GRID_COLOR,
|
| 64 |
+
"grid.linestyle": "--",
|
| 65 |
+
"grid.linewidth": 0.5,
|
| 66 |
+
"grid.alpha": 0.7,
|
| 67 |
+
"legend.facecolor": PANEL_BG,
|
| 68 |
+
"legend.edgecolor": GRID_COLOR,
|
| 69 |
+
"legend.labelcolor": TEXT_COLOR,
|
| 70 |
+
"text.color": TEXT_COLOR,
|
| 71 |
+
"font.family": "DejaVu Sans",
|
| 72 |
+
"font.size": 10,
|
| 73 |
+
"axes.titlesize": 11,
|
| 74 |
+
"axes.labelsize": 10,
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ------------------------------------------------------------------ #
|
| 79 |
+
# DATA LOADING
|
| 80 |
+
# ------------------------------------------------------------------ #
|
| 81 |
+
|
| 82 |
+
def load_log(log_path: str) -> dict:
|
| 83 |
+
"""
|
| 84 |
+
Loads train_log.jsonl and returns separate arrays for each metric.
|
| 85 |
+
Returns dict of metric_name -> list of values, aligned by step.
|
| 86 |
+
"""
|
| 87 |
+
train_steps = []
|
| 88 |
+
train_loss = []
|
| 89 |
+
val_steps = []
|
| 90 |
+
val_loss = []
|
| 91 |
+
lr_steps = []
|
| 92 |
+
lr_vals = []
|
| 93 |
+
tok_steps = []
|
| 94 |
+
tok_vals = []
|
| 95 |
+
vram_steps = []
|
| 96 |
+
vram_vals = []
|
| 97 |
+
grad_steps = []
|
| 98 |
+
grad_vals = []
|
| 99 |
+
|
| 100 |
+
if not os.path.exists(log_path):
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
with open(log_path, "r") as f:
|
| 104 |
+
for line in f:
|
| 105 |
+
line = line.strip()
|
| 106 |
+
if not line:
|
| 107 |
+
continue
|
| 108 |
+
try:
|
| 109 |
+
entry = json.loads(line)
|
| 110 |
+
except json.JSONDecodeError:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
step = entry.get("step")
|
| 114 |
+
if step is None:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
if "loss" in entry:
|
| 118 |
+
train_steps.append(step)
|
| 119 |
+
train_loss.append(entry["loss"])
|
| 120 |
+
|
| 121 |
+
if "val_loss" in entry:
|
| 122 |
+
val_steps.append(step)
|
| 123 |
+
val_loss.append(entry["val_loss"])
|
| 124 |
+
|
| 125 |
+
if "lr" in entry:
|
| 126 |
+
lr_steps.append(step)
|
| 127 |
+
lr_vals.append(entry["lr"])
|
| 128 |
+
|
| 129 |
+
if "tok_per_sec" in entry:
|
| 130 |
+
tok_steps.append(step)
|
| 131 |
+
tok_vals.append(entry["tok_per_sec"])
|
| 132 |
+
|
| 133 |
+
if "vram_gb" in entry:
|
| 134 |
+
vram_steps.append(step)
|
| 135 |
+
vram_vals.append(entry["vram_gb"])
|
| 136 |
+
|
| 137 |
+
if "grad_norm" in entry and entry["grad_norm"] is not None:
|
| 138 |
+
grad_steps.append(step)
|
| 139 |
+
grad_vals.append(entry["grad_norm"])
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"train": (train_steps, train_loss),
|
| 143 |
+
"val": (val_steps, val_loss),
|
| 144 |
+
"lr": (lr_steps, lr_vals),
|
| 145 |
+
"tok": (tok_steps, tok_vals),
|
| 146 |
+
"vram": (vram_steps, vram_vals),
|
| 147 |
+
"grad": (grad_steps, grad_vals),
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def ema_smooth(values: list, alpha: float = 0.9) -> list:
|
| 152 |
+
"""Exponential moving average smoothing."""
|
| 153 |
+
if not values:
|
| 154 |
+
return values
|
| 155 |
+
smoothed = [values[0]]
|
| 156 |
+
for v in values[1:]:
|
| 157 |
+
smoothed.append(alpha * smoothed[-1] + (1 - alpha) * v)
|
| 158 |
+
return smoothed
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ------------------------------------------------------------------ #
|
| 162 |
+
# PLOTTING
|
| 163 |
+
# ------------------------------------------------------------------ #
|
| 164 |
+
|
| 165 |
+
def make_dashboard(data_dict: dict, run_names: list, save_path: str = None):
|
| 166 |
+
"""
|
| 167 |
+
Renders a multi-panel training dashboard.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
data_dict : dict of run_name -> metrics dict
|
| 171 |
+
run_names : list of run display names
|
| 172 |
+
save_path : if set, saves figure to this path instead of showing
|
| 173 |
+
"""
|
| 174 |
+
fig = plt.figure(figsize=(16, 10), facecolor=DARK_BG)
|
| 175 |
+
fig.suptitle(
|
| 176 |
+
"SLLM Training Dashboard",
|
| 177 |
+
fontsize=16,
|
| 178 |
+
fontweight="bold",
|
| 179 |
+
color=TEXT_COLOR,
|
| 180 |
+
y=0.98,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# 3x2 grid of panels
|
| 184 |
+
gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.45, wspace=0.3,
|
| 185 |
+
left=0.06, right=0.97, top=0.93, bottom=0.06)
|
| 186 |
+
|
| 187 |
+
ax_loss = fig.add_subplot(gs[0, 0])
|
| 188 |
+
ax_val = fig.add_subplot(gs[0, 1])
|
| 189 |
+
ax_lr = fig.add_subplot(gs[1, 0])
|
| 190 |
+
ax_tok = fig.add_subplot(gs[1, 1])
|
| 191 |
+
ax_vram = fig.add_subplot(gs[2, 0])
|
| 192 |
+
ax_grad = fig.add_subplot(gs[2, 1])
|
| 193 |
+
|
| 194 |
+
colors = [ACCENT_BLUE, ACCENT_GREEN, ACCENT_ORANGE, ACCENT_PURPLE]
|
| 195 |
+
|
| 196 |
+
has_val = False
|
| 197 |
+
has_vram = False
|
| 198 |
+
has_grad = False
|
| 199 |
+
|
| 200 |
+
for idx, (run_name, data) in enumerate(data_dict.items()):
|
| 201 |
+
if data is None:
|
| 202 |
+
continue
|
| 203 |
+
color = colors[idx % len(colors)]
|
| 204 |
+
|
| 205 |
+
# --- Train loss ------------------------------------------ #
|
| 206 |
+
steps, loss = data["train"]
|
| 207 |
+
if steps:
|
| 208 |
+
smoothed = ema_smooth(loss, alpha=0.92)
|
| 209 |
+
ax_loss.plot(steps, loss, color=color, alpha=0.25, linewidth=0.8)
|
| 210 |
+
ax_loss.plot(steps, smoothed, color=color, alpha=1.0, linewidth=1.8,
|
| 211 |
+
label=run_name)
|
| 212 |
+
# Annotate final loss
|
| 213 |
+
ax_loss.annotate(
|
| 214 |
+
f"{smoothed[-1]:.4f}",
|
| 215 |
+
xy=(steps[-1], smoothed[-1]),
|
| 216 |
+
xytext=(5, 0), textcoords="offset points",
|
| 217 |
+
color=color, fontsize=8, va="center",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# --- Val loss -------------------------------------------- #
|
| 221 |
+
vsteps, vloss = data["val"]
|
| 222 |
+
if vsteps:
|
| 223 |
+
has_val = True
|
| 224 |
+
ax_val.plot(vsteps, vloss, color=color, linewidth=2, marker="o",
|
| 225 |
+
markersize=4, label=run_name)
|
| 226 |
+
ax_val.annotate(
|
| 227 |
+
f"{vloss[-1]:.4f}",
|
| 228 |
+
xy=(vsteps[-1], vloss[-1]),
|
| 229 |
+
xytext=(5, 0), textcoords="offset points",
|
| 230 |
+
color=color, fontsize=8, va="center",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# --- LR -------------------------------------------------- #
|
| 234 |
+
lsteps, lvals = data["lr"]
|
| 235 |
+
if lsteps:
|
| 236 |
+
ax_lr.plot(lsteps, lvals, color=color, linewidth=1.5, label=run_name)
|
| 237 |
+
|
| 238 |
+
# --- Throughput ------------------------------------------ #
|
| 239 |
+
tsteps, tvals = data["tok"]
|
| 240 |
+
if tsteps:
|
| 241 |
+
avg_tok = np.mean(tvals)
|
| 242 |
+
ax_tok.plot(tsteps, tvals, color=color, alpha=0.6, linewidth=1.0)
|
| 243 |
+
ax_tok.axhline(avg_tok, color=color, linewidth=1.5, linestyle="--",
|
| 244 |
+
label=f"{run_name} (avg {avg_tok:.0f})")
|
| 245 |
+
|
| 246 |
+
# --- VRAM ------------------------------------------------- #
|
| 247 |
+
vsteps2, vvals = data["vram"]
|
| 248 |
+
if vsteps2:
|
| 249 |
+
has_vram = True
|
| 250 |
+
ax_vram.plot(vsteps2, vvals, color=color, linewidth=1.5, label=run_name)
|
| 251 |
+
|
| 252 |
+
# --- Grad norm ------------------------------------------- #
|
| 253 |
+
gsteps, gvals = data["grad"]
|
| 254 |
+
if gsteps:
|
| 255 |
+
has_grad = True
|
| 256 |
+
smoothed_g = ema_smooth(gvals, alpha=0.85)
|
| 257 |
+
ax_grad.plot(gsteps, gvals, color=color, alpha=0.2, linewidth=0.8)
|
| 258 |
+
ax_grad.plot(gsteps, smoothed_g, color=color, linewidth=1.5, label=run_name)
|
| 259 |
+
|
| 260 |
+
# --- Style panels -------------------------------------------- #
|
| 261 |
+
def _style(ax, title, xlabel, ylabel, legend=True):
|
| 262 |
+
ax.set_title(title, fontweight="bold", pad=8)
|
| 263 |
+
ax.set_xlabel(xlabel)
|
| 264 |
+
ax.set_ylabel(ylabel)
|
| 265 |
+
ax.grid(True)
|
| 266 |
+
ax.tick_params(which="both", length=3)
|
| 267 |
+
if legend and ax.get_legend_handles_labels()[0]:
|
| 268 |
+
ax.legend(fontsize=8, loc="upper right")
|
| 269 |
+
|
| 270 |
+
_style(ax_loss, "Training Loss (EMA smoothed)", "Step", "Loss")
|
| 271 |
+
_style(ax_lr, "Learning Rate Schedule", "Step", "LR")
|
| 272 |
+
_style(ax_tok, "Throughput", "Step", "Tokens / sec")
|
| 273 |
+
|
| 274 |
+
if has_val:
|
| 275 |
+
_style(ax_val, "Validation Loss", "Step", "Val Loss")
|
| 276 |
+
else:
|
| 277 |
+
ax_val.text(0.5, 0.5, "No validation data yet",
|
| 278 |
+
ha="center", va="center", transform=ax_val.transAxes,
|
| 279 |
+
color=MUTED_COLOR, fontsize=11)
|
| 280 |
+
ax_val.set_title("Validation Loss", fontweight="bold", pad=8)
|
| 281 |
+
|
| 282 |
+
if has_vram:
|
| 283 |
+
_style(ax_vram, "VRAM Usage", "Step", "GB")
|
| 284 |
+
ax_vram.axhline(4.0, color=ACCENT_RED, linewidth=1, linestyle=":", alpha=0.6, label="4 GB limit")
|
| 285 |
+
ax_vram.legend(fontsize=8)
|
| 286 |
+
else:
|
| 287 |
+
ax_vram.text(0.5, 0.5, "No VRAM data\n(requires CUDA)", ha="center", va="center",
|
| 288 |
+
transform=ax_vram.transAxes, color=MUTED_COLOR, fontsize=11)
|
| 289 |
+
ax_vram.set_title("VRAM Usage", fontweight="bold", pad=8)
|
| 290 |
+
|
| 291 |
+
if has_grad:
|
| 292 |
+
_style(ax_grad, "Gradient Norm (EMA smoothed)", "Step", "Norm")
|
| 293 |
+
else:
|
| 294 |
+
ax_grad.text(0.5, 0.5, "No gradient norm data", ha="center", va="center",
|
| 295 |
+
transform=ax_grad.transAxes, color=MUTED_COLOR, fontsize=11)
|
| 296 |
+
ax_grad.set_title("Gradient Norm", fontweight="bold", pad=8)
|
| 297 |
+
|
| 298 |
+
# LR scientific notation
|
| 299 |
+
ax_lr.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
|
| 300 |
+
ax_lr.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
|
| 301 |
+
|
| 302 |
+
if save_path:
|
| 303 |
+
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=DARK_BG)
|
| 304 |
+
print(f"[PLOT] Saved to {save_path}")
|
| 305 |
+
else:
|
| 306 |
+
plt.show()
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ------------------------------------------------------------------ #
|
| 310 |
+
# CLI
|
| 311 |
+
# ------------------------------------------------------------------ #
|
| 312 |
+
|
| 313 |
+
def parse_args():
|
| 314 |
+
p = argparse.ArgumentParser(description="SLLM Training Dashboard")
|
| 315 |
+
p.add_argument("--run_dir", nargs="+", default=["runs/run_001"],
|
| 316 |
+
help="One or more run directories to plot")
|
| 317 |
+
p.add_argument("--live", action="store_true",
|
| 318 |
+
help="Refresh plot every --interval seconds (live mode)")
|
| 319 |
+
p.add_argument("--interval", type=int, default=10,
|
| 320 |
+
help="Refresh interval in seconds for --live mode")
|
| 321 |
+
p.add_argument("--save", type=str, default=None,
|
| 322 |
+
help="Save plot to this path instead of showing interactively")
|
| 323 |
+
return p.parse_args()
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def main():
|
| 327 |
+
args = parse_args()
|
| 328 |
+
|
| 329 |
+
run_dirs = args.run_dir
|
| 330 |
+
run_names = [Path(d).name for d in run_dirs]
|
| 331 |
+
|
| 332 |
+
def _reload_and_plot():
|
| 333 |
+
data_dict = {}
|
| 334 |
+
for name, run_dir in zip(run_names, run_dirs):
|
| 335 |
+
log_path = os.path.join(run_dir, "train_log.jsonl")
|
| 336 |
+
data = load_log(log_path)
|
| 337 |
+
if data is None:
|
| 338 |
+
print(f"[WARN] No log found at: {log_path}")
|
| 339 |
+
data_dict[name] = data
|
| 340 |
+
|
| 341 |
+
# Check if any data was loaded
|
| 342 |
+
total_steps = sum(
|
| 343 |
+
len(d["train"][0]) for d in data_dict.values() if d
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
if total_steps == 0:
|
| 347 |
+
print("[PLOT] No data logged yet. Waiting...")
|
| 348 |
+
return
|
| 349 |
+
|
| 350 |
+
steps_info = {n: len(d["train"][0]) for n, d in data_dict.items() if d}
|
| 351 |
+
print(f"[PLOT] Plotting {steps_info} train steps")
|
| 352 |
+
|
| 353 |
+
plt.close("all")
|
| 354 |
+
make_dashboard(data_dict, run_names, save_path=args.save)
|
| 355 |
+
|
| 356 |
+
if args.live:
|
| 357 |
+
print(f"[LIVE] Refreshing every {args.interval}s (Ctrl+C to stop)")
|
| 358 |
+
matplotlib.use("TkAgg") if sys.platform == "win32" else None
|
| 359 |
+
try:
|
| 360 |
+
while True:
|
| 361 |
+
_reload_and_plot()
|
| 362 |
+
plt.pause(args.interval)
|
| 363 |
+
except KeyboardInterrupt:
|
| 364 |
+
print("\n[LIVE] Stopped.")
|
| 365 |
+
else:
|
| 366 |
+
_reload_and_plot()
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt — SLLM project
|
| 2 |
+
# Install into the 'pytorch' conda env:
|
| 3 |
+
# conda run -n pytorch pip install -r requirements.txt
|
| 4 |
+
|
| 5 |
+
# Core ML
|
| 6 |
+
torch>=2.3.0
|
| 7 |
+
torchvision
|
| 8 |
+
|
| 9 |
+
# Data
|
| 10 |
+
datasets>=2.14.0 # HuggingFace datasets (streaming)
|
| 11 |
+
tokenizers>=0.15.0 # fast BPE tokenizer
|
| 12 |
+
transformers>=4.40.0 # PreTrainedTokenizerFast
|
| 13 |
+
|
| 14 |
+
# Utilities
|
| 15 |
+
numpy>=1.26.0
|
| 16 |
+
tqdm
|
| 17 |
+
matplotlib # training plots
|
| 18 |
+
rich # pretty terminal output (optional)
|
| 19 |
+
|
| 20 |
+
# Dev
|
run.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Start training (first run):
|
| 2 |
+
|
| 3 |
+
python train.py ^
|
| 4 |
+
--config 150M ^
|
| 5 |
+
--data_dir tokenizer/data ^
|
| 6 |
+
--batch_size 2 ^
|
| 7 |
+
--grad_accum 16 ^
|
| 8 |
+
--grad_checkpoint ^
|
| 9 |
+
--dtype bf16 ^
|
| 10 |
+
--max_steps 5000 ^
|
| 11 |
+
--run_dir runs/sllm_150m ^
|
| 12 |
+
--log_every 10 ^
|
| 13 |
+
--save_every 500 ^
|
| 14 |
+
--val_every 500 ^
|
| 15 |
+
--val_steps 20 ^
|
| 16 |
+
--warmup_steps 200
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Resume from where you stopped:
|
| 21 |
+
|
| 22 |
+
python train.py --resume --data_dir tokenizer/data --batch_size 2 --grad_accum 16 --grad_checkpoint --dtype bf16 --extra_steps 5000 --run_dir runs/sllm_150m --log_every 10 --save_every 500 --val_every 500 --val_steps 20 --warmup_steps 200
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Plot while training (in a second terminal):
|
| 27 |
+
conda activate pytorch
|
| 28 |
+
cd c:\geetesh\aimldl\projects\sllm
|
| 29 |
+
python plot_training.py --run_dir runs/sllm_150m --live --interval 30
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
python finetune/prepare_data.py
|
| 33 |
+
python finetune/sft_train.py --base_ckpt runs/sllm_150m/ckpt_0011500.pt --run_dir runs/sllm_150m_chat --max_steps 2500 --batch_size 4 --grad_accum 8 --grad_checkpoint
|
| 34 |
+
python finetune/chat.py --run_dir runs/sllm_150m_chat
|
test_chatmodel.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
test_chatmodel.py — Interactive CLI chat and evaluation for the fine-tuned SLLM chat model.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python test_chatmodel.py --run_dir runs/sllm_150m_chat
|
| 6 |
+
python test_chatmodel.py --run_dir runs/sllm_150m_chat --mode sample
|
| 7 |
+
|
| 8 |
+
In interactive mode:
|
| 9 |
+
Type your message and press Enter.
|
| 10 |
+
Special commands:
|
| 11 |
+
/reset Clear conversation history
|
| 12 |
+
/system <text> Change the system prompt
|
| 13 |
+
/quit Exit the chat
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import argparse
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from torch.amp import autocast
|
| 24 |
+
from transformers import PreTrainedTokenizerFast
|
| 25 |
+
|
| 26 |
+
# Add project root to path
|
| 27 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 28 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 29 |
+
|
| 30 |
+
from model.config import SLLM_150M
|
| 31 |
+
from model.model import SLLM
|
| 32 |
+
|
| 33 |
+
DEFAULT_SYSTEM = "You are a helpful, concise assistant."
|
| 34 |
+
DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ------------------------------------------------------------------ #
|
| 38 |
+
# HELPERS
|
| 39 |
+
# ------------------------------------------------------------------ #
|
| 40 |
+
|
| 41 |
+
def find_latest_ckpt(run_dir: str) -> str:
|
| 42 |
+
"""Returns path to the most recent SFT or base checkpoint in run_dir."""
|
| 43 |
+
if not os.path.isdir(run_dir):
|
| 44 |
+
raise FileNotFoundError(f"Run directory '{run_dir}' does not exist.")
|
| 45 |
+
|
| 46 |
+
ckpts = sorted([
|
| 47 |
+
f for f in os.listdir(run_dir)
|
| 48 |
+
if (f.startswith("ckpt_sft_") or f.startswith("ckpt_")) and f.endswith(".pt")
|
| 49 |
+
])
|
| 50 |
+
if not ckpts:
|
| 51 |
+
raise FileNotFoundError(
|
| 52 |
+
f"No checkpoints found in '{run_dir}'.\n"
|
| 53 |
+
f"Please ensure you have trained the model or point to the correct folder."
|
| 54 |
+
)
|
| 55 |
+
return os.path.join(run_dir, ckpts[-1])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def resize_token_embeddings(model: SLLM, new_vocab_size: int):
|
| 59 |
+
"""Resizes the token embeddings matrix to support added special tokens."""
|
| 60 |
+
old_size = model.config.vocab_size
|
| 61 |
+
if new_vocab_size == old_size:
|
| 62 |
+
return
|
| 63 |
+
d_model = model.config.d_model
|
| 64 |
+
device = model.token_emb.weight.device
|
| 65 |
+
dtype = model.token_emb.weight.dtype
|
| 66 |
+
old_weight = model.token_emb.weight.data.clone()
|
| 67 |
+
mean_vec = old_weight.mean(dim=0)
|
| 68 |
+
|
| 69 |
+
new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
|
| 70 |
+
new_weight[:old_size] = old_weight
|
| 71 |
+
new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
|
| 72 |
+
|
| 73 |
+
new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
|
| 74 |
+
new_emb.weight.data = new_weight
|
| 75 |
+
model.token_emb = new_emb
|
| 76 |
+
model.lm_head.weight = model.token_emb.weight
|
| 77 |
+
model.config.vocab_size = new_vocab_size
|
| 78 |
+
print(f" [INFO] Resized model vocab embedding from {old_size:,} to {new_vocab_size:,}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_model_and_tokenizer(run_dir: str, device: torch.device):
|
| 82 |
+
"""Loads tokenizer and the latest model checkpoint."""
|
| 83 |
+
# ---- Tokenizer ------------------------------------------------- #
|
| 84 |
+
# Look in finetune/data or tokenizer/fineweb_edu_tokenizer
|
| 85 |
+
data_tok_dir = PROJECT_ROOT / "finetune" / "data"
|
| 86 |
+
base_tok_dir = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
|
| 87 |
+
|
| 88 |
+
if os.path.exists(data_tok_dir / "tokenizer.json"):
|
| 89 |
+
tok_path = str(data_tok_dir)
|
| 90 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
|
| 91 |
+
print(f" Tokenizer: Loaded extended tokenizer from '{tok_path}'")
|
| 92 |
+
elif os.path.exists(base_tok_dir):
|
| 93 |
+
tok_path = str(base_tok_dir)
|
| 94 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
|
| 95 |
+
tokenizer.add_special_tokens({
|
| 96 |
+
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
|
| 97 |
+
})
|
| 98 |
+
print(f" Tokenizer: Loaded base tokenizer from '{tok_path}' and added ChatML tokens")
|
| 99 |
+
else:
|
| 100 |
+
raise FileNotFoundError("Could not find a tokenizer directory.")
|
| 101 |
+
|
| 102 |
+
# ---- Checkpoint ------------------------------------------------ #
|
| 103 |
+
try:
|
| 104 |
+
ckpt_path = find_latest_ckpt(run_dir)
|
| 105 |
+
except FileNotFoundError:
|
| 106 |
+
# Fall back to base pretraining checkpoint if SFT directory is empty
|
| 107 |
+
print(f" [WARN] No checkpoint found in '{run_dir}'. Trying pretraining base run...")
|
| 108 |
+
base_dir = PROJECT_ROOT / "runs" / "sllm_150m"
|
| 109 |
+
ckpt_path = find_latest_ckpt(str(base_dir))
|
| 110 |
+
|
| 111 |
+
print(f" Loading checkpoint: {ckpt_path}")
|
| 112 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 113 |
+
|
| 114 |
+
# ---- Model ----------------------------------------------------- #
|
| 115 |
+
model = SLLM(SLLM_150M).to(device)
|
| 116 |
+
saved_vocab = ckpt.get("vocab_size", len(tokenizer))
|
| 117 |
+
resize_token_embeddings(model, saved_vocab)
|
| 118 |
+
|
| 119 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 120 |
+
model.eval()
|
| 121 |
+
|
| 122 |
+
step = ckpt.get("step", "?")
|
| 123 |
+
loss = ckpt.get("loss", float("nan"))
|
| 124 |
+
return model, tokenizer, ckpt_path, step, loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ------------------------------------------------------------------ #
|
| 128 |
+
# PROMPT BUILDING
|
| 129 |
+
# ------------------------------------------------------------------ #
|
| 130 |
+
|
| 131 |
+
def build_prompt(history: list[dict], system_prompt: str,
|
| 132 |
+
tokenizer: PreTrainedTokenizerFast) -> torch.Tensor:
|
| 133 |
+
"""Formats conversation history as ChatML and tokenizes it."""
|
| 134 |
+
text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
| 135 |
+
for turn in history:
|
| 136 |
+
text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
|
| 137 |
+
# Prime the model to respond as assistant
|
| 138 |
+
text += "<|im_start|>assistant\n"
|
| 139 |
+
|
| 140 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 141 |
+
return torch.tensor([ids], dtype=torch.long)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ------------------------------------------------------------------ #
|
| 145 |
+
# GENERATION
|
| 146 |
+
# ------------------------------------------------------------------ #
|
| 147 |
+
|
| 148 |
+
@torch.no_grad()
|
| 149 |
+
def generate_response(
|
| 150 |
+
model: SLLM,
|
| 151 |
+
input_ids: torch.Tensor,
|
| 152 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 153 |
+
max_new_tokens: int = 200,
|
| 154 |
+
temperature: float = 0.7,
|
| 155 |
+
top_k: int = 40,
|
| 156 |
+
top_p: float = 0.9,
|
| 157 |
+
device: torch.device = None,
|
| 158 |
+
dtype_torch: torch.dtype = torch.float32,
|
| 159 |
+
use_amp: bool = False,
|
| 160 |
+
) -> str:
|
| 161 |
+
"""Generates a response from the model using top-k/top-p sampling."""
|
| 162 |
+
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 163 |
+
eos_id = tokenizer.eos_token_id
|
| 164 |
+
|
| 165 |
+
ids = input_ids.to(device)
|
| 166 |
+
generated = []
|
| 167 |
+
|
| 168 |
+
for _ in range(max_new_tokens):
|
| 169 |
+
# Crop context to model window
|
| 170 |
+
ctx = ids if ids.shape[1] <= model.config.context_length \
|
| 171 |
+
else ids[:, -model.config.context_length:]
|
| 172 |
+
|
| 173 |
+
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
|
| 174 |
+
logits, _ = model(ctx) # (1, T, V)
|
| 175 |
+
|
| 176 |
+
# Pull last token logits
|
| 177 |
+
logits = logits[:, -1, :]
|
| 178 |
+
|
| 179 |
+
if temperature == 0.0:
|
| 180 |
+
# Greedy
|
| 181 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
| 182 |
+
else:
|
| 183 |
+
logits = logits / max(temperature, 1e-8)
|
| 184 |
+
|
| 185 |
+
# Top-k filtering
|
| 186 |
+
if top_k and top_k > 0:
|
| 187 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 188 |
+
logits[logits < v[:, [-1]]] = float("-inf")
|
| 189 |
+
|
| 190 |
+
# Top-p (nucleus) filtering
|
| 191 |
+
if top_p < 1.0:
|
| 192 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 193 |
+
cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 194 |
+
sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
|
| 195 |
+
logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
|
| 196 |
+
|
| 197 |
+
probs = torch.softmax(logits, dim=-1)
|
| 198 |
+
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
|
| 199 |
+
|
| 200 |
+
tok_id = next_token.item()
|
| 201 |
+
|
| 202 |
+
# Stop if end of message or end of stream token is generated
|
| 203 |
+
if tok_id == im_end_id or tok_id == eos_id:
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
generated.append(tok_id)
|
| 207 |
+
ids = torch.cat([ids, next_token], dim=1)
|
| 208 |
+
|
| 209 |
+
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ------------------------------------------------------------------ #
|
| 213 |
+
# MODES
|
| 214 |
+
# ------------------------------------------------------------------ #
|
| 215 |
+
|
| 216 |
+
def run_interactive(model, tokenizer, device, dtype_torch, use_amp, args):
|
| 217 |
+
system_prompt = args.system
|
| 218 |
+
history = []
|
| 219 |
+
|
| 220 |
+
print("\n" + "=" * 60)
|
| 221 |
+
print(" CHAT MODE (Interactive)")
|
| 222 |
+
print("=" * 60)
|
| 223 |
+
print(f" System prompt : {system_prompt}")
|
| 224 |
+
print(" Commands : /reset to clear memory | /system <prompt> | /quit to exit")
|
| 225 |
+
print("─" * 60 + "\n")
|
| 226 |
+
|
| 227 |
+
while True:
|
| 228 |
+
try:
|
| 229 |
+
user_input = input("You: ").strip()
|
| 230 |
+
except (EOFError, KeyboardInterrupt):
|
| 231 |
+
print("\nBye!")
|
| 232 |
+
break
|
| 233 |
+
|
| 234 |
+
if not user_input:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
# Check for commands
|
| 238 |
+
if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
|
| 239 |
+
print("Bye!")
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
if user_input.lower() == "/reset":
|
| 243 |
+
history = []
|
| 244 |
+
print(" [Conversation history reset]\n")
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
if user_input.lower().startswith("/system "):
|
| 248 |
+
new_sys = user_input[8:].strip()
|
| 249 |
+
if new_sys:
|
| 250 |
+
system_prompt = new_sys
|
| 251 |
+
history = []
|
| 252 |
+
print(f" [System prompt updated. History cleared.]\n")
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
# Add to history and build ChatML prompt
|
| 256 |
+
history.append({"role": "user", "content": user_input})
|
| 257 |
+
input_ids = build_prompt(history, system_prompt, tokenizer)
|
| 258 |
+
|
| 259 |
+
# Trim conversation window if it exceeds model context length
|
| 260 |
+
while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10:
|
| 261 |
+
if len(history) > 2:
|
| 262 |
+
history = history[2:] # Remove oldest user + assistant turn
|
| 263 |
+
input_ids = build_prompt(history, system_prompt, tokenizer)
|
| 264 |
+
else:
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
print("SLLM: ", end="", flush=True)
|
| 268 |
+
response = generate_response(
|
| 269 |
+
model, input_ids, tokenizer,
|
| 270 |
+
max_new_tokens=args.max_new_tokens,
|
| 271 |
+
temperature=args.temperature,
|
| 272 |
+
top_k=args.top_k,
|
| 273 |
+
top_p=args.top_p,
|
| 274 |
+
device=device,
|
| 275 |
+
dtype_torch=dtype_torch,
|
| 276 |
+
use_amp=use_amp,
|
| 277 |
+
)
|
| 278 |
+
print(response + "\n")
|
| 279 |
+
history.append({"role": "assistant", "content": response})
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def run_sample(model, tokenizer, device, dtype_torch, use_amp, args):
|
| 283 |
+
sample_prompts = [
|
| 284 |
+
"Hello! Who are you?",
|
| 285 |
+
"What is the capital of France?",
|
| 286 |
+
"Write a quick, 3-line poem about a small robot learning to speak.",
|
| 287 |
+
"Explain gravity in one simple sentence.",
|
| 288 |
+
]
|
| 289 |
+
|
| 290 |
+
print("\n" + "=" * 60)
|
| 291 |
+
print(" SAMPLE EVALUATION MODE")
|
| 292 |
+
print("=" * 60)
|
| 293 |
+
print(f" System prompt: {args.system}")
|
| 294 |
+
print("─" * 60)
|
| 295 |
+
|
| 296 |
+
for prompt in sample_prompts:
|
| 297 |
+
print(f"\n[PROMPT] : {prompt}")
|
| 298 |
+
history = [{"role": "user", "content": prompt}]
|
| 299 |
+
input_ids = build_prompt(history, args.system, tokenizer)
|
| 300 |
+
|
| 301 |
+
print("[SLLM] : ", end="", flush=True)
|
| 302 |
+
response = generate_response(
|
| 303 |
+
model, input_ids, tokenizer,
|
| 304 |
+
max_new_tokens=args.max_new_tokens,
|
| 305 |
+
temperature=args.temperature,
|
| 306 |
+
top_k=args.top_k,
|
| 307 |
+
top_p=args.top_p,
|
| 308 |
+
device=device,
|
| 309 |
+
dtype_torch=dtype_torch,
|
| 310 |
+
use_amp=use_amp,
|
| 311 |
+
)
|
| 312 |
+
print(response)
|
| 313 |
+
print("\n" + "─" * 60 + "\n")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ------------------------------------------------------------------ #
|
| 317 |
+
# MAIN
|
| 318 |
+
# ------------------------------------------------------------------ #
|
| 319 |
+
|
| 320 |
+
def main():
|
| 321 |
+
p = argparse.ArgumentParser(description="SLLM Chat Checker")
|
| 322 |
+
p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR)
|
| 323 |
+
p.add_argument("--mode", type=str, default="interactive", choices=["interactive", "sample"])
|
| 324 |
+
p.add_argument("--temperature", type=float, default=0.7)
|
| 325 |
+
p.add_argument("--top_k", type=int, default=40)
|
| 326 |
+
p.add_argument("--top_p", type=float, default=0.9)
|
| 327 |
+
p.add_argument("--max_new_tokens", type=int, default=200)
|
| 328 |
+
p.add_argument("--system", type=str, default=DEFAULT_SYSTEM)
|
| 329 |
+
p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
| 330 |
+
args = p.parse_args()
|
| 331 |
+
|
| 332 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 333 |
+
print(f"\nDevice : {device}")
|
| 334 |
+
if device.type == "cuda":
|
| 335 |
+
print(f"GPU : {torch.cuda.get_device_name(0)}")
|
| 336 |
+
|
| 337 |
+
# Precision setup
|
| 338 |
+
use_amp = False
|
| 339 |
+
if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
|
| 340 |
+
dtype_torch = torch.bfloat16
|
| 341 |
+
use_amp = True
|
| 342 |
+
elif args.dtype == "fp16" and device.type == "cuda":
|
| 343 |
+
dtype_torch = torch.float16
|
| 344 |
+
use_amp = True
|
| 345 |
+
else:
|
| 346 |
+
dtype_torch = torch.float32
|
| 347 |
+
print(f"dtype : {args.dtype}")
|
| 348 |
+
|
| 349 |
+
# Load Model and Tokenizer
|
| 350 |
+
try:
|
| 351 |
+
model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device)
|
| 352 |
+
print(f" Step : {step}")
|
| 353 |
+
if not torch.isnan(torch.tensor(loss)):
|
| 354 |
+
print(f" Loss : {loss:.4f}")
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(f"\n[ERROR] Failed to load chat model: {e}")
|
| 357 |
+
return
|
| 358 |
+
|
| 359 |
+
if args.mode == "interactive":
|
| 360 |
+
run_interactive(model, tokenizer, device, dtype_torch, use_amp, args)
|
| 361 |
+
elif args.mode == "sample":
|
| 362 |
+
run_sample(model, tokenizer, device, dtype_torch, use_amp, args)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
main()
|
test_checkpoint.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
test_checkpoint.py — Load a checkpoint and run inference / inspect it.
|
| 3 |
+
|
| 4 |
+
QUICK START: Edit the variables in the CONFIG section below, then run:
|
| 5 |
+
python test_checkpoint.py
|
| 6 |
+
|
| 7 |
+
Modes:
|
| 8 |
+
INTERACTIVE — Chat loop: type prompts, model responds.
|
| 9 |
+
SAMPLE — Auto-generate N samples from fixed prompts and exit.
|
| 10 |
+
INSPECT — Just print checkpoint info (no generation).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import torch
|
| 16 |
+
from torch.amp import autocast
|
| 17 |
+
|
| 18 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 19 |
+
from model.config import SLLM_100M, SLLM_150M, ModelConfig
|
| 20 |
+
from model.model import SLLM
|
| 21 |
+
|
| 22 |
+
# ================================================================== #
|
| 23 |
+
# ✏️ EDIT THESE VARIABLES
|
| 24 |
+
# ================================================================== #
|
| 25 |
+
|
| 26 |
+
# --- Checkpoint to load -------------------------------------------
|
| 27 |
+
# Point to any .pt file inside a runs/ subfolder.
|
| 28 |
+
# Examples:
|
| 29 |
+
# RUN_DIR = "runs/sllm_150m" # loads latest .pt in this folder
|
| 30 |
+
# CKPT_FILE = None # set to a specific filename to override
|
| 31 |
+
# CKPT_FILE = "ckpt_0002000.pt" # or pick a specific step
|
| 32 |
+
RUN_DIR = "runs/sllm_150m"
|
| 33 |
+
CKPT_FILE = None # None = auto-pick latest checkpoint in RUN_DIR
|
| 34 |
+
|
| 35 |
+
# --- Model config --------------------------------------------------
|
| 36 |
+
# Must match what you trained with: "100M" or "150M"
|
| 37 |
+
CONFIG = "150M"
|
| 38 |
+
|
| 39 |
+
# --- Generation settings ------------------------------------------
|
| 40 |
+
MAX_NEW_TOKENS = 100 # tokens to generate per prompt
|
| 41 |
+
TEMPERATURE = 0.8 # 0.0 = greedy, 1.0 = random, 0.8 = balanced
|
| 42 |
+
TOP_K = 50 # keep only top-k logits (0 = disabled)
|
| 43 |
+
TOP_P = 0.95 # nucleus sampling threshold (1.0 = disabled)
|
| 44 |
+
|
| 45 |
+
# --- Mode ---------------------------------------------------------
|
| 46 |
+
# "interactive" : chat loop in the terminal
|
| 47 |
+
# "sample" : run SAMPLE_PROMPTS list and exit
|
| 48 |
+
# "inspect" : just print checkpoint metadata, no generation
|
| 49 |
+
MODE = "sample"
|
| 50 |
+
|
| 51 |
+
# --- Prompts for SAMPLE mode --------------------------------------
|
| 52 |
+
SAMPLE_PROMPTS = [
|
| 53 |
+
"Once upon a time",
|
| 54 |
+
"The meaning of life is",
|
| 55 |
+
"In the year 2050,",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
# --- dtype --------------------------------------------------------
|
| 59 |
+
# "bf16" (recommended on RTX cards), "fp16", or "fp32"
|
| 60 |
+
DTYPE = "bf16"
|
| 61 |
+
|
| 62 |
+
# ================================================================== #
|
| 63 |
+
# INTERNALS (no need to edit below)
|
| 64 |
+
# ================================================================== #
|
| 65 |
+
|
| 66 |
+
def resolve_checkpoint(run_dir: str, ckpt_file) -> str:
|
| 67 |
+
"""Return full path to the checkpoint file."""
|
| 68 |
+
if ckpt_file is not None:
|
| 69 |
+
path = os.path.join(run_dir, ckpt_file)
|
| 70 |
+
if not os.path.isfile(path):
|
| 71 |
+
raise FileNotFoundError(f"Checkpoint not found: {path}")
|
| 72 |
+
return path
|
| 73 |
+
|
| 74 |
+
# Auto-pick latest
|
| 75 |
+
if not os.path.isdir(run_dir):
|
| 76 |
+
raise FileNotFoundError(f"Run directory not found: {run_dir}")
|
| 77 |
+
ckpts = sorted([
|
| 78 |
+
f for f in os.listdir(run_dir)
|
| 79 |
+
if f.startswith("ckpt_") and f.endswith(".pt")
|
| 80 |
+
])
|
| 81 |
+
if not ckpts:
|
| 82 |
+
raise FileNotFoundError(f"No checkpoints found in: {run_dir}")
|
| 83 |
+
return os.path.join(run_dir, ckpts[-1])
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_model(ckpt_path: str, config_name: str, device, dtype_torch):
|
| 87 |
+
"""Load model weights from checkpoint."""
|
| 88 |
+
cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M}
|
| 89 |
+
cfg = cfg_map[config_name]
|
| 90 |
+
|
| 91 |
+
print(f"\n Config : {cfg}")
|
| 92 |
+
model = SLLM(cfg).to(device)
|
| 93 |
+
|
| 94 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 95 |
+
|
| 96 |
+
# Prefer config_name stored in checkpoint (override CLI if available)
|
| 97 |
+
ckpt_cfg_name = ckpt.get("config_name", config_name)
|
| 98 |
+
if ckpt_cfg_name != config_name:
|
| 99 |
+
print(f" [WARN] Checkpoint config_name='{ckpt_cfg_name}' "
|
| 100 |
+
f"differs from CONFIG='{config_name}'. "
|
| 101 |
+
f"Using checkpoint's config: '{ckpt_cfg_name}'")
|
| 102 |
+
cfg = cfg_map[ckpt_cfg_name]
|
| 103 |
+
model = SLLM(cfg).to(device)
|
| 104 |
+
|
| 105 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 106 |
+
model.eval()
|
| 107 |
+
|
| 108 |
+
step = ckpt.get("step", "?")
|
| 109 |
+
loss = ckpt.get("loss", float("nan"))
|
| 110 |
+
return model, cfg, step, loss
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def generate(model, prompt_ids: list[int], cfg: ModelConfig, device,
|
| 115 |
+
dtype_torch, use_amp: bool,
|
| 116 |
+
max_new_tokens: int, temperature: float,
|
| 117 |
+
top_k: int, top_p: float) -> list[int]:
|
| 118 |
+
"""Token-by-token autoregressive generation."""
|
| 119 |
+
ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 120 |
+
ctx_len = cfg.context_length
|
| 121 |
+
|
| 122 |
+
for _ in range(max_new_tokens):
|
| 123 |
+
# Crop to context window
|
| 124 |
+
ids_crop = ids[:, -ctx_len:]
|
| 125 |
+
|
| 126 |
+
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
|
| 127 |
+
logits, _ = model(ids_crop)
|
| 128 |
+
|
| 129 |
+
# Logits for the last position
|
| 130 |
+
logits = logits[:, -1, :] # (1, vocab)
|
| 131 |
+
|
| 132 |
+
if temperature == 0.0:
|
| 133 |
+
# Greedy
|
| 134 |
+
next_id = logits.argmax(dim=-1, keepdim=True)
|
| 135 |
+
else:
|
| 136 |
+
logits = logits / temperature
|
| 137 |
+
|
| 138 |
+
# Top-K filtering
|
| 139 |
+
if top_k > 0:
|
| 140 |
+
vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 141 |
+
logits[logits < vals[:, [-1]]] = float("-inf")
|
| 142 |
+
|
| 143 |
+
# Top-P (nucleus) filtering
|
| 144 |
+
if top_p < 1.0:
|
| 145 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 146 |
+
cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 147 |
+
# Remove tokens with cumulative prob > top_p
|
| 148 |
+
sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
|
| 149 |
+
logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
|
| 150 |
+
|
| 151 |
+
probs = torch.softmax(logits, dim=-1)
|
| 152 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
| 153 |
+
|
| 154 |
+
ids = torch.cat([ids, next_id], dim=1)
|
| 155 |
+
|
| 156 |
+
return ids[0].tolist()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def char_tokenize(text: str) -> list[int]:
|
| 160 |
+
"""
|
| 161 |
+
Fallback character-level tokenizer.
|
| 162 |
+
Your model uses a real tokenizer — swap this out with yours if available.
|
| 163 |
+
Each char maps to its Unicode code point (capped at vocab_size - 1).
|
| 164 |
+
"""
|
| 165 |
+
return [min(ord(c), 31_999) for c in text]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def char_detokenize(ids: list[int]) -> str:
|
| 169 |
+
"""Reverse of char_tokenize."""
|
| 170 |
+
return "".join(chr(i) if 32 <= i < 127 else "?" for i in ids)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def try_load_sentencepiece(tokenizer_dir="tokenizer/fineweb_edu_tokenizer"):
|
| 174 |
+
"""Load the HuggingFace PreTrainedTokenizerFast used during training."""
|
| 175 |
+
try:
|
| 176 |
+
from transformers import PreTrainedTokenizerFast
|
| 177 |
+
tok = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
|
| 178 |
+
encode = lambda text: tok.encode(text)
|
| 179 |
+
decode = lambda ids: tok.decode(ids, skip_special_tokens=True)
|
| 180 |
+
print(f" Tokenizer: HuggingFace tokenizer loaded from '{tokenizer_dir}'")
|
| 181 |
+
print(f" vocab_size={tok.vocab_size:,} eos_id={tok.eos_token_id}")
|
| 182 |
+
return encode, decode
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f" Tokenizer: Could not load HuggingFace tokenizer ({e})")
|
| 185 |
+
print(" Falling back to char tokenizer — output will be garbled!")
|
| 186 |
+
return char_tokenize, char_detokenize
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode):
|
| 190 |
+
print("\n" + "="*60)
|
| 191 |
+
print(" INTERACTIVE MODE (type 'quit' or 'exit' to stop)")
|
| 192 |
+
print("="*60)
|
| 193 |
+
print(f" max_new_tokens : {MAX_NEW_TOKENS}")
|
| 194 |
+
print(f" temperature : {TEMPERATURE}")
|
| 195 |
+
print(f" top_k / top_p : {TOP_K} / {TOP_P}")
|
| 196 |
+
print()
|
| 197 |
+
|
| 198 |
+
while True:
|
| 199 |
+
try:
|
| 200 |
+
prompt = input("Prompt> ").strip()
|
| 201 |
+
except (EOFError, KeyboardInterrupt):
|
| 202 |
+
print("\n Exiting.")
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
if prompt.lower() in ("quit", "exit", ""):
|
| 206 |
+
print(" Exiting.")
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
prompt_ids = encode(prompt)
|
| 210 |
+
output_ids = generate(
|
| 211 |
+
model, prompt_ids, cfg, device, dtype_torch, use_amp,
|
| 212 |
+
MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P,
|
| 213 |
+
)
|
| 214 |
+
# Only show the newly generated tokens
|
| 215 |
+
new_ids = output_ids[len(prompt_ids):]
|
| 216 |
+
print(f"\nGenerated: {decode(new_ids)}\n")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode):
|
| 220 |
+
print("\n" + "="*60)
|
| 221 |
+
print(" SAMPLE MODE")
|
| 222 |
+
print("="*60)
|
| 223 |
+
for i, prompt in enumerate(SAMPLE_PROMPTS, 1):
|
| 224 |
+
print(f"\n[{i}] Prompt : {prompt!r}")
|
| 225 |
+
prompt_ids = encode(prompt)
|
| 226 |
+
output_ids = generate(
|
| 227 |
+
model, prompt_ids, cfg, device, dtype_torch, use_amp,
|
| 228 |
+
MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P,
|
| 229 |
+
)
|
| 230 |
+
new_ids = output_ids[len(prompt_ids):]
|
| 231 |
+
print(f" Output : {decode(new_ids)}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def run_inspect(ckpt_path, step, loss, cfg):
|
| 235 |
+
print("\n" + "="*60)
|
| 236 |
+
print(" INSPECT MODE")
|
| 237 |
+
print("="*60)
|
| 238 |
+
print(f" Checkpoint : {ckpt_path}")
|
| 239 |
+
print(f" Step : {step}")
|
| 240 |
+
print(f" Loss : {loss:.4f}" if isinstance(loss, float) else f" Loss: {loss}")
|
| 241 |
+
print(f" Config : {cfg}")
|
| 242 |
+
print(f" Params : {cfg.count_params()/1e6:.1f}M")
|
| 243 |
+
print()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def main():
|
| 247 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 248 |
+
print(f"\nDevice : {device}")
|
| 249 |
+
if device.type == "cuda":
|
| 250 |
+
print(f"GPU : {torch.cuda.get_device_name(0)}")
|
| 251 |
+
print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 252 |
+
|
| 253 |
+
# dtype setup
|
| 254 |
+
use_amp = False
|
| 255 |
+
if DTYPE == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
|
| 256 |
+
dtype_torch = torch.bfloat16
|
| 257 |
+
use_amp = True
|
| 258 |
+
elif DTYPE == "fp16" and device.type == "cuda":
|
| 259 |
+
dtype_torch = torch.float16
|
| 260 |
+
use_amp = True
|
| 261 |
+
else:
|
| 262 |
+
dtype_torch = torch.float32
|
| 263 |
+
print(f"dtype : {DTYPE}")
|
| 264 |
+
|
| 265 |
+
# Resolve checkpoint path
|
| 266 |
+
ckpt_path = resolve_checkpoint(RUN_DIR, CKPT_FILE)
|
| 267 |
+
print(f"\nCheckpoint: {ckpt_path}")
|
| 268 |
+
|
| 269 |
+
# Load model
|
| 270 |
+
model, cfg, step, loss = load_model(ckpt_path, CONFIG, device, dtype_torch)
|
| 271 |
+
print(f" Loaded : step={step}, loss={loss:.4f}")
|
| 272 |
+
print(f" Params : {model.count_params()/1e6:.1f}M")
|
| 273 |
+
|
| 274 |
+
if MODE == "inspect":
|
| 275 |
+
run_inspect(ckpt_path, step, loss, cfg)
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
# Load tokenizer
|
| 279 |
+
encode, decode = try_load_sentencepiece()
|
| 280 |
+
|
| 281 |
+
if MODE == "interactive":
|
| 282 |
+
run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode)
|
| 283 |
+
elif MODE == "sample":
|
| 284 |
+
run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode)
|
| 285 |
+
else:
|
| 286 |
+
print(f" [ERROR] Unknown MODE: '{MODE}'. Use 'interactive', 'sample', or 'inspect'.")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|
tokenizer/bpe.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer, AddedToken
|
| 2 |
+
from tokenizers.models import BPE
|
| 3 |
+
from tokenizers.trainers import BpeTrainer
|
| 4 |
+
from tokenizers.pre_tokenizers import Sequence, ByteLevel
|
| 5 |
+
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
|
| 6 |
+
|
| 7 |
+
from pretokenizer import get_pretokenizer
|
| 8 |
+
|
| 9 |
+
VOCAB_SIZE = 32_000
|
| 10 |
+
MIN_FREQUENCY = 3
|
| 11 |
+
SPECIAL_TOKENS = ["<|endoftext|>"]
|
| 12 |
+
|
| 13 |
+
def build_tokenizer() -> Tokenizer:
|
| 14 |
+
"""
|
| 15 |
+
Builds and returns an untrained tokenizer with all components configured.
|
| 16 |
+
Call .train_from_iterator() or .train() on the returned object to train it.
|
| 17 |
+
|
| 18 |
+
Pipeline:
|
| 19 |
+
Raw text
|
| 20 |
+
-> Normalizer (handled externally in our normalize() fn)
|
| 21 |
+
-> Pre-tokenizer (custom regex splits + byte level conversion)
|
| 22 |
+
-> BPE Model (learns merge rules during training)
|
| 23 |
+
-> Decoder (reverses byte level for human readable output)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# ---- 1. BPE Model ------------------------------------------------
|
| 27 |
+
# unk_token=None because byte-level means we NEVER have unknowns
|
| 28 |
+
# every character always maps to at least one byte token
|
| 29 |
+
model = BPE(
|
| 30 |
+
unk_token=None, # no unknown token - byte fallback handles everything
|
| 31 |
+
byte_fallback=True, # unknown chars represented as <0xXX> byte tokens
|
| 32 |
+
# e.g. ∇ -> <0xE2><0x88><0x87>
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
tokenizer = Tokenizer(model)
|
| 36 |
+
|
| 37 |
+
# ---- 2. Pre-tokenizer --------------------------------------------
|
| 38 |
+
# Sequence chains two pre-tokenizers in order:
|
| 39 |
+
#
|
| 40 |
+
# Step A: Our custom regex splits text into meaningful chunks
|
| 41 |
+
# (contractions, abbreviations, numbers, operators etc.)
|
| 42 |
+
#
|
| 43 |
+
# Step B: ByteLevel converts each chunk's characters to their
|
| 44 |
+
# byte representation using a 256-char printable alphabet
|
| 45 |
+
# e.g. é (bytes 0xC3 0xA9) -> "é"
|
| 46 |
+
#
|
| 47 |
+
# add_prefix_space=False because our regex already handles
|
| 48 |
+
# whitespace explicitly as its own token category
|
| 49 |
+
tokenizer.pre_tokenizer = Sequence([
|
| 50 |
+
get_pretokenizer(), # Step A - our regex
|
| 51 |
+
ByteLevel(add_prefix_space=False), # Step B - byte conversion
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
# ---- 3. Decoder --------------------------------------------------
|
| 55 |
+
# Reverses the ByteLevel encoding so output is human readable
|
| 56 |
+
# Without this tokenizer.decode() would return "é" instead of "é"
|
| 57 |
+
tokenizer.decoder = ByteLevelDecoder()
|
| 58 |
+
|
| 59 |
+
return tokenizer
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ------------------------------------------------------------------ #
|
| 63 |
+
# TRAINER CONFIG
|
| 64 |
+
# ------------------------------------------------------------------ #
|
| 65 |
+
|
| 66 |
+
def build_trainer() -> BpeTrainer:
|
| 67 |
+
"""
|
| 68 |
+
Configures the BPE trainer.
|
| 69 |
+
|
| 70 |
+
vocab_size breakdown:
|
| 71 |
+
256 base byte tokens (one per possible byte value, always present)
|
| 72 |
+
+ 31,743 learned BPE merge tokens
|
| 73 |
+
+ 1 special token (<|endoftext|>)
|
| 74 |
+
= 32,000 total
|
| 75 |
+
|
| 76 |
+
The trainer automatically accounts for the 256 base tokens,
|
| 77 |
+
so setting vocab_size=32_000 gives you the right final count.
|
| 78 |
+
"""
|
| 79 |
+
return BpeTrainer(
|
| 80 |
+
vocab_size=VOCAB_SIZE,
|
| 81 |
+
min_frequency=MIN_FREQUENCY,
|
| 82 |
+
special_tokens=SPECIAL_TOKENS,
|
| 83 |
+
|
| 84 |
+
# show_progress shows a progress bar during training
|
| 85 |
+
show_progress=True,
|
| 86 |
+
|
| 87 |
+
# initial_alphabet tells the trainer to include all 256 bytes
|
| 88 |
+
# as base tokens before any merges happen
|
| 89 |
+
# This is what guarantees byte-level fallback works
|
| 90 |
+
initial_alphabet=ByteLevel.alphabet(),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# CONVENIENCE: get special token IDs after training
|
| 94 |
+
|
| 95 |
+
def get_special_token_ids(tokenizer: Tokenizer) -> dict:
|
| 96 |
+
"""
|
| 97 |
+
Returns a dict of special token string -> token ID.
|
| 98 |
+
Call this AFTER training to get the final IDs.
|
| 99 |
+
|
| 100 |
+
Example:
|
| 101 |
+
ids = get_special_token_ids(tokenizer)
|
| 102 |
+
eot_id = ids["<|endoftext|>"] # typically 0
|
| 103 |
+
"""
|
| 104 |
+
return {
|
| 105 |
+
token: tokenizer.token_to_id(token)
|
| 106 |
+
for token in SPECIAL_TOKENS
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# QUICK SANITY CHECK
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
print("Building tokenizer...")
|
| 113 |
+
tokenizer = build_tokenizer()
|
| 114 |
+
|
| 115 |
+
print("Building trainer...")
|
| 116 |
+
trainer = build_trainer()
|
| 117 |
+
|
| 118 |
+
# Verify pre-tokenizer chain is set up correctly
|
| 119 |
+
print("\nPre-tokenizer chain:")
|
| 120 |
+
print(f" {tokenizer.pre_tokenizer}")
|
| 121 |
+
|
| 122 |
+
# Verify decoder is set
|
| 123 |
+
print(f"\nDecoder:")
|
| 124 |
+
print(f" {tokenizer.decoder}")
|
| 125 |
+
|
| 126 |
+
# Verify trainer config
|
| 127 |
+
print(f"\nTrainer config:")
|
| 128 |
+
print(f" vocab_size : {trainer.vocab_size}")
|
| 129 |
+
print(f" min_frequency : {trainer.min_frequency}")
|
| 130 |
+
print(f" special_tokens: {trainer.special_tokens}")
|
| 131 |
+
print(f" base alphabet : {len(ByteLevel.alphabet())} byte tokens")
|
| 132 |
+
|
| 133 |
+
print("\nAll good - ready to train.")
|
| 134 |
+
print("Next step: pipe FineWeb-Edu text into tokenizer.train_from_iterator()")
|
tokenizer/fineweb_edu_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer/fineweb_edu_tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<|endoftext|>",
|
| 3 |
+
"eos_token": "<|endoftext|>",
|
| 4 |
+
"pad_token": "<|endoftext|>"
|
| 5 |
+
}
|
tokenizer/fineweb_edu_tokenizer/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer/fineweb_edu_tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<|endoftext|>",
|
| 4 |
+
"eos_token": "<|endoftext|>",
|
| 5 |
+
"model_max_length": 1024,
|
| 6 |
+
"pad_token": "<|endoftext|>",
|
| 7 |
+
"padding_side": "right",
|
| 8 |
+
"tokenizer_class": "TokenizersBackend",
|
| 9 |
+
"truncation_side": "right",
|
| 10 |
+
"unk_token": null
|
| 11 |
+
}
|
tokenizer/normalizer.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import html
|
| 3 |
+
import unicodedata
|
| 4 |
+
|
| 5 |
+
def normalization(text):
|
| 6 |
+
|
| 7 |
+
# Strip HTML tags (note: won't catch multiline tags)
|
| 8 |
+
text = re.sub(r'<[^>]+>', ' ', text)
|
| 9 |
+
|
| 10 |
+
# HTML entity decoding
|
| 11 |
+
text = html.unescape(text)
|
| 12 |
+
|
| 13 |
+
# NFC normalization
|
| 14 |
+
text = unicodedata.normalize('NFC', text)
|
| 15 |
+
|
| 16 |
+
# Control characters — including \x7f (DEL)
|
| 17 |
+
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
| 18 |
+
|
| 19 |
+
# Unicode line/paragraph separators → newline (structural, not removed)
|
| 20 |
+
text = re.sub(r'[\u2028\u2029]', '\n', text)
|
| 21 |
+
|
| 22 |
+
# Zero-width characters
|
| 23 |
+
text = re.sub(r'[\u200b\u200c\u200d\ufeff\u00ad]', '', text)
|
| 24 |
+
|
| 25 |
+
# Replacement character
|
| 26 |
+
text = text.replace('\ufffd', '')
|
| 27 |
+
|
| 28 |
+
# Normalize line endings
|
| 29 |
+
text = text.replace('\r\n', '\n')
|
| 30 |
+
text = text.replace('\r', '\n')
|
| 31 |
+
|
| 32 |
+
# Collapse spaces only (preserve leading tabs for indentation)
|
| 33 |
+
text = re.sub(r' +', ' ', text)
|
| 34 |
+
|
| 35 |
+
# Trailing spaces/tabs at end of line
|
| 36 |
+
text = re.sub(r'[ \t]+\n', '\n', text)
|
| 37 |
+
|
| 38 |
+
# Collapse excess newlines
|
| 39 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 40 |
+
|
| 41 |
+
text = text.strip()
|
| 42 |
+
return text
|
tokenizer/post_processor.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers.processors import TemplateProcessing
|
| 2 |
+
from tokenizers import Tokenizer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# ------------------------------------------------------------------ #
|
| 6 |
+
# POST-PROCESSOR
|
| 7 |
+
# Runs after BPE encoding, appends <|endoftext|> to every sequence
|
| 8 |
+
# ------------------------------------------------------------------ #
|
| 9 |
+
|
| 10 |
+
def add_post_processor(tokenizer: Tokenizer) -> Tokenizer:
|
| 11 |
+
"""
|
| 12 |
+
Adds a post-processor to the tokenizer that appends
|
| 13 |
+
<|endoftext|> to every encoded sequence.
|
| 14 |
+
|
| 15 |
+
Must be called AFTER training because we need the real
|
| 16 |
+
token ID of <|endoftext|> from the trained vocab.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
tokenizer: a trained Tokenizer object
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
The same tokenizer with post-processor attached
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Get the real ID from the trained vocab
|
| 26 |
+
# This is why we can only do this after training
|
| 27 |
+
eot_id = tokenizer.token_to_id("<|endoftext|>")
|
| 28 |
+
|
| 29 |
+
if eot_id is None:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
"<|endoftext|> not found in vocab. "
|
| 32 |
+
"Make sure the tokenizer is trained before adding post-processor."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# TemplateProcessing defines the final sequence structure
|
| 36 |
+
# using a simple template syntax:
|
| 37 |
+
#
|
| 38 |
+
# $A -> the encoded sequence (single sequence)
|
| 39 |
+
# $A $B -> two sequences (for pair tasks like QA)
|
| 40 |
+
# <|endoftext|>:ID -> insert this special token with its ID
|
| 41 |
+
#
|
| 42 |
+
# Our template:
|
| 43 |
+
# single : [tokens...] <|endoftext|>
|
| 44 |
+
# pair : [tokens_A...] <|endoftext|> [tokens_B...] <|endoftext|>
|
| 45 |
+
#
|
| 46 |
+
# pair template handles future use cases like
|
| 47 |
+
# question-context pairs without needing to change the tokenizer
|
| 48 |
+
|
| 49 |
+
tokenizer.post_processor = TemplateProcessing(
|
| 50 |
+
single="$A <|endoftext|>:0",
|
| 51 |
+
pair="$A <|endoftext|>:0 $B:1 <|endoftext|>:0",
|
| 52 |
+
special_tokens=[
|
| 53 |
+
("<|endoftext|>", eot_id),
|
| 54 |
+
],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
print(f"Post-processor added: <|endoftext|> (ID: {eot_id}) appended to sequences")
|
| 58 |
+
|
| 59 |
+
return tokenizer
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ------------------------------------------------------------------ #
|
| 63 |
+
# VERIFICATION
|
| 64 |
+
# ------------------------------------------------------------------ #
|
| 65 |
+
|
| 66 |
+
def verify_post_processor(tokenizer: Tokenizer):
|
| 67 |
+
"""
|
| 68 |
+
Verifies the post-processor is working correctly.
|
| 69 |
+
Checks that <|endoftext|> appears at end of every encoded sequence.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
eot_id = tokenizer.token_to_id("<|endoftext|>")
|
| 73 |
+
eot_token = "<|endoftext|>"
|
| 74 |
+
|
| 75 |
+
print("\n" + "="*60)
|
| 76 |
+
print(" POST-PROCESSOR VERIFICATION")
|
| 77 |
+
print("="*60 + "\n")
|
| 78 |
+
|
| 79 |
+
test_cases = [
|
| 80 |
+
# Single documents
|
| 81 |
+
"The mitochondria is the powerhouse of the cell.",
|
| 82 |
+
"CO2 levels rose by 1.5e-3 ppm.",
|
| 83 |
+
# Short edge cases
|
| 84 |
+
"Hi.",
|
| 85 |
+
"42",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
all_passed = True
|
| 89 |
+
|
| 90 |
+
for text in test_cases:
|
| 91 |
+
encoded = tokenizer.encode(text)
|
| 92 |
+
last_token = encoded.tokens[-1]
|
| 93 |
+
last_id = encoded.ids[-1]
|
| 94 |
+
passed = last_token == eot_token and last_id == eot_id
|
| 95 |
+
|
| 96 |
+
if not passed:
|
| 97 |
+
all_passed = False
|
| 98 |
+
|
| 99 |
+
status = "PASS" if passed else "FAIL"
|
| 100 |
+
print(f"[{status}] {repr(text)}")
|
| 101 |
+
print(f" tokens : {encoded.tokens}")
|
| 102 |
+
print(f" last : {last_token!r} (ID: {last_id})")
|
| 103 |
+
print()
|
| 104 |
+
|
| 105 |
+
# Verify pair encoding
|
| 106 |
+
encoded_pair = tokenizer.encode("question here", "answer here")
|
| 107 |
+
pair_ids = encoded_pair.ids
|
| 108 |
+
eot_positions = [i for i, id in enumerate(pair_ids) if id == eot_id]
|
| 109 |
+
|
| 110 |
+
print(f"Pair encoding test:")
|
| 111 |
+
print(f" tokens : {encoded_pair.tokens}")
|
| 112 |
+
print(f" eot positions: {eot_positions}")
|
| 113 |
+
print(f" expected : 2 eot tokens (one after each sequence)")
|
| 114 |
+
print(f" [{'PASS' if len(eot_positions) == 2 else 'FAIL'}]")
|
| 115 |
+
|
| 116 |
+
print(f"\nAll tests passed: {all_passed}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ------------------------------------------------------------------ #
|
| 120 |
+
# HOW THIS FITS INTO THE FULL PIPELINE
|
| 121 |
+
# ------------------------------------------------------------------ #
|
| 122 |
+
|
| 123 |
+
# The correct order when building your full tokenizer:
|
| 124 |
+
#
|
| 125 |
+
# 1. build_tokenizer() <- sets up model + pre-tokenizer + decoder
|
| 126 |
+
# 2. train_from_iterator() <- trains BPE, assigns real vocab IDs
|
| 127 |
+
# 3. add_post_processor() <- NOW we can add post-processor (needs real IDs)
|
| 128 |
+
# 4. tokenizer.save() <- saves everything including post-processor
|
| 129 |
+
#
|
| 130 |
+
# Loading later:
|
| 131 |
+
# tokenizer = Tokenizer.from_file("fineweb_edu_tokenizer.json")
|
| 132 |
+
# <- post-processor is automatically restored, no extra steps
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
import sys
|
| 137 |
+
|
| 138 |
+
# Load a trained tokenizer from disk to test
|
| 139 |
+
# Pass the path as argument: python post_processor.py fineweb_edu_tokenizer.json
|
| 140 |
+
# Or it will try the default path
|
| 141 |
+
|
| 142 |
+
path = sys.argv[1] if len(sys.argv) > 1 else "fineweb_edu_tokenizer.json"
|
| 143 |
+
|
| 144 |
+
print(f"Loading tokenizer from: {path}")
|
| 145 |
+
tokenizer = Tokenizer.from_file(path)
|
| 146 |
+
|
| 147 |
+
tokenizer = add_post_processor(tokenizer)
|
| 148 |
+
verify_post_processor(tokenizer)
|
| 149 |
+
|
| 150 |
+
# Save with post-processor included
|
| 151 |
+
tokenizer.save(path)
|
| 152 |
+
print(f"\nTokenizer re-saved with post-processor to: {path}")
|
tokenizer/pretokenizer.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from tokenizers.pre_tokenizers import PreTokenizer, Split
|
| 3 |
+
from tokenizers import Regex
|
| 4 |
+
|
| 5 |
+
# Each category is defined separately so its easy to understand, modify, or debug individually
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# 1. Contractions
|
| 9 |
+
# Matches: 's 't 're 've 'll 'm 'd
|
| 10 |
+
# Example: "don't" -> ["don", "'t"]
|
| 11 |
+
CONTRACTIONS = r"'(?:s|t|re|ve|ll|m|d)"
|
| 12 |
+
|
| 13 |
+
# 2. Abbreviations
|
| 14 |
+
# Matches: letter(s) separated by dots, optional trailing dot
|
| 15 |
+
# Example: "U.S.A" -> ["U.S.A"]
|
| 16 |
+
# "e.g." -> ["e.g."]
|
| 17 |
+
# "Ph.D" -> ["Ph.D"]
|
| 18 |
+
# \b = word boundary, ensures we dont partially match inside a word
|
| 19 |
+
ABBREVIATIONS = r"\b[A-Za-z](?:\.[A-Za-z])+\.?"
|
| 20 |
+
|
| 21 |
+
# 3. Scientific Notation
|
| 22 |
+
# Matches: number, optional decimal, e/E, optional sign, exponent
|
| 23 |
+
# Example: "1.5e-3" -> ["1.5e-3"]
|
| 24 |
+
# "3e10" -> ["3e10"]
|
| 25 |
+
# "2.0E+4" -> ["2.0E+4"]
|
| 26 |
+
# Must come BEFORE decimals otherwise "1.5" in "1.5e-3" matches first
|
| 27 |
+
SCIENTIFIC = r"\d+\.?\d*[eE][+-]?\d+"
|
| 28 |
+
|
| 29 |
+
# 4. Decimal Numbers
|
| 30 |
+
# Matches: digits, dot, digits
|
| 31 |
+
# Example: "3.14" -> ["3.14"]
|
| 32 |
+
# "0.001" -> ["0.001"]
|
| 33 |
+
# Must come BEFORE integers otherwise "3" in "3.14" matches first
|
| 34 |
+
DECIMALS = r"\d+\.\d+"
|
| 35 |
+
|
| 36 |
+
# 5. Integers
|
| 37 |
+
# Matches: any sequence of digits
|
| 38 |
+
# Example: "42" -> ["42"]
|
| 39 |
+
# "1984" -> ["1984"]
|
| 40 |
+
# Comes last among numbers since scientific and decimal match first
|
| 41 |
+
INTEGERS = r"\d+"
|
| 42 |
+
|
| 43 |
+
# 6. Multi-character Operators
|
| 44 |
+
# Matches: common programming operators that are 2 characters
|
| 45 |
+
# Example: "==" -> ["=="] "!=" -> ["!="]
|
| 46 |
+
# "->" -> ["->"] "+=" -> ["+="]
|
| 47 |
+
# Must come BEFORE single punctuation catch-all
|
| 48 |
+
# [-+*/]= matches +=, -=, *=, /= in one pattern
|
| 49 |
+
OPERATORS = r"==|!=|->|<=|>=|\*\*|//|[-+*/]="
|
| 50 |
+
|
| 51 |
+
# 7. Snake Case Identifiers
|
| 52 |
+
# Matches: words that contain underscores (code identifiers)
|
| 53 |
+
# Example: "snake_case" -> ["snake_case"]
|
| 54 |
+
# "var_name_2" -> ["var_name_2"]
|
| 55 |
+
# "_private" -> ["_private"]
|
| 56 |
+
# Must come BEFORE regular words otherwise "snake" matches first
|
| 57 |
+
SNAKE_CASE = r"[A-Za-z_][A-Za-z0-9_]*"
|
| 58 |
+
|
| 59 |
+
# 8. Regular Unicode Words
|
| 60 |
+
# Matches: any sequence of word characters (letters, digits)
|
| 61 |
+
# \w+ in unicode mode covers non-english letters too
|
| 62 |
+
# Example: "hello" -> ["hello"]
|
| 63 |
+
# "café" -> ["café"]
|
| 64 |
+
WORDS = r"\w+"
|
| 65 |
+
|
| 66 |
+
# 9. Whitespace
|
| 67 |
+
# Newlines are matched separately from spaces/tabs
|
| 68 |
+
# This preserves document structure (paragraph breaks etc.)
|
| 69 |
+
# Example: "\n\n" -> ["\n\n"] " " -> [" "]
|
| 70 |
+
WHITESPACE = r"\n+|[ \t]+"
|
| 71 |
+
|
| 72 |
+
# 10. Punctuation Catch-all
|
| 73 |
+
# Matches any single non-whitespace character that nothing above caught
|
| 74 |
+
# Example: "!" -> ["!"] "@" -> ["@"] "." -> ["."]
|
| 75 |
+
PUNCTUATION = r"[^\s]"
|
| 76 |
+
|
| 77 |
+
# ------------------------------------------------------------------ #
|
| 78 |
+
# Combine all patterns in ORDER - first match wins
|
| 79 |
+
# ------------------------------------------------------------------ #
|
| 80 |
+
|
| 81 |
+
PRETOKENIZER_PATTERN = "|".join([
|
| 82 |
+
CONTRACTIONS, # 1 - most specific first
|
| 83 |
+
ABBREVIATIONS, # 2 - before plain words
|
| 84 |
+
SCIENTIFIC, # 3 - before decimals
|
| 85 |
+
DECIMALS, # 4 - before integers
|
| 86 |
+
INTEGERS, # 5
|
| 87 |
+
OPERATORS, # 6 - before single punctuation
|
| 88 |
+
SNAKE_CASE, # 7 - before plain words
|
| 89 |
+
WORDS, # 8
|
| 90 |
+
WHITESPACE, # 9
|
| 91 |
+
PUNCTUATION, # 10 - catch everything else
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_pretokenizer():
|
| 96 |
+
"""
|
| 97 |
+
Returns a HuggingFace Split pre-tokenizer using our custom regex.
|
| 98 |
+
|
| 99 |
+
Split behavior:
|
| 100 |
+
- pattern : the regex to split/match on
|
| 101 |
+
- behavior : "removed" -> splits on matches and discards them
|
| 102 |
+
"isolated" -> splits on matches and keeps them as tokens
|
| 103 |
+
"merged_with_previous" / "merged_with_next"
|
| 104 |
+
|
| 105 |
+
We use "isolated" because we WANT to keep whitespace, operators,
|
| 106 |
+
punctuation etc. as their own tokens rather than discard them.
|
| 107 |
+
"""
|
| 108 |
+
return Split(
|
| 109 |
+
pattern=Regex(PRETOKENIZER_PATTERN),
|
| 110 |
+
behavior="isolated",
|
| 111 |
+
invert=True # invert=True means: match the pattern and KEEP matches as tokens
|
| 112 |
+
# (rather than treating matches as split points)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ------------------------------------------------------------------ #
|
| 117 |
+
# Quick test - run this file directly to verify behavior
|
| 118 |
+
# ------------------------------------------------------------------ #
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
from tokenizers import Tokenizer
|
| 122 |
+
from tokenizers.models import BPE
|
| 123 |
+
|
| 124 |
+
# Build a bare tokenizer just to test the pre-tokenizer
|
| 125 |
+
tokenizer = Tokenizer(BPE())
|
| 126 |
+
tokenizer.pre_tokenizer = get_pretokenizer()
|
| 127 |
+
|
| 128 |
+
test_cases = [
|
| 129 |
+
# Contractions
|
| 130 |
+
("Contractions", "don't she'll they've"),
|
| 131 |
+
# Abbreviations
|
| 132 |
+
("Abbreviations", "U.S.A has a Ph.D e.g. this"),
|
| 133 |
+
# Scientific notation
|
| 134 |
+
("Scientific", "the value is 1.5e-3 and 2.0E+4"),
|
| 135 |
+
# Decimals
|
| 136 |
+
("Decimals", "pi is 3.14159 and e is 2.718"),
|
| 137 |
+
# Integers
|
| 138 |
+
("Integers", "there are 1000 students in 2024"),
|
| 139 |
+
# Operators
|
| 140 |
+
("Operators", "if x==0 or y!=1 then z+=2"),
|
| 141 |
+
# Snake case
|
| 142 |
+
("Snake case", "my_variable and snake_case_name"),
|
| 143 |
+
# Mixed real world
|
| 144 |
+
("Real world", "The CO2 level is 415.2 ppm\n\nSee e.g. Smith et al."),
|
| 145 |
+
# Code like
|
| 146 |
+
("Code-like", "def my_func(x):\n return x**2 + 1"),
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
print(f"\n{'='*60}")
|
| 150 |
+
print(f" PRE-TOKENIZER TEST")
|
| 151 |
+
print(f"{'='*60}\n")
|
| 152 |
+
|
| 153 |
+
for label, text in test_cases:
|
| 154 |
+
tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
|
| 155 |
+
token_strings = [t[0] for t in tokens] # tokens are (string, offset) tuples
|
| 156 |
+
print(f"[{label}]")
|
| 157 |
+
print(f" Input : {repr(text)}")
|
| 158 |
+
print(f" Tokens : {token_strings}")
|
| 159 |
+
print()
|
tokenizer/tempCodeRunnerFile.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
with open(os.path.join(save_dir, "special_tokens_map.json"), "w") as f:
|
| 3 |
+
json.dump(special_tokens_map, f, indent=2)
|
| 4 |
+
|
| 5 |
+
print("special_tokens_map.json written manually")
|
tokenizer/tokenize_dataset.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenize_dataset.py — Parallel tokenization pipeline
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
Main thread : stream HF dataset → filter → normalize → batch texts
|
| 6 |
+
Worker pool : N_WORKERS processes, each with own loaded tokenizer,
|
| 7 |
+
tokenize batches concurrently using ProcessPoolExecutor
|
| 8 |
+
Main thread : collect results IN ORDER → route train/val → flush shards
|
| 9 |
+
|
| 10 |
+
Why this is faster:
|
| 11 |
+
Old code: stream → [normalize] → [tokenize 1000 docs, 1 CPU] → write
|
| 12 |
+
New code: stream → [normalize] → [tokenize 1000 docs × N cores] → write
|
| 13 |
+
|
| 14 |
+
On 12-core machine: expect 6-10× speedup on tokenization step.
|
| 15 |
+
Bottleneck shifts to HF streaming bandwidth, not CPU.
|
| 16 |
+
|
| 17 |
+
Notes:
|
| 18 |
+
- Workers are initialized ONCE with the tokenizer loaded (no repeated disk reads)
|
| 19 |
+
- Results collected in SUBMISSION ORDER so train/val routing is deterministic
|
| 20 |
+
- Sliding window of MAX_PENDING futures keeps all cores busy without
|
| 21 |
+
unbounded memory growth
|
| 22 |
+
- Ctrl+C safe: flushes remaining buffers before exit
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
import warnings
|
| 29 |
+
import numpy as np
|
| 30 |
+
from collections import deque
|
| 31 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
from transformers import PreTrainedTokenizerFast, logging as hf_logging
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
|
| 36 |
+
# Import normalizer from same directory
|
| 37 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
+
from normalizer import normalization
|
| 39 |
+
|
| 40 |
+
hf_logging.set_verbosity_error()
|
| 41 |
+
warnings.filterwarnings("ignore")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------------ #
|
| 45 |
+
# CONSTANTS
|
| 46 |
+
# ------------------------------------------------------------------ #
|
| 47 |
+
|
| 48 |
+
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
|
| 49 |
+
DATASET_SUBSET = "CC-MAIN-2014-49"
|
| 50 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 51 |
+
TOKENIZER_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
|
| 52 |
+
DATA_DIR = os.path.join(SCRIPT_DIR, "data")
|
| 53 |
+
|
| 54 |
+
MIN_QUALITY = 3
|
| 55 |
+
SHARD_SIZE = 100_000_000 # tokens per shard (~190 MB at uint16)
|
| 56 |
+
BATCH_SIZE = 2_000 # docs per tokenization task (↑ from 1000)
|
| 57 |
+
VAL_RATIO = 100 # every 100th accepted doc → val
|
| 58 |
+
SHUFFLE_BUFFER = 10_000
|
| 59 |
+
MIN_DOC_LENGTH = 100
|
| 60 |
+
DTYPE = np.uint16
|
| 61 |
+
MAX_TOKENS = 3_200_000_000
|
| 62 |
+
|
| 63 |
+
# Parallel workers: leave 2 cores for OS + HF streaming
|
| 64 |
+
N_WORKERS = max(1, os.cpu_count() - 2)
|
| 65 |
+
|
| 66 |
+
# How many tokenization futures to keep in-flight at once
|
| 67 |
+
# = N_WORKERS × 2 keeps the pipeline full without excess memory
|
| 68 |
+
MAX_PENDING = N_WORKERS * 2
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ------------------------------------------------------------------ #
|
| 72 |
+
# WORKER PROCESS — loaded once per process at startup
|
| 73 |
+
# ------------------------------------------------------------------ #
|
| 74 |
+
|
| 75 |
+
# Module-level tokenizer in each worker process
|
| 76 |
+
_worker_tokenizer = None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _worker_init(tokenizer_dir: str):
|
| 80 |
+
"""
|
| 81 |
+
Called ONCE per worker process at startup.
|
| 82 |
+
Loads the tokenizer into the worker's global state.
|
| 83 |
+
Subsequent calls to _tokenize_worker_fn reuse this loaded tokenizer.
|
| 84 |
+
"""
|
| 85 |
+
global _worker_tokenizer
|
| 86 |
+
import warnings
|
| 87 |
+
from transformers import PreTrainedTokenizerFast, logging as hf_log
|
| 88 |
+
hf_log.set_verbosity_error()
|
| 89 |
+
warnings.filterwarnings("ignore")
|
| 90 |
+
_worker_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _tokenize_worker_fn(texts: list) -> list:
|
| 94 |
+
"""
|
| 95 |
+
Tokenizes a batch of pre-normalized texts in a worker process.
|
| 96 |
+
Returns a list of token-ID lists, one per document.
|
| 97 |
+
Each doc ends with <|endoftext|> (added by add_special_tokens=True).
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
texts : list of normalized strings (already filtered, normalized)
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
list of list[int] — token IDs per document
|
| 104 |
+
"""
|
| 105 |
+
global _worker_tokenizer
|
| 106 |
+
encoded = _worker_tokenizer(
|
| 107 |
+
texts,
|
| 108 |
+
add_special_tokens = True, # appends <|endoftext|>
|
| 109 |
+
truncation = False, # keep full document
|
| 110 |
+
padding = False, # no padding (we pack shards)
|
| 111 |
+
return_attention_mask= False, # not needed
|
| 112 |
+
)
|
| 113 |
+
return encoded["input_ids"]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ------------------------------------------------------------------ #
|
| 117 |
+
# SHARD HELPERS
|
| 118 |
+
# ------------------------------------------------------------------ #
|
| 119 |
+
|
| 120 |
+
def get_shard_path(split: str, shard_idx: int) -> str:
|
| 121 |
+
return os.path.join(DATA_DIR, f"{split}_{shard_idx:03d}.bin")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def save_shard(tokens: list, split: str, shard_idx: int):
|
| 125 |
+
arr = np.array(tokens, dtype=DTYPE)
|
| 126 |
+
path = get_shard_path(split, shard_idx)
|
| 127 |
+
arr.tofile(path)
|
| 128 |
+
size_mb = arr.nbytes / 1024 / 1024
|
| 129 |
+
tqdm.write(f" saved {split}_{shard_idx:03d}.bin | {len(tokens):,} tokens | {size_mb:.1f} MB")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ------------------------------------------------------------------ #
|
| 133 |
+
# ROUTE BATCH RESULTS → train / val buffers
|
| 134 |
+
# ------------------------------------------------------------------ #
|
| 135 |
+
|
| 136 |
+
def route_results(
|
| 137 |
+
all_ids : list,
|
| 138 |
+
doc_count_start: int,
|
| 139 |
+
train_buffer : list,
|
| 140 |
+
val_buffer : list,
|
| 141 |
+
train_tokens : int,
|
| 142 |
+
val_tokens : int,
|
| 143 |
+
total_tokens : int,
|
| 144 |
+
) -> tuple:
|
| 145 |
+
"""
|
| 146 |
+
Routes tokenized docs to train or val buffer by doc index.
|
| 147 |
+
Returns updated (train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count).
|
| 148 |
+
"""
|
| 149 |
+
batch_tok_count = 0
|
| 150 |
+
|
| 151 |
+
for i, ids in enumerate(all_ids):
|
| 152 |
+
doc_num = doc_count_start + i
|
| 153 |
+
|
| 154 |
+
if doc_num % VAL_RATIO == 0: # every 100th doc → val
|
| 155 |
+
val_buffer.extend(ids)
|
| 156 |
+
val_tokens += len(ids)
|
| 157 |
+
else:
|
| 158 |
+
train_buffer.extend(ids)
|
| 159 |
+
train_tokens += len(ids)
|
| 160 |
+
|
| 161 |
+
total_tokens += len(ids)
|
| 162 |
+
batch_tok_count += len(ids)
|
| 163 |
+
|
| 164 |
+
return train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ------------------------------------------------------------------ #
|
| 168 |
+
# MAIN PARALLEL TOKENIZATION PIPELINE
|
| 169 |
+
# ------------------------------------------------------------------ #
|
| 170 |
+
|
| 171 |
+
def tokenize_dataset():
|
| 172 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 173 |
+
|
| 174 |
+
print(f"Loading tokenizer from: {TOKENIZER_DIR}")
|
| 175 |
+
print(f" workers : {N_WORKERS} of {os.cpu_count()} CPUs")
|
| 176 |
+
|
| 177 |
+
print(f"\nLoading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
|
| 178 |
+
ds = load_dataset(
|
| 179 |
+
DATASET_NAME,
|
| 180 |
+
name = DATASET_SUBSET,
|
| 181 |
+
split = "train",
|
| 182 |
+
streaming = True,
|
| 183 |
+
).shuffle(buffer_size=SHUFFLE_BUFFER, seed=42)
|
| 184 |
+
|
| 185 |
+
# ---- State ------------------------------------------------------ #
|
| 186 |
+
train_buffer = []
|
| 187 |
+
val_buffer = []
|
| 188 |
+
train_shard = 0
|
| 189 |
+
val_shard = 0
|
| 190 |
+
total_docs = 0
|
| 191 |
+
skipped_docs = 0
|
| 192 |
+
total_tokens = 0
|
| 193 |
+
train_tokens = 0
|
| 194 |
+
val_tokens = 0
|
| 195 |
+
batch_texts = [] # accumulating next batch to submit
|
| 196 |
+
batch_doc_start = 0 # doc index at start of current batch_texts
|
| 197 |
+
|
| 198 |
+
# pending: deque of (future, doc_count_start)
|
| 199 |
+
# We always pop from the LEFT (oldest submission) to preserve order
|
| 200 |
+
pending = deque()
|
| 201 |
+
cap_reached = False
|
| 202 |
+
|
| 203 |
+
# ---- Progress bars ----------------------------------------------- #
|
| 204 |
+
token_bar = tqdm(
|
| 205 |
+
total=MAX_TOKENS,
|
| 206 |
+
desc="tokens",
|
| 207 |
+
unit="tok",
|
| 208 |
+
unit_scale=True,
|
| 209 |
+
unit_divisor=1000,
|
| 210 |
+
colour="green",
|
| 211 |
+
position=0,
|
| 212 |
+
)
|
| 213 |
+
doc_bar = tqdm(
|
| 214 |
+
desc="docs ",
|
| 215 |
+
unit="doc",
|
| 216 |
+
unit_scale=True,
|
| 217 |
+
colour="blue",
|
| 218 |
+
position=1,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
t_start = time.time()
|
| 222 |
+
|
| 223 |
+
# ------------------------------------------------------------------ #
|
| 224 |
+
# DRAIN HELPER — collect the oldest pending future and process it
|
| 225 |
+
# ------------------------------------------------------------------ #
|
| 226 |
+
|
| 227 |
+
def drain_one():
|
| 228 |
+
nonlocal train_buffer, val_buffer, train_shard, val_shard
|
| 229 |
+
nonlocal total_tokens, train_tokens, val_tokens
|
| 230 |
+
|
| 231 |
+
if not pending:
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
future, doc_start = pending.popleft()
|
| 235 |
+
all_ids = future.result() # blocks until this task done
|
| 236 |
+
|
| 237 |
+
(train_buffer, val_buffer,
|
| 238 |
+
train_tokens, val_tokens,
|
| 239 |
+
total_tokens, batch_tok) = route_results(
|
| 240 |
+
all_ids, doc_start,
|
| 241 |
+
train_buffer, val_buffer,
|
| 242 |
+
train_tokens, val_tokens, total_tokens,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
token_bar.update(batch_tok)
|
| 246 |
+
token_bar.set_postfix({
|
| 247 |
+
"train": f"{train_tokens/1e9:.2f}B",
|
| 248 |
+
"val" : f"{val_tokens/1e6:.0f}M",
|
| 249 |
+
"shards": train_shard,
|
| 250 |
+
})
|
| 251 |
+
|
| 252 |
+
# Flush train shards
|
| 253 |
+
while len(train_buffer) >= SHARD_SIZE:
|
| 254 |
+
save_shard(train_buffer[:SHARD_SIZE], "train", train_shard)
|
| 255 |
+
train_buffer = train_buffer[SHARD_SIZE:]
|
| 256 |
+
train_shard += 1
|
| 257 |
+
|
| 258 |
+
# Flush val shards
|
| 259 |
+
while len(val_buffer) >= SHARD_SIZE:
|
| 260 |
+
save_shard(val_buffer[:SHARD_SIZE], "val", val_shard)
|
| 261 |
+
val_buffer = val_buffer[SHARD_SIZE:]
|
| 262 |
+
val_shard += 1
|
| 263 |
+
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
# ------------------------------------------------------------------ #
|
| 267 |
+
# MAIN LOOP with ProcessPoolExecutor
|
| 268 |
+
# ------------------------------------------------------------------ #
|
| 269 |
+
|
| 270 |
+
print(f"\nStarting tokenization...")
|
| 271 |
+
print(f" token target : {MAX_TOKENS:,}")
|
| 272 |
+
print(f" shard size : {SHARD_SIZE:,} tokens")
|
| 273 |
+
print(f" batch size : {BATCH_SIZE} docs")
|
| 274 |
+
print(f" val ratio : every {VAL_RATIO}th doc")
|
| 275 |
+
print(f" quality : int_score >= {MIN_QUALITY}\n")
|
| 276 |
+
|
| 277 |
+
with ProcessPoolExecutor(
|
| 278 |
+
max_workers = N_WORKERS,
|
| 279 |
+
initializer = _worker_init,
|
| 280 |
+
initargs = (TOKENIZER_DIR,),
|
| 281 |
+
) as executor:
|
| 282 |
+
|
| 283 |
+
for doc in ds:
|
| 284 |
+
|
| 285 |
+
# ---- Quality filter ------------------------------------ #
|
| 286 |
+
if doc["int_score"] < MIN_QUALITY:
|
| 287 |
+
skipped_docs += 1
|
| 288 |
+
doc_bar.set_postfix({"skipped": skipped_docs})
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
# ---- Length + normalize -------------------------------- #
|
| 292 |
+
text = doc["text"]
|
| 293 |
+
if len(text) < MIN_DOC_LENGTH:
|
| 294 |
+
skipped_docs += 1
|
| 295 |
+
doc_bar.set_postfix({"skipped": skipped_docs})
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
text = normalization(text)
|
| 299 |
+
if len(text) < MIN_DOC_LENGTH:
|
| 300 |
+
skipped_docs += 1
|
| 301 |
+
doc_bar.set_postfix({"skipped": skipped_docs})
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
batch_texts.append(text)
|
| 305 |
+
total_docs += 1
|
| 306 |
+
doc_bar.update(1)
|
| 307 |
+
|
| 308 |
+
# ---- Submit batch when full ---------------------------- #
|
| 309 |
+
if len(batch_texts) == BATCH_SIZE:
|
| 310 |
+
# Record which doc index this batch starts at
|
| 311 |
+
doc_start = total_docs - BATCH_SIZE
|
| 312 |
+
|
| 313 |
+
future = executor.submit(_tokenize_worker_fn, batch_texts)
|
| 314 |
+
pending.append((future, doc_start))
|
| 315 |
+
batch_texts = []
|
| 316 |
+
|
| 317 |
+
# ---- Backpressure: drain oldest if queue full ------- #
|
| 318 |
+
# This prevents unbounded memory accumulation
|
| 319 |
+
# while keeping all N_WORKERS busy
|
| 320 |
+
while len(pending) >= MAX_PENDING:
|
| 321 |
+
drain_one()
|
| 322 |
+
|
| 323 |
+
# ---- Check token cap -------------------------------- #
|
| 324 |
+
if total_tokens >= MAX_TOKENS:
|
| 325 |
+
tqdm.write(f"\nToken cap reached: {total_tokens:,} tokens from {total_docs:,} docs")
|
| 326 |
+
cap_reached = True
|
| 327 |
+
break
|
| 328 |
+
|
| 329 |
+
# ---- Submit any remaining partial batch -------------------- #
|
| 330 |
+
if batch_texts and not cap_reached:
|
| 331 |
+
doc_start = total_docs - len(batch_texts)
|
| 332 |
+
future = executor.submit(_tokenize_worker_fn, batch_texts)
|
| 333 |
+
pending.append((future, doc_start))
|
| 334 |
+
|
| 335 |
+
# ---- Drain all remaining pending futures ------------------- #
|
| 336 |
+
while pending:
|
| 337 |
+
drain_one()
|
| 338 |
+
|
| 339 |
+
# ---- Close progress bars --------------------------------------- #
|
| 340 |
+
token_bar.close()
|
| 341 |
+
doc_bar.close()
|
| 342 |
+
|
| 343 |
+
# ---- Save remaining partial shards ----------------------------- #
|
| 344 |
+
if train_buffer:
|
| 345 |
+
save_shard(train_buffer, "train", train_shard)
|
| 346 |
+
train_shard += 1
|
| 347 |
+
|
| 348 |
+
if val_buffer:
|
| 349 |
+
save_shard(val_buffer, "val", val_shard)
|
| 350 |
+
val_shard += 1
|
| 351 |
+
|
| 352 |
+
# ---- Final summary --------------------------------------------- #
|
| 353 |
+
print(f"\n{'='*60}")
|
| 354 |
+
print(f" TOKENIZATION COMPLETE")
|
| 355 |
+
print(f"{'='*60}")
|
| 356 |
+
print(f" total docs : {total_docs:,}")
|
| 357 |
+
print(f" skipped docs : {skipped_docs:,}")
|
| 358 |
+
print(f" total tokens : {total_tokens:,}")
|
| 359 |
+
print(f" train tokens : {train_tokens:,}")
|
| 360 |
+
print(f" val tokens : {val_tokens:,}")
|
| 361 |
+
print(f" train shards : {train_shard}")
|
| 362 |
+
print(f" val shards : {val_shard}")
|
| 363 |
+
print(f" data dir : {os.path.abspath(DATA_DIR)}")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ------------------------------------------------------------------ #
|
| 367 |
+
# LOAD SHARDS DURING TRAINING (unchanged)
|
| 368 |
+
# ------------------------------------------------------------------ #
|
| 369 |
+
|
| 370 |
+
def load_shard(split: str, shard_idx: int) -> np.ndarray:
|
| 371 |
+
"""
|
| 372 |
+
Loads a shard as a memory-mapped numpy array.
|
| 373 |
+
The full shard never loads into RAM at once.
|
| 374 |
+
|
| 375 |
+
Usage during training:
|
| 376 |
+
shard = load_shard("train", 0)
|
| 377 |
+
chunk = shard[i : i + 1024]
|
| 378 |
+
"""
|
| 379 |
+
path = get_shard_path(split, shard_idx)
|
| 380 |
+
return np.memmap(path, dtype=DTYPE, mode="r")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# ------------------------------------------------------------------ #
|
| 384 |
+
# ENTRY POINT
|
| 385 |
+
# ------------------------------------------------------------------ #
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
# Windows requires this guard for multiprocessing with spawn start method
|
| 389 |
+
tokenize_dataset()
|
tokenizer/traintokenizer.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from tokenizers import Tokenizer
|
| 3 |
+
|
| 4 |
+
# Import our components
|
| 5 |
+
from normalizer import normalization # our normalize function
|
| 6 |
+
from bpe import build_tokenizer, build_trainer, get_special_token_ids
|
| 7 |
+
|
| 8 |
+
from post_processor import add_post_processor
|
| 9 |
+
# ------------------------------------------------------------------ #
|
| 10 |
+
# CONSTANTS
|
| 11 |
+
# ------------------------------------------------------------------ #
|
| 12 |
+
|
| 13 |
+
DATASET_NAME = "HuggingFaceFW/fineweb-edu"
|
| 14 |
+
DATASET_SUBSET = "CC-MAIN-2014-49"
|
| 15 |
+
MIN_QUALITY = 3 # int_score >= 3 only
|
| 16 |
+
MAX_TOKENS = 25_000_000 # ~100M characters worth, enough for BPE training
|
| 17 |
+
# FineWeb-Edu tokens avg 4-5 chars each
|
| 18 |
+
MIN_DOC_LENGTH = 100 # skip very short documents, likely boilerplate
|
| 19 |
+
import os
|
| 20 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 21 |
+
SAVE_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ------------------------------------------------------------------ #
|
| 25 |
+
# DATA GENERATOR
|
| 26 |
+
# ------------------------------------------------------------------ #
|
| 27 |
+
|
| 28 |
+
def fineweb_edu_iterator(
|
| 29 |
+
max_tokens: int = MAX_TOKENS,
|
| 30 |
+
min_quality: int = MIN_QUALITY,
|
| 31 |
+
min_length: int = MIN_DOC_LENGTH,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Streams FineWeb-Edu documents, filters by quality,
|
| 35 |
+
normalizes text, and yields clean strings for BPE training.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
max_tokens : stop after consuming this many tokens total
|
| 39 |
+
min_quality : only yield docs with int_score >= this value
|
| 40 |
+
min_length : skip docs shorter than this many characters
|
| 41 |
+
|
| 42 |
+
Yields:
|
| 43 |
+
str: normalized, clean document text
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
print(f"Loading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
|
| 47 |
+
ds = load_dataset(
|
| 48 |
+
DATASET_NAME,
|
| 49 |
+
name=DATASET_SUBSET,
|
| 50 |
+
split="train",
|
| 51 |
+
streaming=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
tokens_seen = 0 # running total of tokens consumed
|
| 55 |
+
docs_yielded = 0 # how many docs passed all filters
|
| 56 |
+
docs_skipped = 0 # how many docs were filtered out
|
| 57 |
+
|
| 58 |
+
for doc in ds:
|
| 59 |
+
|
| 60 |
+
# ---- Stop condition ----------------------------------------
|
| 61 |
+
if tokens_seen >= max_tokens:
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
# ---- Quality filter ----------------------------------------
|
| 65 |
+
# int_score is 0-5, we want educational quality >= 3
|
| 66 |
+
if doc["int_score"] < min_quality:
|
| 67 |
+
docs_skipped += 1
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
# ---- Extract and normalize ---------------------------------
|
| 71 |
+
text = doc["text"]
|
| 72 |
+
|
| 73 |
+
# Skip very short documents before normalization
|
| 74 |
+
# (saves compute on boilerplate/empty docs)
|
| 75 |
+
if len(text) < min_length:
|
| 76 |
+
docs_skipped += 1
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# Run our normalization pipeline
|
| 80 |
+
text = normalization(text)
|
| 81 |
+
|
| 82 |
+
# Skip if normalization made it too short
|
| 83 |
+
# (e.g. doc was mostly HTML tags or control chars)
|
| 84 |
+
if len(text) < min_length:
|
| 85 |
+
docs_skipped += 1
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# ---- Track progress ----------------------------------------
|
| 89 |
+
tokens_seen += doc["token_count"]
|
| 90 |
+
docs_yielded += 1
|
| 91 |
+
|
| 92 |
+
# Log progress every 100k documents
|
| 93 |
+
if docs_yielded % 100_000 == 0:
|
| 94 |
+
print(
|
| 95 |
+
f" docs yielded: {docs_yielded:,} | "
|
| 96 |
+
f"docs skipped: {docs_skipped:,} | "
|
| 97 |
+
f"tokens seen: {tokens_seen:,} / {max_tokens:,} "
|
| 98 |
+
f"({100 * tokens_seen / max_tokens:.1f}%)"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
yield text
|
| 102 |
+
|
| 103 |
+
# Final stats
|
| 104 |
+
print(f"\nStream complete:")
|
| 105 |
+
print(f" docs yielded : {docs_yielded:,}")
|
| 106 |
+
print(f" docs skipped : {docs_skipped:,}")
|
| 107 |
+
print(f" tokens seen : {tokens_seen:,}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------------ #
|
| 111 |
+
# TRAINING
|
| 112 |
+
# ------------------------------------------------------------------ #
|
| 113 |
+
|
| 114 |
+
def train_tokenizer() -> Tokenizer:
|
| 115 |
+
"""
|
| 116 |
+
Builds, trains, and saves the tokenizer.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Trained Tokenizer object
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
# Build untrained tokenizer and trainer
|
| 123 |
+
tokenizer = build_tokenizer()
|
| 124 |
+
trainer = build_trainer()
|
| 125 |
+
|
| 126 |
+
print("\nStarting BPE training...")
|
| 127 |
+
print(f" vocab size : {trainer.vocab_size:,}")
|
| 128 |
+
print(f" min frequency : {trainer.min_frequency}")
|
| 129 |
+
print(f" quality filter: int_score >= {MIN_QUALITY}")
|
| 130 |
+
print(f" max tokens : {MAX_TOKENS:,}\n")
|
| 131 |
+
|
| 132 |
+
# train_from_iterator expects an iterable of strings
|
| 133 |
+
# our generator yields one clean document string at a time
|
| 134 |
+
tokenizer.train_from_iterator(
|
| 135 |
+
iterator=fineweb_edu_iterator(),
|
| 136 |
+
trainer=trainer,
|
| 137 |
+
length=MAX_TOKENS, # optional hint for progress bar accuracy
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
print("\nTraining complete.")
|
| 141 |
+
|
| 142 |
+
tokenizer = add_post_processor(tokenizer)
|
| 143 |
+
|
| 144 |
+
# Print special token IDs
|
| 145 |
+
ids = get_special_token_ids(tokenizer)
|
| 146 |
+
print(f"\nSpecial token IDs:")
|
| 147 |
+
for token, token_id in ids.items():
|
| 148 |
+
print(f" {token} -> {token_id}")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Save tokenizer to disk
|
| 152 |
+
tokenizer.save(f"{SAVE_PATH}.json")
|
| 153 |
+
print(f"\nTokenizer saved to: {SAVE_PATH}.json")
|
| 154 |
+
|
| 155 |
+
return tokenizer
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ------------------------------------------------------------------ #
|
| 159 |
+
# QUICK VERIFICATION after training
|
| 160 |
+
# ------------------------------------------------------------------ #
|
| 161 |
+
|
| 162 |
+
def verify_tokenizer(tokenizer: Tokenizer):
|
| 163 |
+
"""
|
| 164 |
+
Runs a few quick checks after training to verify correctness.
|
| 165 |
+
"""
|
| 166 |
+
print("\n" + "="*60)
|
| 167 |
+
print(" TOKENIZER VERIFICATION")
|
| 168 |
+
print("="*60 + "\n")
|
| 169 |
+
|
| 170 |
+
test_cases = [
|
| 171 |
+
"The mitochondria is the powerhouse of the cell.",
|
| 172 |
+
"CO2 levels rose by 1.5e-3 ppm in 2024.",
|
| 173 |
+
"def compute_loss(y_pred, y_true):\n return (y_pred - y_true)**2",
|
| 174 |
+
"U.S.A has a Ph.D program e.g. at MIT.",
|
| 175 |
+
"don't they've she'll",
|
| 176 |
+
"∇f(x) = 0 is a necessary condition.", # tests byte fallback
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
for text in test_cases:
|
| 180 |
+
encoded = tokenizer.encode(text)
|
| 181 |
+
decoded = tokenizer.decode(encoded.ids)
|
| 182 |
+
n_tokens = len(encoded.ids)
|
| 183 |
+
|
| 184 |
+
print(f"Input : {repr(text)}")
|
| 185 |
+
print(f"Tokens : {encoded.tokens}")
|
| 186 |
+
print(f"IDs : {encoded.ids}")
|
| 187 |
+
print(f"N tokens: {n_tokens}")
|
| 188 |
+
print(f"Decoded : {repr(decoded)}")
|
| 189 |
+
print(f"Lossless: {text == decoded}")
|
| 190 |
+
print()
|
| 191 |
+
|
| 192 |
+
# Verify vocab size
|
| 193 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 194 |
+
print(f"Final vocab size: {vocab_size:,}")
|
| 195 |
+
|
| 196 |
+
# Verify endoftext token exists
|
| 197 |
+
eot_id = tokenizer.token_to_id("<|endoftext|>")
|
| 198 |
+
print(f"<|endoftext|> ID: {eot_id}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ------------------------------------------------------------------ #
|
| 202 |
+
# ENTRY POINT
|
| 203 |
+
# ------------------------------------------------------------------ #
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
tokenizer = train_tokenizer()
|
| 207 |
+
verify_tokenizer(tokenizer)
|
tokenizer/wrap_tokenizer.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer
|
| 2 |
+
from transformers import PreTrainedTokenizerFast
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# ------------------------------------------------------------------ #
|
| 7 |
+
# CONSTANTS
|
| 8 |
+
# ------------------------------------------------------------------ #
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
TOKENIZER_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer.json")
|
| 13 |
+
SAVE_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer") # output folder
|
| 14 |
+
MODEL_MAX_LENGTH = 1024 # context length
|
| 15 |
+
PADDING_SIDE = "right" # causal LM standard
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ------------------------------------------------------------------ #
|
| 19 |
+
# WRAP
|
| 20 |
+
# ------------------------------------------------------------------ #
|
| 21 |
+
|
| 22 |
+
def wrap_tokenizer(
|
| 23 |
+
tokenizer_path: str = TOKENIZER_PATH,
|
| 24 |
+
save_dir: str = SAVE_DIR,
|
| 25 |
+
) -> PreTrainedTokenizerFast:
|
| 26 |
+
"""
|
| 27 |
+
Wraps a trained HuggingFace Tokenizer as a PreTrainedTokenizerFast.
|
| 28 |
+
|
| 29 |
+
This gives us:
|
| 30 |
+
- datasets.map() compatibility for bulk tokenization
|
| 31 |
+
- HuggingFace Trainer + DataCollator compatibility
|
| 32 |
+
- Automatic padding, truncation, attention masks
|
| 33 |
+
- from_pretrained() loading support
|
| 34 |
+
- return_tensors="pt" for PyTorch tensors
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
tokenizer_path : path to trained tokenizer .json file
|
| 38 |
+
save_dir : folder to save the wrapped tokenizer
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
PreTrainedTokenizerFast ready for training
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
print(f"Loading trained tokenizer from: {tokenizer_path}")
|
| 45 |
+
base_tokenizer = Tokenizer.from_file(tokenizer_path)
|
| 46 |
+
|
| 47 |
+
# ---- Wrap --------------------------------------------------------
|
| 48 |
+
# We map <|endoftext|> to all three roles:
|
| 49 |
+
#
|
| 50 |
+
# eos_token - end of sequence marker, used during generation
|
| 51 |
+
# to know when to stop
|
| 52 |
+
#
|
| 53 |
+
# bos_token - beginning of sequence, GPT-2 style uses eos
|
| 54 |
+
# for both since there is no separate BOS token
|
| 55 |
+
#
|
| 56 |
+
# pad_token - safe to reuse eos here because we are packing
|
| 57 |
+
# sequences and will never actually pad during
|
| 58 |
+
# pretraining. Defined so HuggingFace doesn't
|
| 59 |
+
# complain about missing pad token
|
| 60 |
+
#
|
| 61 |
+
# unk_token - None because byte-level means no unknowns ever
|
| 62 |
+
|
| 63 |
+
tokenizer = PreTrainedTokenizerFast(
|
| 64 |
+
tokenizer_object=base_tokenizer,
|
| 65 |
+
|
| 66 |
+
# Special token mappings
|
| 67 |
+
eos_token="<|endoftext|>",
|
| 68 |
+
bos_token="<|endoftext|>",
|
| 69 |
+
pad_token="<|endoftext|>",
|
| 70 |
+
unk_token=None,
|
| 71 |
+
|
| 72 |
+
# Context length
|
| 73 |
+
model_max_length=MODEL_MAX_LENGTH,
|
| 74 |
+
|
| 75 |
+
# Padding behavior
|
| 76 |
+
padding_side=PADDING_SIDE,
|
| 77 |
+
|
| 78 |
+
# Truncation side - truncate from the right
|
| 79 |
+
# (keep the beginning of the sequence, drop the end)
|
| 80 |
+
truncation_side="right",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
tokenizer.add_special_tokens({
|
| 84 |
+
"eos_token": "<|endoftext|>",
|
| 85 |
+
"bos_token": "<|endoftext|>",
|
| 86 |
+
"pad_token": "<|endoftext|>",
|
| 87 |
+
})
|
| 88 |
+
special_tokens_map = {
|
| 89 |
+
"bos_token": "<|endoftext|>",
|
| 90 |
+
"eos_token": "<|endoftext|>",
|
| 91 |
+
"pad_token": "<|endoftext|>",
|
| 92 |
+
}
|
| 93 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
with open(os.path.join(save_dir, "special_tokens_map.json"), "w") as f:
|
| 96 |
+
json.dump(special_tokens_map, f, indent=2)
|
| 97 |
+
|
| 98 |
+
print("special_tokens_map.json written manually")
|
| 99 |
+
# ---- Save --------------------------------------------------------
|
| 100 |
+
# Saves three files to save_dir/:
|
| 101 |
+
# tokenizer.json - the trained BPE tokenizer
|
| 102 |
+
# tokenizer_config.json - max length, pad token, special tokens
|
| 103 |
+
# special_tokens_map.json - maps eos/bos/pad to actual tokens
|
| 104 |
+
tokenizer.save_pretrained(save_dir)
|
| 105 |
+
print(f"Tokenizer saved to: {save_dir}/")
|
| 106 |
+
print(f" tokenizer.json")
|
| 107 |
+
print(f" tokenizer_config.json")
|
| 108 |
+
print(f" special_tokens_map.json")
|
| 109 |
+
|
| 110 |
+
return tokenizer
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ------------------------------------------------------------------ #
|
| 114 |
+
# VERIFICATION
|
| 115 |
+
# ------------------------------------------------------------------ #
|
| 116 |
+
|
| 117 |
+
def verify_wrapped_tokenizer(tokenizer: PreTrainedTokenizerFast):
|
| 118 |
+
"""
|
| 119 |
+
Verifies the wrapped tokenizer behaves correctly.
|
| 120 |
+
Tests encoding, decoding, padding, truncation and batch encoding.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
print("\n" + "="*60)
|
| 124 |
+
print(" WRAPPED TOKENIZER VERIFICATION")
|
| 125 |
+
print("="*60 + "\n")
|
| 126 |
+
|
| 127 |
+
eot_id = tokenizer.eos_token_id
|
| 128 |
+
|
| 129 |
+
# ---- 1. Basic config -----------------------------------------
|
| 130 |
+
print("Config:")
|
| 131 |
+
print(f" vocab size : {tokenizer.vocab_size:,}")
|
| 132 |
+
print(f" model_max_length : {tokenizer.model_max_length}")
|
| 133 |
+
print(f" padding_side : {tokenizer.padding_side}")
|
| 134 |
+
print(f" eos_token : {tokenizer.eos_token!r} (ID: {eot_id})")
|
| 135 |
+
print(f" bos_token : {tokenizer.bos_token!r}")
|
| 136 |
+
print(f" pad_token : {tokenizer.pad_token!r} (ID: {tokenizer.pad_token_id})")
|
| 137 |
+
print(f" unk_token : {tokenizer.unk_token!r}")
|
| 138 |
+
print()
|
| 139 |
+
|
| 140 |
+
# ---- 2. Basic encode/decode ----------------------------------
|
| 141 |
+
text = "The mitochondria is the powerhouse of the cell."
|
| 142 |
+
encoded = tokenizer(text)
|
| 143 |
+
decoded = tokenizer.decode(encoded["input_ids"])
|
| 144 |
+
|
| 145 |
+
print("Basic encode/decode:")
|
| 146 |
+
print(f" input : {repr(text)}")
|
| 147 |
+
print(f" input_ids: {encoded['input_ids']}")
|
| 148 |
+
print(f" decoded : {repr(decoded)}")
|
| 149 |
+
print()
|
| 150 |
+
|
| 151 |
+
# ---- 3. Padding ----------------------------------------------
|
| 152 |
+
# Batch of two sequences with different lengths
|
| 153 |
+
# shorter one should be right-padded to match the longer
|
| 154 |
+
batch = [
|
| 155 |
+
"Short sentence.",
|
| 156 |
+
"This is a much longer sentence that has more tokens in it.",
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
encoded_batch = tokenizer(
|
| 160 |
+
batch,
|
| 161 |
+
padding=True, # pad to longest in batch
|
| 162 |
+
return_tensors="pt", # return PyTorch tensors
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
print("Batch padding (right padding):")
|
| 166 |
+
print(f" input_ids shape : {encoded_batch['input_ids'].shape}")
|
| 167 |
+
print(f" attention_mask shape : {encoded_batch['attention_mask'].shape}")
|
| 168 |
+
print(f" input_ids[0] : {encoded_batch['input_ids'][0].tolist()}")
|
| 169 |
+
print(f" input_ids[1] : {encoded_batch['input_ids'][1].tolist()}")
|
| 170 |
+
print(f" attention_mask[0] : {encoded_batch['attention_mask'][0].tolist()}")
|
| 171 |
+
print()
|
| 172 |
+
|
| 173 |
+
# ---- 4. Truncation -------------------------------------------
|
| 174 |
+
# Sequence longer than model_max_length should be truncated
|
| 175 |
+
long_text = "word " * 2000 # 2000 words >> 1024 tokens
|
| 176 |
+
encoded_long = tokenizer(
|
| 177 |
+
long_text,
|
| 178 |
+
truncation=True,
|
| 179 |
+
max_length=MODEL_MAX_LENGTH,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
print("Truncation:")
|
| 183 |
+
print(f" input length : {len(long_text.split())} words")
|
| 184 |
+
print(f" token count : {len(encoded_long['input_ids'])} (max: {MODEL_MAX_LENGTH})")
|
| 185 |
+
print(f" truncated : {len(encoded_long['input_ids']) <= MODEL_MAX_LENGTH}")
|
| 186 |
+
print()
|
| 187 |
+
|
| 188 |
+
# ---- 5. Load from disk and verify ----------------------------
|
| 189 |
+
print("Loading from disk:")
|
| 190 |
+
reloaded = PreTrainedTokenizerFast.from_pretrained(SAVE_DIR)
|
| 191 |
+
reloaded_ids = reloaded(text)["input_ids"]
|
| 192 |
+
original_ids = encoded["input_ids"]
|
| 193 |
+
match = reloaded_ids == original_ids
|
| 194 |
+
|
| 195 |
+
print(f" from_pretrained() : OK")
|
| 196 |
+
print(f" IDs match original: {match}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ------------------------------------------------------------------ #
|
| 200 |
+
# ENTRY POINT
|
| 201 |
+
# ------------------------------------------------------------------ #
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
tokenizer = wrap_tokenizer()
|
| 205 |
+
verify_wrapped_tokenizer(tokenizer)
|
| 206 |
+
|
| 207 |
+
print("\n" + "="*60)
|
| 208 |
+
print(" USAGE EXAMPLES")
|
| 209 |
+
print("="*60)
|
| 210 |
+
print("""
|
| 211 |
+
# Load anywhere with one line
|
| 212 |
+
from transformers import PreTrainedTokenizerFast
|
| 213 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained("fineweb_edu_tokenizer")
|
| 214 |
+
|
| 215 |
+
# Single encode
|
| 216 |
+
ids = tokenizer("Hello world")["input_ids"]
|
| 217 |
+
|
| 218 |
+
# Batch encode with padding and tensors
|
| 219 |
+
batch = tokenizer(
|
| 220 |
+
["sentence one", "sentence two"],
|
| 221 |
+
padding=True,
|
| 222 |
+
truncation=True,
|
| 223 |
+
max_length=1024,
|
| 224 |
+
return_tensors="pt",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Decode
|
| 228 |
+
text = tokenizer.decode(ids, skip_special_tokens=True)
|
| 229 |
+
|
| 230 |
+
# Get eos token id (use as document separator when packing)
|
| 231 |
+
eot_id = tokenizer.eos_token_id
|
| 232 |
+
""")
|
tokenizer_walkthrough.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Walkthrough: SLLM Custom BPE Tokenizer
|
| 2 |
+
|
| 3 |
+
This document explains the architecture, execution pipeline, and design choices of the custom **Byte-Pair Encoding (BPE)** tokenizer implemented in the `tokenizer/` directory of the `sllm` project.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🏗️ Overall Architecture & Pipeline
|
| 8 |
+
|
| 9 |
+
The SLLM tokenizer is a custom-built BPE tokenizer tailored for pre-training small language models on the educational subset of HuggingFace's **FineWeb-Edu** dataset. It integrates custom text normalization, a regex-based pre-tokenization strategy, standard BPE training with byte-level fallback, and packaging utility scripts for high-performance training.
|
| 10 |
+
|
| 11 |
+
```mermaid
|
| 12 |
+
graph TD
|
| 13 |
+
A[Raw Text Stream] --> B[normalizer.py: Normalization]
|
| 14 |
+
B --> C[pretokenizer.py: Custom Regex Split]
|
| 15 |
+
C --> D[bpe.py: Byte-Level Encoding]
|
| 16 |
+
D --> E[traintokenizer.py: BPE Trainer]
|
| 17 |
+
E --> F[post_processor.py: Template Post-Processing]
|
| 18 |
+
F --> G[wrap_tokenizer.py: PreTrainedTokenizerFast Wrapper]
|
| 19 |
+
G --> H[tokenize_dataset.py: Packed binary .bin Shards]
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 📁 Component-by-Component Breakdown
|
| 25 |
+
|
| 26 |
+
### 1. `normalizer.py` (Text Normalization)
|
| 27 |
+
Before any splitting occurs, the raw input text is standardized and cleaned to eliminate noise while preserving syntax and code structure:
|
| 28 |
+
* **HTML Stripping & Decoding**: Removes HTML tags using regex and decodes HTML entities (e.g., `&` $\rightarrow$ `&`).
|
| 29 |
+
* **Unicode Normalization**: Performs **NFC** normalization to ensure characters like accented letters are represented consistently.
|
| 30 |
+
* **Noise Removal**: Eliminates raw control characters, zero-width characters (e.g., zero-width spaces/joins), and the Unicode replacement character (`\ufffd`).
|
| 31 |
+
* **Whitespace Control**:
|
| 32 |
+
* Collapses multiple consecutive spaces into a single space (preserving leading tabs for code indentation).
|
| 33 |
+
* Cleans trailing whitespaces at the end of lines.
|
| 34 |
+
* Collapses 3 or more consecutive newlines into exactly two newlines (`\n\n`) to preserve paragraph structure.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
### 2. `pretokenizer.py` (Custom Regex Segmentation)
|
| 39 |
+
Instead of relying on standard GPT-2/Llama pre-tokenization, this model implements a custom, ordered, priority-based regex pre-tokenizer:
|
| 40 |
+
1. **Contractions**: `'s`, `'t`, `'re`, `'ve`, `'ll`, `'m`, `'d`.
|
| 41 |
+
2. **Abbreviations**: Acronyms and shorthand (e.g., `U.S.A`, `e.g.`, `Ph.D`).
|
| 42 |
+
3. **Scientific Notation**: E.g., `1.5e-3`, `3e10`, `2.0E+4` (evaluated *before* decimals to avoid splitting).
|
| 43 |
+
4. **Decimal Numbers**: E.g., `3.14` (evaluated *before* integers).
|
| 44 |
+
5. **Integers**: E.g., `42`, `1984`.
|
| 45 |
+
6. **Multi-character Operators**: Common coding operators like `==`, `!=`, `->`, `<=`, `>=`, `**`, `//`, `+=`, `-=`, `*=`, `/=`.
|
| 46 |
+
7. **Snake Case Identifiers**: E.g., `snake_case`, `_private` (evaluated *before* plain words for clean code representation).
|
| 47 |
+
8. **Regular Unicode Words**: Alphanumeric words covering non-English languages.
|
| 48 |
+
9. **Whitespace**: Preserves sequences of spaces/tabs separately from newlines to keep structural formatting.
|
| 49 |
+
10. **Punctuation Catch-all**: Individual punctuation characters.
|
| 50 |
+
|
| 51 |
+
> [!NOTE]
|
| 52 |
+
> The pre-tokenizer uses HuggingFace's `Split` pre-tokenizer with `behavior="isolated"` and `invert=True`, meaning matched strings are isolated and kept as distinct, individual tokens instead of being discarded as delimiters.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
### 3. `bpe.py` (BPE Model Configuration)
|
| 57 |
+
Defines the base tokenizer pipeline:
|
| 58 |
+
* **Byte Fallback**: Configures the BPE model with `unk_token=None` and `byte_fallback=True`. This guarantees that *every* character maps to at least one byte-level token, resulting in **zero out-of-vocabulary (OOV)** issues.
|
| 59 |
+
* **Pre-Tokenizer Chain**: Sequentially runs the custom Regex pre-tokenizer followed by `ByteLevel(add_prefix_space=False)` to translate character segments to their corresponding byte values.
|
| 60 |
+
* **Decoder**: Instantiates the standard `ByteLevelDecoder` to reverse byte conversions, allowing human-readable decoded strings.
|
| 61 |
+
* **Trainer Config**: Builds a `BpeTrainer` specifying a vocabulary of `32,000` tokens, minimum merge frequency of `3`, and initial alphabet containing all `256` bytes to enforce the fallback capability.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
### 4. `post_processor.py` (Sequence Endings)
|
| 66 |
+
Once BPE rules have been learned and vocabulary IDs are assigned:
|
| 67 |
+
* Attaches `TemplateProcessing` to automatically append `<|endoftext|>` to every sequence.
|
| 68 |
+
* For single documents, it maps to `[tokens...] <|endoftext|>`.
|
| 69 |
+
* For sequence pairs (useful in downstream tasks like Question-Answering), it automatically maps to `[tokens_A...] <|endoftext|> [tokens_B...] <|endoftext|>`.
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
### 5. `traintokenizer.py` (BPE Training Loop)
|
| 74 |
+
* Streams the educational subset of `HuggingFaceFW/fineweb-edu` (`CC-MAIN-2014-49` split).
|
| 75 |
+
* Filters out low-quality documents (requires educational score `int_score >= 3`) and documents shorter than 100 characters.
|
| 76 |
+
* Feeds documents iteratively into BPE training via `train_from_iterator()`.
|
| 77 |
+
* Adds the post-processor and runs comprehensive verification checks against edge cases (equations, scientific numbers, code snippets, byte fallbacks, and contractions).
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
### 6. `wrap_tokenizer.py` (HuggingFace Integration)
|
| 82 |
+
Wraps the trained HuggingFace BPE model into `PreTrainedTokenizerFast` from `transformers`:
|
| 83 |
+
* Associates `<|endoftext|>` as the `bos_token`, `eos_token`, and `pad_token`.
|
| 84 |
+
* Enables compatibility with the `datasets.map()` bulk utility, the HuggingFace Trainer, and PyTorch dataloaders.
|
| 85 |
+
* Standardizes right-padding, right-truncation, and context length configurations (`model_max_length=1024`).
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
### 7. `tokenize_dataset.py` (Dataset Packing)
|
| 90 |
+
A highly optimized bulk-tokenization utility:
|
| 91 |
+
* Tokenizes the streamed FineWeb-Edu dataset up to a target cap (e.g., `3.2` Billion tokens).
|
| 92 |
+
* Performs a 99% train and 1% validation split (every 100th document is routed to the validation buffer).
|
| 93 |
+
* Concatenates/packs documents sequentially (using `<|endoftext|>` as the document boundary) and writes them to disk as high-performance flat binary shards (`.bin` files of `np.uint16` type).
|
| 94 |
+
* Standard shard size is `100,000,000` tokens.
|
| 95 |
+
* Provides a memory-mapped helper `load_shard(split, shard_idx)` using `np.memmap` so that models can stream training batches without loading multi-gigabyte files into RAM.
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## 💡 Key Design Highlights
|
| 100 |
+
|
| 101 |
+
> [!TIP]
|
| 102 |
+
> **Why Byte Fallback is Critical**: By initializing the alphabet with 256 unique byte values and enabling fallback, characters like math symbols ($\nabla$) or emojis don't fail or return an `<unk>` token; instead, they represent themselves as their raw UTF-8 bytes (e.g., $\nabla$ is parsed perfectly as `<0xE2><0x88><0x87>`).
|
| 103 |
+
|
| 104 |
+
> [!TIP]
|
| 105 |
+
> **Code-Aware Features**: The combination of preserving leading tabs in `normalizer.py`, isolating multi-character operators (`==`, `!=`, etc.), and protecting `snake_case` variables guarantees high-fidelity, compact token representation when the language model is trained on code.
|
train.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py — SLLM Training Loop
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
--max_steps N Run for exactly N steps then save checkpoint and exit.
|
| 6 |
+
Omit to train indefinitely (until Ctrl+C or data exhausted).
|
| 7 |
+
--resume Resume from the latest checkpoint in --run_dir.
|
| 8 |
+
--config 100M|150M Choose model config (default: 100M).
|
| 9 |
+
--synthetic Use synthetic data (for testing without real shards).
|
| 10 |
+
|
| 11 |
+
Features:
|
| 12 |
+
- bf16 mixed precision (autocast) + GradScaler for stable training
|
| 13 |
+
- Gradient accumulation: --grad_accum N steps per optimizer update
|
| 14 |
+
- Gradient checkpointing: --grad_checkpoint to save VRAM
|
| 15 |
+
- Cosine LR schedule with linear warmup
|
| 16 |
+
- Checkpoint save every --save_every steps (and on clean exit/Ctrl+C)
|
| 17 |
+
- Metric logging to <run_dir>/train_log.jsonl (one JSON line per log step)
|
| 18 |
+
- Real-time terminal progress with tqdm
|
| 19 |
+
|
| 20 |
+
Recommended for RTX 3050 4GB:
|
| 21 |
+
python train.py --config 100M --batch_size 4 --grad_accum 8 \\
|
| 22 |
+
--grad_checkpoint --max_steps 1000
|
| 23 |
+
|
| 24 |
+
Run for N steps, stop, then resume:
|
| 25 |
+
python train.py --max_steps 500 --run_dir runs/my_run
|
| 26 |
+
python train.py --max_steps 500 --run_dir runs/my_run --resume
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
import json
|
| 32 |
+
import math
|
| 33 |
+
import time
|
| 34 |
+
import signal
|
| 35 |
+
import argparse
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn.functional as F
|
| 39 |
+
from torch.amp import autocast, GradScaler
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 43 |
+
from model.config import SLLM_100M, SLLM_150M, ModelConfig
|
| 44 |
+
from model.model import SLLM
|
| 45 |
+
from data.dataloader import build_dataloader
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ------------------------------------------------------------------ #
|
| 49 |
+
# ARG PARSING
|
| 50 |
+
# ------------------------------------------------------------------ #
|
| 51 |
+
|
| 52 |
+
def parse_args():
|
| 53 |
+
p = argparse.ArgumentParser(description="SLLM Training Loop")
|
| 54 |
+
|
| 55 |
+
# Run management
|
| 56 |
+
p.add_argument("--run_dir", type=str, default="runs/run_001", help="Directory for checkpoints and logs")
|
| 57 |
+
p.add_argument("--run_name", type=str, default=None, help="Override run name (defaults to run_dir basename)")
|
| 58 |
+
p.add_argument("--resume", action="store_true", help="Resume from latest checkpoint in run_dir")
|
| 59 |
+
p.add_argument("--max_steps", type=int, default=None, help="Absolute step target — stop when step reaches this number.")
|
| 60 |
+
p.add_argument("--extra_steps", type=int, default=None, help="Run N MORE steps from current checkpoint (relative). Converted to --max_steps internally.")
|
| 61 |
+
|
| 62 |
+
# Model
|
| 63 |
+
p.add_argument("--config", type=str, default="100M", choices=["100M", "150M"])
|
| 64 |
+
|
| 65 |
+
# Data
|
| 66 |
+
p.add_argument("--data_dir", type=str, default="tokenizer/data")
|
| 67 |
+
p.add_argument("--synthetic", action="store_true", help="Use synthetic random data (for testing)")
|
| 68 |
+
p.add_argument("--num_workers",type=int, default=2)
|
| 69 |
+
|
| 70 |
+
# Training
|
| 71 |
+
p.add_argument("--batch_size", type=int, default=4, help="Per-device batch size")
|
| 72 |
+
p.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps")
|
| 73 |
+
p.add_argument("--max_lr", type=float, default=3e-4)
|
| 74 |
+
p.add_argument("--min_lr", type=float, default=3e-5)
|
| 75 |
+
p.add_argument("--warmup_steps", type=int, default=100)
|
| 76 |
+
p.add_argument("--weight_decay", type=float, default=0.1)
|
| 77 |
+
p.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping norm (0 = disabled)")
|
| 78 |
+
|
| 79 |
+
# Memory
|
| 80 |
+
p.add_argument("--grad_checkpoint", action="store_true", help="Enable gradient checkpointing (saves VRAM, slower)")
|
| 81 |
+
p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
| 82 |
+
|
| 83 |
+
# Logging / Saving
|
| 84 |
+
p.add_argument("--log_every", type=int, default=10, help="Log metrics every N optimizer steps")
|
| 85 |
+
p.add_argument("--save_every", type=int, default=500, help="Save checkpoint every N optimizer steps")
|
| 86 |
+
p.add_argument("--val_every", type=int, default=250, help="Run validation every N optimizer steps")
|
| 87 |
+
p.add_argument("--val_steps", type=int, default=20, help="Number of val batches to average")
|
| 88 |
+
|
| 89 |
+
return p.parse_args()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ------------------------------------------------------------------ #
|
| 93 |
+
# LEARNING RATE SCHEDULE
|
| 94 |
+
# ------------------------------------------------------------------ #
|
| 95 |
+
|
| 96 |
+
def get_lr(step: int, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float) -> float:
|
| 97 |
+
"""
|
| 98 |
+
Linear warmup then cosine decay.
|
| 99 |
+
If total_steps is None (training indefinitely), uses a fixed 10k step decay window.
|
| 100 |
+
"""
|
| 101 |
+
# Linear warmup
|
| 102 |
+
if step < warmup_steps:
|
| 103 |
+
return max_lr * (step + 1) / warmup_steps
|
| 104 |
+
|
| 105 |
+
# After decay: hold at min_lr
|
| 106 |
+
decay_steps = total_steps if total_steps else 10_000
|
| 107 |
+
if step >= decay_steps:
|
| 108 |
+
return min_lr
|
| 109 |
+
|
| 110 |
+
# Cosine decay
|
| 111 |
+
progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps)
|
| 112 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 113 |
+
return min_lr + coeff * (max_lr - min_lr)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ------------------------------------------------------------------ #
|
| 117 |
+
# OPTIMIZER (AdamW with selective weight decay)
|
| 118 |
+
# ------------------------------------------------------------------ #
|
| 119 |
+
|
| 120 |
+
def build_optimizer(model: SLLM, lr: float, weight_decay: float) -> torch.optim.AdamW:
|
| 121 |
+
"""
|
| 122 |
+
AdamW with weight decay applied only to 2D params (Linear weights).
|
| 123 |
+
Excludes: embeddings, norms (RMSNorm weight vectors), biases.
|
| 124 |
+
|
| 125 |
+
This is the standard approach from GPT-2/NanoGPT.
|
| 126 |
+
"""
|
| 127 |
+
decay_params = []
|
| 128 |
+
no_decay_params = []
|
| 129 |
+
|
| 130 |
+
for name, param in model.named_parameters():
|
| 131 |
+
if not param.requires_grad:
|
| 132 |
+
continue
|
| 133 |
+
# 2D tensors (weight matrices) get weight decay
|
| 134 |
+
if param.dim() >= 2:
|
| 135 |
+
decay_params.append(param)
|
| 136 |
+
else:
|
| 137 |
+
# 1D: norm weights, biases, embeddings
|
| 138 |
+
no_decay_params.append(param)
|
| 139 |
+
|
| 140 |
+
optim_groups = [
|
| 141 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
| 142 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
n_decay = sum(p.numel() for p in decay_params)
|
| 146 |
+
n_no_decay = sum(p.numel() for p in no_decay_params)
|
| 147 |
+
print(f" Optimizer: {n_decay/1e6:.1f}M decay params | {n_no_decay/1e6:.1f}M no-decay params")
|
| 148 |
+
|
| 149 |
+
return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ------------------------------------------------------------------ #
|
| 153 |
+
# CHECKPOINT SAVE / LOAD
|
| 154 |
+
# ------------------------------------------------------------------ #
|
| 155 |
+
|
| 156 |
+
def save_checkpoint(path: str, model: SLLM, optimizer, step: int, args, loss: float):
|
| 157 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 158 |
+
torch.save({
|
| 159 |
+
"step": step,
|
| 160 |
+
"model_state_dict": model.state_dict(),
|
| 161 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 162 |
+
"loss": loss,
|
| 163 |
+
"config_name": args.config,
|
| 164 |
+
}, path)
|
| 165 |
+
print(f"\n [CKPT] Saved checkpoint: {path} (step={step}, loss={loss:.4f})")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def load_checkpoint(run_dir: str, model: SLLM, optimizer, device):
|
| 169 |
+
"""Loads the latest checkpoint from run_dir. Returns step number."""
|
| 170 |
+
ckpts = sorted([
|
| 171 |
+
f for f in os.listdir(run_dir)
|
| 172 |
+
if f.startswith("ckpt_") and f.endswith(".pt")
|
| 173 |
+
])
|
| 174 |
+
if not ckpts:
|
| 175 |
+
raise FileNotFoundError(f"No checkpoints found in {run_dir}")
|
| 176 |
+
|
| 177 |
+
path = os.path.join(run_dir, ckpts[-1])
|
| 178 |
+
ckpt = torch.load(path, map_location=device, weights_only=False)
|
| 179 |
+
|
| 180 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 181 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 182 |
+
|
| 183 |
+
step = ckpt["step"]
|
| 184 |
+
loss = ckpt.get("loss", float("nan"))
|
| 185 |
+
print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})")
|
| 186 |
+
return step
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ------------------------------------------------------------------ #
|
| 190 |
+
# VALIDATION
|
| 191 |
+
# ------------------------------------------------------------------ #
|
| 192 |
+
|
| 193 |
+
@torch.no_grad()
|
| 194 |
+
def estimate_val_loss(model, val_loader, val_steps: int, device, dtype_ctx) -> float:
|
| 195 |
+
model.eval()
|
| 196 |
+
losses = []
|
| 197 |
+
for i, (x, y) in enumerate(val_loader):
|
| 198 |
+
if i >= val_steps:
|
| 199 |
+
break
|
| 200 |
+
x, y = x.to(device), y.to(device)
|
| 201 |
+
with dtype_ctx:
|
| 202 |
+
_, loss = model(x, y)
|
| 203 |
+
losses.append(loss.item())
|
| 204 |
+
model.train()
|
| 205 |
+
return sum(losses) / len(losses) if losses else float("nan")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ------------------------------------------------------------------ #
|
| 209 |
+
# METRIC LOGGING
|
| 210 |
+
# ------------------------------------------------------------------ #
|
| 211 |
+
|
| 212 |
+
class MetricLogger:
|
| 213 |
+
"""Appends one JSON line per step to train_log.jsonl."""
|
| 214 |
+
|
| 215 |
+
def __init__(self, log_path: str):
|
| 216 |
+
self.log_path = log_path
|
| 217 |
+
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
| 218 |
+
# Don't clear existing log when resuming — append
|
| 219 |
+
print(f" [LOG] Logging to: {log_path}")
|
| 220 |
+
|
| 221 |
+
def log(self, **kwargs):
|
| 222 |
+
with open(self.log_path, "a") as f:
|
| 223 |
+
f.write(json.dumps(kwargs) + "\n")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ------------------------------------------------------------------ #
|
| 227 |
+
# MAIN TRAINING LOOP
|
| 228 |
+
# ------------------------------------------------------------------ #
|
| 229 |
+
|
| 230 |
+
def train():
|
| 231 |
+
args = parse_args()
|
| 232 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 233 |
+
print(f"\nDevice : {device}")
|
| 234 |
+
if device.type == "cuda":
|
| 235 |
+
print(f"GPU : {torch.cuda.get_device_name(0)}")
|
| 236 |
+
print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 237 |
+
|
| 238 |
+
# ---- dtype context --------------------------------------------- #
|
| 239 |
+
if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
|
| 240 |
+
dtype_torch = torch.bfloat16
|
| 241 |
+
dtype_name = "bf16"
|
| 242 |
+
elif args.dtype == "fp16" and device.type == "cuda":
|
| 243 |
+
dtype_torch = torch.float16
|
| 244 |
+
dtype_name = "fp16"
|
| 245 |
+
else:
|
| 246 |
+
dtype_torch = torch.float32
|
| 247 |
+
dtype_name = "fp32"
|
| 248 |
+
|
| 249 |
+
print(f"dtype : {dtype_name}")
|
| 250 |
+
use_amp = dtype_torch in (torch.float16, torch.bfloat16)
|
| 251 |
+
dtype_ctx = autocast(device_type=device.type, dtype=dtype_torch) if use_amp else torch.no_grad().__class__()
|
| 252 |
+
scaler = GradScaler(enabled=(dtype_torch == torch.float16)) # bf16 doesn't need scaler
|
| 253 |
+
|
| 254 |
+
# ---- Auto-detect config on resume ------------------------------ #
|
| 255 |
+
if args.resume:
|
| 256 |
+
try:
|
| 257 |
+
ckpts = sorted([
|
| 258 |
+
f for f in os.listdir(args.run_dir)
|
| 259 |
+
if f.startswith("ckpt_") and f.endswith(".pt")
|
| 260 |
+
])
|
| 261 |
+
if ckpts:
|
| 262 |
+
ckpt_path = os.path.join(args.run_dir, ckpts[-1])
|
| 263 |
+
_tmp_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 264 |
+
if "config_name" in _tmp_ckpt and _tmp_ckpt["config_name"] != args.config:
|
| 265 |
+
print(f" [CKPT] Auto-switching config from '{args.config}' to '{_tmp_ckpt['config_name']}' to match checkpoint.")
|
| 266 |
+
args.config = _tmp_ckpt["config_name"]
|
| 267 |
+
del _tmp_ckpt
|
| 268 |
+
except Exception:
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
# ---- Model ----------------------------------------------------- #
|
| 272 |
+
cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M}
|
| 273 |
+
cfg = cfg_map[args.config]
|
| 274 |
+
model = SLLM(cfg).to(device)
|
| 275 |
+
|
| 276 |
+
if args.grad_checkpoint:
|
| 277 |
+
model.enable_gradient_checkpointing()
|
| 278 |
+
print(" Gradient checkpointing: ON")
|
| 279 |
+
|
| 280 |
+
print(f"\nModel : SLLM-{args.config} ({model.count_params()/1e6:.1f}M params)")
|
| 281 |
+
print(f"Config : {cfg}")
|
| 282 |
+
|
| 283 |
+
# ---- Optimizer ------------------------------------------------- #
|
| 284 |
+
optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay)
|
| 285 |
+
|
| 286 |
+
# ---- Data ------------------------------------------------------ #
|
| 287 |
+
train_loader = build_dataloader(
|
| 288 |
+
data_dir = args.data_dir,
|
| 289 |
+
split = "train",
|
| 290 |
+
context_length = cfg.context_length,
|
| 291 |
+
batch_size = args.batch_size,
|
| 292 |
+
num_workers = args.num_workers,
|
| 293 |
+
use_synthetic = args.synthetic,
|
| 294 |
+
vocab_size = cfg.vocab_size,
|
| 295 |
+
)
|
| 296 |
+
val_loader = build_dataloader(
|
| 297 |
+
data_dir = args.data_dir,
|
| 298 |
+
split = "val",
|
| 299 |
+
context_length = cfg.context_length,
|
| 300 |
+
batch_size = args.batch_size,
|
| 301 |
+
num_workers = 0,
|
| 302 |
+
use_synthetic = args.synthetic,
|
| 303 |
+
vocab_size = cfg.vocab_size,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# ---- Run directory --------------------------------------------- #
|
| 307 |
+
os.makedirs(args.run_dir, exist_ok=True)
|
| 308 |
+
log_path = os.path.join(args.run_dir, "train_log.jsonl")
|
| 309 |
+
logger = MetricLogger(log_path)
|
| 310 |
+
|
| 311 |
+
# ---- Resume ---------------------------------------------------- #
|
| 312 |
+
start_step = 0
|
| 313 |
+
if args.resume:
|
| 314 |
+
try:
|
| 315 |
+
start_step = load_checkpoint(args.run_dir, model, optimizer, device)
|
| 316 |
+
except FileNotFoundError as e:
|
| 317 |
+
print(f" [WARN] {e} — starting from scratch.")
|
| 318 |
+
|
| 319 |
+
# ---- Effective batch size info --------------------------------- #
|
| 320 |
+
eff_batch = args.batch_size * args.grad_accum
|
| 321 |
+
tokens_per_step = eff_batch * cfg.context_length
|
| 322 |
+
print(f"\nTraining:")
|
| 323 |
+
# ---- Resolve extra_steps -> max_steps -------------------------- #
|
| 324 |
+
if args.extra_steps is not None:
|
| 325 |
+
if args.max_steps is not None:
|
| 326 |
+
print(" [WARN] Both --extra_steps and --max_steps given. --extra_steps takes priority.")
|
| 327 |
+
args.max_steps = start_step + args.extra_steps
|
| 328 |
+
print(f" [INFO] --extra_steps {args.extra_steps} → running until step {args.max_steps}")
|
| 329 |
+
|
| 330 |
+
print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} -> effective={eff_batch})")
|
| 331 |
+
print(f" tokens/step : {tokens_per_step:,}")
|
| 332 |
+
print(f" max_steps : {args.max_steps or 'unlimited'} (absolute step target)")
|
| 333 |
+
print(f" start_step : {start_step}")
|
| 334 |
+
print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else 'unlimited'}")
|
| 335 |
+
print(f" save_every : {args.save_every}")
|
| 336 |
+
print(f" log_every : {args.log_every}")
|
| 337 |
+
|
| 338 |
+
# ---- Early exit if already past max_steps ---------------------- #
|
| 339 |
+
if args.max_steps is not None and start_step >= args.max_steps:
|
| 340 |
+
print(f"\n [WARN] start_step ({start_step}) >= max_steps ({args.max_steps}).")
|
| 341 |
+
print(f" Nothing to train. Use --extra_steps N to run N more steps.")
|
| 342 |
+
print(f"\nExample: python train.py --resume --run_dir {args.run_dir} --extra_steps 5000")
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
# ---- Graceful Ctrl+C handler ----------------------------------- #
|
| 346 |
+
stop_flag = {"stop": False}
|
| 347 |
+
def _signal_handler(sig, frame):
|
| 348 |
+
print("\n [SIGNAL] Ctrl+C received — will save checkpoint and exit after current step.")
|
| 349 |
+
stop_flag["stop"] = True
|
| 350 |
+
signal.signal(signal.SIGINT, _signal_handler)
|
| 351 |
+
|
| 352 |
+
# ---- Training loop --------------------------------------------- #
|
| 353 |
+
model.train()
|
| 354 |
+
step = start_step
|
| 355 |
+
micro_step = 0 # within grad_accum window
|
| 356 |
+
running_loss = 0.0 # accumulated for logging
|
| 357 |
+
t_start = time.time()
|
| 358 |
+
t_step_start = time.time()
|
| 359 |
+
data_iter = iter(train_loader)
|
| 360 |
+
|
| 361 |
+
print(f"\n{'='*60}")
|
| 362 |
+
print(f" TRAINING STARTED (step {step} -> {args.max_steps or '∞'})")
|
| 363 |
+
print(f"{'='*60}\n")
|
| 364 |
+
|
| 365 |
+
pbar = tqdm(
|
| 366 |
+
initial=step,
|
| 367 |
+
total=args.max_steps,
|
| 368 |
+
desc="Training",
|
| 369 |
+
unit="step",
|
| 370 |
+
dynamic_ncols=True,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
while True:
|
| 374 |
+
# ---- Stop conditions --------------------------------------- #
|
| 375 |
+
if stop_flag["stop"]:
|
| 376 |
+
break
|
| 377 |
+
if args.max_steps is not None and step >= args.max_steps:
|
| 378 |
+
print(f"\n [DONE] Reached max_steps={args.max_steps}")
|
| 379 |
+
break
|
| 380 |
+
|
| 381 |
+
optimizer.zero_grad(set_to_none=True)
|
| 382 |
+
accum_loss = 0.0
|
| 383 |
+
|
| 384 |
+
# ---- Gradient accumulation micro-steps --------------------- #
|
| 385 |
+
for micro in range(args.grad_accum):
|
| 386 |
+
# Get next batch
|
| 387 |
+
try:
|
| 388 |
+
x, y = next(data_iter)
|
| 389 |
+
except StopIteration:
|
| 390 |
+
data_iter = iter(train_loader)
|
| 391 |
+
x, y = next(data_iter)
|
| 392 |
+
|
| 393 |
+
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
|
| 394 |
+
|
| 395 |
+
# Forward + loss (inside AMP context)
|
| 396 |
+
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
|
| 397 |
+
logits, loss = model(x, y)
|
| 398 |
+
# Scale loss by grad_accum so gradients average correctly
|
| 399 |
+
loss = loss / args.grad_accum
|
| 400 |
+
|
| 401 |
+
# Backward
|
| 402 |
+
scaler.scale(loss).backward()
|
| 403 |
+
accum_loss += loss.item()
|
| 404 |
+
|
| 405 |
+
# ---- Gradient clipping ------------------------------------- #
|
| 406 |
+
if args.grad_clip > 0:
|
| 407 |
+
scaler.unscale_(optimizer)
|
| 408 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 409 |
+
else:
|
| 410 |
+
grad_norm = float("nan")
|
| 411 |
+
|
| 412 |
+
# ---- LR update --------------------------------------------- #
|
| 413 |
+
lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr)
|
| 414 |
+
for pg in optimizer.param_groups:
|
| 415 |
+
pg["lr"] = lr
|
| 416 |
+
|
| 417 |
+
# ---- Optimizer step ---------------------------------------- #
|
| 418 |
+
scaler.step(optimizer)
|
| 419 |
+
scaler.update()
|
| 420 |
+
|
| 421 |
+
step += 1
|
| 422 |
+
running_loss = accum_loss # loss for this step
|
| 423 |
+
|
| 424 |
+
# ---- Tokens per second ------------------------------------- #
|
| 425 |
+
t_now = time.time()
|
| 426 |
+
elapsed = t_now - t_step_start
|
| 427 |
+
t_step_start = t_now
|
| 428 |
+
tok_per_sec = tokens_per_step / max(elapsed, 1e-6)
|
| 429 |
+
|
| 430 |
+
# ---- Progress bar update ----------------------------------- #
|
| 431 |
+
pbar.update(1)
|
| 432 |
+
pbar.set_postfix({
|
| 433 |
+
"loss": f"{running_loss:.4f}",
|
| 434 |
+
"lr": f"{lr:.2e}",
|
| 435 |
+
"tok/s": f"{tok_per_sec:.0f}",
|
| 436 |
+
})
|
| 437 |
+
|
| 438 |
+
# ---- Logging ----------------------------------------------- #
|
| 439 |
+
if step % args.log_every == 0:
|
| 440 |
+
log_entry = {
|
| 441 |
+
"step": step,
|
| 442 |
+
"loss": round(running_loss, 6),
|
| 443 |
+
"lr": lr,
|
| 444 |
+
"grad_norm": round(float(grad_norm), 4) if not math.isnan(float(grad_norm)) else None,
|
| 445 |
+
"tok_per_sec": round(tok_per_sec, 1),
|
| 446 |
+
"elapsed_s": round(t_now - t_start, 1),
|
| 447 |
+
}
|
| 448 |
+
if device.type == "cuda":
|
| 449 |
+
log_entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3)
|
| 450 |
+
logger.log(**log_entry)
|
| 451 |
+
|
| 452 |
+
# ---- Validation -------------------------------------------- #
|
| 453 |
+
if step % args.val_every == 0:
|
| 454 |
+
val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp))
|
| 455 |
+
tqdm.write(f" [STEP {step:6d}] train_loss={running_loss:.4f} val_loss={val_loss:.4f} lr={lr:.2e}")
|
| 456 |
+
logger.log(step=step, val_loss=round(val_loss, 6))
|
| 457 |
+
|
| 458 |
+
# ---- Checkpoint -------------------------------------------- #
|
| 459 |
+
if step % args.save_every == 0:
|
| 460 |
+
ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt")
|
| 461 |
+
save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss)
|
| 462 |
+
|
| 463 |
+
# ---- Final checkpoint on exit (only if we actually ran steps) -- #
|
| 464 |
+
pbar.close()
|
| 465 |
+
steps_done = step - start_step
|
| 466 |
+
if steps_done > 0:
|
| 467 |
+
ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt")
|
| 468 |
+
save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss)
|
| 469 |
+
else:
|
| 470 |
+
print("\n [SKIP] No steps were taken — skipping final checkpoint save.")
|
| 471 |
+
|
| 472 |
+
total_time = time.time() - t_start
|
| 473 |
+
print(f"\n{'='*60}")
|
| 474 |
+
print(f" TRAINING COMPLETE")
|
| 475 |
+
print(f"{'='*60}")
|
| 476 |
+
print(f" Steps completed : {step - start_step}")
|
| 477 |
+
print(f" Final loss : {running_loss:.4f}")
|
| 478 |
+
print(f" Total time : {total_time/60:.1f} min")
|
| 479 |
+
print(f" Run dir : {args.run_dir}")
|
| 480 |
+
print(f"\nTo resume: python train.py --resume --run_dir {args.run_dir} --max_steps <N>")
|
| 481 |
+
print(f"To plot : python plot_training.py --run_dir {args.run_dir}")
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
if __name__ == "__main__":
|
| 485 |
+
train()
|