geeteshcodes commited on
Commit
7f974df
·
verified ·
1 Parent(s): 80e2a42

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Checkpoints & training runs ──────────────────────────────────────
2
+ runs/
3
+
4
+ # ── Python ───────────────────────────────────────────────────────────
5
+ __pycache__/
6
+ *.py[cod]
7
+ *.pyo
8
+ *.pyd
9
+ .Python
10
+ *.egg-info/
11
+ dist/
12
+ build/
13
+ *.egg
14
+
15
+ # ── Virtual environments ──────────────────────────────────────────────
16
+ .env
17
+ .venv
18
+ env/
19
+ venv/
20
+
21
+ # ── Jupyter ───────────────────────────────────────────────────────────
22
+ .ipynb_checkpoints/
23
+ *.ipynb
24
+
25
+ # ── Data / binaries ──────────────────────────────────────────────────
26
+ *.bin
27
+ *.pt
28
+ *.pth
29
+ *.safetensors
30
+ *.npy
31
+ *.npz
32
+
33
+ # ── Logs ─────────────────────────────────────────────────────────────
34
+ *.log
35
+ *.jsonl
36
+
37
+ # ── OS ───────────────────────────────────────────────────────────────
38
+ .DS_Store
39
+ Thumbs.db
40
+
41
+ # ── IDE ───────────────────────────────────────────────────────────────
42
+ .vscode/
43
+ .idea/
44
+ *.swp
README.md ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SLLM — Small Language Model from Scratch
2
+
3
+ A GPT-style decoder-only transformer built and trained from scratch in PyTorch. Two model sizes are available (100M and 150M parameters), designed to fit on consumer GPUs as small as a 4 GB VRAM card (e.g. RTX 3050).
4
+
5
+ ---
6
+
7
+ ## ✨ Features
8
+
9
+ - **Architecture**: Decoder-only transformer (GPT-style) with modern improvements
10
+ - RMSNorm instead of LayerNorm (faster, no bias)
11
+ - RoPE (Rotary Position Embeddings) — used in LLaMA, Mistral, Gemma
12
+ - SwiGLU feed-forward network — outperforms GELU at the same parameter count
13
+ - Flash Attention via `F.scaled_dot_product_attention` (O(T²) memory avoided)
14
+ - Weight-tied token embeddings + LM head (saves ~32M parameters)
15
+ - **Training**
16
+ - bf16 mixed-precision with gradient accumulation
17
+ - Gradient checkpointing for low-VRAM GPUs
18
+ - Cosine LR schedule with linear warmup
19
+ - Resumable checkpointing (`--resume`, `--extra_steps`)
20
+ - JSONL metric logging + live training dashboard
21
+ - **Custom BPE Tokenizer** — trained on FineWeb-Edu with byte fallback (zero OOV)
22
+ - **Supervised Fine-Tuning (SFT)** — chat model pipeline included in `finetune/`
23
+
24
+ ---
25
+
26
+ ## 🏗️ Project Structure
27
+
28
+ ```
29
+ sllm/
30
+ ├── model/ # Model architecture
31
+ │ ├── config.py # ModelConfig dataclass (SLLM_100M, SLLM_150M presets)
32
+ │ ├── model.py # SLLM — full model assembly, weight init, gradient checkpointing
33
+ │ ├── block.py # TransformerBlock (pre-norm, residual)
34
+ │ ├── attention.py # Causal multi-head self-attention + RoPE
35
+ │ ├── mlp.py # SwiGLU feed-forward network
36
+ │ ├── norm.py # RMSNorm
37
+ │ └── rope.py # Rotary Position Embeddings
38
+
39
+ ├── tokenizer/ # Custom BPE tokenizer
40
+ │ ├── normalizer.py # HTML stripping, unicode NFC, whitespace cleanup
41
+ │ ├── pretokenizer.py # Regex pre-tokenizer (code-aware, contraction-aware)
42
+ │ ├── bpe.py # BPE model config with byte fallback (32k vocab)
43
+ │ ├── traintokenizer.py # Train on FineWeb-Edu stream
44
+ │ ├── post_processor.py # Append <|endoftext|> to every sequence
45
+ │ ├── wrap_tokenizer.py # Wrap into PreTrainedTokenizerFast
46
+ │ └── tokenize_dataset.py # Pack tokens into flat binary .bin shards
47
+
48
+ ├── data/
49
+ │ └── dataloader.py # Memory-mapped shard dataloader
50
+
51
+ ├── finetune/ # Supervised fine-tuning (SFT) pipeline
52
+ │ ├── prepare_data.py # Prepare chat data
53
+ │ ├── sft_train.py # SFT training loop
54
+ │ ├── sft_dataset.py # Chat dataset
55
+ │ └── chat.py # Interactive chat with the fine-tuned model
56
+
57
+ ├── train.py # Pre-training loop
58
+ ├── plot_training.py # Training dashboard (static + live mode)
59
+ ├── requirements.txt
60
+ ├── model_explained.md # Deep-dive into every model component
61
+ └── tokenizer_walkthrough.md # Tokenizer design and pipeline walkthrough
62
+ ```
63
+
64
+ ---
65
+
66
+ ## 📐 Model Configs
67
+
68
+ | Config | d_model | Heads | Layers | Parameters |
69
+ |------------|---------|-------|--------|------------|
70
+ | `SLLM_100M` | 768 | 12 | 12 | ~109.5M |
71
+ | `SLLM_150M` | 1024 | 16 | 9 | ~148.4M |
72
+
73
+ Both configs use:
74
+ - Context length: **1024 tokens**
75
+ - Vocab size: **32,000** (custom BPE)
76
+ - SwiGLU d_ff: computed as `round_up_256(⌊2/3 × 4 × d_model⌋)`
77
+
78
+ ---
79
+
80
+ ## ⚙️ Installation
81
+
82
+ **Requires:** Python 3.10+, PyTorch 2.3+, CUDA-capable GPU (bf16 recommended)
83
+
84
+ ```bash
85
+ # Create and activate a conda environment
86
+ conda create -n pytorch python=3.11
87
+ conda activate pytorch
88
+
89
+ # Install dependencies
90
+ pip install -r requirements.txt
91
+ ```
92
+
93
+ ---
94
+
95
+ ## 🚀 Training
96
+
97
+ ### Start a new run (RTX 3050 4GB recommended settings)
98
+
99
+ ```bash
100
+ python train.py \
101
+ --config 150M \
102
+ --data_dir tokenizer/data \
103
+ --batch_size 2 \
104
+ --grad_accum 16 \
105
+ --grad_checkpoint \
106
+ --dtype bf16 \
107
+ --max_steps 5000 \
108
+ --run_dir runs/sllm_150m \
109
+ --log_every 10 \
110
+ --save_every 500 \
111
+ --val_every 500 \
112
+ --warmup_steps 200
113
+ ```
114
+
115
+ ### Resume from a checkpoint
116
+
117
+ ```bash
118
+ python train.py \
119
+ --resume \
120
+ --run_dir runs/sllm_150m \
121
+ --extra_steps 5000 \
122
+ --data_dir tokenizer/data \
123
+ --batch_size 2 \
124
+ --grad_accum 16 \
125
+ --grad_checkpoint \
126
+ --dtype bf16
127
+ ```
128
+
129
+ ### Key training flags
130
+
131
+ | Flag | Default | Description |
132
+ |------|---------|-------------|
133
+ | `--config` | `100M` | Model size (`100M` or `150M`) |
134
+ | `--batch_size` | `4` | Per-device micro-batch size |
135
+ | `--grad_accum` | `8` | Gradient accumulation steps |
136
+ | `--max_steps` | unlimited | Absolute step target |
137
+ | `--extra_steps` | — | Run N more steps from current checkpoint |
138
+ | `--resume` | — | Resume from latest checkpoint in `--run_dir` |
139
+ | `--grad_checkpoint` | — | Enable gradient checkpointing (saves VRAM) |
140
+ | `--dtype` | `bf16` | Mixed precision dtype (`fp32`, `fp16`, `bf16`) |
141
+ | `--synthetic` | — | Use random data (for testing without real shards) |
142
+
143
+ ---
144
+
145
+ ## 📊 Training Dashboard
146
+
147
+ Visualize training metrics in a dark-mode 6-panel dashboard:
148
+
149
+ ```bash
150
+ # Static plot
151
+ python plot_training.py --run_dir runs/sllm_150m
152
+
153
+ # Live mode — refresh every 30 seconds while training
154
+ python plot_training.py --run_dir runs/sllm_150m --live --interval 30
155
+
156
+ # Compare two runs
157
+ python plot_training.py --run_dir runs/run_a runs/run_b
158
+
159
+ # Save to file
160
+ python plot_training.py --run_dir runs/sllm_150m --save dashboard.png
161
+ ```
162
+
163
+ **Dashboard panels:** Training Loss (raw + EMA) · Validation Loss · Learning Rate · Tokens/sec · VRAM usage · Gradient norm
164
+
165
+ ---
166
+
167
+ ## 💬 Fine-Tuning (Chat Model)
168
+
169
+ After pre-training, you can fine-tune with supervised instruction data:
170
+
171
+ ```bash
172
+ # 1. Prepare chat data
173
+ python finetune/prepare_data.py
174
+
175
+ # 2. Fine-tune
176
+ python finetune/sft_train.py \
177
+ --base_ckpt runs/sllm_150m/ckpt_0011500.pt \
178
+ --run_dir runs/sllm_150m_chat \
179
+ --max_steps 2500 \
180
+ --batch_size 4 \
181
+ --grad_accum 8 \
182
+ --grad_checkpoint
183
+
184
+ # 3. Chat interactively
185
+ python finetune/chat.py --run_dir runs/sllm_150m_chat
186
+ ```
187
+
188
+ ---
189
+
190
+ ## 🔡 Tokenizer
191
+
192
+ A custom BPE tokenizer trained on the educational subset of [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu):
193
+
194
+ - **32,000 token vocabulary**
195
+ - **Byte fallback** — zero out-of-vocabulary tokens (even math symbols and emojis work)
196
+ - **Code-aware** — preserves `snake_case`, operators (`==`, `->`, `**`), and indentation
197
+ - **Contraction-aware** — `don't`, `I've`, `they're` are split correctly
198
+ - Packaged as a `PreTrainedTokenizerFast` (HuggingFace-compatible)
199
+
200
+ Training data is packed into flat binary `.bin` shards (`np.uint16`, 100M tokens each) for fast memory-mapped loading.
201
+
202
+ See [`tokenizer_walkthrough.md`](tokenizer_walkthrough.md) for a full pipeline deep-dive.
203
+
204
+ ---
205
+
206
+ ## 🧠 Architecture Deep-Dive
207
+
208
+ See [`model_explained.md`](model_explained.md) for a plain-language walkthrough of every model component, including:
209
+ - Why RMSNorm is faster than LayerNorm
210
+ - How RoPE encodes relative position without extra parameters
211
+ - Why SwiGLU outperforms GELU
212
+ - How weight tying saves 32M parameters
213
+ - Flash Attention and gradient checkpointing explained
214
+
215
+ ---
216
+
217
+ ## 📋 Checkpoints & Logging
218
+
219
+ - Checkpoints are saved to `<run_dir>/ckpt_NNNNNNN.pt` every `--save_every` steps and on clean exit (Ctrl+C)
220
+ - Metrics are appended to `<run_dir>/train_log.jsonl` (one JSON line per log step)
221
+ - Each checkpoint stores: model weights, optimizer state, step number, loss, and config name
222
+ - Resuming auto-detects the correct model config from the checkpoint
223
+
224
+ ---
225
+
226
+ ## 📦 Requirements
227
+
228
+ ```
229
+ torch>=2.3.0
230
+ datasets>=2.14.0 # HuggingFace datasets (streaming)
231
+ tokenizers>=0.15.0 # Fast BPE tokenizer
232
+ transformers>=4.40.0 # PreTrainedTokenizerFast
233
+ numpy>=1.26.0
234
+ tqdm
235
+ matplotlib
236
+ ```
237
+
238
+ ---
239
+
240
+ ## 📄 License
241
+
242
+ This project is released for educational purposes.
data/dataloader.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data/dataloader.py
3
+
4
+ Streaming dataloader for the pre-tokenized binary shards produced by
5
+ tokenizer/tokenize_dataset.py.
6
+
7
+ Each shard is a flat binary file of np.uint16 token IDs.
8
+ 100M tokens * 2 bytes = ~200MB per shard.
9
+
10
+ Strategy:
11
+ 1. Discover all shards matching split name (train/val).
12
+ 2. Shuffle shard order at start of each epoch.
13
+ 3. For each shard, load it (memmap or full) and yield non-overlapping
14
+ chunks of (context_length + 1) tokens.
15
+ 4. Inputs = chunk[:-1] (length context_length)
16
+ Targets = chunk[1:] (length context_length, shifted right by 1)
17
+
18
+ When no data shards exist yet (tokenization not done), a SyntheticShard
19
+ can be used for architecture testing.
20
+ """
21
+
22
+ import os
23
+ import glob
24
+ import random
25
+ import numpy as np
26
+ import torch
27
+ from torch.utils.data import IterableDataset, DataLoader
28
+
29
+
30
+ # ------------------------------------------------------------------ #
31
+ # SHARD DISCOVERY
32
+ # ------------------------------------------------------------------ #
33
+
34
+ def find_shards(data_dir: str, split: str) -> list[str]:
35
+ """
36
+ Returns sorted list of shard paths for the given split.
37
+
38
+ Args:
39
+ data_dir : directory containing .bin shard files
40
+ split : 'train' or 'val'
41
+ """
42
+ pattern = os.path.join(data_dir, f"{split}_*.bin")
43
+ shards = sorted(glob.glob(pattern))
44
+ return shards
45
+
46
+
47
+ # ------------------------------------------------------------------ #
48
+ # ITERABLE DATASET
49
+ # ------------------------------------------------------------------ #
50
+
51
+ class ShardedTokenDataset(IterableDataset):
52
+ """
53
+ IterableDataset that streams token chunks from binary shards.
54
+
55
+ Each worker processes a disjoint subset of shards so we get
56
+ proper parallelism with DataLoader(num_workers=N).
57
+
58
+ Usage:
59
+ dataset = ShardedTokenDataset(data_dir, split='train', context_length=1024)
60
+ loader = DataLoader(dataset, batch_size=4)
61
+ for input_ids, targets in loader:
62
+ ...
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ data_dir: str,
68
+ split: str,
69
+ context_length: int,
70
+ shuffle_shards: bool = True,
71
+ ):
72
+ """
73
+ Args:
74
+ data_dir : path to directory with .bin shard files
75
+ split : 'train' or 'val'
76
+ context_length : sequence length (model context length)
77
+ shuffle_shards : shuffle shard order each epoch (train only)
78
+ """
79
+ super().__init__()
80
+ self.context_length = context_length
81
+ self.shuffle_shards = shuffle_shards
82
+
83
+ self.shards = find_shards(data_dir, split)
84
+ if not self.shards:
85
+ raise FileNotFoundError(
86
+ f"No {split} shards found in {data_dir}.\n"
87
+ f"Run tokenizer/tokenize_dataset.py first to generate data."
88
+ )
89
+ print(f"[DataLoader] Found {len(self.shards)} {split} shards in {data_dir}")
90
+
91
+ def __iter__(self):
92
+ worker_info = torch.utils.data.get_worker_info()
93
+
94
+ shards = self.shards.copy()
95
+ if self.shuffle_shards:
96
+ random.shuffle(shards)
97
+
98
+ # Split shards across workers
99
+ if worker_info is not None:
100
+ shards = shards[worker_info.id :: worker_info.num_workers]
101
+
102
+ chunk = self.context_length + 1 # +1 so we can shift for targets
103
+
104
+ for shard_path in shards:
105
+ # Load shard as uint16 array
106
+ tokens = np.fromfile(shard_path, dtype=np.uint16).astype(np.int32)
107
+
108
+ # Yield non-overlapping chunks
109
+ n_chunks = len(tokens) // chunk
110
+ for i in range(n_chunks):
111
+ start = i * chunk
112
+ seq = torch.from_numpy(tokens[start : start + chunk].copy())
113
+ input_ids = seq[:-1].long() # (context_length,)
114
+ targets = seq[1:].long() # (context_length,)
115
+ yield input_ids, targets
116
+
117
+
118
+ # ------------------------------------------------------------------ #
119
+ # SYNTHETIC DATASET (for testing without real data)
120
+ # ------------------------------------------------------------------ #
121
+
122
+ class SyntheticDataset(IterableDataset):
123
+ """
124
+ Generates random token sequences for architecture testing.
125
+ Use when real shards are not yet available.
126
+ """
127
+
128
+ def __init__(self, vocab_size: int, context_length: int, n_batches: int = 1000):
129
+ super().__init__()
130
+ self.vocab_size = vocab_size
131
+ self.context_length = context_length
132
+ self.n_batches = n_batches
133
+
134
+ def __iter__(self):
135
+ for _ in range(self.n_batches):
136
+ seq = torch.randint(0, self.vocab_size, (self.context_length + 1,))
137
+ input_ids = seq[:-1]
138
+ targets = seq[1:]
139
+ yield input_ids, targets
140
+
141
+
142
+ # ------------------------------------------------------------------ #
143
+ # FACTORY FUNCTION
144
+ # ------------------------------------------------------------------ #
145
+
146
+ def build_dataloader(
147
+ data_dir: str,
148
+ split: str,
149
+ context_length: int,
150
+ batch_size: int,
151
+ num_workers: int = 2,
152
+ use_synthetic: bool = False,
153
+ vocab_size: int = 32_000,
154
+ ) -> DataLoader:
155
+ """
156
+ Builds and returns a DataLoader for the given split.
157
+
158
+ Falls back to SyntheticDataset if use_synthetic=True or no shards found.
159
+
160
+ Args:
161
+ data_dir : directory with .bin shards
162
+ split : 'train' or 'val'
163
+ context_length : model context length (1024)
164
+ batch_size : number of sequences per batch
165
+ num_workers : DataLoader workers (0 = main process)
166
+ use_synthetic : force synthetic data (for testing)
167
+ vocab_size : needed for synthetic fallback
168
+
169
+ Returns:
170
+ DataLoader yielding (input_ids, targets) each of shape (B, T)
171
+ """
172
+ if use_synthetic:
173
+ dataset = SyntheticDataset(vocab_size, context_length)
174
+ print(f"[DataLoader] Using synthetic data (use_synthetic=True)")
175
+ else:
176
+ try:
177
+ dataset = ShardedTokenDataset(
178
+ data_dir = data_dir,
179
+ split = split,
180
+ context_length = context_length,
181
+ shuffle_shards = (split == "train"),
182
+ )
183
+ except FileNotFoundError as e:
184
+ print(f"[DataLoader] WARNING: {e}")
185
+ print(f"[DataLoader] Falling back to synthetic data for testing.")
186
+ dataset = SyntheticDataset(vocab_size, context_length)
187
+
188
+ return DataLoader(
189
+ dataset,
190
+ batch_size = batch_size,
191
+ num_workers = num_workers,
192
+ pin_memory = True, # faster CPU->GPU transfer
193
+ )
194
+
195
+
196
+ # ------------------------------------------------------------------ #
197
+ # QUICK CHECK
198
+ # ------------------------------------------------------------------ #
199
+
200
+ if __name__ == "__main__":
201
+ import sys
202
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
203
+ from model.config import SLLM_100M
204
+
205
+ cfg = SLLM_100M
206
+
207
+ print("Testing with synthetic data...")
208
+ loader = build_dataloader(
209
+ data_dir = "tokenizer/data",
210
+ split = "train",
211
+ context_length = cfg.context_length,
212
+ batch_size = 4,
213
+ num_workers = 0,
214
+ use_synthetic = True,
215
+ vocab_size = cfg.vocab_size,
216
+ )
217
+
218
+ for i, (x, y) in enumerate(loader):
219
+ print(f"Batch {i}: input_ids={x.shape}, targets={y.shape}, dtype={x.dtype}")
220
+ if i == 3:
221
+ break
222
+
223
+ print("DataLoader OK")
finetune/README.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SLLM-150M → Chat Model (SFT)
2
+
3
+ Supervised Fine-Tuning pipeline to turn the pretrained **SLLM-150M** base model into
4
+ an instruction-following chat model using **OpenHermes-2.5**.
5
+
6
+ ## Pipeline
7
+
8
+ ```
9
+ Base model (runs/sllm_150m/ckpt_0011500.pt)
10
+
11
+
12
+ prepare_data.py ─── download & tokenize OpenHermes-2.5 (80k convs)
13
+
14
+
15
+ sft_train.py ─── SFT with ChatML loss masking
16
+
17
+
18
+ chat.py ─── interactive CLI chat
19
+ ```
20
+
21
+ ## Step 1 — Install dependency
22
+
23
+ ```bash
24
+ pip install datasets
25
+ ```
26
+
27
+ ## Step 2 — Prepare data
28
+
29
+ Downloads 80k conversations, formats as ChatML, tokenizes, saves shards.
30
+ Also saves the extended tokenizer (vocab 32,002) to `finetune/data/`.
31
+
32
+ ```bash
33
+ python finetune/prepare_data.py
34
+ ```
35
+
36
+ Options:
37
+
38
+ | Flag | Default | Description |
39
+ |------|---------|-------------|
40
+ | `--n_samples` | `80000` | Conversations to sample |
41
+ | `--val_ratio` | `0.05` | Validation fraction |
42
+ | `--output_dir` | `finetune/data` | Output directory |
43
+ | `--seed` | `42` | Random seed |
44
+
45
+ Expected output:
46
+ ```
47
+ finetune/data/
48
+ tokenizer.json ← extended tokenizer (32,002 vocab)
49
+ tokenizer_config.json
50
+ special_tokens_map.json
51
+ train_sft.pt ← ~76,000 examples
52
+ val_sft.pt ← ~4,000 examples
53
+ meta.json ← stats
54
+ ```
55
+
56
+ ## Step 3 — Fine-tune
57
+
58
+ ```bash
59
+ python finetune/sft_train.py \
60
+ --base_ckpt runs/sllm_150m/ckpt_0011500.pt \
61
+ --run_dir runs/sllm_150m_chat \
62
+ --max_steps 2000 \
63
+ --batch_size 4 --grad_accum 8 \
64
+ --grad_checkpoint
65
+ ```
66
+
67
+ For an RTX 3050 4 GB, these settings use ~3.5 GB VRAM and take **~5–8 minutes**.
68
+
69
+ **Resume training:**
70
+ ```bash
71
+ python finetune/sft_train.py \
72
+ --resume --run_dir runs/sllm_150m_chat \
73
+ --extra_steps 1000
74
+ ```
75
+
76
+ Key options:
77
+
78
+ | Flag | Default | Description |
79
+ |------|---------|-------------|
80
+ | `--base_ckpt` | `runs/sllm_150m/ckpt_0011500.pt` | Base pretrained checkpoint |
81
+ | `--max_lr` | `1e-5` | Peak LR (10× lower than pretraining) |
82
+ | `--dropout` | `0.1` | SFT dropout (0 in pretraining) |
83
+ | `--max_steps` | `2000` | Total training steps |
84
+ | `--grad_checkpoint` | off | Enable for lower VRAM |
85
+
86
+ Checkpoints are saved to `runs/sllm_150m_chat/ckpt_sft_XXXXXXX.pt`.
87
+ Training log: `runs/sllm_150m_chat/sft_log.jsonl`.
88
+
89
+ ## Step 4 — Chat
90
+
91
+ ```bash
92
+ python finetune/chat.py
93
+ python finetune/chat.py --run_dir runs/sllm_150m_chat --temperature 0.7
94
+ ```
95
+
96
+ In-chat commands:
97
+
98
+ | Command | Effect |
99
+ |---------|--------|
100
+ | `/reset` | Clear conversation history |
101
+ | `/system <text>` | Change system prompt |
102
+ | `/quit` | Exit |
103
+
104
+ ## What changes vs pretraining
105
+
106
+ | | Pretraining (`train.py`) | SFT (`sft_train.py`) |
107
+ |---|---|---|
108
+ | Data | Raw text shards (`.bin`) | ChatML conversations (`.pt`) |
109
+ | Loss | Every token | **Assistant tokens only** (`ignore_index=-100`) |
110
+ | Learning rate | `3e-4` | **`1e-5`** |
111
+ | Warmup | 100 steps | 30 steps |
112
+ | Vocab | 32,000 | **32,002** (`<\|im_start\|>` + `<\|im_end\|>`) |
113
+ | Dropout | 0.0 | **0.1** |
114
+ | Checkpoint prefix | `ckpt_` | `ckpt_sft_` |
115
+
116
+ ## Expected loss curve
117
+
118
+ | Stage | Expected loss |
119
+ |-------|--------------|
120
+ | Start (step 0) | 1.5 – 2.5 |
121
+ | Step 500 | 1.0 – 1.5 |
122
+ | Step 2000 | 0.8 – 1.2 |
123
+
124
+ > **If loss starts above 4.0 or goes NaN** → reduce `--max_lr` to `5e-6`.
125
+
126
+ ## Prompt format (ChatML)
127
+
128
+ ```
129
+ <|im_start|>system
130
+ You are a helpful, concise assistant.<|im_end|>
131
+ <|im_start|>user
132
+ What is the capital of France?<|im_end|>
133
+ <|im_start|>assistant
134
+ The capital of France is Paris.<|im_end|>
135
+ ```
136
+
137
+ Generation stops automatically when the model produces `<|im_end|>`.
finetune/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # finetune package
finetune/chat.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune/chat.py
3
+
4
+ Interactive CLI chat with the fine-tuned SLLM-150M chat model.
5
+
6
+ Loads the latest SFT checkpoint from --run_dir, formats your input
7
+ as a ChatML prompt, generates a response token-by-token, and stops
8
+ at the <|im_end|> token.
9
+
10
+ Usage:
11
+ python finetune/chat.py
12
+ python finetune/chat.py --run_dir runs/sllm_150m_chat
13
+ python finetune/chat.py --temperature 0.7 --top_k 40
14
+
15
+ In-chat commands:
16
+ /reset clear conversation history (start fresh)
17
+ /system <text> change the system prompt
18
+ /quit exit
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import argparse
24
+ from pathlib import Path
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from transformers import PreTrainedTokenizerFast
29
+
30
+ SCRIPT_DIR = Path(__file__).resolve().parent
31
+ PROJECT_ROOT = SCRIPT_DIR.parent
32
+ DATA_DIR = SCRIPT_DIR / "data"
33
+
34
+ sys.path.insert(0, str(PROJECT_ROOT))
35
+
36
+ from model.config import SLLM_150M
37
+ from model.model import SLLM
38
+
39
+ DEFAULT_SYSTEM = "You are a helpful, concise assistant."
40
+ DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat")
41
+
42
+
43
+ # ------------------------------------------------------------------ #
44
+ # HELPERS
45
+ # ------------------------------------------------------------------ #
46
+
47
+ def find_latest_ckpt(run_dir: str) -> str:
48
+ """Returns path to the most recent ckpt_sft_*.pt in run_dir."""
49
+ ckpts = sorted([
50
+ f for f in os.listdir(run_dir)
51
+ if f.startswith("ckpt_sft_") and f.endswith(".pt")
52
+ ])
53
+ if not ckpts:
54
+ raise FileNotFoundError(
55
+ f"No SFT checkpoints found in '{run_dir}'.\n"
56
+ f"Run sft_train.py first."
57
+ )
58
+ return os.path.join(run_dir, ckpts[-1])
59
+
60
+
61
+ def resize_token_embeddings(model: SLLM, new_vocab_size: int):
62
+ """Same resize logic as sft_train.py — kept local to avoid circular imports."""
63
+ old_size = model.config.vocab_size
64
+ if new_vocab_size == old_size:
65
+ return
66
+ d_model = model.config.d_model
67
+ device = model.token_emb.weight.device
68
+ dtype = model.token_emb.weight.dtype
69
+ old_weight = model.token_emb.weight.data.clone()
70
+ mean_vec = old_weight.mean(dim=0)
71
+ new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
72
+ new_weight[:old_size] = old_weight
73
+ new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
74
+ new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
75
+ new_emb.weight.data = new_weight
76
+ model.token_emb = new_emb
77
+ model.lm_head.weight = model.token_emb.weight
78
+ model.config.vocab_size = new_vocab_size
79
+
80
+
81
+ def load_model_and_tokenizer(run_dir: str, device: torch.device):
82
+ """Loads tokenizer (from data dir) and fine-tuned model (from run_dir)."""
83
+
84
+ # ---- Tokenizer ------------------------------------------------- #
85
+ tok_path = str(DATA_DIR)
86
+ if os.path.exists(os.path.join(tok_path, "tokenizer.json")):
87
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
88
+ else:
89
+ # Fallback: base tokenizer + manual special token add
90
+ base_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer")
91
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(base_dir)
92
+ tokenizer.add_special_tokens({
93
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
94
+ })
95
+
96
+ # ---- Checkpoint ------------------------------------------------ #
97
+ ckpt_path = find_latest_ckpt(run_dir)
98
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
99
+
100
+ # ---- Model ----------------------------------------------------- #
101
+ model = SLLM(SLLM_150M).to(device)
102
+ saved_vocab = ckpt.get("vocab_size", len(tokenizer))
103
+ resize_token_embeddings(model, saved_vocab)
104
+ model.load_state_dict(ckpt["model_state_dict"])
105
+ model.eval()
106
+
107
+ return model, tokenizer, ckpt_path, ckpt.get("step", "?"), ckpt.get("loss", float("nan"))
108
+
109
+
110
+ # ------------------------------------------------------------------ #
111
+ # PROMPT BUILDING
112
+ # ------------------------------------------------------------------ #
113
+
114
+ def build_prompt(history: list[dict], system_prompt: str,
115
+ tokenizer: PreTrainedTokenizerFast) -> torch.Tensor:
116
+ """
117
+ Formats conversation history as ChatML and tokenises it.
118
+
119
+ Template:
120
+ <|im_start|>system
121
+ {system}<|im_end|>
122
+ <|im_start|>user
123
+ {user}<|im_end|>
124
+ <|im_start|>assistant
125
+ {assistant}<|im_end|>
126
+ ...
127
+ <|im_start|>assistant\\n ← left open for the model to complete
128
+
129
+ Returns:
130
+ input_ids : (1, T) LongTensor
131
+ """
132
+ text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
133
+ for turn in history:
134
+ text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
135
+ # Prime the model to generate as assistant
136
+ text += "<|im_start|>assistant\n"
137
+
138
+ ids = tokenizer.encode(text, add_special_tokens=False)
139
+ return torch.tensor([ids], dtype=torch.long)
140
+
141
+
142
+ # ------------------------------------------------------------------ #
143
+ # GENERATION
144
+ # ------------------------------------------------------------------ #
145
+
146
+ @torch.no_grad()
147
+ def generate_response(
148
+ model: SLLM,
149
+ input_ids: torch.Tensor,
150
+ tokenizer: PreTrainedTokenizerFast,
151
+ max_new_tokens: int = 300,
152
+ temperature: float = 0.8,
153
+ top_k: int = 50,
154
+ device: torch.device = None,
155
+ ) -> str:
156
+ """
157
+ Autoregressively generates tokens until:
158
+ - <|im_end|> is produced (clean stop), or
159
+ - eos_token_id is produced, or
160
+ - max_new_tokens is reached
161
+
162
+ Returns the decoded response string (special tokens stripped).
163
+ """
164
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
165
+ eos_id = tokenizer.eos_token_id
166
+
167
+ ids = input_ids.to(device)
168
+ generated = []
169
+
170
+ for _ in range(max_new_tokens):
171
+ # Crop to context window
172
+ ctx = ids if ids.shape[1] <= model.config.context_length \
173
+ else ids[:, -model.config.context_length:]
174
+
175
+ logits, _ = model(ctx) # (1, T, V)
176
+ logits = logits[:, -1, :] / max(temperature, 1e-8)
177
+
178
+ # Top-k filtering
179
+ if top_k and top_k > 0:
180
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
181
+ logits[logits < v[:, [-1]]] = float("-inf")
182
+
183
+ probs = torch.softmax(logits, dim=-1)
184
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
185
+ tok_id = next_token.item()
186
+
187
+ # Stop conditions
188
+ if tok_id == im_end_id or tok_id == eos_id:
189
+ break
190
+
191
+ generated.append(tok_id)
192
+ ids = torch.cat([ids, next_token], dim=1)
193
+
194
+ return tokenizer.decode(generated, skip_special_tokens=True).strip()
195
+
196
+
197
+ # ------------------------------------------------------------------ #
198
+ # MAIN
199
+ # ------------------------------------------------------------------ #
200
+
201
+ def parse_args():
202
+ p = argparse.ArgumentParser(description="SLLM-150M Chat")
203
+ p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR)
204
+ p.add_argument("--temperature", type=float, default=0.8,
205
+ help="Sampling temperature (lower = more focused)")
206
+ p.add_argument("--top_k", type=int, default=50,
207
+ help="Top-k sampling (0 = disabled)")
208
+ p.add_argument("--max_new_tokens", type=int, default=300,
209
+ help="Max tokens per assistant response")
210
+ p.add_argument("--system", type=str, default=DEFAULT_SYSTEM,
211
+ help="System prompt")
212
+ return p.parse_args()
213
+
214
+
215
+ def main():
216
+ args = parse_args()
217
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
218
+
219
+ print("\n" + "=" * 60)
220
+ print(" SLLM-150M Chat")
221
+ print("=" * 60)
222
+ print(f" Device : {device}")
223
+ if device.type == "cuda":
224
+ print(f" GPU : {torch.cuda.get_device_name(0)}")
225
+
226
+ # ---- Load ------------------------------------------------------ #
227
+ print("\nLoading model...")
228
+ model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device)
229
+ print(f" Checkpoint : {ckpt_path}")
230
+ print(f" Step : {step} Loss: {loss:.4f}")
231
+ print(f" Vocab size : {len(tokenizer):,}")
232
+
233
+ # ---- Chat loop ------------------------------------------------- #
234
+ system_prompt = args.system
235
+ history: list[dict] = []
236
+
237
+ print(f"\n System : {system_prompt}")
238
+ print(" Commands: /reset | /system <new prompt> | /quit")
239
+ print("─" * 60 + "\n")
240
+
241
+ while True:
242
+ try:
243
+ user_input = input("You: ").strip()
244
+ except (EOFError, KeyboardInterrupt):
245
+ print("\nBye!")
246
+ break
247
+
248
+ if not user_input:
249
+ continue
250
+
251
+ # ---- Commands ---------------------------------------------- #
252
+ if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
253
+ print("Bye!")
254
+ break
255
+
256
+ if user_input.lower() == "/reset":
257
+ history = []
258
+ print(" [Conversation cleared]\n")
259
+ continue
260
+
261
+ if user_input.lower().startswith("/system "):
262
+ new_sys = user_input[8:].strip()
263
+ if new_sys:
264
+ system_prompt = new_sys
265
+ history = []
266
+ print(f" [System prompt updated. Conversation cleared.]\n")
267
+ continue
268
+
269
+ # ---- Build prompt ------------------------------------------ #
270
+ history.append({"role": "user", "content": user_input})
271
+ input_ids = build_prompt(history, system_prompt, tokenizer)
272
+
273
+ # Trim history if prompt is getting close to context limit
274
+ while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10:
275
+ if len(history) > 2:
276
+ history = history[2:] # drop oldest user+assistant pair
277
+ input_ids = build_prompt(history, system_prompt, tokenizer)
278
+ else:
279
+ break # can't trim further — just truncate in generation
280
+
281
+ # ---- Generate ---------------------------------------------- #
282
+ print("SLLM: ", end="", flush=True)
283
+ response = generate_response(
284
+ model, input_ids, tokenizer,
285
+ max_new_tokens = args.max_new_tokens,
286
+ temperature = args.temperature,
287
+ top_k = args.top_k,
288
+ device = device,
289
+ )
290
+ print(response + "\n")
291
+
292
+ history.append({"role": "assistant", "content": response})
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
finetune/check_data.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune/check_data.py
3
+
4
+ Smoke-test: loads 5 rows from OpenHermes-2.5, runs them through the
5
+ same format_and_tokenize() logic used by prepare_data.py, and prints
6
+ a full visual audit so you can confirm everything lines up.
7
+
8
+ Checks:
9
+ 1. Raw conversation structure from the dataset
10
+ 2. ChatML text that gets fed to the tokenizer
11
+ 3. Token IDs and decoded tokens (side-by-side)
12
+ 4. Label mask — ✓ (labeled) vs (masked -100) for every token
13
+ 5. Label ratio (should be ~30-60% assistant tokens)
14
+
15
+ Run from project root:
16
+ python finetune/check_data.py
17
+ python finetune/check_data.py --row 3 # inspect a specific row index
18
+ """
19
+
20
+ import sys
21
+ import argparse
22
+ from pathlib import Path
23
+
24
+ # ------------------------------------------------------------------ #
25
+ # Paths
26
+ # ------------------------------------------------------------------ #
27
+
28
+ SCRIPT_DIR = Path(__file__).resolve().parent
29
+ PROJECT_ROOT = SCRIPT_DIR.parent
30
+ TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
31
+
32
+ sys.path.insert(0, str(PROJECT_ROOT))
33
+
34
+ from transformers import PreTrainedTokenizerFast
35
+ from datasets import load_dataset
36
+
37
+ SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"]
38
+ MAX_LENGTH = 1024
39
+
40
+ ROLE_MAP = {
41
+ "system": "system",
42
+ "human": "user",
43
+ "gpt": "assistant",
44
+ "user": "user",
45
+ "assistant": "assistant",
46
+ }
47
+
48
+
49
+ # ------------------------------------------------------------------ #
50
+ # Replicated from prepare_data.py (no import to keep this self-contained)
51
+ # ------------------------------------------------------------------ #
52
+
53
+ def load_tokenizer() -> PreTrainedTokenizerFast:
54
+ tok = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
55
+ new = [t for t in SPECIAL_TOKENS if t not in tok.get_vocab()]
56
+ if new:
57
+ tok.add_special_tokens({"additional_special_tokens": new})
58
+ return tok
59
+
60
+
61
+ def format_and_tokenize(conversations, tokenizer):
62
+ """Identical logic to prepare_data.py — returns (input_ids, labels) or None."""
63
+ input_ids, labels = [], []
64
+
65
+ for turn in conversations:
66
+ role_raw = turn.get("from", turn.get("role", "")).strip().lower()
67
+ content = turn.get("value", turn.get("content", "")).strip()
68
+ role = ROLE_MAP.get(role_raw, role_raw)
69
+
70
+ if not content or not role:
71
+ continue
72
+
73
+ header_text = f"<|im_start|>{role}\n"
74
+ header_ids = tokenizer.encode(header_text, add_special_tokens=False)
75
+
76
+ body_text = f"{content}<|im_end|>\n"
77
+ body_ids = tokenizer.encode(body_text, add_special_tokens=False)
78
+
79
+ turn_input = header_ids + body_ids
80
+
81
+ if role == "assistant":
82
+ turn_labels = [-100] * len(header_ids) + body_ids
83
+ else:
84
+ turn_labels = [-100] * len(turn_input)
85
+
86
+ input_ids.extend(turn_input)
87
+ labels.extend(turn_labels)
88
+
89
+ if not any(l != -100 for l in labels):
90
+ return None
91
+
92
+ input_ids = input_ids[:MAX_LENGTH]
93
+ labels = labels[:MAX_LENGTH]
94
+
95
+ if len(input_ids) < 8:
96
+ return None
97
+
98
+ return input_ids, labels
99
+
100
+
101
+ # ------------------------------------------------------------------ #
102
+ # Pretty-print helpers
103
+ # ------------------------------------------------------------------ #
104
+
105
+ def print_section(title: str):
106
+ print(f"\n{'─'*60}")
107
+ print(f" {title}")
108
+ print(f"{'─'*60}")
109
+
110
+
111
+ def print_token_table(input_ids, labels, tokenizer, max_rows: int = 80):
112
+ """
113
+ Prints a table: idx | token_str | label (✓ or ✗)
114
+ Green ✓ = labeled (assistant) — model learns this
115
+ Red ✗ = masked -100 — model ignores this
116
+ """
117
+ GREEN = "\033[92m"
118
+ RED = "\033[91m"
119
+ RESET = "\033[0m"
120
+
121
+ print(f"\n {'IDX':>5} {'TOKEN':<22} {'ID':>6} {'LABEL':>8} {'LEARN?'}")
122
+ print(f" {'─'*5} {'─'*22} {'─'*6} {'─'*8} {'─'*6}")
123
+
124
+ shown = 0
125
+ for i, (tok_id, lbl) in enumerate(zip(input_ids, labels)):
126
+ tok_str = repr(tokenizer.decode([tok_id]))[:22]
127
+ if lbl == -100:
128
+ learn_str = f"{RED}✗ masked{RESET}"
129
+ lbl_str = " -100"
130
+ else:
131
+ learn_str = f"{GREEN}✓ learn {RESET}"
132
+ lbl_str = f"{lbl:>8}"
133
+
134
+ print(f" {i:>5} {tok_str:<22} {tok_id:>6} {lbl_str} {learn_str}")
135
+ shown += 1
136
+ if shown >= max_rows:
137
+ remaining = len(input_ids) - max_rows
138
+ print(f" ... ({remaining} more tokens not shown)")
139
+ break
140
+
141
+ # Summary
142
+ n_labeled = sum(1 for l in labels if l != -100)
143
+ n_total = len(labels)
144
+ print(f"\n Total tokens : {n_total}")
145
+ print(f" Labeled : {n_labeled} ({n_labeled/n_total:.1%}) ← assistant tokens")
146
+ print(f" Masked : {n_total - n_labeled} ({(n_total-n_labeled)/n_total:.1%}) ← user/system tokens")
147
+
148
+
149
+ # ------------------------------------------------------------------ #
150
+ # MAIN
151
+ # ------------------------------------------------------------------ #
152
+
153
+ def parse_args():
154
+ p = argparse.ArgumentParser(description="Check one OpenHermes row through the SFT pipeline")
155
+ p.add_argument("--row", type=int, default=0,
156
+ help="Which row to inspect in detail (0-indexed, from the first 20 fetched)")
157
+ p.add_argument("--n_fetch", type=int, default=20,
158
+ help="How many rows to fetch from HuggingFace (default: 20)")
159
+ return p.parse_args()
160
+
161
+
162
+ def main():
163
+ args = parse_args()
164
+
165
+ print("\n" + "=" * 60)
166
+ print(" SFT Pipeline — Data Alignment Check")
167
+ print("=" * 60)
168
+
169
+ # ---- 1. Tokenizer ---------------------------------------------- #
170
+ print_section("1. Tokenizer")
171
+ tokenizer = load_tokenizer()
172
+ im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
173
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
174
+ print(f" Vocab size : {len(tokenizer):,}")
175
+ print(f" <|im_start|> : token ID {im_start_id}")
176
+ print(f" <|im_end|> : token ID {im_end_id}")
177
+ assert im_start_id != tokenizer.unk_token_id, "ERROR: <|im_start|> not in vocab!"
178
+ assert im_end_id != tokenizer.unk_token_id, "ERROR: <|im_end|> not in vocab!"
179
+ print(" ✓ Special tokens present in vocab")
180
+
181
+ # ---- 2. Load one row ------------------------------------------- #
182
+ print_section(f"2. Loading row {args.row} from OpenHermes-2.5")
183
+ print(f" Loading first {args.n_fetch} rows from local cache (Arrow format)...")
184
+ ds = load_dataset("teknium/OpenHermes-2.5", split="train")
185
+ row = ds[args.row]
186
+ convs = row.get("conversations", [])
187
+
188
+ print(f" Row index : {args.row}")
189
+ print(f" Turns in conv : {len(convs)}")
190
+
191
+ # ---- 3. Raw conversation --------------------------------------- #
192
+ print_section("3. Raw conversation (from dataset)")
193
+ for i, turn in enumerate(convs):
194
+ role = turn.get("from", "?")
195
+ content = turn.get("value", "").strip()
196
+ preview = content[:120].replace("\n", "↵")
197
+ print(f" [{i}] from={role!r:12s} | {preview!r}")
198
+
199
+ # ---- 4. ChatML formatted text ---------------------------------- #
200
+ print_section("4. ChatML text (what tokenizer sees)")
201
+ chatml = ""
202
+ for turn in convs:
203
+ role_raw = turn.get("from", "").strip().lower()
204
+ content = turn.get("value", "").strip()
205
+ role = ROLE_MAP.get(role_raw, role_raw)
206
+ if content and role:
207
+ chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n"
208
+ print(chatml[:800])
209
+ if len(chatml) > 800:
210
+ print(f" ... ({len(chatml) - 800} more chars)")
211
+
212
+ # ---- 5. Run through format_and_tokenize ----------------------- #
213
+ print_section("5. format_and_tokenize() output")
214
+ result = format_and_tokenize(convs, tokenizer)
215
+
216
+ if result is None:
217
+ print(" ✗ RETURNED None — no assistant turn or too short.")
218
+ print(" Try a different --row index.")
219
+ return
220
+
221
+ input_ids, labels = result
222
+ print(f" input_ids length : {len(input_ids)}")
223
+ print(f" labels length : {len(labels)}")
224
+ assert len(input_ids) == len(labels), "MISMATCH: input_ids and labels have different lengths!"
225
+ print(" ✓ Lengths match")
226
+
227
+ # ---- 6. Verify label alignment --------------------------------- #
228
+ print_section("6. Label alignment sanity checks")
229
+
230
+ # Every im_start should be masked
231
+ im_start_positions = [i for i, t in enumerate(input_ids) if t == im_start_id]
232
+ im_end_positions = [i for i, t in enumerate(input_ids) if t == im_end_id]
233
+
234
+ print(f" <|im_start|> positions : {im_start_positions}")
235
+ print(f" <|im_end|> positions : {im_end_positions}")
236
+
237
+ im_start_masked = all(labels[i] == -100 for i in im_start_positions)
238
+ print(f" All <|im_start|> tokens are masked (-100) : {'✓' if im_start_masked else '✗ FAIL'}")
239
+
240
+ # Decode the labeled span to confirm it's the assistant content
241
+ labeled_ids = [t for t, l in zip(input_ids, labels) if l != -100]
242
+ labeled_text = tokenizer.decode(labeled_ids, skip_special_tokens=False)
243
+ print(f"\n Labeled (assistant) text preview:")
244
+ print(f" {labeled_text[:300].replace(chr(10), '↵')!r}")
245
+
246
+ # Check that labeled text doesn't contain user/system markers
247
+ if "user\n" in labeled_text or "system\n" in labeled_text:
248
+ print(" ✗ WARNING: user/system content found in labeled tokens!")
249
+ else:
250
+ print(" ✓ Labeled tokens contain only assistant content")
251
+
252
+ # ---- 7. Token-by-token table ----------------------------------- #
253
+ print_section("7. Token-by-token table (first 80 tokens)")
254
+ print_token_table(input_ids, labels, tokenizer, max_rows=80)
255
+
256
+ # ---- 8. Decode round-trip ------------------------------------- #
257
+ print_section("8. Full decode round-trip (skip_special_tokens=False)")
258
+ decoded = tokenizer.decode(input_ids, skip_special_tokens=False)
259
+ print(decoded[:600])
260
+
261
+ print("\n" + "=" * 60)
262
+ print(" CHECK COMPLETE — pipeline looks aligned ✓")
263
+ print("=" * 60)
264
+ print(f"\nWhen ready, run the full data prep:")
265
+ print(f" python finetune/prepare_data.py")
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
finetune/data/meta.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": "teknium/OpenHermes-2.5",
3
+ "n_sampled": 80000,
4
+ "n_train": 76000,
5
+ "n_val": 4000,
6
+ "vocab_size": 32002,
7
+ "special_tokens": [
8
+ "<|im_start|>",
9
+ "<|im_end|>"
10
+ ],
11
+ "max_length": 1024,
12
+ "seed": 42
13
+ }
finetune/data/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
finetune/data/tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|endoftext|>",
4
+ "eos_token": "<|endoftext|>",
5
+ "extra_special_tokens": [
6
+ "<|im_start|>",
7
+ "<|im_end|>"
8
+ ],
9
+ "is_local": true,
10
+ "local_files_only": false,
11
+ "model_max_length": 1024,
12
+ "pad_token": "<|endoftext|>",
13
+ "padding_side": "right",
14
+ "tokenizer_class": "TokenizersBackend",
15
+ "truncation_side": "right",
16
+ "unk_token": null
17
+ }
finetune/prepare_data.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune/prepare_data.py
3
+
4
+ Downloads teknium/OpenHermes-2.5 from HuggingFace, formats conversations
5
+ as ChatML, tokenizes with our custom tokenizer + 2 new special tokens,
6
+ and saves train_sft.pt / val_sft.pt to finetune/data/.
7
+
8
+ Also saves the tokenizer (with special tokens baked in) to finetune/data/
9
+ so sft_train.py and chat.py can load it without re-adding tokens.
10
+
11
+ Usage:
12
+ python finetune/prepare_data.py
13
+ python finetune/prepare_data.py --n_samples 50000
14
+
15
+ Dataset structure (OpenHermes-2.5):
16
+ Each row has a "conversations" key:
17
+ [
18
+ {"from": "system", "value": "..."}, # optional
19
+ {"from": "human", "value": "..."},
20
+ {"from": "gpt", "value": "..."},
21
+ ... # may have more turns
22
+ ]
23
+ """
24
+
25
+ import os
26
+ import sys
27
+ import json
28
+ import random
29
+ import argparse
30
+ from pathlib import Path
31
+
32
+ import torch
33
+ from transformers import PreTrainedTokenizerFast
34
+ from datasets import load_dataset
35
+ from tqdm import tqdm
36
+
37
+ # ------------------------------------------------------------------ #
38
+ # Paths (relative to project root, not this script)
39
+ # ------------------------------------------------------------------ #
40
+
41
+ SCRIPT_DIR = Path(__file__).resolve().parent
42
+ PROJECT_ROOT = SCRIPT_DIR.parent
43
+
44
+ sys.path.insert(0, str(PROJECT_ROOT))
45
+
46
+ TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
47
+
48
+ # The two new tokens that define ChatML structure
49
+ SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"]
50
+
51
+ MAX_LENGTH = 1024 # model context_length — truncate anything longer
52
+
53
+ # Map OpenHermes role names → ChatML role names
54
+ ROLE_MAP = {
55
+ "system": "system",
56
+ "human": "user",
57
+ "gpt": "assistant",
58
+ "user": "user",
59
+ "assistant": "assistant",
60
+ }
61
+
62
+
63
+ # ------------------------------------------------------------------ #
64
+ # TOKENIZER
65
+ # ------------------------------------------------------------------ #
66
+
67
+ def load_and_extend_tokenizer() -> PreTrainedTokenizerFast:
68
+ """
69
+ Loads our pretrained BPE tokenizer and adds the two ChatML tokens.
70
+ Returns the extended tokenizer (vocab 32,000 → 32,002).
71
+ """
72
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
73
+
74
+ new_tokens = [t for t in SPECIAL_TOKENS if t not in tokenizer.get_vocab()]
75
+ if new_tokens:
76
+ added = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
77
+ print(f" Added {added} special token(s): {new_tokens}")
78
+ else:
79
+ print(" Special tokens already present — skipping add.")
80
+
81
+ print(f" Final vocab size: {len(tokenizer):,}")
82
+ return tokenizer
83
+
84
+
85
+ # ------------------------------------------------------------------ #
86
+ # FORMAT + TOKENIZE ONE CONVERSATION
87
+ # ------------------------------------------------------------------ #
88
+
89
+ def format_and_tokenize(
90
+ conversations: list[dict],
91
+ tokenizer: PreTrainedTokenizerFast,
92
+ ) -> tuple[list[int], list[int]] | None:
93
+ """
94
+ Converts a list of chat turns into (input_ids, labels).
95
+
96
+ ChatML format per turn:
97
+ <|im_start|>{role}\\n{content}<|im_end|>\\n
98
+
99
+ Labels:
100
+ - User / system turns → all -100 (not learned)
101
+ - Assistant turns → header (-100) + content (actual token ids)
102
+ i.e. we learn the response but not the "<|im_start|>assistant\\n" prefix
103
+
104
+ Returns None for:
105
+ - Conversations with no assistant turns (nothing to learn)
106
+ - Conversations that tokenize to fewer than 8 tokens
107
+ """
108
+ input_ids: list[int] = []
109
+ labels: list[int] = []
110
+
111
+ for turn in conversations:
112
+ role_raw = turn.get("from", turn.get("role", "")).strip().lower()
113
+ content = turn.get("value", turn.get("content", "")).strip()
114
+ role = ROLE_MAP.get(role_raw, role_raw)
115
+
116
+ if not content or not role:
117
+ continue
118
+
119
+ # ---- header: <|im_start|>role\n — never labeled ----------- #
120
+ header_text = f"<|im_start|>{role}\n"
121
+ header_ids = tokenizer.encode(header_text, add_special_tokens=False)
122
+
123
+ # ---- body: content<|im_end|>\n ------------------------------ #
124
+ body_text = f"{content}<|im_end|>\n"
125
+ body_ids = tokenizer.encode(body_text, add_special_tokens=False)
126
+
127
+ turn_input = header_ids + body_ids
128
+
129
+ if role == "assistant":
130
+ # Teach the model the body (response + im_end), not the header
131
+ turn_labels = [-100] * len(header_ids) + body_ids
132
+ else:
133
+ # User / system: no learning signal
134
+ turn_labels = [-100] * len(turn_input)
135
+
136
+ input_ids.extend(turn_input)
137
+ labels.extend(turn_labels)
138
+
139
+ # Must have at least one labeled token to be a valid training example
140
+ if not any(l != -100 for l in labels):
141
+ return None
142
+
143
+ # Truncate to context window
144
+ input_ids = input_ids[:MAX_LENGTH]
145
+ labels = labels[:MAX_LENGTH]
146
+
147
+ # Skip micro-sequences (likely malformed)
148
+ if len(input_ids) < 8:
149
+ return None
150
+
151
+ return input_ids, labels
152
+
153
+
154
+ # ------------------------------------------------------------------ #
155
+ # ARG PARSING
156
+ # ------------------------------------------------------------------ #
157
+
158
+ def parse_args():
159
+ p = argparse.ArgumentParser(description="Prepare SFT data from OpenHermes-2.5")
160
+ p.add_argument("--n_samples", type=int, default=80_000,
161
+ help="Number of conversations to sample (default: 80000)")
162
+ p.add_argument("--val_ratio", type=float, default=0.05,
163
+ help="Fraction held out for validation (default: 0.05)")
164
+ p.add_argument("--output_dir", type=str, default=str(SCRIPT_DIR / "data"),
165
+ help="Where to save train_sft.pt, val_sft.pt, and tokenizer")
166
+ p.add_argument("--seed", type=int, default=42)
167
+ return p.parse_args()
168
+
169
+
170
+ # ------------------------------------------------------------------ #
171
+ # MAIN
172
+ # ------------------------------------------------------------------ #
173
+
174
+ def main():
175
+ args = parse_args()
176
+ random.seed(args.seed)
177
+ os.makedirs(args.output_dir, exist_ok=True)
178
+
179
+ print("\n" + "=" * 60)
180
+ print(" SLLM-150M SFT — Data Preparation")
181
+ print("=" * 60)
182
+
183
+ # ---------------------------------------------------------------- #
184
+ # 1. Tokenizer
185
+ # ---------------------------------------------------------------- #
186
+ print("\n[1/4] Loading tokenizer + adding ChatML special tokens...")
187
+ tokenizer = load_and_extend_tokenizer()
188
+
189
+ # Save the extended tokenizer to data dir so training/chat can load it
190
+ tokenizer.save_pretrained(args.output_dir)
191
+ print(f" Extended tokenizer saved → {args.output_dir}/")
192
+
193
+ # ---------------------------------------------------------------- #
194
+ # 2. Dataset download
195
+ # ---------------------------------------------------------------- #
196
+ print(f"\n[2/4] Loading teknium/OpenHermes-2.5 from HuggingFace...")
197
+ ds = load_dataset("teknium/OpenHermes-2.5")
198
+ full = ds["train"] # only split in this dataset
199
+ print(f" Full dataset size: {len(full):,} examples")
200
+
201
+ # Sample a subset
202
+ n = min(args.n_samples, len(full))
203
+ indices = random.sample(range(len(full)), n)
204
+ subset = full.select(indices)
205
+ print(f" Sampled: {n:,} examples (seed={args.seed})")
206
+
207
+ # ---------------------------------------------------------------- #
208
+ # 3. Tokenize
209
+ # ---------------------------------------------------------------- #
210
+ print(f"\n[3/4] Formatting and tokenizing conversations...")
211
+
212
+ all_input_ids: list[torch.Tensor] = []
213
+ all_labels: list[torch.Tensor] = []
214
+ skipped = 0
215
+
216
+ for example in tqdm(subset, desc="Tokenizing", unit="conv"):
217
+ conversations = example.get("conversations", [])
218
+ result = format_and_tokenize(conversations, tokenizer)
219
+
220
+ if result is None:
221
+ skipped += 1
222
+ continue
223
+
224
+ ids, lbls = result
225
+ all_input_ids.append(torch.tensor(ids, dtype=torch.long))
226
+ all_labels.append( torch.tensor(lbls, dtype=torch.long))
227
+
228
+ total = len(all_input_ids)
229
+ print(f"\n Kept : {total:,}")
230
+ print(f" Skipped: {skipped:,} (no assistant turn or too short)")
231
+
232
+ if total == 0:
233
+ raise RuntimeError("No valid examples produced — check dataset structure.")
234
+
235
+ # Print a sample so we can visually verify
236
+ print("\n ── Sample (first conversation, first 400 chars) ──")
237
+ sample_decoded = tokenizer.decode(all_input_ids[0].tolist(), skip_special_tokens=False)
238
+ print(" " + sample_decoded[:400].replace("\n", "\n "))
239
+ print()
240
+
241
+ # ---------------------------------------------------------------- #
242
+ # 4. Split + save
243
+ # ---------------------------------------------------------------- #
244
+ print(f"[4/4] Splitting and saving...")
245
+
246
+ perm = list(range(total))
247
+ random.shuffle(perm)
248
+ val_n = max(1, int(total * args.val_ratio))
249
+ train_n = total - val_n
250
+
251
+ train_ids = [all_input_ids[i] for i in perm[:train_n]]
252
+ train_lbl = [all_labels[i] for i in perm[:train_n]]
253
+ val_ids = [all_input_ids[i] for i in perm[train_n:]]
254
+ val_lbl = [all_labels[i] for i in perm[train_n:]]
255
+
256
+ train_path = os.path.join(args.output_dir, "train_sft.pt")
257
+ val_path = os.path.join(args.output_dir, "val_sft.pt")
258
+
259
+ torch.save({"input_ids": train_ids, "labels": train_lbl}, train_path)
260
+ torch.save({"input_ids": val_ids, "labels": val_lbl}, val_path)
261
+
262
+ # Stats
263
+ lengths = [len(x) for x in all_input_ids]
264
+ label_ratios = [(t != -100).float().mean().item() for t in all_labels]
265
+ avg_len = sum(lengths) / len(lengths)
266
+ avg_lbl_ratio = sum(label_ratios) / len(label_ratios)
267
+
268
+ print(f"\n train_sft.pt : {train_n:,} examples")
269
+ print(f" val_sft.pt : {val_n:,} examples")
270
+ print(f"\n Avg seq length : {avg_len:.0f} tokens (max={max(lengths)})")
271
+ print(f" Avg assistant ratio : {avg_lbl_ratio:.1%} of tokens are labeled")
272
+
273
+ # Save metadata for reference
274
+ meta = {
275
+ "dataset": "teknium/OpenHermes-2.5",
276
+ "n_sampled": n,
277
+ "n_train": train_n,
278
+ "n_val": val_n,
279
+ "vocab_size": len(tokenizer),
280
+ "special_tokens": SPECIAL_TOKENS,
281
+ "max_length": MAX_LENGTH,
282
+ "seed": args.seed,
283
+ }
284
+ with open(os.path.join(args.output_dir, "meta.json"), "w") as f:
285
+ json.dump(meta, f, indent=2)
286
+ print(f"\n meta.json saved → {args.output_dir}/meta.json")
287
+
288
+ print("\n" + "=" * 60)
289
+ print(" Data preparation complete!")
290
+ print("=" * 60)
291
+ print(f"""
292
+ Next step:
293
+ python finetune/sft_train.py \\
294
+ --base_ckpt runs/sllm_150m/ckpt_0011500.pt \\
295
+ --run_dir runs/sllm_150m_chat \\
296
+ --max_steps 2000 \\
297
+ --batch_size 4 --grad_accum 8 \\
298
+ --grad_checkpoint
299
+ """)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
finetune/sft_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune/sft_dataset.py
3
+
4
+ SFT Dataset — loads pre-tokenized ChatML sequences from .pt shards
5
+ produced by prepare_data.py.
6
+
7
+ Each item returns (input_ids, labels) where labels has -100 for all
8
+ non-assistant tokens so CrossEntropy only trains on assistant responses.
9
+ """
10
+
11
+ from functools import partial
12
+
13
+ import torch
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+
17
+ class SFTDataset(Dataset):
18
+ """
19
+ Dataset for Supervised Fine-Tuning.
20
+
21
+ Loads a .pt shard containing:
22
+ {
23
+ "input_ids": list of LongTensors (variable length),
24
+ "labels": list of LongTensors (same shapes, -100 for masked)
25
+ }
26
+
27
+ Each __getitem__ returns:
28
+ input_ids : (seq_len,) LongTensor
29
+ labels : (seq_len,) LongTensor — -100 for user/system tokens
30
+ """
31
+
32
+ def __init__(self, data_path: str, context_length: int = 1024):
33
+ data = torch.load(data_path, weights_only=False)
34
+ self.input_ids = data["input_ids"]
35
+ self.labels = data["labels"]
36
+ self.context_length = context_length
37
+
38
+ assert len(self.input_ids) == len(self.labels), "input_ids / labels length mismatch"
39
+ print(f"[SFTDataset] Loaded {len(self.input_ids):,} examples from {data_path}")
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.input_ids)
43
+
44
+ def __getitem__(self, idx):
45
+ ids = self.input_ids[idx]
46
+ lbl = self.labels[idx]
47
+
48
+ # Hard-truncate to model context length
49
+ if len(ids) > self.context_length:
50
+ ids = ids[: self.context_length]
51
+ lbl = lbl[: self.context_length]
52
+
53
+ return ids, lbl
54
+
55
+
56
+ # ------------------------------------------------------------------ #
57
+ # COLLATE
58
+ # ------------------------------------------------------------------ #
59
+
60
+ def sft_collate_fn(batch, pad_token_id: int):
61
+ """
62
+ Pads a batch of variable-length sequences to the same length.
63
+ input_ids → padded with pad_token_id
64
+ labels → padded with -100 (ignored by CrossEntropy)
65
+ """
66
+ input_ids_list, labels_list = zip(*batch)
67
+
68
+ max_len = max(x.size(0) for x in input_ids_list)
69
+
70
+ input_ids_padded = torch.full((len(batch), max_len), pad_token_id, dtype=torch.long)
71
+ labels_padded = torch.full((len(batch), max_len), -100, dtype=torch.long)
72
+
73
+ for i, (ids, lbl) in enumerate(zip(input_ids_list, labels_list)):
74
+ n = ids.size(0)
75
+ input_ids_padded[i, :n] = ids
76
+ labels_padded[i, :n] = lbl
77
+
78
+ return input_ids_padded, labels_padded
79
+
80
+
81
+ # ------------------------------------------------------------------ #
82
+ # FACTORY
83
+ # ------------------------------------------------------------------ #
84
+
85
+ def build_sft_dataloader(
86
+ data_path: str,
87
+ batch_size: int,
88
+ pad_token_id: int,
89
+ context_length: int = 1024,
90
+ num_workers: int = 0,
91
+ shuffle: bool = True,
92
+ ) -> DataLoader:
93
+ dataset = SFTDataset(data_path, context_length=context_length)
94
+ collate_fn = partial(sft_collate_fn, pad_token_id=pad_token_id)
95
+
96
+ return DataLoader(
97
+ dataset,
98
+ batch_size = batch_size,
99
+ shuffle = shuffle,
100
+ num_workers = num_workers,
101
+ collate_fn = collate_fn,
102
+ pin_memory = True,
103
+ )
finetune/sft_train.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune/sft_train.py
3
+
4
+ Full Supervised Fine-Tuning (SFT) of SLLM-150M → Chat Model.
5
+
6
+ Starts from the pretrained base checkpoint, resizes the token embedding
7
+ for 2 new ChatML special tokens, then trains with masked CrossEntropy
8
+ so only assistant response tokens contribute to the loss.
9
+
10
+ Usage (first run):
11
+ python finetune/sft_train.py \\
12
+ --base_ckpt runs/sllm_150m/ckpt_0011500.pt \\
13
+ --run_dir runs/sllm_150m_chat \\
14
+ --max_steps 2000 \\
15
+ --batch_size 4 --grad_accum 8 \\
16
+ --grad_checkpoint
17
+
18
+ Resume:
19
+ python finetune/sft_train.py \\
20
+ --resume --run_dir runs/sllm_150m_chat \\
21
+ --extra_steps 1000
22
+ """
23
+
24
+ import os
25
+ import sys
26
+ import json
27
+ import math
28
+ import time
29
+ import signal
30
+ import argparse
31
+ from pathlib import Path
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ from torch.amp import autocast, GradScaler
37
+ from transformers import PreTrainedTokenizerFast
38
+ from tqdm import tqdm
39
+
40
+ # ------------------------------------------------------------------ #
41
+ # Resolve project root so model/ is importable
42
+ # ------------------------------------------------------------------ #
43
+
44
+ SCRIPT_DIR = Path(__file__).resolve().parent
45
+ PROJECT_ROOT = SCRIPT_DIR.parent
46
+ DATA_DIR = SCRIPT_DIR / "data"
47
+
48
+ sys.path.insert(0, str(PROJECT_ROOT))
49
+ sys.path.insert(0, str(SCRIPT_DIR)) # so we can import sft_dataset
50
+
51
+ from model.config import SLLM_150M
52
+ from model.model import SLLM
53
+ from sft_dataset import build_sft_dataloader
54
+
55
+
56
+ # ------------------------------------------------------------------ #
57
+ # ARG PARSING
58
+ # ------------------------------------------------------------------ #
59
+
60
+ def parse_args():
61
+ p = argparse.ArgumentParser(description="SLLM-150M SFT Training")
62
+
63
+ # Checkpoints
64
+ p.add_argument("--base_ckpt", type=str,
65
+ default=str(PROJECT_ROOT / "runs" / "sllm_150m" / "ckpt_0011500.pt"),
66
+ help="Path to pretrained base checkpoint (.pt)")
67
+ p.add_argument("--run_dir", type=str, default="runs/sllm_150m_chat",
68
+ help="Output directory for SFT checkpoints and logs")
69
+ p.add_argument("--resume", action="store_true",
70
+ help="Resume from latest SFT checkpoint in --run_dir")
71
+ p.add_argument("--max_steps", type=int, default=2000,
72
+ help="Absolute step target for this run")
73
+ p.add_argument("--extra_steps", type=int, default=None,
74
+ help="Run N more steps from current checkpoint (relative)")
75
+
76
+ # Data
77
+ p.add_argument("--data_dir", type=str, default=str(DATA_DIR),
78
+ help="Directory with train_sft.pt, val_sft.pt, and tokenizer files")
79
+ p.add_argument("--num_workers", type=int, default=0)
80
+
81
+ # Optimisation — note: much lower LR than pretraining
82
+ p.add_argument("--batch_size", type=int, default=4)
83
+ p.add_argument("--grad_accum", type=int, default=8)
84
+ p.add_argument("--max_lr", type=float, default=1e-5,
85
+ help="Peak LR (10x lower than pretraining)")
86
+ p.add_argument("--min_lr", type=float, default=1e-6)
87
+ p.add_argument("--warmup_steps", type=int, default=30)
88
+ p.add_argument("--weight_decay", type=float, default=0.1)
89
+ p.add_argument("--grad_clip", type=float, default=1.0)
90
+ p.add_argument("--dropout", type=float, default=0.1,
91
+ help="Dropout rate during SFT (0.0 in pretraining)")
92
+
93
+ # Memory
94
+ p.add_argument("--grad_checkpoint", action="store_true",
95
+ help="Enable gradient checkpointing (saves VRAM)")
96
+ p.add_argument("--dtype", type=str, default="bf16",
97
+ choices=["fp32", "fp16", "bf16"])
98
+
99
+ # Logging
100
+ p.add_argument("--log_every", type=int, default=10)
101
+ p.add_argument("--save_every", type=int, default=500)
102
+ p.add_argument("--val_every", type=int, default=250)
103
+ p.add_argument("--val_steps", type=int, default=20)
104
+
105
+ return p.parse_args()
106
+
107
+
108
+ # ------------------------------------------------------------------ #
109
+ # VOCAB RESIZE
110
+ # ------------------------------------------------------------------ #
111
+
112
+ def resize_token_embeddings(model: SLLM, new_vocab_size: int):
113
+ """
114
+ Grows model.token_emb from old_vocab_size → new_vocab_size.
115
+
116
+ New rows are initialised to the mean of existing embeddings so
117
+ training starts from a stable point rather than random noise.
118
+ lm_head weight-tying is re-applied automatically.
119
+ """
120
+ old_size = model.config.vocab_size
121
+ if new_vocab_size == old_size:
122
+ return
123
+ if new_vocab_size < old_size:
124
+ raise ValueError(f"Cannot shrink vocab ({old_size} → {new_vocab_size})")
125
+
126
+ d_model = model.config.d_model
127
+ device = model.token_emb.weight.device
128
+ dtype = model.token_emb.weight.dtype
129
+ old_weight = model.token_emb.weight.data.clone() # (old_size, d)
130
+ mean_vec = old_weight.mean(dim=0) # (d,)
131
+
132
+ new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
133
+ new_weight[:old_size] = old_weight
134
+ # Broadcast mean_vec into new rows
135
+ new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
136
+
137
+ # Replace the embedding module in-place
138
+ new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
139
+ new_emb.weight.data = new_weight
140
+ model.token_emb = new_emb
141
+
142
+ # Re-tie the LM head to the (now larger) embedding
143
+ model.lm_head.weight = model.token_emb.weight
144
+
145
+ # Keep config consistent
146
+ model.config.vocab_size = new_vocab_size
147
+
148
+ n_new = new_vocab_size - old_size
149
+ print(f" Vocab resized: {old_size:,} → {new_vocab_size:,} (+{n_new} tokens, init=mean)")
150
+
151
+
152
+ # ------------------------------------------------------------------ #
153
+ # DROPOUT
154
+ # ------------------------------------------------------------------ #
155
+
156
+ def set_dropout(model: SLLM, rate: float):
157
+ """Applies dropout rate to every nn.Dropout in the model."""
158
+ count = 0
159
+ for m in model.modules():
160
+ if isinstance(m, nn.Dropout):
161
+ m.p = rate
162
+ count += 1
163
+ if count:
164
+ print(f" Dropout set to {rate} on {count} layer(s)")
165
+
166
+
167
+ # ------------------------------------------------------------------ #
168
+ # LR SCHEDULE (cosine with linear warmup, same shape as train.py)
169
+ # ------------------------------------------------------------------ #
170
+
171
+ def get_lr(step: int, warmup_steps: int, total_steps: int,
172
+ max_lr: float, min_lr: float) -> float:
173
+ if step < warmup_steps:
174
+ return max_lr * (step + 1) / warmup_steps
175
+ decay_steps = total_steps if total_steps else 5_000
176
+ if step >= decay_steps:
177
+ return min_lr
178
+ progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps)
179
+ coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
180
+ return min_lr + coeff * (max_lr - min_lr)
181
+
182
+
183
+ # ------------------------------------------------------------------ #
184
+ # OPTIMIZER (mirrors train.py — AdamW selective decay)
185
+ # ------------------------------------------------------------------ #
186
+
187
+ def build_optimizer(model: SLLM, lr: float, weight_decay: float):
188
+ decay, no_decay = [], []
189
+ for name, param in model.named_parameters():
190
+ if not param.requires_grad:
191
+ continue
192
+ if param.dim() >= 2:
193
+ decay.append(param)
194
+ else:
195
+ no_decay.append(param)
196
+
197
+ groups = [
198
+ {"params": decay, "weight_decay": weight_decay},
199
+ {"params": no_decay, "weight_decay": 0.0},
200
+ ]
201
+ n_d = sum(p.numel() for p in decay)
202
+ n_nd = sum(p.numel() for p in no_decay)
203
+ print(f" Optimizer: {n_d/1e6:.1f}M decay | {n_nd/1e6:.1f}M no-decay | lr={lr:.2e}")
204
+
205
+ # Note: no fused=True here — new embedding rows need correct grad flow
206
+ return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)
207
+
208
+
209
+ # ------------------------------------------------------------------ #
210
+ # CHECKPOINT SAVE / LOAD
211
+ # ------------------------------------------------------------------ #
212
+
213
+ def save_checkpoint(path: str, model: SLLM, optimizer, step: int,
214
+ loss: float, vocab_size: int):
215
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
216
+ torch.save({
217
+ "step": step,
218
+ "model_state_dict": model.state_dict(),
219
+ "optimizer_state_dict": optimizer.state_dict(),
220
+ "loss": loss,
221
+ "vocab_size": vocab_size,
222
+ }, path)
223
+ print(f"\n [CKPT] Saved: {path} (step={step}, loss={loss:.4f})")
224
+
225
+
226
+ def load_sft_checkpoint(run_dir: str, model: SLLM, optimizer, device):
227
+ """Loads the latest ckpt_sft_*.pt from run_dir. Returns (step, vocab_size)."""
228
+ ckpts = sorted([
229
+ f for f in os.listdir(run_dir)
230
+ if f.startswith("ckpt_sft_") and f.endswith(".pt")
231
+ ])
232
+ if not ckpts:
233
+ raise FileNotFoundError(f"No SFT checkpoints found in {run_dir}")
234
+
235
+ path = os.path.join(run_dir, ckpts[-1])
236
+ ckpt = torch.load(path, map_location=device, weights_only=False)
237
+ model.load_state_dict(ckpt["model_state_dict"])
238
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
239
+ step = ckpt["step"]
240
+ vocab_size = ckpt.get("vocab_size", model.config.vocab_size)
241
+ loss = ckpt.get("loss", float("nan"))
242
+ print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})")
243
+ return step, vocab_size
244
+
245
+
246
+ # ------------------------------------------------------------------ #
247
+ # VALIDATION (uses ignore_index=-100 like training)
248
+ # ------------------------------------------------------------------ #
249
+
250
+ @torch.no_grad()
251
+ def estimate_val_loss(model: SLLM, val_loader, val_steps: int,
252
+ device, dtype_ctx) -> float:
253
+ model.eval()
254
+ losses = []
255
+ for i, (x, y) in enumerate(val_loader):
256
+ if i >= val_steps:
257
+ break
258
+ x, y = x.to(device), y.to(device)
259
+ with dtype_ctx:
260
+ logits, _ = model(x)
261
+ # Shift logits and labels by 1 to predict the next token
262
+ shift_logits = logits[..., :-1, :].contiguous()
263
+ shift_labels = y[..., 1:].contiguous()
264
+ loss = F.cross_entropy(
265
+ shift_logits.view(-1, shift_logits.size(-1)),
266
+ shift_labels.view(-1),
267
+ ignore_index=-100,
268
+ )
269
+ losses.append(loss.item())
270
+ model.train()
271
+ return sum(losses) / len(losses) if losses else float("nan")
272
+
273
+
274
+ # ------------------------------------------------------------------ #
275
+ # METRIC LOGGER
276
+ # ------------------------------------------------------------------ #
277
+
278
+ class MetricLogger:
279
+ def __init__(self, log_path: str):
280
+ self.log_path = log_path
281
+ os.makedirs(os.path.dirname(os.path.abspath(log_path)), exist_ok=True)
282
+ print(f" [LOG] Logging to: {log_path}")
283
+
284
+ def log(self, **kwargs):
285
+ with open(self.log_path, "a") as f:
286
+ f.write(json.dumps(kwargs) + "\n")
287
+
288
+
289
+ # ------------------------------------------------------------------ #
290
+ # MAIN TRAINING LOOP
291
+ # ------------------------------------------------------------------ #
292
+
293
+ def train():
294
+ args = parse_args()
295
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
296
+
297
+ print(f"\n{'='*60}")
298
+ print(f" SLLM-150M → Chat Model (SFT)")
299
+ print(f"{'='*60}")
300
+ print(f"\nDevice : {device}")
301
+ if device.type == "cuda":
302
+ print(f"GPU : {torch.cuda.get_device_name(0)}")
303
+ print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
304
+
305
+ # ---- dtype ----------------------------------------------------- #
306
+ if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
307
+ dtype_torch, dtype_name = torch.bfloat16, "bf16"
308
+ elif args.dtype == "fp16" and device.type == "cuda":
309
+ dtype_torch, dtype_name = torch.float16, "fp16"
310
+ else:
311
+ dtype_torch, dtype_name = torch.float32, "fp32"
312
+
313
+ print(f"dtype : {dtype_name}")
314
+ use_amp = dtype_torch in (torch.float16, torch.bfloat16)
315
+ dtype_ctx = (autocast(device_type=device.type, dtype=dtype_torch)
316
+ if use_amp else torch.no_grad().__class__())
317
+ scaler = GradScaler(enabled=(dtype_torch == torch.float16))
318
+
319
+ # ---- Tokenizer ------------------------------------------------- #
320
+ print("\n[1/5] Loading tokenizer...")
321
+ tok_path = args.data_dir
322
+ if os.path.exists(os.path.join(tok_path, "tokenizer.json")):
323
+ # Prefer the saved tokenizer from prepare_data.py (has special tokens)
324
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
325
+ print(f" Loaded from data dir: {tok_path}")
326
+ else:
327
+ # Fallback: load base tokenizer and add special tokens manually
328
+ base_tok_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer")
329
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(base_tok_dir)
330
+ tokenizer.add_special_tokens({"additional_special_tokens":
331
+ ["<|im_start|>", "<|im_end|>"]})
332
+ print(f" Loaded base tokenizer + added special tokens")
333
+
334
+ new_vocab_size = len(tokenizer)
335
+ pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None \
336
+ else tokenizer.eos_token_id
337
+ print(f" Vocab size : {new_vocab_size:,}")
338
+ print(f" Pad token : {pad_id}")
339
+
340
+ # ---- Model ----------------------------------------------------- #
341
+ print("\n[2/5] Loading model...")
342
+ cfg = SLLM_150M
343
+ model = SLLM(cfg).to(device)
344
+
345
+ if not args.resume:
346
+ # Load pretrained base weights (step 11,500)
347
+ print(f" Loading base checkpoint: {args.base_ckpt}")
348
+ base_ckpt = torch.load(args.base_ckpt, map_location=device, weights_only=False)
349
+ model.load_state_dict(base_ckpt["model_state_dict"])
350
+ base_step = base_ckpt.get("step", "?")
351
+ base_loss = base_ckpt.get("loss", float("nan"))
352
+ print(f" Base model step={base_step} loss={base_loss:.4f}")
353
+ del base_ckpt
354
+
355
+ # Grow embedding for the 2 new special tokens
356
+ resize_token_embeddings(model, new_vocab_size)
357
+
358
+ # Apply SFT dropout (was 0.0 in pretraining)
359
+ set_dropout(model, args.dropout)
360
+
361
+ if args.grad_checkpoint:
362
+ model.enable_gradient_checkpointing()
363
+ print(" Gradient checkpointing: ON")
364
+
365
+ print(f" Model params: {model.count_params()/1e6:.1f}M")
366
+
367
+ # ---- Optimizer ------------------------------------------------- #
368
+ print("\n[3/5] Building optimizer...")
369
+ optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay)
370
+
371
+ # ---- Resume from SFT checkpoint -------------------------------- #
372
+ start_step = 0
373
+ if args.resume:
374
+ try:
375
+ start_step, _ = load_sft_checkpoint(args.run_dir, model, optimizer, device)
376
+ except FileNotFoundError as e:
377
+ print(f" [WARN] {e} — starting SFT from base checkpoint.")
378
+
379
+ # Resolve --extra_steps → --max_steps
380
+ if args.extra_steps is not None:
381
+ args.max_steps = start_step + args.extra_steps
382
+ print(f" --extra_steps {args.extra_steps} → max_steps={args.max_steps}")
383
+
384
+ if args.max_steps is not None and start_step >= args.max_steps:
385
+ print(f"\n [WARN] Already at step {start_step} >= max_steps {args.max_steps}.")
386
+ print(f" Use --extra_steps N to run N more steps.")
387
+ return
388
+
389
+ # ---- Data ------------------------------------------------------ #
390
+ print("\n[4/5] Loading SFT dataset...")
391
+ train_path = os.path.join(args.data_dir, "train_sft.pt")
392
+ val_path = os.path.join(args.data_dir, "val_sft.pt")
393
+
394
+ train_loader = build_sft_dataloader(
395
+ data_path=train_path, batch_size=args.batch_size,
396
+ pad_token_id=pad_id, context_length=cfg.context_length,
397
+ num_workers=args.num_workers, shuffle=True,
398
+ )
399
+ val_loader = build_sft_dataloader(
400
+ data_path=val_path, batch_size=args.batch_size,
401
+ pad_token_id=pad_id, context_length=cfg.context_length,
402
+ num_workers=0, shuffle=False,
403
+ )
404
+
405
+ # ---- Run dir + logger ------------------------------------------ #
406
+ os.makedirs(args.run_dir, exist_ok=True)
407
+ log_path = os.path.join(args.run_dir, "sft_log.jsonl")
408
+ logger = MetricLogger(log_path)
409
+
410
+ # ---- Training info --------------------------------------------- #
411
+ eff_batch = args.batch_size * args.grad_accum
412
+ print(f"\n[5/5] Training config:")
413
+ print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} → eff={eff_batch})")
414
+ print(f" max_steps : {args.max_steps}")
415
+ print(f" start_step : {start_step}")
416
+ print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else '∞'}")
417
+ print(f" max_lr / min_lr: {args.max_lr:.2e} / {args.min_lr:.2e}")
418
+ print(f" warmup_steps : {args.warmup_steps}")
419
+ print(f" save_every : {args.save_every}")
420
+ print(f" val_every : {args.val_every}")
421
+
422
+ # ---- Ctrl+C handler -------------------------------------------- #
423
+ stop_flag = {"stop": False}
424
+ def _signal_handler(sig, frame):
425
+ print("\n [SIGNAL] Ctrl+C — will save and exit after this step.")
426
+ stop_flag["stop"] = True
427
+ signal.signal(signal.SIGINT, _signal_handler)
428
+
429
+ # ================================================================ #
430
+ # TRAINING LOOP
431
+ # ================================================================ #
432
+ model.train()
433
+ step = start_step
434
+ running_loss = 0.0
435
+ t_start = time.time()
436
+ t_step_start = time.time()
437
+ data_iter = iter(train_loader)
438
+
439
+ print(f"\n{'='*60}")
440
+ print(f" SFT STARTED (step {step} → {args.max_steps})")
441
+ print(f"{'='*60}\n")
442
+
443
+ pbar = tqdm(
444
+ initial=step, total=args.max_steps,
445
+ desc="SFT", unit="step", dynamic_ncols=True,
446
+ )
447
+
448
+ while True:
449
+ # ---- Stop conditions --------------------------------------- #
450
+ if stop_flag["stop"]:
451
+ break
452
+ if args.max_steps is not None and step >= args.max_steps:
453
+ print(f"\n [DONE] Reached max_steps={args.max_steps}")
454
+ break
455
+
456
+ optimizer.zero_grad(set_to_none=True)
457
+ accum_loss = 0.0
458
+
459
+ # ---- Gradient accumulation micro-steps --------------------- #
460
+ for _ in range(args.grad_accum):
461
+ try:
462
+ x, y = next(data_iter)
463
+ except StopIteration:
464
+ data_iter = iter(train_loader)
465
+ x, y = next(data_iter)
466
+
467
+ x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
468
+
469
+ with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
470
+ logits, _ = model(x) # (B, T, V) — don't use built-in loss
471
+ # Shift logits and labels by 1 to predict the next token
472
+ shift_logits = logits[..., :-1, :].contiguous()
473
+ shift_labels = y[..., 1:].contiguous()
474
+ # Use ignore_index=-100 so only assistant tokens drive the loss
475
+ loss = F.cross_entropy(
476
+ shift_logits.view(-1, shift_logits.size(-1)),
477
+ shift_labels.view(-1),
478
+ ignore_index=-100,
479
+ ) / args.grad_accum # scale for accumulation
480
+
481
+ scaler.scale(loss).backward()
482
+ accum_loss += loss.item()
483
+
484
+ # ---- Grad clip --------------------------------------------- #
485
+ if args.grad_clip > 0:
486
+ scaler.unscale_(optimizer)
487
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
488
+ else:
489
+ grad_norm = float("nan")
490
+
491
+ # ---- LR ---------------------------------------------------- #
492
+ lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr)
493
+ for pg in optimizer.param_groups:
494
+ pg["lr"] = lr
495
+
496
+ # ---- Optimizer step ---------------------------------------- #
497
+ scaler.step(optimizer)
498
+ scaler.update()
499
+
500
+ step += 1
501
+ running_loss = accum_loss
502
+
503
+ t_now = time.time()
504
+ elapsed_step = t_now - t_step_start
505
+ t_step_start = t_now
506
+
507
+ pbar.update(1)
508
+ pbar.set_postfix({"loss": f"{running_loss:.4f}", "lr": f"{lr:.1e}"})
509
+
510
+ # ---- Logging ----------------------------------------------- #
511
+ if step % args.log_every == 0:
512
+ entry = {
513
+ "step": step,
514
+ "loss": round(running_loss, 6),
515
+ "lr": lr,
516
+ "grad_norm": round(float(grad_norm), 4)
517
+ if not math.isnan(float(grad_norm)) else None,
518
+ "elapsed_s": round(t_now - t_start, 1),
519
+ }
520
+ if device.type == "cuda":
521
+ entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3)
522
+ logger.log(**entry)
523
+
524
+ # ---- Validation -------------------------------------------- #
525
+ if step % args.val_every == 0:
526
+ v_ctx = autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp)
527
+ val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, v_ctx)
528
+ tqdm.write(
529
+ f" [STEP {step:5d}] train={running_loss:.4f} "
530
+ f"val={val_loss:.4f} lr={lr:.1e}"
531
+ )
532
+ logger.log(step=step, val_loss=round(val_loss, 6))
533
+
534
+ # ---- Checkpoint -------------------------------------------- #
535
+ if step % args.save_every == 0:
536
+ ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt")
537
+ save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size)
538
+
539
+ # ================================================================ #
540
+ # FINAL SAVE
541
+ # ================================================================ #
542
+ pbar.close()
543
+ steps_done = step - start_step
544
+ if steps_done > 0:
545
+ ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt")
546
+ save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size)
547
+ else:
548
+ print("\n [SKIP] No steps taken — skipping checkpoint save.")
549
+
550
+ total_time = time.time() - t_start
551
+ print(f"\n{'='*60}")
552
+ print(f" SFT COMPLETE")
553
+ print(f"{'='*60}")
554
+ print(f" Steps done : {steps_done}")
555
+ print(f" Final loss : {running_loss:.4f}")
556
+ print(f" Total time : {total_time/60:.1f} min")
557
+ print(f" Run dir : {args.run_dir}")
558
+ print(f"\nStart chatting:")
559
+ print(f" python finetune/chat.py --run_dir {args.run_dir}")
560
+
561
+
562
+ if __name__ == "__main__":
563
+ train()
model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # model/__init__.py
2
+ from model.config import ModelConfig, SLLM_100M, SLLM_150M
3
+ from model.model import SLLM
4
+
5
+ __all__ = ["ModelConfig", "SLLM_100M", "SLLM_150M", "SLLM"]
model/attention.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/attention.py
3
+
4
+ Causal Multi-Head Self-Attention with RoPE.
5
+
6
+ Architecture:
7
+ Input x (B, T, d_model)
8
+ -> Linear projections Q, K, V (no bias)
9
+ -> Reshape to (B, n_heads, T, head_dim)
10
+ -> Apply RoPE to Q and K
11
+ -> Scaled dot-product attention with causal mask
12
+ -> Reshape back to (B, T, d_model)
13
+ -> Output projection O (no bias)
14
+
15
+ Uses torch.nn.functional.scaled_dot_product_attention (Flash Attention
16
+ when available via PyTorch 2.0+) for memory-efficient attention.
17
+ The causal mask is handled by is_causal=True — no need to materialize
18
+ an explicit O(T^2) mask tensor.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from model.config import ModelConfig
26
+ from model.rope import RoPECache, apply_rope
27
+
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config: ModelConfig):
32
+ super().__init__()
33
+ self.n_heads = config.n_heads
34
+ self.head_dim = config.head_dim
35
+ self.d_model = config.d_model
36
+ self.dropout = config.dropout
37
+
38
+ # Q, K, V projections fused into one matrix for efficiency
39
+ # Output: (B, T, 3 * d_model), then split
40
+ self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
41
+
42
+ # Output projection
43
+ self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
44
+
45
+ # Attention dropout (applied inside sdpa)
46
+ self.attn_dropout = config.dropout
47
+
48
+ # RoPE cache — lives as a buffer (moves to GPU automatically)
49
+ self.rope = RoPECache(
50
+ head_dim = config.head_dim,
51
+ max_seq_len = config.context_length,
52
+ theta = config.rope_theta,
53
+ )
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Args:
58
+ x : (B, T, d_model)
59
+
60
+ Returns:
61
+ out : (B, T, d_model)
62
+ """
63
+ B, T, C = x.shape # C = d_model
64
+
65
+ # ---- QKV projection ---------------------------------------- #
66
+ qkv = self.qkv_proj(x) # (B, T, 3*C)
67
+ q, k, v = qkv.split(self.d_model, dim=-1) # each: (B, T, C)
68
+
69
+ # ---- Reshape to (B, n_heads, T, head_dim) ------------------ #
70
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
71
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
72
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
73
+
74
+ # ---- Apply RoPE to Q and K --------------------------------- #
75
+ cos, sin = self.rope.get(T) # (T, head_dim)
76
+ q, k = apply_rope(q, k, cos, sin)
77
+
78
+ # ---- Scaled dot-product attention (Flash Attention) -------- #
79
+ # is_causal=True handles the causal mask internally — no mask alloc.
80
+ # dropout_p only applies during training.
81
+ attn_out = F.scaled_dot_product_attention(
82
+ q, k, v,
83
+ attn_mask = None,
84
+ dropout_p = self.attn_dropout if self.training else 0.0,
85
+ is_causal = True,
86
+ ) # (B, n_heads, T, head_dim)
87
+
88
+ # ---- Merge heads ------------------------------------------- #
89
+ # contiguous() needed before view after transpose
90
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)
91
+
92
+ # ---- Output projection ------------------------------------- #
93
+ return self.o_proj(attn_out) # (B, T, d_model)
94
+
95
+
96
+ # ------------------------------------------------------------------ #
97
+ # QUICK CHECK
98
+ # ------------------------------------------------------------------ #
99
+
100
+ if __name__ == "__main__":
101
+ from model.config import SLLM_100M
102
+
103
+ cfg = SLLM_100M
104
+ attn = CausalSelfAttention(cfg)
105
+ print(f"Attention params : {sum(p.numel() for p in attn.parameters())/1e6:.2f}M")
106
+
107
+ B, T = 2, 64
108
+ x = torch.randn(B, T, cfg.d_model)
109
+ out = attn(x)
110
+
111
+ print(f"Input shape : {x.shape}")
112
+ print(f"Output shape : {out.shape}")
113
+ assert out.shape == (B, T, cfg.d_model), "Shape mismatch!"
114
+ print("PASS")
model/block.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/block.py
3
+
4
+ Single Transformer Block (pre-norm LLaMA-style).
5
+
6
+ Pre-Norm vs Post-Norm:
7
+ GPT-2 (post-norm): x = x + Attention(LayerNorm(x)) <- less stable
8
+ LLaMA (pre-norm): x = LayerNorm(x); x = x + Attention(x) <- more stable
9
+
10
+ We use PRE-NORM with RMSNorm for training stability at scale.
11
+
12
+ Block structure:
13
+ x -> RMSNorm -> CausalSelfAttention -> (+residual)
14
+ -> RMSNorm -> SwiGLU MLP -> (+residual)
15
+ -> output
16
+
17
+ Note: Residual connections bypass both norm and sublayer, which allows
18
+ gradients to flow directly to earlier layers during backprop.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from model.config import ModelConfig
25
+ from model.norm import RMSNorm
26
+ from model.attention import CausalSelfAttention
27
+ from model.mlp import SwiGLU
28
+
29
+
30
+ class TransformerBlock(nn.Module):
31
+
32
+ def __init__(self, config: ModelConfig):
33
+ super().__init__()
34
+
35
+ # Pre-attention norm
36
+ self.norm_attn = RMSNorm(config.d_model)
37
+
38
+ # Causal self-attention with RoPE
39
+ self.attn = CausalSelfAttention(config)
40
+
41
+ # Pre-FFN norm
42
+ self.norm_mlp = RMSNorm(config.d_model)
43
+
44
+ # SwiGLU feed-forward
45
+ self.mlp = SwiGLU(config)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Args:
50
+ x : (B, T, d_model)
51
+
52
+ Returns:
53
+ x : (B, T, d_model)
54
+ """
55
+ # Attention sub-layer with residual
56
+ x = x + self.attn(self.norm_attn(x))
57
+
58
+ # FFN sub-layer with residual
59
+ x = x + self.mlp(self.norm_mlp(x))
60
+
61
+ return x
62
+
63
+
64
+ # ------------------------------------------------------------------ #
65
+ # QUICK CHECK
66
+ # ------------------------------------------------------------------ #
67
+
68
+ if __name__ == "__main__":
69
+ from model.config import SLLM_100M
70
+
71
+ cfg = SLLM_100M
72
+ block = TransformerBlock(cfg)
73
+
74
+ n = sum(p.numel() for p in block.parameters())
75
+ print(f"Block params : {n/1e6:.3f}M")
76
+
77
+ B, T = 2, 64
78
+ x = torch.randn(B, T, cfg.d_model)
79
+ out = block(x)
80
+
81
+ print(f"Input shape : {x.shape}")
82
+ print(f"Output shape : {out.shape}")
83
+ assert out.shape == x.shape, "Shape mismatch!"
84
+ print("PASS")
model/config.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/config.py
3
+
4
+ ModelConfig dataclass + preset configs for SLLM-100M and SLLM-150M.
5
+ All hyperparameters live here so every other module imports from one place.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+
11
+ def _swiglu_d_ff(d_model: int) -> int:
12
+ """
13
+ SwiGLU hidden dimension.
14
+ LLaMA formula: round_up_256( int(2/3 * 4 * d_model) )
15
+ """
16
+ raw = int(2 / 3 * 4 * d_model)
17
+ return ((raw + 255) // 256) * 256 # round up to nearest 256
18
+
19
+
20
+ @dataclass
21
+ class ModelConfig:
22
+ # ---- Vocabulary ------------------------------------------------- #
23
+ vocab_size: int = 32_000 # must match trained tokenizer
24
+
25
+ # ---- Sequence --------------------------------------------------- #
26
+ context_length: int = 1024 # max tokens per sequence
27
+
28
+ # ---- Transformer dimensions ------------------------------------- #
29
+ d_model: int = 768 # embedding / hidden dim
30
+ n_heads: int = 12 # number of attention heads
31
+ n_layers: int = 12 # number of transformer blocks
32
+
33
+ # ---- FFN -------------------------------------------------------- #
34
+ # SwiGLU d_ff is auto-computed from d_model if not set explicitly
35
+ d_ff: int = 0 # 0 = auto
36
+
37
+ # ---- Regularization --------------------------------------------- #
38
+ dropout: float = 0.0 # 0.0 for pre-training
39
+
40
+ # ---- Misc ------------------------------------------------------- #
41
+ bias: bool = False # no bias (cleaner, matches LLaMA)
42
+ rope_theta: float = 10_000.0 # RoPE base frequency
43
+
44
+ def __post_init__(self):
45
+ # Auto-compute d_ff if not set
46
+ if self.d_ff == 0:
47
+ self.d_ff = _swiglu_d_ff(self.d_model)
48
+
49
+ # Sanity checks
50
+ assert self.d_model % self.n_heads == 0, (
51
+ f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
52
+ )
53
+
54
+ @property
55
+ def head_dim(self) -> int:
56
+ return self.d_model // self.n_heads
57
+
58
+ def count_params(self) -> int:
59
+ """Returns total trainable parameter count (with tied embeddings)."""
60
+ embed = self.vocab_size * self.d_model
61
+ attn = 4 * self.d_model * self.d_model # Q, K, V, O
62
+ mlp = 3 * self.d_model * self.d_ff # gate, up, down
63
+ norms = 2 * self.d_model # pre-attn + pre-mlp
64
+ per_block = attn + mlp + norms
65
+ final_norm = self.d_model
66
+ return embed + self.n_layers * per_block + final_norm
67
+
68
+ def __repr__(self) -> str:
69
+ n = self.count_params()
70
+ return (
71
+ f"ModelConfig("
72
+ f"d={self.d_model}, h={self.n_heads}, l={self.n_layers}, "
73
+ f"ff={self.d_ff}, ctx={self.context_length}, "
74
+ f"params={n/1e6:.1f}M)"
75
+ )
76
+
77
+
78
+ # ------------------------------------------------------------------ #
79
+ # PRESET CONFIGS
80
+ # ------------------------------------------------------------------ #
81
+
82
+ SLLM_100M = ModelConfig(
83
+ vocab_size = 32_000,
84
+ context_length = 1024,
85
+ d_model = 768,
86
+ n_heads = 12,
87
+ n_layers = 12,
88
+ # d_ff auto = 2048
89
+ )
90
+
91
+ SLLM_150M = ModelConfig(
92
+ vocab_size = 32_000,
93
+ context_length = 1024,
94
+ d_model = 1024,
95
+ n_heads = 16,
96
+ n_layers = 9,
97
+ # d_ff auto = 2816
98
+ )
99
+
100
+
101
+ # ------------------------------------------------------------------ #
102
+ # QUICK CHECK
103
+ # ------------------------------------------------------------------ #
104
+
105
+ if __name__ == "__main__":
106
+ for cfg in [SLLM_100M, SLLM_150M]:
107
+ print(cfg)
108
+ print(f" head_dim : {cfg.head_dim}")
109
+ print(f" d_ff : {cfg.d_ff}")
110
+ print()
model/mlp.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/mlp.py
3
+
4
+ SwiGLU Feed-Forward Network — used in LLaMA, PaLM, Mistral, etc.
5
+
6
+ Standard FFN (GPT-2):
7
+ out = dropout(W2 * GELU(W1 * x))
8
+
9
+ SwiGLU FFN (LLaMA):
10
+ gate = W_gate * x # linear gate
11
+ up = W_up * x # linear up-proj
12
+ hidden = SiLU(gate) * up # element-wise gating (learned)
13
+ out = W_down * hidden # down-proj back to d_model
14
+
15
+ SiLU (Sigmoid Linear Unit):
16
+ SiLU(x) = x * sigmoid(x)
17
+
18
+ Why SwiGLU is better:
19
+ - The gating mechanism (SiLU(gate) * up) gives the model a learned
20
+ way to activate or suppress each hidden dimension independently.
21
+ - Empirically outperforms GELU/ReLU FFNs at the same parameter count.
22
+ - d_ff is set to int(2/3 * 4 * d_model) rounded to nearest 256.
23
+ This compensates for having 3 matrices instead of 2, keeping
24
+ total parameter count comparable to a standard 4x FFN.
25
+ """
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ from model.config import ModelConfig
32
+
33
+
34
+ class SwiGLU(nn.Module):
35
+
36
+ def __init__(self, config: ModelConfig):
37
+ super().__init__()
38
+
39
+ d_model = config.d_model
40
+ d_ff = config.d_ff
41
+
42
+ # Three weight matrices — no bias
43
+ self.gate = nn.Linear(d_model, d_ff, bias=config.bias) # gate projection
44
+ self.up = nn.Linear(d_model, d_ff, bias=config.bias) # up projection
45
+ self.down = nn.Linear(d_ff, d_model, bias=config.bias) # down projection
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Args:
50
+ x : (B, T, d_model)
51
+ Returns:
52
+ out : (B, T, d_model)
53
+ """
54
+ # SiLU = x * sigmoid(x) (also called swish)
55
+ # Element-wise gating: SiLU(gate) acts as a learned activation mask on up
56
+ return self.down(F.silu(self.gate(x)) * self.up(x))
57
+
58
+
59
+ # ------------------------------------------------------------------ #
60
+ # QUICK CHECK
61
+ # ------------------------------------------------------------------ #
62
+
63
+ if __name__ == "__main__":
64
+ from model.config import SLLM_100M
65
+
66
+ cfg = SLLM_100M
67
+ mlp = SwiGLU(cfg)
68
+
69
+ n_params = sum(p.numel() for p in mlp.parameters())
70
+ print(f"SwiGLU d_model={cfg.d_model} d_ff={cfg.d_ff}")
71
+ print(f" gate : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
72
+ print(f" up : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
73
+ print(f" down : {cfg.d_ff} x {cfg.d_model} = {cfg.d_ff * cfg.d_model:,}")
74
+ print(f" total MLP params : {n_params/1e6:.3f}M")
75
+
76
+ B, T = 2, 64
77
+ x = torch.randn(B, T, cfg.d_model)
78
+ out = mlp(x)
79
+
80
+ print(f"Input shape : {x.shape}")
81
+ print(f"Output shape : {out.shape}")
82
+ assert out.shape == x.shape, "Shape mismatch!"
83
+ print("PASS")
model/model.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/model.py
3
+
4
+ SLLM — Small Language Model (decoder-only Transformer).
5
+
6
+ Full architecture:
7
+ tokens (B, T)
8
+ -> Embedding (vocab_size -> d_model)
9
+ -> N x TransformerBlock (attention + FFN)
10
+ -> Final RMSNorm
11
+ -> LM Head (Linear d_model -> vocab_size) <- weight-TIED to embedding
12
+
13
+ Weight tying:
14
+ The embedding matrix and the LM head output matrix share the same weights.
15
+ - Halves memory for the embedding/output layers.
16
+ - A standard practice since GPT-2 (Press & Wolf, 2016).
17
+
18
+ Weight initialization:
19
+ - Embeddings: std=0.02 (GPT-2 convention)
20
+ - Linear layers: std=0.02
21
+ - Output projections (attn.o_proj, mlp.down): std = 0.02/sqrt(2*n_layers)
22
+ - Scaled down per GPT-2/NanoGPT: at initialization, the residual
23
+ stream grows as sqrt(n_layers), so we scale residual contributions down.
24
+
25
+ Forward:
26
+ Returns logits (B, T, vocab_size).
27
+ Loss is computed externally in the training loop for flexibility.
28
+ """
29
+
30
+ import math
31
+ import torch
32
+ import torch.nn as nn
33
+ from torch.utils.checkpoint import checkpoint
34
+
35
+ from model.config import ModelConfig
36
+ from model.norm import RMSNorm
37
+ from model.block import TransformerBlock
38
+
39
+
40
+ class SLLM(nn.Module):
41
+
42
+ def __init__(self, config: ModelConfig):
43
+ super().__init__()
44
+ self.config = config
45
+
46
+ # ---- Token embedding --------------------------------------- #
47
+ self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
48
+
49
+ # ---- Transformer blocks ------------------------------------ #
50
+ self.blocks = nn.ModuleList([
51
+ TransformerBlock(config) for _ in range(config.n_layers)
52
+ ])
53
+
54
+ # ---- Final norm -------------------------------------------- #
55
+ self.norm = RMSNorm(config.d_model)
56
+
57
+ # ---- LM Head ----------------------------------------------- #
58
+ # Linear: d_model -> vocab_size, no bias
59
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
60
+
61
+ # ---- Weight tying ------------------------------------------ #
62
+ # Share embedding weights with lm_head
63
+ self.lm_head.weight = self.token_emb.weight
64
+
65
+ # ---- Gradient checkpointing flag --------------------------- #
66
+ # Enabled via enable_gradient_checkpointing() to save VRAM
67
+ self._gradient_checkpointing = False
68
+
69
+ # ---- Initialize weights ------------------------------------ #
70
+ self.apply(self._init_weights)
71
+
72
+ def _init_weights(self, module: nn.Module):
73
+ """
74
+ Custom weight initialization.
75
+ - Normal(0, 0.02) for Linear and Embedding
76
+ - Scaled residual projections: std *= 1/sqrt(2 * n_layers)
77
+ """
78
+ if isinstance(module, nn.Linear):
79
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
80
+ if module.bias is not None:
81
+ nn.init.zeros_(module.bias)
82
+
83
+ elif isinstance(module, nn.Embedding):
84
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
85
+
86
+ # Scale down residual projections (attn output + mlp down)
87
+ # Accessed by name: o_proj and down
88
+ if isinstance(module, nn.Linear):
89
+ if getattr(module, '_is_residual', False):
90
+ scale = 0.02 / math.sqrt(2 * self.config.n_layers)
91
+ nn.init.normal_(module.weight, mean=0.0, std=scale)
92
+
93
+ def _mark_residual_projections(self):
94
+ """
95
+ Mark output projections so _init_weights can scale them.
96
+ Called after __init__ to tag the specific layers.
97
+ """
98
+ for block in self.blocks:
99
+ block.attn.o_proj._is_residual = True
100
+ block.mlp.down._is_residual = True
101
+ self.apply(self._init_weights)
102
+
103
+ def forward(
104
+ self,
105
+ input_ids: torch.Tensor,
106
+ targets: torch.Tensor = None,
107
+ ):
108
+ """
109
+ Args:
110
+ input_ids : (B, T) — integer token IDs
111
+ targets : (B, T) — optional, for loss computation
112
+
113
+ Returns:
114
+ logits : (B, T, vocab_size)
115
+ loss : scalar CrossEntropy loss if targets given, else None
116
+ """
117
+ B, T = input_ids.shape
118
+ assert T <= self.config.context_length, (
119
+ f"Sequence length {T} exceeds context_length {self.config.context_length}"
120
+ )
121
+
122
+ # ---- Embedding --------------------------------------------- #
123
+ x = self.token_emb(input_ids) # (B, T, d_model)
124
+
125
+ # ---- Transformer blocks ------------------------------------ #
126
+ for block in self.blocks:
127
+ if self._gradient_checkpointing and self.training:
128
+ # Recompute activations during backward to save VRAM
129
+ # use_reentrant=False is the modern recommended API
130
+ x = checkpoint(block, x, use_reentrant=False)
131
+ else:
132
+ x = block(x)
133
+
134
+ # ---- Final norm -------------------------------------------- #
135
+ x = self.norm(x) # (B, T, d_model)
136
+
137
+ # ---- LM Head ----------------------------------------------- #
138
+ logits = self.lm_head(x) # (B, T, vocab_size)
139
+
140
+ # ---- Loss -------------------------------------------------- #
141
+ loss = None
142
+ if targets is not None:
143
+ # Flatten for cross-entropy: (B*T, vocab_size) vs (B*T,)
144
+ loss = nn.functional.cross_entropy(
145
+ logits.view(-1, logits.size(-1)),
146
+ targets.view(-1),
147
+ )
148
+
149
+ return logits, loss
150
+
151
+ @torch.no_grad()
152
+ def generate(
153
+ self,
154
+ input_ids: torch.Tensor,
155
+ max_new_tokens: int,
156
+ temperature: float = 1.0,
157
+ top_k: int = None,
158
+ ) -> torch.Tensor:
159
+ """
160
+ Autoregressive text generation (greedy or top-k sampling).
161
+
162
+ Args:
163
+ input_ids : (B, T) prompt tokens
164
+ max_new_tokens : number of tokens to generate
165
+ temperature : softmax temperature (1.0 = neutral, <1 = sharper)
166
+ top_k : if set, sample from top-k tokens only
167
+
168
+ Returns:
169
+ (B, T + max_new_tokens) token IDs
170
+ """
171
+ self.eval()
172
+ for _ in range(max_new_tokens):
173
+
174
+ # Crop context if longer than max
175
+ ctx = input_ids
176
+ if ctx.shape[1] > self.config.context_length:
177
+ ctx = ctx[:, -self.config.context_length:]
178
+
179
+ # Forward pass — only need last logit
180
+ logits, _ = self(ctx)
181
+ logits = logits[:, -1, :] / temperature # (B, vocab_size)
182
+
183
+ # Optional top-k filtering
184
+ if top_k is not None:
185
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
186
+ logits[logits < v[:, [-1]]] = float('-inf')
187
+
188
+ # Sample from distribution
189
+ probs = torch.softmax(logits, dim=-1)
190
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
191
+
192
+ input_ids = torch.cat([input_ids, next_token], dim=1)
193
+
194
+ return input_ids
195
+
196
+ def enable_gradient_checkpointing(self):
197
+ """
198
+ Enables gradient checkpointing to reduce VRAM usage.
199
+ Recomputes activations during the backward pass instead of
200
+ storing them — trades ~30% more compute for ~40% less memory.
201
+ Essential for fitting 100M+ models on 4GB VRAM.
202
+ """
203
+ self._gradient_checkpointing = True
204
+
205
+ def count_params(self, non_embedding: bool = False) -> int:
206
+ """
207
+ Returns parameter count.
208
+
209
+ Args:
210
+ non_embedding: if True, exclude embedding parameters
211
+ (common in LLM reporting since embeddings scale
212
+ with vocab size and not model capacity)
213
+ """
214
+ total = sum(p.numel() for p in self.parameters())
215
+ if non_embedding:
216
+ total -= self.token_emb.weight.numel()
217
+ return total
218
+
219
+
220
+ # ------------------------------------------------------------------ #
221
+ # QUICK CHECK
222
+ # ------------------------------------------------------------------ #
223
+
224
+ if __name__ == "__main__":
225
+ from model.config import SLLM_100M, SLLM_150M
226
+
227
+ for name, cfg in [("SLLM-100M", SLLM_100M), ("SLLM-150M", SLLM_150M)]:
228
+ model = SLLM(cfg)
229
+
230
+ total = model.count_params()
231
+ non_emb = model.count_params(non_embedding=True)
232
+ print(f"{name}")
233
+ print(f" total params : {total/1e6:.1f}M")
234
+ print(f" non-embedding params : {non_emb/1e6:.1f}M")
235
+ print(f" embedding params : {(total-non_emb)/1e6:.1f}M")
236
+
237
+ # Forward pass check
238
+ B, T = 2, 64
239
+ ids = torch.randint(0, cfg.vocab_size, (B, T))
240
+ targets = torch.randint(0, cfg.vocab_size, (B, T))
241
+
242
+ logits, loss = model(ids, targets)
243
+ print(f" logits shape : {logits.shape}")
244
+ print(f" loss : {loss.item():.4f}")
245
+ print()
model/norm.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/norm.py
3
+
4
+ RMSNorm — Root Mean Square Layer Normalization.
5
+ Used in LLaMA-style transformers instead of standard LayerNorm.
6
+
7
+ Key difference from LayerNorm:
8
+ - No mean subtraction (centering)
9
+ - No bias term
10
+ - Only re-scales with a single learned gain vector (weight)
11
+ - ~40% faster in practice (no mean computation)
12
+
13
+ Formula:
14
+ RMSNorm(x) = x / RMS(x) * weight
15
+ where RMS(x) = sqrt( mean(x^2) + eps )
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class RMSNorm(nn.Module):
23
+
24
+ def __init__(self, d_model: int, eps: float = 1e-6):
25
+ """
26
+ Args:
27
+ d_model : hidden dimension (size of last axis of input)
28
+ eps : small constant for numerical stability
29
+ """
30
+ super().__init__()
31
+ self.eps = eps
32
+ self.weight = nn.Parameter(torch.ones(d_model)) # learnable gain
33
+
34
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
35
+ # x: (..., d_model)
36
+ # compute RMS along last dimension, keepdim for broadcasting
37
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ # cast to float32 for stable norm, then back to input dtype
41
+ output = self._norm(x.float()).type_as(x)
42
+ return output * self.weight
43
+
44
+
45
+ # ------------------------------------------------------------------ #
46
+ # QUICK CHECK
47
+ # ------------------------------------------------------------------ #
48
+
49
+ if __name__ == "__main__":
50
+ torch.manual_seed(0)
51
+ B, T, D = 2, 16, 768
52
+ x = torch.randn(B, T, D)
53
+ norm = RMSNorm(D)
54
+
55
+ out = norm(x)
56
+ print(f"Input shape : {x.shape}")
57
+ print(f"Output shape : {out.shape}")
58
+ print(f"Output dtype : {out.dtype}")
59
+
60
+ # Verify: each vector should be approximately unit RMS after norm (before weight)
61
+ rms_before = x.pow(2).mean(dim=-1).sqrt()
62
+ rms_after = out.pow(2).mean(dim=-1).sqrt()
63
+ print(f"RMS before norm : {rms_before.mean():.3f}")
64
+ print(f"RMS after norm : {rms_after.mean():.3f} (weight=1 so should be ~1.0)")
65
+ print("PASS" if torch.allclose(rms_after, torch.ones_like(rms_after), atol=1e-4) else "FAIL")
model/rope.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model/rope.py
3
+
4
+ Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer).
5
+ Used in LLaMA, Mistral, Gemma, etc.
6
+
7
+ Core idea:
8
+ Instead of adding position embeddings to token vectors, we ROTATE
9
+ the query and key vectors in attention using position-dependent angles.
10
+
11
+ - Relative positions are encoded implicitly via dot-product invariance.
12
+ - Works for any sequence length (extrapolates beyond training length).
13
+ - Only applied to Q and K, NOT V.
14
+
15
+ Implementation:
16
+ 1. Precompute cos/sin tables for all positions up to max_seq_len.
17
+ Shape: (max_seq_len, head_dim)
18
+
19
+ 2. At forward time, slice cos/sin to the current seq_len and
20
+ apply rotation to Q and K.
21
+
22
+ Rotation formula (pairs of dims):
23
+ Given a vector x with dims [x0, x1, x2, x3, ...]:
24
+ Pair each consecutive two dims: (x0,x1), (x2,x3), ...
25
+ Rotate each pair by angle theta_i * position:
26
+ [x0*cos - x1*sin, x0*sin + x1*cos, ...]
27
+
28
+ Equivalent implementation using rotate_half:
29
+ rotated = concat([-x_second_half, x_first_half]) # swapped halves
30
+ out = x * cos + rotated * sin
31
+ """
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ from typing import Tuple
36
+
37
+
38
+ def precompute_rope_freqs(
39
+ head_dim: int,
40
+ max_seq_len: int,
41
+ theta: float = 10_000.0,
42
+ device: torch.device = None,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ """
45
+ Precompute RoPE cosine and sine tables.
46
+
47
+ Args:
48
+ head_dim : dimension of each attention head (must be even)
49
+ max_seq_len : max sequence length to precompute
50
+ theta : RoPE base frequency (default 10_000, use 500_000 for long context)
51
+ device : torch device
52
+
53
+ Returns:
54
+ cos : (max_seq_len, head_dim)
55
+ sin : (max_seq_len, head_dim)
56
+ """
57
+ assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
58
+
59
+ # Inverse frequencies: shape (head_dim // 2,)
60
+ # inv_freq[i] = 1 / theta^(2i / head_dim)
61
+ i = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
62
+ inv_freq = 1.0 / (theta ** (i / head_dim))
63
+
64
+ # Position indices: shape (max_seq_len,)
65
+ positions = torch.arange(max_seq_len, dtype=torch.float32, device=device)
66
+
67
+ # Outer product: (max_seq_len, head_dim // 2)
68
+ freqs = torch.outer(positions, inv_freq)
69
+
70
+ # Duplicate along last dim to match head_dim:
71
+ # (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
72
+ # cos/sin applied to [x0,x1,x2,x3,...] as [theta0,theta0, theta1,theta1, ...]
73
+ freqs = torch.cat([freqs, freqs], dim=-1)
74
+
75
+ return freqs.cos(), freqs.sin()
76
+
77
+
78
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Rotates pairs of dimensions in the last axis.
81
+ Splits last dim in half, negates the second half, then swaps:
82
+ [x0..xN/2, xN/2..xN] -> [-xN/2..xN, x0..xN/2]
83
+
84
+ Args:
85
+ x: (..., head_dim)
86
+ Returns:
87
+ rotated: (..., head_dim)
88
+ """
89
+ half = x.shape[-1] // 2
90
+ x1 = x[..., :half] # first half
91
+ x2 = x[..., half:] # second half
92
+ return torch.cat([-x2, x1], dim=-1)
93
+
94
+
95
+ def apply_rope(
96
+ q: torch.Tensor,
97
+ k: torch.Tensor,
98
+ cos: torch.Tensor,
99
+ sin: torch.Tensor,
100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
101
+ """
102
+ Apply RoPE rotation to query and key tensors.
103
+
104
+ Args:
105
+ q : (B, n_heads, T, head_dim)
106
+ k : (B, n_heads, T, head_dim)
107
+ cos : (T, head_dim) - precomputed from precompute_rope_freqs
108
+ sin : (T, head_dim) - precomputed from precompute_rope_freqs
109
+
110
+ Returns:
111
+ q_rot, k_rot : same shapes as inputs
112
+ """
113
+ # Broadcast cos/sin from (T, head_dim) to (1, 1, T, head_dim)
114
+ cos = cos.unsqueeze(0).unsqueeze(0)
115
+ sin = sin.unsqueeze(0).unsqueeze(0)
116
+
117
+ q_rot = (q * cos) + (rotate_half(q) * sin)
118
+ k_rot = (k * cos) + (rotate_half(k) * sin)
119
+ return q_rot, k_rot
120
+
121
+
122
+ class RoPECache(nn.Module):
123
+ """
124
+ Module that holds the RoPE cos/sin cache as a buffer.
125
+ Not a learnable module — just stores precomputed freqs and moves them
126
+ to the right device automatically via register_buffer.
127
+ """
128
+
129
+ def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0):
130
+ super().__init__()
131
+ cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta)
132
+ # register_buffer: not a parameter, but moves with .to(device)
133
+ self.register_buffer("cos", cos, persistent=True)
134
+ self.register_buffer("sin", sin, persistent=True)
135
+
136
+ def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Slice cos/sin to current sequence length."""
138
+ return self.cos[:seq_len], self.sin[:seq_len]
139
+
140
+
141
+ # ------------------------------------------------------------------ #
142
+ # QUICK CHECK
143
+ # ------------------------------------------------------------------ #
144
+
145
+ if __name__ == "__main__":
146
+ torch.manual_seed(0)
147
+
148
+ B, n_heads, T, head_dim = 2, 12, 16, 64
149
+
150
+ cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024)
151
+ cos_T = cos[:T]
152
+ sin_T = sin[:T]
153
+
154
+ q = torch.randn(B, n_heads, T, head_dim)
155
+ k = torch.randn(B, n_heads, T, head_dim)
156
+
157
+ q_rot, k_rot = apply_rope(q, k, cos_T, sin_T)
158
+
159
+ print(f"q shape : {q.shape}")
160
+ print(f"q_rot shape : {q_rot.shape}")
161
+ print(f"k_rot shape : {k_rot.shape}")
162
+
163
+ # Verify: rotation should preserve norm (|x| = |Rx|)
164
+ q_norm = q.norm(dim=-1)
165
+ q_rot_norm = q_rot.norm(dim=-1)
166
+ print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}")
167
+
168
+ # Test RoPECache
169
+ cache = RoPECache(head_dim=64, max_seq_len=1024)
170
+ c, s = cache.get(T)
171
+ print(f"Cache cos shape: {c.shape}")
172
+ print("PASS")
model_explained.md ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Folder — Plain Language Explanation
2
+
3
+ The `model/` folder builds a **GPT-style decoder-only transformer** from scratch,
4
+ piece by piece. Each file is one component. Here's how they stack:
5
+
6
+ ```
7
+ tokens (integers)
8
+
9
+
10
+ ┌─────────────┐
11
+ │ Embedding │ config.py defines the shape of everything
12
+ └──────┬──────┘
13
+
14
+ ▼ ×N layers
15
+ ┌──────────────────────────────────────┐
16
+ │ TransformerBlock │ block.py
17
+ │ │
18
+ │ ┌──────────┐ ┌──────────────┐ │
19
+ │ │ RMSNorm │ │ RMSNorm │ │ norm.py
20
+ │ └────┬─────┘ └──────┬───────┘ │
21
+ │ │ │ │
22
+ │ ┌────▼─────┐ ┌──────▼───────┐ │
23
+ │ │Attention │ │ SwiGLU MLP │ │ attention.py / mlp.py
24
+ │ │ + RoPE │ │ │ │ rope.py
25
+ │ └────┬─────┘ └──────┬───────┘ │
26
+ │ │ (+residual) │ (+residual)│
27
+ └────────┼─────────────────┼───────────┘
28
+ │ │
29
+ └────────┬────────┘
30
+
31
+
32
+ ┌──────────┐
33
+ │ RMSNorm │ final norm
34
+ └────┬─────┘
35
+
36
+ ┌────▼─────┐
37
+ │ LM Head │ Linear → vocab_size logits
38
+ └──────────┘
39
+ ```
40
+
41
+ ---
42
+
43
+ ## 1. `config.py` — The Blueprint
44
+
45
+ **What it does:** Stores all the numbers that define the model size.
46
+ Nothing computes anything here — it's just a settings object.
47
+
48
+ ```python
49
+ @dataclass
50
+ class ModelConfig:
51
+ vocab_size = 32_000 # how many tokens exist
52
+ context_length = 1024 # max sequence length
53
+ d_model = 1024 # width of every vector throughout the model
54
+ n_heads = 16 # how many attention heads
55
+ n_layers = 9 # how many transformer blocks stacked
56
+ d_ff = 2816 # width of the MLP hidden layer (auto-computed)
57
+ ```
58
+
59
+ **Why these numbers?**
60
+ - `d_model` is the "resolution" of the model — bigger = more expressive but more memory
61
+ - `n_heads` splits each attention layer into parallel sub-attentions
62
+ - `head_dim = d_model / n_heads = 64` — each head sees 64-dim slices
63
+ - `d_ff` for SwiGLU = `round_256( 2/3 × 4 × d_model )` — compensates for having 3 matrices instead of 2
64
+
65
+ **Presets defined here:**
66
+ ```
67
+ SLLM_100M: d=768, h=12, l=12 → 109.5M params
68
+ SLLM_150M: d=1024, h=16, l=9 → 148.4M params
69
+ ```
70
+
71
+ ---
72
+
73
+ ## 2. `norm.py` — RMSNorm
74
+
75
+ **What it does:** Normalizes vectors so they don't explode or vanish during training.
76
+ Used before every attention and MLP layer.
77
+
78
+ **Standard LayerNorm (GPT-2):**
79
+ ```
80
+ 1. Compute mean of x
81
+ 2. Subtract mean (centering)
82
+ 3. Divide by std
83
+ 4. Scale by learned weight
84
+ 5. Add learned bias
85
+ ```
86
+
87
+ **RMSNorm (LLaMA / our model):**
88
+ ```
89
+ 1. Compute RMS = sqrt( mean(x²) ) ← no mean subtraction!
90
+ 2. Divide by RMS
91
+ 3. Scale by learned weight ← no bias!
92
+ ```
93
+
94
+ **Why simpler is better:**
95
+ - No mean subtraction → ~40% faster
96
+ - No bias → fewer parameters
97
+ - Works just as well in practice
98
+ - LLaMA, Mistral, Gemma all use it
99
+
100
+ ```python
101
+ # What it computes:
102
+ output = (x / sqrt(mean(x²) + 1e-6)) * weight
103
+ # ↑ normalize ↑ rescale with learned gain
104
+ ```
105
+
106
+ The `weight` starts at all-ones (no change at init) and is learned during training.
107
+
108
+ ---
109
+
110
+ ## 3. `rope.py` — Rotary Position Embedding (RoPE)
111
+
112
+ **The problem it solves:** Transformers have no built-in sense of position.
113
+ Without position encoding, `"cat sat on mat"` and `"mat on sat cat"` look identical.
114
+
115
+ **How older models solved it (GPT-2):**
116
+ Added a fixed learned vector to each token: `token[i] += position_embedding[i]`
117
+ Problem: can't generalize beyond the training length.
118
+
119
+ **What RoPE does instead:**
120
+ Instead of adding position info to token vectors, it **rotates** the Query and Key
121
+ vectors in attention by an angle that depends on their position.
122
+
123
+ ```
124
+ Token at position 3 → rotate Q and K by angle θ₃
125
+ Token at position 7 → rotate Q and K by angle θ₇
126
+ ```
127
+
128
+ When you compute attention score `Q·K`, the rotation cancels out in a way that
129
+ encodes *relative distance* between tokens, not absolute positions.
130
+
131
+ **Why this is better:**
132
+ - No extra parameters (pure math, no learned table)
133
+ - Works beyond training length (extrapolates)
134
+ - Used in LLaMA, Mistral, GPT-4 (likely), Gemma
135
+
136
+ **How the code works:**
137
+ ```python
138
+ # Step 1: precompute a table of cos/sin values for every position
139
+ cos, sin = precompute_rope_freqs(head_dim=64, max_seq_len=1024)
140
+ # cos/sin shape: (1024, 64)
141
+
142
+ # Step 2: at forward time, rotate Q and K
143
+ q_rotated = q * cos + rotate_half(q) * sin
144
+ k_rotated = k * cos + rotate_half(k) * sin
145
+
146
+ # rotate_half(x): splits x in half, negates second half, swaps
147
+ # [a, b, c, d] → [-c, -d, a, b]
148
+ ```
149
+
150
+ V (values) are **not** rotated — only Q and K get position encoding.
151
+
152
+ ---
153
+
154
+ ## 4. `attention.py` — Causal Self-Attention
155
+
156
+ **What it does:** Lets every token look at all *previous* tokens and decide
157
+ which ones are relevant to predict the next token.
158
+
159
+ **The full flow:**
160
+
161
+ ```
162
+ Input x: (Batch, Tokens, d_model)
163
+ e.g. (2, 1024, 1024)
164
+
165
+
166
+ QKV projection: one big Linear(d_model → 3×d_model)
167
+
168
+ ├─── Q: (2, 1024, 1024) — "what am I looking for?"
169
+ ├─── K: (2, 1024, 1024) — "what do I contain?"
170
+ └─── V: (2, 1024, 1024) — "what do I send if attended to?"
171
+
172
+
173
+ Reshape to heads: (2, 16_heads, 1024, 64_head_dim)
174
+
175
+
176
+ Apply RoPE to Q and K ← position encoding happens here
177
+
178
+
179
+ Scaled Dot-Product Attention:
180
+ scores = Q @ K^T / sqrt(64) # how much does each token attend to each other
181
+ mask = causal mask # can only look LEFT (past), not right (future)
182
+ weights = softmax(scores + mask)
183
+ out = weights @ V # weighted sum of values
184
+
185
+
186
+ Reshape back: (2, 1024, 1024)
187
+
188
+
189
+ Output projection: Linear(d_model → d_model)
190
+ ```
191
+
192
+ **Causal mask** — this is what makes it a *language model* (predicts next token):
193
+ ```
194
+ Position: 0 1 2 3
195
+ Token 0: [✓ ✗ ✗ ✗] can only see itself
196
+ Token 1: [✓ ✓ ✗ ✗] can see 0,1
197
+ Token 2: [✓ ✓ ✓ ✗] can see 0,1,2
198
+ Token 3: [✓ ✓ ✓ ✓] can see all
199
+ ```
200
+
201
+ **Flash Attention:** We use `F.scaled_dot_product_attention(..., is_causal=True)`
202
+ which is PyTorch 2.0's built-in Flash Attention — it never materializes the full
203
+ O(T²) attention matrix in memory. Much faster and uses far less VRAM.
204
+
205
+ ---
206
+
207
+ ## 5. `mlp.py` — SwiGLU Feed-Forward Network
208
+
209
+ **What it does:** After attention (which mixes *between* tokens), the MLP
210
+ transforms each token *independently* — it's where most of the model's
211
+ "knowledge" is stored.
212
+
213
+ **Standard MLP (GPT-2):**
214
+ ```python
215
+ out = W2 @ GELU(W1 @ x) # 2 matrices
216
+ ```
217
+
218
+ **SwiGLU (LLaMA / our model):**
219
+ ```python
220
+ gate = W_gate @ x # linear
221
+ up = W_up @ x # linear
222
+ hidden = SiLU(gate) * up # element-wise gate ← the key difference
223
+ out = W_down @ hidden # 3 matrices total
224
+ ```
225
+
226
+ **What is SiLU?**
227
+ ```
228
+ SiLU(x) = x × sigmoid(x)
229
+ ```
230
+ It's a smooth version of ReLU — never exactly zero, has a small negative region.
231
+
232
+ **Why gating matters:**
233
+ - `SiLU(gate)` acts as a learned on/off switch for each hidden dimension
234
+ - The model learns to activate only the neurons relevant to each input
235
+ - Empirically outperforms GELU at the same parameter count
236
+ - Used in LLaMA, PaLM, Mistral
237
+
238
+ **The d_ff formula:**
239
+ ```
240
+ d_ff = round_up_256( int(2/3 × 4 × d_model) )
241
+
242
+ For 150M: round_up_256( int(2/3 × 4 × 1024) ) = round_up_256(2730) = 2816
243
+ ```
244
+ The `2/3` factor compensates for having 3 matrices instead of 2 — keeps
245
+ total parameter count equal to a standard 4× FFN.
246
+
247
+ ---
248
+
249
+ ## 6. `block.py` — TransformerBlock
250
+
251
+ **What it does:** Wraps attention + MLP into one reusable block.
252
+ The model is just N copies of this block stacked.
253
+
254
+ ```python
255
+ def forward(x):
256
+ # Attention sub-layer
257
+ x = x + attention( rmsnorm(x) ) # pre-norm + residual
258
+
259
+ # MLP sub-layer
260
+ x = x + mlp( rmsnorm(x) ) # pre-norm + residual
261
+
262
+ return x
263
+ ```
264
+
265
+ **Two key ideas:**
266
+
267
+ **1. Pre-norm (normalize BEFORE the sublayer):**
268
+ ```
269
+ Pre-norm (LLaMA): x → norm → attention → + original x
270
+ Post-norm (GPT-2): x → attention → + original x → norm
271
+ ```
272
+ Pre-norm is more stable at large scale — gradients flow more cleanly.
273
+
274
+ **2. Residual connections (`x + sublayer(x)`):**
275
+ The output of each sublayer is *added* back to the input, not replacing it.
276
+ This means:
277
+ - Gradients can skip directly to earlier layers during backprop
278
+ - The model learns *corrections* to the input, not transformations from scratch
279
+ - Allows stacking many layers without vanishing gradients
280
+
281
+ ---
282
+
283
+ ## 7. `model.py` — SLLM (The Full Model)
284
+
285
+ **What it does:** Assembles everything into the complete language model.
286
+
287
+ ```
288
+ tokens: (B, T) ← integer IDs like [423, 1829, 55, ...]
289
+
290
+
291
+ token_emb: Embedding(32000 → 1024)
292
+ │ converts each integer to a 1024-dim vector
293
+
294
+ blocks[0]: TransformerBlock ─┐
295
+ blocks[1]: TransformerBlock │ 9 blocks for 150M
296
+ ... │
297
+ blocks[8]: TransformerBlock ─┘
298
+
299
+
300
+ norm: RMSNorm(1024) ← final stabilization
301
+
302
+
303
+ lm_head: Linear(1024 → 32000)
304
+ │ produces a score for each possible next token
305
+
306
+ logits: (B, T, 32000) ← unnormalized scores
307
+ ```
308
+
309
+ **Weight tying:**
310
+ The `token_emb` matrix and `lm_head` matrix **share the same weights**.
311
+ ```python
312
+ self.lm_head.weight = self.token_emb.weight
313
+ ```
314
+ - Same matrix used for: embedding lookup (input) AND output projection
315
+ - Saves 32M parameters (32000 × 1024)
316
+ - Works because: if token X has a similar embedding to the current hidden state,
317
+ it should also score highly as the next token prediction
318
+
319
+ **Loss computation:**
320
+ ```python
321
+ # Cross-entropy: at each position, predict the NEXT token
322
+ # Input: [The, cat, sat, on] → predicts [cat, sat, on, mat]
323
+ # targets = input shifted by 1
324
+ loss = cross_entropy(logits.view(-1, 32000), targets.view(-1))
325
+ ```
326
+
327
+ **Gradient checkpointing** (`enable_gradient_checkpointing()`):
328
+ Normally PyTorch saves all intermediate activations during forward pass to use
329
+ in backprop. For 9 layers with batch_size=2 and seq_len=1024, that's ~1.5GB.
330
+
331
+ With gradient checkpointing:
332
+ - Activations are **NOT saved** during forward pass
333
+ - During backward pass, they are **recomputed on-the-fly**
334
+ - Result: ~40% less VRAM, ~30% slower training
335
+ - Essential for fitting 150M on a 4GB GPU
336
+
337
+ **Weight initialization:**
338
+ ```python
339
+ # All Linear and Embedding weights: Normal(mean=0, std=0.02)
340
+ # Residual projections (o_proj, mlp.down): scaled down by 1/sqrt(2 × n_layers)
341
+ ```
342
+ The residual scaling prevents the residual stream from growing too large
343
+ at initialization when many layers add to it.
344
+
345
+ ---
346
+
347
+ ## How it all fits together — One forward pass
348
+
349
+ ```
350
+ "The cat sat" → tokenizer → [423, 1829, 55]
351
+
352
+ token_emb: [423]→[0.1,-0.3,...] (1024 floats)
353
+ [1829]→[0.8, 0.2,...] (1024 floats)
354
+ [55] →[-0.1,0.4,...] (1024 floats)
355
+
356
+ Block 0:
357
+ norm → Q,K,V projections → RoPE rotation → Flash Attention → output proj → + residual
358
+ norm → gate,up projections → SiLU(gate)*up → down proj → + residual
359
+
360
+ Block 1..8: same
361
+
362
+ Final norm → LM head → 32000 scores per position
363
+
364
+ softmax → probabilities → sample next token
365
+ ```
366
+
367
+ **Total parameters (150M):**
368
+ ```
369
+ Embedding: 32000 × 1024 = 32.8M
370
+ Per block: attn(4.2M) + mlp(8.6M) + norms(~0M) = 12.85M
371
+ 9 blocks: 9 × 12.85M = 115.6M
372
+ Final norm: 1024 = ~0M
373
+ LM head: TIED to embedding = 0M (reuses same weights)
374
+ ─────────────────────────────────────────
375
+ TOTAL: 148.4M params
376
+ ```
plot_training.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ plot_training.py — Training Visualization Dashboard
3
+
4
+ Reads train_log.jsonl and renders a clean, dark-mode training dashboard.
5
+
6
+ Usage:
7
+ # Static plot of completed/current run
8
+ python plot_training.py --run_dir runs/run_001
9
+
10
+ # Live mode: refresh every 5 seconds while training runs
11
+ python plot_training.py --run_dir runs/run_001 --live
12
+
13
+ # Compare multiple runs
14
+ python plot_training.py --run_dir runs/run_001 runs/run_002
15
+
16
+ Dashboard panels:
17
+ 1. Training Loss (raw + EMA smoothed)
18
+ 2. Validation Loss (if available)
19
+ 3. Learning Rate schedule
20
+ 4. Tokens / second (throughput)
21
+ 5. VRAM usage (if logged)
22
+ 6. Gradient norm (if logged)
23
+ """
24
+
25
+ import os
26
+ import sys
27
+ import json
28
+ import time
29
+ import argparse
30
+ from pathlib import Path
31
+
32
+ import matplotlib
33
+ import matplotlib.pyplot as plt
34
+ import matplotlib.gridspec as gridspec
35
+ import matplotlib.ticker as ticker
36
+ import numpy as np
37
+
38
+
39
+ # ------------------------------------------------------------------ #
40
+ # STYLE
41
+ # ------------------------------------------------------------------ #
42
+
43
+ DARK_BG = "#0d1117"
44
+ PANEL_BG = "#161b22"
45
+ GRID_COLOR = "#21262d"
46
+ TEXT_COLOR = "#c9d1d9"
47
+ MUTED_COLOR = "#6e7681"
48
+ ACCENT_BLUE = "#58a6ff"
49
+ ACCENT_GREEN = "#3fb950"
50
+ ACCENT_ORANGE= "#d29922"
51
+ ACCENT_RED = "#f85149"
52
+ ACCENT_PURPLE= "#bc8cff"
53
+ ACCENT_TEAL = "#39d353"
54
+
55
+ matplotlib.rcParams.update({
56
+ "figure.facecolor": DARK_BG,
57
+ "axes.facecolor": PANEL_BG,
58
+ "axes.edgecolor": GRID_COLOR,
59
+ "axes.labelcolor": TEXT_COLOR,
60
+ "axes.titlecolor": TEXT_COLOR,
61
+ "xtick.color": MUTED_COLOR,
62
+ "ytick.color": MUTED_COLOR,
63
+ "grid.color": GRID_COLOR,
64
+ "grid.linestyle": "--",
65
+ "grid.linewidth": 0.5,
66
+ "grid.alpha": 0.7,
67
+ "legend.facecolor": PANEL_BG,
68
+ "legend.edgecolor": GRID_COLOR,
69
+ "legend.labelcolor": TEXT_COLOR,
70
+ "text.color": TEXT_COLOR,
71
+ "font.family": "DejaVu Sans",
72
+ "font.size": 10,
73
+ "axes.titlesize": 11,
74
+ "axes.labelsize": 10,
75
+ })
76
+
77
+
78
+ # ------------------------------------------------------------------ #
79
+ # DATA LOADING
80
+ # ------------------------------------------------------------------ #
81
+
82
+ def load_log(log_path: str) -> dict:
83
+ """
84
+ Loads train_log.jsonl and returns separate arrays for each metric.
85
+ Returns dict of metric_name -> list of values, aligned by step.
86
+ """
87
+ train_steps = []
88
+ train_loss = []
89
+ val_steps = []
90
+ val_loss = []
91
+ lr_steps = []
92
+ lr_vals = []
93
+ tok_steps = []
94
+ tok_vals = []
95
+ vram_steps = []
96
+ vram_vals = []
97
+ grad_steps = []
98
+ grad_vals = []
99
+
100
+ if not os.path.exists(log_path):
101
+ return None
102
+
103
+ with open(log_path, "r") as f:
104
+ for line in f:
105
+ line = line.strip()
106
+ if not line:
107
+ continue
108
+ try:
109
+ entry = json.loads(line)
110
+ except json.JSONDecodeError:
111
+ continue
112
+
113
+ step = entry.get("step")
114
+ if step is None:
115
+ continue
116
+
117
+ if "loss" in entry:
118
+ train_steps.append(step)
119
+ train_loss.append(entry["loss"])
120
+
121
+ if "val_loss" in entry:
122
+ val_steps.append(step)
123
+ val_loss.append(entry["val_loss"])
124
+
125
+ if "lr" in entry:
126
+ lr_steps.append(step)
127
+ lr_vals.append(entry["lr"])
128
+
129
+ if "tok_per_sec" in entry:
130
+ tok_steps.append(step)
131
+ tok_vals.append(entry["tok_per_sec"])
132
+
133
+ if "vram_gb" in entry:
134
+ vram_steps.append(step)
135
+ vram_vals.append(entry["vram_gb"])
136
+
137
+ if "grad_norm" in entry and entry["grad_norm"] is not None:
138
+ grad_steps.append(step)
139
+ grad_vals.append(entry["grad_norm"])
140
+
141
+ return {
142
+ "train": (train_steps, train_loss),
143
+ "val": (val_steps, val_loss),
144
+ "lr": (lr_steps, lr_vals),
145
+ "tok": (tok_steps, tok_vals),
146
+ "vram": (vram_steps, vram_vals),
147
+ "grad": (grad_steps, grad_vals),
148
+ }
149
+
150
+
151
+ def ema_smooth(values: list, alpha: float = 0.9) -> list:
152
+ """Exponential moving average smoothing."""
153
+ if not values:
154
+ return values
155
+ smoothed = [values[0]]
156
+ for v in values[1:]:
157
+ smoothed.append(alpha * smoothed[-1] + (1 - alpha) * v)
158
+ return smoothed
159
+
160
+
161
+ # ------------------------------------------------------------------ #
162
+ # PLOTTING
163
+ # ------------------------------------------------------------------ #
164
+
165
+ def make_dashboard(data_dict: dict, run_names: list, save_path: str = None):
166
+ """
167
+ Renders a multi-panel training dashboard.
168
+
169
+ Args:
170
+ data_dict : dict of run_name -> metrics dict
171
+ run_names : list of run display names
172
+ save_path : if set, saves figure to this path instead of showing
173
+ """
174
+ fig = plt.figure(figsize=(16, 10), facecolor=DARK_BG)
175
+ fig.suptitle(
176
+ "SLLM Training Dashboard",
177
+ fontsize=16,
178
+ fontweight="bold",
179
+ color=TEXT_COLOR,
180
+ y=0.98,
181
+ )
182
+
183
+ # 3x2 grid of panels
184
+ gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.45, wspace=0.3,
185
+ left=0.06, right=0.97, top=0.93, bottom=0.06)
186
+
187
+ ax_loss = fig.add_subplot(gs[0, 0])
188
+ ax_val = fig.add_subplot(gs[0, 1])
189
+ ax_lr = fig.add_subplot(gs[1, 0])
190
+ ax_tok = fig.add_subplot(gs[1, 1])
191
+ ax_vram = fig.add_subplot(gs[2, 0])
192
+ ax_grad = fig.add_subplot(gs[2, 1])
193
+
194
+ colors = [ACCENT_BLUE, ACCENT_GREEN, ACCENT_ORANGE, ACCENT_PURPLE]
195
+
196
+ has_val = False
197
+ has_vram = False
198
+ has_grad = False
199
+
200
+ for idx, (run_name, data) in enumerate(data_dict.items()):
201
+ if data is None:
202
+ continue
203
+ color = colors[idx % len(colors)]
204
+
205
+ # --- Train loss ------------------------------------------ #
206
+ steps, loss = data["train"]
207
+ if steps:
208
+ smoothed = ema_smooth(loss, alpha=0.92)
209
+ ax_loss.plot(steps, loss, color=color, alpha=0.25, linewidth=0.8)
210
+ ax_loss.plot(steps, smoothed, color=color, alpha=1.0, linewidth=1.8,
211
+ label=run_name)
212
+ # Annotate final loss
213
+ ax_loss.annotate(
214
+ f"{smoothed[-1]:.4f}",
215
+ xy=(steps[-1], smoothed[-1]),
216
+ xytext=(5, 0), textcoords="offset points",
217
+ color=color, fontsize=8, va="center",
218
+ )
219
+
220
+ # --- Val loss -------------------------------------------- #
221
+ vsteps, vloss = data["val"]
222
+ if vsteps:
223
+ has_val = True
224
+ ax_val.plot(vsteps, vloss, color=color, linewidth=2, marker="o",
225
+ markersize=4, label=run_name)
226
+ ax_val.annotate(
227
+ f"{vloss[-1]:.4f}",
228
+ xy=(vsteps[-1], vloss[-1]),
229
+ xytext=(5, 0), textcoords="offset points",
230
+ color=color, fontsize=8, va="center",
231
+ )
232
+
233
+ # --- LR -------------------------------------------------- #
234
+ lsteps, lvals = data["lr"]
235
+ if lsteps:
236
+ ax_lr.plot(lsteps, lvals, color=color, linewidth=1.5, label=run_name)
237
+
238
+ # --- Throughput ------------------------------------------ #
239
+ tsteps, tvals = data["tok"]
240
+ if tsteps:
241
+ avg_tok = np.mean(tvals)
242
+ ax_tok.plot(tsteps, tvals, color=color, alpha=0.6, linewidth=1.0)
243
+ ax_tok.axhline(avg_tok, color=color, linewidth=1.5, linestyle="--",
244
+ label=f"{run_name} (avg {avg_tok:.0f})")
245
+
246
+ # --- VRAM ------------------------------------------------- #
247
+ vsteps2, vvals = data["vram"]
248
+ if vsteps2:
249
+ has_vram = True
250
+ ax_vram.plot(vsteps2, vvals, color=color, linewidth=1.5, label=run_name)
251
+
252
+ # --- Grad norm ------------------------------------------- #
253
+ gsteps, gvals = data["grad"]
254
+ if gsteps:
255
+ has_grad = True
256
+ smoothed_g = ema_smooth(gvals, alpha=0.85)
257
+ ax_grad.plot(gsteps, gvals, color=color, alpha=0.2, linewidth=0.8)
258
+ ax_grad.plot(gsteps, smoothed_g, color=color, linewidth=1.5, label=run_name)
259
+
260
+ # --- Style panels -------------------------------------------- #
261
+ def _style(ax, title, xlabel, ylabel, legend=True):
262
+ ax.set_title(title, fontweight="bold", pad=8)
263
+ ax.set_xlabel(xlabel)
264
+ ax.set_ylabel(ylabel)
265
+ ax.grid(True)
266
+ ax.tick_params(which="both", length=3)
267
+ if legend and ax.get_legend_handles_labels()[0]:
268
+ ax.legend(fontsize=8, loc="upper right")
269
+
270
+ _style(ax_loss, "Training Loss (EMA smoothed)", "Step", "Loss")
271
+ _style(ax_lr, "Learning Rate Schedule", "Step", "LR")
272
+ _style(ax_tok, "Throughput", "Step", "Tokens / sec")
273
+
274
+ if has_val:
275
+ _style(ax_val, "Validation Loss", "Step", "Val Loss")
276
+ else:
277
+ ax_val.text(0.5, 0.5, "No validation data yet",
278
+ ha="center", va="center", transform=ax_val.transAxes,
279
+ color=MUTED_COLOR, fontsize=11)
280
+ ax_val.set_title("Validation Loss", fontweight="bold", pad=8)
281
+
282
+ if has_vram:
283
+ _style(ax_vram, "VRAM Usage", "Step", "GB")
284
+ ax_vram.axhline(4.0, color=ACCENT_RED, linewidth=1, linestyle=":", alpha=0.6, label="4 GB limit")
285
+ ax_vram.legend(fontsize=8)
286
+ else:
287
+ ax_vram.text(0.5, 0.5, "No VRAM data\n(requires CUDA)", ha="center", va="center",
288
+ transform=ax_vram.transAxes, color=MUTED_COLOR, fontsize=11)
289
+ ax_vram.set_title("VRAM Usage", fontweight="bold", pad=8)
290
+
291
+ if has_grad:
292
+ _style(ax_grad, "Gradient Norm (EMA smoothed)", "Step", "Norm")
293
+ else:
294
+ ax_grad.text(0.5, 0.5, "No gradient norm data", ha="center", va="center",
295
+ transform=ax_grad.transAxes, color=MUTED_COLOR, fontsize=11)
296
+ ax_grad.set_title("Gradient Norm", fontweight="bold", pad=8)
297
+
298
+ # LR scientific notation
299
+ ax_lr.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
300
+ ax_lr.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
301
+
302
+ if save_path:
303
+ plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=DARK_BG)
304
+ print(f"[PLOT] Saved to {save_path}")
305
+ else:
306
+ plt.show()
307
+
308
+
309
+ # ------------------------------------------------------------------ #
310
+ # CLI
311
+ # ------------------------------------------------------------------ #
312
+
313
+ def parse_args():
314
+ p = argparse.ArgumentParser(description="SLLM Training Dashboard")
315
+ p.add_argument("--run_dir", nargs="+", default=["runs/run_001"],
316
+ help="One or more run directories to plot")
317
+ p.add_argument("--live", action="store_true",
318
+ help="Refresh plot every --interval seconds (live mode)")
319
+ p.add_argument("--interval", type=int, default=10,
320
+ help="Refresh interval in seconds for --live mode")
321
+ p.add_argument("--save", type=str, default=None,
322
+ help="Save plot to this path instead of showing interactively")
323
+ return p.parse_args()
324
+
325
+
326
+ def main():
327
+ args = parse_args()
328
+
329
+ run_dirs = args.run_dir
330
+ run_names = [Path(d).name for d in run_dirs]
331
+
332
+ def _reload_and_plot():
333
+ data_dict = {}
334
+ for name, run_dir in zip(run_names, run_dirs):
335
+ log_path = os.path.join(run_dir, "train_log.jsonl")
336
+ data = load_log(log_path)
337
+ if data is None:
338
+ print(f"[WARN] No log found at: {log_path}")
339
+ data_dict[name] = data
340
+
341
+ # Check if any data was loaded
342
+ total_steps = sum(
343
+ len(d["train"][0]) for d in data_dict.values() if d
344
+ )
345
+
346
+ if total_steps == 0:
347
+ print("[PLOT] No data logged yet. Waiting...")
348
+ return
349
+
350
+ steps_info = {n: len(d["train"][0]) for n, d in data_dict.items() if d}
351
+ print(f"[PLOT] Plotting {steps_info} train steps")
352
+
353
+ plt.close("all")
354
+ make_dashboard(data_dict, run_names, save_path=args.save)
355
+
356
+ if args.live:
357
+ print(f"[LIVE] Refreshing every {args.interval}s (Ctrl+C to stop)")
358
+ matplotlib.use("TkAgg") if sys.platform == "win32" else None
359
+ try:
360
+ while True:
361
+ _reload_and_plot()
362
+ plt.pause(args.interval)
363
+ except KeyboardInterrupt:
364
+ print("\n[LIVE] Stopped.")
365
+ else:
366
+ _reload_and_plot()
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt — SLLM project
2
+ # Install into the 'pytorch' conda env:
3
+ # conda run -n pytorch pip install -r requirements.txt
4
+
5
+ # Core ML
6
+ torch>=2.3.0
7
+ torchvision
8
+
9
+ # Data
10
+ datasets>=2.14.0 # HuggingFace datasets (streaming)
11
+ tokenizers>=0.15.0 # fast BPE tokenizer
12
+ transformers>=4.40.0 # PreTrainedTokenizerFast
13
+
14
+ # Utilities
15
+ numpy>=1.26.0
16
+ tqdm
17
+ matplotlib # training plots
18
+ rich # pretty terminal output (optional)
19
+
20
+ # Dev
run.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Start training (first run):
2
+
3
+ python train.py ^
4
+ --config 150M ^
5
+ --data_dir tokenizer/data ^
6
+ --batch_size 2 ^
7
+ --grad_accum 16 ^
8
+ --grad_checkpoint ^
9
+ --dtype bf16 ^
10
+ --max_steps 5000 ^
11
+ --run_dir runs/sllm_150m ^
12
+ --log_every 10 ^
13
+ --save_every 500 ^
14
+ --val_every 500 ^
15
+ --val_steps 20 ^
16
+ --warmup_steps 200
17
+
18
+
19
+
20
+ Resume from where you stopped:
21
+
22
+ python train.py --resume --data_dir tokenizer/data --batch_size 2 --grad_accum 16 --grad_checkpoint --dtype bf16 --extra_steps 5000 --run_dir runs/sllm_150m --log_every 10 --save_every 500 --val_every 500 --val_steps 20 --warmup_steps 200
23
+
24
+
25
+
26
+ Plot while training (in a second terminal):
27
+ conda activate pytorch
28
+ cd c:\geetesh\aimldl\projects\sllm
29
+ python plot_training.py --run_dir runs/sllm_150m --live --interval 30
30
+
31
+
32
+ python finetune/prepare_data.py
33
+ python finetune/sft_train.py --base_ckpt runs/sllm_150m/ckpt_0011500.pt --run_dir runs/sllm_150m_chat --max_steps 2500 --batch_size 4 --grad_accum 8 --grad_checkpoint
34
+ python finetune/chat.py --run_dir runs/sllm_150m_chat
test_chatmodel.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_chatmodel.py — Interactive CLI chat and evaluation for the fine-tuned SLLM chat model.
3
+
4
+ Usage:
5
+ python test_chatmodel.py --run_dir runs/sllm_150m_chat
6
+ python test_chatmodel.py --run_dir runs/sllm_150m_chat --mode sample
7
+
8
+ In interactive mode:
9
+ Type your message and press Enter.
10
+ Special commands:
11
+ /reset Clear conversation history
12
+ /system <text> Change the system prompt
13
+ /quit Exit the chat
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import argparse
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.amp import autocast
24
+ from transformers import PreTrainedTokenizerFast
25
+
26
+ # Add project root to path
27
+ PROJECT_ROOT = Path(__file__).resolve().parent
28
+ sys.path.insert(0, str(PROJECT_ROOT))
29
+
30
+ from model.config import SLLM_150M
31
+ from model.model import SLLM
32
+
33
+ DEFAULT_SYSTEM = "You are a helpful, concise assistant."
34
+ DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat")
35
+
36
+
37
+ # ------------------------------------------------------------------ #
38
+ # HELPERS
39
+ # ------------------------------------------------------------------ #
40
+
41
+ def find_latest_ckpt(run_dir: str) -> str:
42
+ """Returns path to the most recent SFT or base checkpoint in run_dir."""
43
+ if not os.path.isdir(run_dir):
44
+ raise FileNotFoundError(f"Run directory '{run_dir}' does not exist.")
45
+
46
+ ckpts = sorted([
47
+ f for f in os.listdir(run_dir)
48
+ if (f.startswith("ckpt_sft_") or f.startswith("ckpt_")) and f.endswith(".pt")
49
+ ])
50
+ if not ckpts:
51
+ raise FileNotFoundError(
52
+ f"No checkpoints found in '{run_dir}'.\n"
53
+ f"Please ensure you have trained the model or point to the correct folder."
54
+ )
55
+ return os.path.join(run_dir, ckpts[-1])
56
+
57
+
58
+ def resize_token_embeddings(model: SLLM, new_vocab_size: int):
59
+ """Resizes the token embeddings matrix to support added special tokens."""
60
+ old_size = model.config.vocab_size
61
+ if new_vocab_size == old_size:
62
+ return
63
+ d_model = model.config.d_model
64
+ device = model.token_emb.weight.device
65
+ dtype = model.token_emb.weight.dtype
66
+ old_weight = model.token_emb.weight.data.clone()
67
+ mean_vec = old_weight.mean(dim=0)
68
+
69
+ new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
70
+ new_weight[:old_size] = old_weight
71
+ new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
72
+
73
+ new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
74
+ new_emb.weight.data = new_weight
75
+ model.token_emb = new_emb
76
+ model.lm_head.weight = model.token_emb.weight
77
+ model.config.vocab_size = new_vocab_size
78
+ print(f" [INFO] Resized model vocab embedding from {old_size:,} to {new_vocab_size:,}")
79
+
80
+
81
+ def load_model_and_tokenizer(run_dir: str, device: torch.device):
82
+ """Loads tokenizer and the latest model checkpoint."""
83
+ # ---- Tokenizer ------------------------------------------------- #
84
+ # Look in finetune/data or tokenizer/fineweb_edu_tokenizer
85
+ data_tok_dir = PROJECT_ROOT / "finetune" / "data"
86
+ base_tok_dir = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
87
+
88
+ if os.path.exists(data_tok_dir / "tokenizer.json"):
89
+ tok_path = str(data_tok_dir)
90
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
91
+ print(f" Tokenizer: Loaded extended tokenizer from '{tok_path}'")
92
+ elif os.path.exists(base_tok_dir):
93
+ tok_path = str(base_tok_dir)
94
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
95
+ tokenizer.add_special_tokens({
96
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
97
+ })
98
+ print(f" Tokenizer: Loaded base tokenizer from '{tok_path}' and added ChatML tokens")
99
+ else:
100
+ raise FileNotFoundError("Could not find a tokenizer directory.")
101
+
102
+ # ---- Checkpoint ------------------------------------------------ #
103
+ try:
104
+ ckpt_path = find_latest_ckpt(run_dir)
105
+ except FileNotFoundError:
106
+ # Fall back to base pretraining checkpoint if SFT directory is empty
107
+ print(f" [WARN] No checkpoint found in '{run_dir}'. Trying pretraining base run...")
108
+ base_dir = PROJECT_ROOT / "runs" / "sllm_150m"
109
+ ckpt_path = find_latest_ckpt(str(base_dir))
110
+
111
+ print(f" Loading checkpoint: {ckpt_path}")
112
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
113
+
114
+ # ---- Model ----------------------------------------------------- #
115
+ model = SLLM(SLLM_150M).to(device)
116
+ saved_vocab = ckpt.get("vocab_size", len(tokenizer))
117
+ resize_token_embeddings(model, saved_vocab)
118
+
119
+ model.load_state_dict(ckpt["model_state_dict"])
120
+ model.eval()
121
+
122
+ step = ckpt.get("step", "?")
123
+ loss = ckpt.get("loss", float("nan"))
124
+ return model, tokenizer, ckpt_path, step, loss
125
+
126
+
127
+ # ------------------------------------------------------------------ #
128
+ # PROMPT BUILDING
129
+ # ------------------------------------------------------------------ #
130
+
131
+ def build_prompt(history: list[dict], system_prompt: str,
132
+ tokenizer: PreTrainedTokenizerFast) -> torch.Tensor:
133
+ """Formats conversation history as ChatML and tokenizes it."""
134
+ text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
135
+ for turn in history:
136
+ text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
137
+ # Prime the model to respond as assistant
138
+ text += "<|im_start|>assistant\n"
139
+
140
+ ids = tokenizer.encode(text, add_special_tokens=False)
141
+ return torch.tensor([ids], dtype=torch.long)
142
+
143
+
144
+ # ------------------------------------------------------------------ #
145
+ # GENERATION
146
+ # ------------------------------------------------------------------ #
147
+
148
+ @torch.no_grad()
149
+ def generate_response(
150
+ model: SLLM,
151
+ input_ids: torch.Tensor,
152
+ tokenizer: PreTrainedTokenizerFast,
153
+ max_new_tokens: int = 200,
154
+ temperature: float = 0.7,
155
+ top_k: int = 40,
156
+ top_p: float = 0.9,
157
+ device: torch.device = None,
158
+ dtype_torch: torch.dtype = torch.float32,
159
+ use_amp: bool = False,
160
+ ) -> str:
161
+ """Generates a response from the model using top-k/top-p sampling."""
162
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
163
+ eos_id = tokenizer.eos_token_id
164
+
165
+ ids = input_ids.to(device)
166
+ generated = []
167
+
168
+ for _ in range(max_new_tokens):
169
+ # Crop context to model window
170
+ ctx = ids if ids.shape[1] <= model.config.context_length \
171
+ else ids[:, -model.config.context_length:]
172
+
173
+ with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
174
+ logits, _ = model(ctx) # (1, T, V)
175
+
176
+ # Pull last token logits
177
+ logits = logits[:, -1, :]
178
+
179
+ if temperature == 0.0:
180
+ # Greedy
181
+ next_token = logits.argmax(dim=-1, keepdim=True)
182
+ else:
183
+ logits = logits / max(temperature, 1e-8)
184
+
185
+ # Top-k filtering
186
+ if top_k and top_k > 0:
187
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
188
+ logits[logits < v[:, [-1]]] = float("-inf")
189
+
190
+ # Top-p (nucleus) filtering
191
+ if top_p < 1.0:
192
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
193
+ cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
194
+ sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
195
+ logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
196
+
197
+ probs = torch.softmax(logits, dim=-1)
198
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
199
+
200
+ tok_id = next_token.item()
201
+
202
+ # Stop if end of message or end of stream token is generated
203
+ if tok_id == im_end_id or tok_id == eos_id:
204
+ break
205
+
206
+ generated.append(tok_id)
207
+ ids = torch.cat([ids, next_token], dim=1)
208
+
209
+ return tokenizer.decode(generated, skip_special_tokens=True).strip()
210
+
211
+
212
+ # ------------------------------------------------------------------ #
213
+ # MODES
214
+ # ------------------------------------------------------------------ #
215
+
216
+ def run_interactive(model, tokenizer, device, dtype_torch, use_amp, args):
217
+ system_prompt = args.system
218
+ history = []
219
+
220
+ print("\n" + "=" * 60)
221
+ print(" CHAT MODE (Interactive)")
222
+ print("=" * 60)
223
+ print(f" System prompt : {system_prompt}")
224
+ print(" Commands : /reset to clear memory | /system <prompt> | /quit to exit")
225
+ print("─" * 60 + "\n")
226
+
227
+ while True:
228
+ try:
229
+ user_input = input("You: ").strip()
230
+ except (EOFError, KeyboardInterrupt):
231
+ print("\nBye!")
232
+ break
233
+
234
+ if not user_input:
235
+ continue
236
+
237
+ # Check for commands
238
+ if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
239
+ print("Bye!")
240
+ break
241
+
242
+ if user_input.lower() == "/reset":
243
+ history = []
244
+ print(" [Conversation history reset]\n")
245
+ continue
246
+
247
+ if user_input.lower().startswith("/system "):
248
+ new_sys = user_input[8:].strip()
249
+ if new_sys:
250
+ system_prompt = new_sys
251
+ history = []
252
+ print(f" [System prompt updated. History cleared.]\n")
253
+ continue
254
+
255
+ # Add to history and build ChatML prompt
256
+ history.append({"role": "user", "content": user_input})
257
+ input_ids = build_prompt(history, system_prompt, tokenizer)
258
+
259
+ # Trim conversation window if it exceeds model context length
260
+ while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10:
261
+ if len(history) > 2:
262
+ history = history[2:] # Remove oldest user + assistant turn
263
+ input_ids = build_prompt(history, system_prompt, tokenizer)
264
+ else:
265
+ break
266
+
267
+ print("SLLM: ", end="", flush=True)
268
+ response = generate_response(
269
+ model, input_ids, tokenizer,
270
+ max_new_tokens=args.max_new_tokens,
271
+ temperature=args.temperature,
272
+ top_k=args.top_k,
273
+ top_p=args.top_p,
274
+ device=device,
275
+ dtype_torch=dtype_torch,
276
+ use_amp=use_amp,
277
+ )
278
+ print(response + "\n")
279
+ history.append({"role": "assistant", "content": response})
280
+
281
+
282
+ def run_sample(model, tokenizer, device, dtype_torch, use_amp, args):
283
+ sample_prompts = [
284
+ "Hello! Who are you?",
285
+ "What is the capital of France?",
286
+ "Write a quick, 3-line poem about a small robot learning to speak.",
287
+ "Explain gravity in one simple sentence.",
288
+ ]
289
+
290
+ print("\n" + "=" * 60)
291
+ print(" SAMPLE EVALUATION MODE")
292
+ print("=" * 60)
293
+ print(f" System prompt: {args.system}")
294
+ print("─" * 60)
295
+
296
+ for prompt in sample_prompts:
297
+ print(f"\n[PROMPT] : {prompt}")
298
+ history = [{"role": "user", "content": prompt}]
299
+ input_ids = build_prompt(history, args.system, tokenizer)
300
+
301
+ print("[SLLM] : ", end="", flush=True)
302
+ response = generate_response(
303
+ model, input_ids, tokenizer,
304
+ max_new_tokens=args.max_new_tokens,
305
+ temperature=args.temperature,
306
+ top_k=args.top_k,
307
+ top_p=args.top_p,
308
+ device=device,
309
+ dtype_torch=dtype_torch,
310
+ use_amp=use_amp,
311
+ )
312
+ print(response)
313
+ print("\n" + "─" * 60 + "\n")
314
+
315
+
316
+ # ------------------------------------------------------------------ #
317
+ # MAIN
318
+ # ------------------------------------------------------------------ #
319
+
320
+ def main():
321
+ p = argparse.ArgumentParser(description="SLLM Chat Checker")
322
+ p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR)
323
+ p.add_argument("--mode", type=str, default="interactive", choices=["interactive", "sample"])
324
+ p.add_argument("--temperature", type=float, default=0.7)
325
+ p.add_argument("--top_k", type=int, default=40)
326
+ p.add_argument("--top_p", type=float, default=0.9)
327
+ p.add_argument("--max_new_tokens", type=int, default=200)
328
+ p.add_argument("--system", type=str, default=DEFAULT_SYSTEM)
329
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
330
+ args = p.parse_args()
331
+
332
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
333
+ print(f"\nDevice : {device}")
334
+ if device.type == "cuda":
335
+ print(f"GPU : {torch.cuda.get_device_name(0)}")
336
+
337
+ # Precision setup
338
+ use_amp = False
339
+ if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
340
+ dtype_torch = torch.bfloat16
341
+ use_amp = True
342
+ elif args.dtype == "fp16" and device.type == "cuda":
343
+ dtype_torch = torch.float16
344
+ use_amp = True
345
+ else:
346
+ dtype_torch = torch.float32
347
+ print(f"dtype : {args.dtype}")
348
+
349
+ # Load Model and Tokenizer
350
+ try:
351
+ model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device)
352
+ print(f" Step : {step}")
353
+ if not torch.isnan(torch.tensor(loss)):
354
+ print(f" Loss : {loss:.4f}")
355
+ except Exception as e:
356
+ print(f"\n[ERROR] Failed to load chat model: {e}")
357
+ return
358
+
359
+ if args.mode == "interactive":
360
+ run_interactive(model, tokenizer, device, dtype_torch, use_amp, args)
361
+ elif args.mode == "sample":
362
+ run_sample(model, tokenizer, device, dtype_torch, use_amp, args)
363
+
364
+
365
+ if __name__ == "__main__":
366
+ main()
test_checkpoint.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_checkpoint.py — Load a checkpoint and run inference / inspect it.
3
+
4
+ QUICK START: Edit the variables in the CONFIG section below, then run:
5
+ python test_checkpoint.py
6
+
7
+ Modes:
8
+ INTERACTIVE — Chat loop: type prompts, model responds.
9
+ SAMPLE — Auto-generate N samples from fixed prompts and exit.
10
+ INSPECT — Just print checkpoint info (no generation).
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import torch
16
+ from torch.amp import autocast
17
+
18
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
19
+ from model.config import SLLM_100M, SLLM_150M, ModelConfig
20
+ from model.model import SLLM
21
+
22
+ # ================================================================== #
23
+ # ✏️ EDIT THESE VARIABLES
24
+ # ================================================================== #
25
+
26
+ # --- Checkpoint to load -------------------------------------------
27
+ # Point to any .pt file inside a runs/ subfolder.
28
+ # Examples:
29
+ # RUN_DIR = "runs/sllm_150m" # loads latest .pt in this folder
30
+ # CKPT_FILE = None # set to a specific filename to override
31
+ # CKPT_FILE = "ckpt_0002000.pt" # or pick a specific step
32
+ RUN_DIR = "runs/sllm_150m"
33
+ CKPT_FILE = None # None = auto-pick latest checkpoint in RUN_DIR
34
+
35
+ # --- Model config --------------------------------------------------
36
+ # Must match what you trained with: "100M" or "150M"
37
+ CONFIG = "150M"
38
+
39
+ # --- Generation settings ------------------------------------------
40
+ MAX_NEW_TOKENS = 100 # tokens to generate per prompt
41
+ TEMPERATURE = 0.8 # 0.0 = greedy, 1.0 = random, 0.8 = balanced
42
+ TOP_K = 50 # keep only top-k logits (0 = disabled)
43
+ TOP_P = 0.95 # nucleus sampling threshold (1.0 = disabled)
44
+
45
+ # --- Mode ---------------------------------------------------------
46
+ # "interactive" : chat loop in the terminal
47
+ # "sample" : run SAMPLE_PROMPTS list and exit
48
+ # "inspect" : just print checkpoint metadata, no generation
49
+ MODE = "sample"
50
+
51
+ # --- Prompts for SAMPLE mode --------------------------------------
52
+ SAMPLE_PROMPTS = [
53
+ "Once upon a time",
54
+ "The meaning of life is",
55
+ "In the year 2050,",
56
+ ]
57
+
58
+ # --- dtype --------------------------------------------------------
59
+ # "bf16" (recommended on RTX cards), "fp16", or "fp32"
60
+ DTYPE = "bf16"
61
+
62
+ # ================================================================== #
63
+ # INTERNALS (no need to edit below)
64
+ # ================================================================== #
65
+
66
+ def resolve_checkpoint(run_dir: str, ckpt_file) -> str:
67
+ """Return full path to the checkpoint file."""
68
+ if ckpt_file is not None:
69
+ path = os.path.join(run_dir, ckpt_file)
70
+ if not os.path.isfile(path):
71
+ raise FileNotFoundError(f"Checkpoint not found: {path}")
72
+ return path
73
+
74
+ # Auto-pick latest
75
+ if not os.path.isdir(run_dir):
76
+ raise FileNotFoundError(f"Run directory not found: {run_dir}")
77
+ ckpts = sorted([
78
+ f for f in os.listdir(run_dir)
79
+ if f.startswith("ckpt_") and f.endswith(".pt")
80
+ ])
81
+ if not ckpts:
82
+ raise FileNotFoundError(f"No checkpoints found in: {run_dir}")
83
+ return os.path.join(run_dir, ckpts[-1])
84
+
85
+
86
+ def load_model(ckpt_path: str, config_name: str, device, dtype_torch):
87
+ """Load model weights from checkpoint."""
88
+ cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M}
89
+ cfg = cfg_map[config_name]
90
+
91
+ print(f"\n Config : {cfg}")
92
+ model = SLLM(cfg).to(device)
93
+
94
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
95
+
96
+ # Prefer config_name stored in checkpoint (override CLI if available)
97
+ ckpt_cfg_name = ckpt.get("config_name", config_name)
98
+ if ckpt_cfg_name != config_name:
99
+ print(f" [WARN] Checkpoint config_name='{ckpt_cfg_name}' "
100
+ f"differs from CONFIG='{config_name}'. "
101
+ f"Using checkpoint's config: '{ckpt_cfg_name}'")
102
+ cfg = cfg_map[ckpt_cfg_name]
103
+ model = SLLM(cfg).to(device)
104
+
105
+ model.load_state_dict(ckpt["model_state_dict"])
106
+ model.eval()
107
+
108
+ step = ckpt.get("step", "?")
109
+ loss = ckpt.get("loss", float("nan"))
110
+ return model, cfg, step, loss
111
+
112
+
113
+ @torch.no_grad()
114
+ def generate(model, prompt_ids: list[int], cfg: ModelConfig, device,
115
+ dtype_torch, use_amp: bool,
116
+ max_new_tokens: int, temperature: float,
117
+ top_k: int, top_p: float) -> list[int]:
118
+ """Token-by-token autoregressive generation."""
119
+ ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
120
+ ctx_len = cfg.context_length
121
+
122
+ for _ in range(max_new_tokens):
123
+ # Crop to context window
124
+ ids_crop = ids[:, -ctx_len:]
125
+
126
+ with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
127
+ logits, _ = model(ids_crop)
128
+
129
+ # Logits for the last position
130
+ logits = logits[:, -1, :] # (1, vocab)
131
+
132
+ if temperature == 0.0:
133
+ # Greedy
134
+ next_id = logits.argmax(dim=-1, keepdim=True)
135
+ else:
136
+ logits = logits / temperature
137
+
138
+ # Top-K filtering
139
+ if top_k > 0:
140
+ vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
141
+ logits[logits < vals[:, [-1]]] = float("-inf")
142
+
143
+ # Top-P (nucleus) filtering
144
+ if top_p < 1.0:
145
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
146
+ cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
147
+ # Remove tokens with cumulative prob > top_p
148
+ sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
149
+ logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
150
+
151
+ probs = torch.softmax(logits, dim=-1)
152
+ next_id = torch.multinomial(probs, num_samples=1)
153
+
154
+ ids = torch.cat([ids, next_id], dim=1)
155
+
156
+ return ids[0].tolist()
157
+
158
+
159
+ def char_tokenize(text: str) -> list[int]:
160
+ """
161
+ Fallback character-level tokenizer.
162
+ Your model uses a real tokenizer — swap this out with yours if available.
163
+ Each char maps to its Unicode code point (capped at vocab_size - 1).
164
+ """
165
+ return [min(ord(c), 31_999) for c in text]
166
+
167
+
168
+ def char_detokenize(ids: list[int]) -> str:
169
+ """Reverse of char_tokenize."""
170
+ return "".join(chr(i) if 32 <= i < 127 else "?" for i in ids)
171
+
172
+
173
+ def try_load_sentencepiece(tokenizer_dir="tokenizer/fineweb_edu_tokenizer"):
174
+ """Load the HuggingFace PreTrainedTokenizerFast used during training."""
175
+ try:
176
+ from transformers import PreTrainedTokenizerFast
177
+ tok = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
178
+ encode = lambda text: tok.encode(text)
179
+ decode = lambda ids: tok.decode(ids, skip_special_tokens=True)
180
+ print(f" Tokenizer: HuggingFace tokenizer loaded from '{tokenizer_dir}'")
181
+ print(f" vocab_size={tok.vocab_size:,} eos_id={tok.eos_token_id}")
182
+ return encode, decode
183
+ except Exception as e:
184
+ print(f" Tokenizer: Could not load HuggingFace tokenizer ({e})")
185
+ print(" Falling back to char tokenizer — output will be garbled!")
186
+ return char_tokenize, char_detokenize
187
+
188
+
189
+ def run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode):
190
+ print("\n" + "="*60)
191
+ print(" INTERACTIVE MODE (type 'quit' or 'exit' to stop)")
192
+ print("="*60)
193
+ print(f" max_new_tokens : {MAX_NEW_TOKENS}")
194
+ print(f" temperature : {TEMPERATURE}")
195
+ print(f" top_k / top_p : {TOP_K} / {TOP_P}")
196
+ print()
197
+
198
+ while True:
199
+ try:
200
+ prompt = input("Prompt> ").strip()
201
+ except (EOFError, KeyboardInterrupt):
202
+ print("\n Exiting.")
203
+ break
204
+
205
+ if prompt.lower() in ("quit", "exit", ""):
206
+ print(" Exiting.")
207
+ break
208
+
209
+ prompt_ids = encode(prompt)
210
+ output_ids = generate(
211
+ model, prompt_ids, cfg, device, dtype_torch, use_amp,
212
+ MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P,
213
+ )
214
+ # Only show the newly generated tokens
215
+ new_ids = output_ids[len(prompt_ids):]
216
+ print(f"\nGenerated: {decode(new_ids)}\n")
217
+
218
+
219
+ def run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode):
220
+ print("\n" + "="*60)
221
+ print(" SAMPLE MODE")
222
+ print("="*60)
223
+ for i, prompt in enumerate(SAMPLE_PROMPTS, 1):
224
+ print(f"\n[{i}] Prompt : {prompt!r}")
225
+ prompt_ids = encode(prompt)
226
+ output_ids = generate(
227
+ model, prompt_ids, cfg, device, dtype_torch, use_amp,
228
+ MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P,
229
+ )
230
+ new_ids = output_ids[len(prompt_ids):]
231
+ print(f" Output : {decode(new_ids)}")
232
+
233
+
234
+ def run_inspect(ckpt_path, step, loss, cfg):
235
+ print("\n" + "="*60)
236
+ print(" INSPECT MODE")
237
+ print("="*60)
238
+ print(f" Checkpoint : {ckpt_path}")
239
+ print(f" Step : {step}")
240
+ print(f" Loss : {loss:.4f}" if isinstance(loss, float) else f" Loss: {loss}")
241
+ print(f" Config : {cfg}")
242
+ print(f" Params : {cfg.count_params()/1e6:.1f}M")
243
+ print()
244
+
245
+
246
+ def main():
247
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ print(f"\nDevice : {device}")
249
+ if device.type == "cuda":
250
+ print(f"GPU : {torch.cuda.get_device_name(0)}")
251
+ print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
252
+
253
+ # dtype setup
254
+ use_amp = False
255
+ if DTYPE == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
256
+ dtype_torch = torch.bfloat16
257
+ use_amp = True
258
+ elif DTYPE == "fp16" and device.type == "cuda":
259
+ dtype_torch = torch.float16
260
+ use_amp = True
261
+ else:
262
+ dtype_torch = torch.float32
263
+ print(f"dtype : {DTYPE}")
264
+
265
+ # Resolve checkpoint path
266
+ ckpt_path = resolve_checkpoint(RUN_DIR, CKPT_FILE)
267
+ print(f"\nCheckpoint: {ckpt_path}")
268
+
269
+ # Load model
270
+ model, cfg, step, loss = load_model(ckpt_path, CONFIG, device, dtype_torch)
271
+ print(f" Loaded : step={step}, loss={loss:.4f}")
272
+ print(f" Params : {model.count_params()/1e6:.1f}M")
273
+
274
+ if MODE == "inspect":
275
+ run_inspect(ckpt_path, step, loss, cfg)
276
+ return
277
+
278
+ # Load tokenizer
279
+ encode, decode = try_load_sentencepiece()
280
+
281
+ if MODE == "interactive":
282
+ run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode)
283
+ elif MODE == "sample":
284
+ run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode)
285
+ else:
286
+ print(f" [ERROR] Unknown MODE: '{MODE}'. Use 'interactive', 'sample', or 'inspect'.")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
tokenizer/bpe.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer, AddedToken
2
+ from tokenizers.models import BPE
3
+ from tokenizers.trainers import BpeTrainer
4
+ from tokenizers.pre_tokenizers import Sequence, ByteLevel
5
+ from tokenizers.decoders import ByteLevel as ByteLevelDecoder
6
+
7
+ from pretokenizer import get_pretokenizer
8
+
9
+ VOCAB_SIZE = 32_000
10
+ MIN_FREQUENCY = 3
11
+ SPECIAL_TOKENS = ["<|endoftext|>"]
12
+
13
+ def build_tokenizer() -> Tokenizer:
14
+ """
15
+ Builds and returns an untrained tokenizer with all components configured.
16
+ Call .train_from_iterator() or .train() on the returned object to train it.
17
+
18
+ Pipeline:
19
+ Raw text
20
+ -> Normalizer (handled externally in our normalize() fn)
21
+ -> Pre-tokenizer (custom regex splits + byte level conversion)
22
+ -> BPE Model (learns merge rules during training)
23
+ -> Decoder (reverses byte level for human readable output)
24
+ """
25
+
26
+ # ---- 1. BPE Model ------------------------------------------------
27
+ # unk_token=None because byte-level means we NEVER have unknowns
28
+ # every character always maps to at least one byte token
29
+ model = BPE(
30
+ unk_token=None, # no unknown token - byte fallback handles everything
31
+ byte_fallback=True, # unknown chars represented as <0xXX> byte tokens
32
+ # e.g. ∇ -> <0xE2><0x88><0x87>
33
+ )
34
+
35
+ tokenizer = Tokenizer(model)
36
+
37
+ # ---- 2. Pre-tokenizer --------------------------------------------
38
+ # Sequence chains two pre-tokenizers in order:
39
+ #
40
+ # Step A: Our custom regex splits text into meaningful chunks
41
+ # (contractions, abbreviations, numbers, operators etc.)
42
+ #
43
+ # Step B: ByteLevel converts each chunk's characters to their
44
+ # byte representation using a 256-char printable alphabet
45
+ # e.g. é (bytes 0xC3 0xA9) -> "é"
46
+ #
47
+ # add_prefix_space=False because our regex already handles
48
+ # whitespace explicitly as its own token category
49
+ tokenizer.pre_tokenizer = Sequence([
50
+ get_pretokenizer(), # Step A - our regex
51
+ ByteLevel(add_prefix_space=False), # Step B - byte conversion
52
+ ])
53
+
54
+ # ---- 3. Decoder --------------------------------------------------
55
+ # Reverses the ByteLevel encoding so output is human readable
56
+ # Without this tokenizer.decode() would return "é" instead of "é"
57
+ tokenizer.decoder = ByteLevelDecoder()
58
+
59
+ return tokenizer
60
+
61
+
62
+ # ------------------------------------------------------------------ #
63
+ # TRAINER CONFIG
64
+ # ------------------------------------------------------------------ #
65
+
66
+ def build_trainer() -> BpeTrainer:
67
+ """
68
+ Configures the BPE trainer.
69
+
70
+ vocab_size breakdown:
71
+ 256 base byte tokens (one per possible byte value, always present)
72
+ + 31,743 learned BPE merge tokens
73
+ + 1 special token (<|endoftext|>)
74
+ = 32,000 total
75
+
76
+ The trainer automatically accounts for the 256 base tokens,
77
+ so setting vocab_size=32_000 gives you the right final count.
78
+ """
79
+ return BpeTrainer(
80
+ vocab_size=VOCAB_SIZE,
81
+ min_frequency=MIN_FREQUENCY,
82
+ special_tokens=SPECIAL_TOKENS,
83
+
84
+ # show_progress shows a progress bar during training
85
+ show_progress=True,
86
+
87
+ # initial_alphabet tells the trainer to include all 256 bytes
88
+ # as base tokens before any merges happen
89
+ # This is what guarantees byte-level fallback works
90
+ initial_alphabet=ByteLevel.alphabet(),
91
+ )
92
+
93
+ # CONVENIENCE: get special token IDs after training
94
+
95
+ def get_special_token_ids(tokenizer: Tokenizer) -> dict:
96
+ """
97
+ Returns a dict of special token string -> token ID.
98
+ Call this AFTER training to get the final IDs.
99
+
100
+ Example:
101
+ ids = get_special_token_ids(tokenizer)
102
+ eot_id = ids["<|endoftext|>"] # typically 0
103
+ """
104
+ return {
105
+ token: tokenizer.token_to_id(token)
106
+ for token in SPECIAL_TOKENS
107
+ }
108
+
109
+ # QUICK SANITY CHECK
110
+
111
+ if __name__ == "__main__":
112
+ print("Building tokenizer...")
113
+ tokenizer = build_tokenizer()
114
+
115
+ print("Building trainer...")
116
+ trainer = build_trainer()
117
+
118
+ # Verify pre-tokenizer chain is set up correctly
119
+ print("\nPre-tokenizer chain:")
120
+ print(f" {tokenizer.pre_tokenizer}")
121
+
122
+ # Verify decoder is set
123
+ print(f"\nDecoder:")
124
+ print(f" {tokenizer.decoder}")
125
+
126
+ # Verify trainer config
127
+ print(f"\nTrainer config:")
128
+ print(f" vocab_size : {trainer.vocab_size}")
129
+ print(f" min_frequency : {trainer.min_frequency}")
130
+ print(f" special_tokens: {trainer.special_tokens}")
131
+ print(f" base alphabet : {len(ByteLevel.alphabet())} byte tokens")
132
+
133
+ print("\nAll good - ready to train.")
134
+ print("Next step: pipe FineWeb-Edu text into tokenizer.train_from_iterator()")
tokenizer/fineweb_edu_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/fineweb_edu_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>"
5
+ }
tokenizer/fineweb_edu_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/fineweb_edu_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|endoftext|>",
4
+ "eos_token": "<|endoftext|>",
5
+ "model_max_length": 1024,
6
+ "pad_token": "<|endoftext|>",
7
+ "padding_side": "right",
8
+ "tokenizer_class": "TokenizersBackend",
9
+ "truncation_side": "right",
10
+ "unk_token": null
11
+ }
tokenizer/normalizer.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import html
3
+ import unicodedata
4
+
5
+ def normalization(text):
6
+
7
+ # Strip HTML tags (note: won't catch multiline tags)
8
+ text = re.sub(r'<[^>]+>', ' ', text)
9
+
10
+ # HTML entity decoding
11
+ text = html.unescape(text)
12
+
13
+ # NFC normalization
14
+ text = unicodedata.normalize('NFC', text)
15
+
16
+ # Control characters — including \x7f (DEL)
17
+ text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
18
+
19
+ # Unicode line/paragraph separators → newline (structural, not removed)
20
+ text = re.sub(r'[\u2028\u2029]', '\n', text)
21
+
22
+ # Zero-width characters
23
+ text = re.sub(r'[\u200b\u200c\u200d\ufeff\u00ad]', '', text)
24
+
25
+ # Replacement character
26
+ text = text.replace('\ufffd', '')
27
+
28
+ # Normalize line endings
29
+ text = text.replace('\r\n', '\n')
30
+ text = text.replace('\r', '\n')
31
+
32
+ # Collapse spaces only (preserve leading tabs for indentation)
33
+ text = re.sub(r' +', ' ', text)
34
+
35
+ # Trailing spaces/tabs at end of line
36
+ text = re.sub(r'[ \t]+\n', '\n', text)
37
+
38
+ # Collapse excess newlines
39
+ text = re.sub(r'\n{3,}', '\n\n', text)
40
+
41
+ text = text.strip()
42
+ return text
tokenizer/post_processor.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers.processors import TemplateProcessing
2
+ from tokenizers import Tokenizer
3
+
4
+
5
+ # ------------------------------------------------------------------ #
6
+ # POST-PROCESSOR
7
+ # Runs after BPE encoding, appends <|endoftext|> to every sequence
8
+ # ------------------------------------------------------------------ #
9
+
10
+ def add_post_processor(tokenizer: Tokenizer) -> Tokenizer:
11
+ """
12
+ Adds a post-processor to the tokenizer that appends
13
+ <|endoftext|> to every encoded sequence.
14
+
15
+ Must be called AFTER training because we need the real
16
+ token ID of <|endoftext|> from the trained vocab.
17
+
18
+ Args:
19
+ tokenizer: a trained Tokenizer object
20
+
21
+ Returns:
22
+ The same tokenizer with post-processor attached
23
+ """
24
+
25
+ # Get the real ID from the trained vocab
26
+ # This is why we can only do this after training
27
+ eot_id = tokenizer.token_to_id("<|endoftext|>")
28
+
29
+ if eot_id is None:
30
+ raise ValueError(
31
+ "<|endoftext|> not found in vocab. "
32
+ "Make sure the tokenizer is trained before adding post-processor."
33
+ )
34
+
35
+ # TemplateProcessing defines the final sequence structure
36
+ # using a simple template syntax:
37
+ #
38
+ # $A -> the encoded sequence (single sequence)
39
+ # $A $B -> two sequences (for pair tasks like QA)
40
+ # <|endoftext|>:ID -> insert this special token with its ID
41
+ #
42
+ # Our template:
43
+ # single : [tokens...] <|endoftext|>
44
+ # pair : [tokens_A...] <|endoftext|> [tokens_B...] <|endoftext|>
45
+ #
46
+ # pair template handles future use cases like
47
+ # question-context pairs without needing to change the tokenizer
48
+
49
+ tokenizer.post_processor = TemplateProcessing(
50
+ single="$A <|endoftext|>:0",
51
+ pair="$A <|endoftext|>:0 $B:1 <|endoftext|>:0",
52
+ special_tokens=[
53
+ ("<|endoftext|>", eot_id),
54
+ ],
55
+ )
56
+
57
+ print(f"Post-processor added: <|endoftext|> (ID: {eot_id}) appended to sequences")
58
+
59
+ return tokenizer
60
+
61
+
62
+ # ------------------------------------------------------------------ #
63
+ # VERIFICATION
64
+ # ------------------------------------------------------------------ #
65
+
66
+ def verify_post_processor(tokenizer: Tokenizer):
67
+ """
68
+ Verifies the post-processor is working correctly.
69
+ Checks that <|endoftext|> appears at end of every encoded sequence.
70
+ """
71
+
72
+ eot_id = tokenizer.token_to_id("<|endoftext|>")
73
+ eot_token = "<|endoftext|>"
74
+
75
+ print("\n" + "="*60)
76
+ print(" POST-PROCESSOR VERIFICATION")
77
+ print("="*60 + "\n")
78
+
79
+ test_cases = [
80
+ # Single documents
81
+ "The mitochondria is the powerhouse of the cell.",
82
+ "CO2 levels rose by 1.5e-3 ppm.",
83
+ # Short edge cases
84
+ "Hi.",
85
+ "42",
86
+ ]
87
+
88
+ all_passed = True
89
+
90
+ for text in test_cases:
91
+ encoded = tokenizer.encode(text)
92
+ last_token = encoded.tokens[-1]
93
+ last_id = encoded.ids[-1]
94
+ passed = last_token == eot_token and last_id == eot_id
95
+
96
+ if not passed:
97
+ all_passed = False
98
+
99
+ status = "PASS" if passed else "FAIL"
100
+ print(f"[{status}] {repr(text)}")
101
+ print(f" tokens : {encoded.tokens}")
102
+ print(f" last : {last_token!r} (ID: {last_id})")
103
+ print()
104
+
105
+ # Verify pair encoding
106
+ encoded_pair = tokenizer.encode("question here", "answer here")
107
+ pair_ids = encoded_pair.ids
108
+ eot_positions = [i for i, id in enumerate(pair_ids) if id == eot_id]
109
+
110
+ print(f"Pair encoding test:")
111
+ print(f" tokens : {encoded_pair.tokens}")
112
+ print(f" eot positions: {eot_positions}")
113
+ print(f" expected : 2 eot tokens (one after each sequence)")
114
+ print(f" [{'PASS' if len(eot_positions) == 2 else 'FAIL'}]")
115
+
116
+ print(f"\nAll tests passed: {all_passed}")
117
+
118
+
119
+ # ------------------------------------------------------------------ #
120
+ # HOW THIS FITS INTO THE FULL PIPELINE
121
+ # ------------------------------------------------------------------ #
122
+
123
+ # The correct order when building your full tokenizer:
124
+ #
125
+ # 1. build_tokenizer() <- sets up model + pre-tokenizer + decoder
126
+ # 2. train_from_iterator() <- trains BPE, assigns real vocab IDs
127
+ # 3. add_post_processor() <- NOW we can add post-processor (needs real IDs)
128
+ # 4. tokenizer.save() <- saves everything including post-processor
129
+ #
130
+ # Loading later:
131
+ # tokenizer = Tokenizer.from_file("fineweb_edu_tokenizer.json")
132
+ # <- post-processor is automatically restored, no extra steps
133
+
134
+
135
+ if __name__ == "__main__":
136
+ import sys
137
+
138
+ # Load a trained tokenizer from disk to test
139
+ # Pass the path as argument: python post_processor.py fineweb_edu_tokenizer.json
140
+ # Or it will try the default path
141
+
142
+ path = sys.argv[1] if len(sys.argv) > 1 else "fineweb_edu_tokenizer.json"
143
+
144
+ print(f"Loading tokenizer from: {path}")
145
+ tokenizer = Tokenizer.from_file(path)
146
+
147
+ tokenizer = add_post_processor(tokenizer)
148
+ verify_post_processor(tokenizer)
149
+
150
+ # Save with post-processor included
151
+ tokenizer.save(path)
152
+ print(f"\nTokenizer re-saved with post-processor to: {path}")
tokenizer/pretokenizer.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from tokenizers.pre_tokenizers import PreTokenizer, Split
3
+ from tokenizers import Regex
4
+
5
+ # Each category is defined separately so its easy to understand, modify, or debug individually
6
+
7
+
8
+ # 1. Contractions
9
+ # Matches: 's 't 're 've 'll 'm 'd
10
+ # Example: "don't" -> ["don", "'t"]
11
+ CONTRACTIONS = r"'(?:s|t|re|ve|ll|m|d)"
12
+
13
+ # 2. Abbreviations
14
+ # Matches: letter(s) separated by dots, optional trailing dot
15
+ # Example: "U.S.A" -> ["U.S.A"]
16
+ # "e.g." -> ["e.g."]
17
+ # "Ph.D" -> ["Ph.D"]
18
+ # \b = word boundary, ensures we dont partially match inside a word
19
+ ABBREVIATIONS = r"\b[A-Za-z](?:\.[A-Za-z])+\.?"
20
+
21
+ # 3. Scientific Notation
22
+ # Matches: number, optional decimal, e/E, optional sign, exponent
23
+ # Example: "1.5e-3" -> ["1.5e-3"]
24
+ # "3e10" -> ["3e10"]
25
+ # "2.0E+4" -> ["2.0E+4"]
26
+ # Must come BEFORE decimals otherwise "1.5" in "1.5e-3" matches first
27
+ SCIENTIFIC = r"\d+\.?\d*[eE][+-]?\d+"
28
+
29
+ # 4. Decimal Numbers
30
+ # Matches: digits, dot, digits
31
+ # Example: "3.14" -> ["3.14"]
32
+ # "0.001" -> ["0.001"]
33
+ # Must come BEFORE integers otherwise "3" in "3.14" matches first
34
+ DECIMALS = r"\d+\.\d+"
35
+
36
+ # 5. Integers
37
+ # Matches: any sequence of digits
38
+ # Example: "42" -> ["42"]
39
+ # "1984" -> ["1984"]
40
+ # Comes last among numbers since scientific and decimal match first
41
+ INTEGERS = r"\d+"
42
+
43
+ # 6. Multi-character Operators
44
+ # Matches: common programming operators that are 2 characters
45
+ # Example: "==" -> ["=="] "!=" -> ["!="]
46
+ # "->" -> ["->"] "+=" -> ["+="]
47
+ # Must come BEFORE single punctuation catch-all
48
+ # [-+*/]= matches +=, -=, *=, /= in one pattern
49
+ OPERATORS = r"==|!=|->|<=|>=|\*\*|//|[-+*/]="
50
+
51
+ # 7. Snake Case Identifiers
52
+ # Matches: words that contain underscores (code identifiers)
53
+ # Example: "snake_case" -> ["snake_case"]
54
+ # "var_name_2" -> ["var_name_2"]
55
+ # "_private" -> ["_private"]
56
+ # Must come BEFORE regular words otherwise "snake" matches first
57
+ SNAKE_CASE = r"[A-Za-z_][A-Za-z0-9_]*"
58
+
59
+ # 8. Regular Unicode Words
60
+ # Matches: any sequence of word characters (letters, digits)
61
+ # \w+ in unicode mode covers non-english letters too
62
+ # Example: "hello" -> ["hello"]
63
+ # "café" -> ["café"]
64
+ WORDS = r"\w+"
65
+
66
+ # 9. Whitespace
67
+ # Newlines are matched separately from spaces/tabs
68
+ # This preserves document structure (paragraph breaks etc.)
69
+ # Example: "\n\n" -> ["\n\n"] " " -> [" "]
70
+ WHITESPACE = r"\n+|[ \t]+"
71
+
72
+ # 10. Punctuation Catch-all
73
+ # Matches any single non-whitespace character that nothing above caught
74
+ # Example: "!" -> ["!"] "@" -> ["@"] "." -> ["."]
75
+ PUNCTUATION = r"[^\s]"
76
+
77
+ # ------------------------------------------------------------------ #
78
+ # Combine all patterns in ORDER - first match wins
79
+ # ------------------------------------------------------------------ #
80
+
81
+ PRETOKENIZER_PATTERN = "|".join([
82
+ CONTRACTIONS, # 1 - most specific first
83
+ ABBREVIATIONS, # 2 - before plain words
84
+ SCIENTIFIC, # 3 - before decimals
85
+ DECIMALS, # 4 - before integers
86
+ INTEGERS, # 5
87
+ OPERATORS, # 6 - before single punctuation
88
+ SNAKE_CASE, # 7 - before plain words
89
+ WORDS, # 8
90
+ WHITESPACE, # 9
91
+ PUNCTUATION, # 10 - catch everything else
92
+ ])
93
+
94
+
95
+ def get_pretokenizer():
96
+ """
97
+ Returns a HuggingFace Split pre-tokenizer using our custom regex.
98
+
99
+ Split behavior:
100
+ - pattern : the regex to split/match on
101
+ - behavior : "removed" -> splits on matches and discards them
102
+ "isolated" -> splits on matches and keeps them as tokens
103
+ "merged_with_previous" / "merged_with_next"
104
+
105
+ We use "isolated" because we WANT to keep whitespace, operators,
106
+ punctuation etc. as their own tokens rather than discard them.
107
+ """
108
+ return Split(
109
+ pattern=Regex(PRETOKENIZER_PATTERN),
110
+ behavior="isolated",
111
+ invert=True # invert=True means: match the pattern and KEEP matches as tokens
112
+ # (rather than treating matches as split points)
113
+ )
114
+
115
+
116
+ # ------------------------------------------------------------------ #
117
+ # Quick test - run this file directly to verify behavior
118
+ # ------------------------------------------------------------------ #
119
+
120
+ if __name__ == "__main__":
121
+ from tokenizers import Tokenizer
122
+ from tokenizers.models import BPE
123
+
124
+ # Build a bare tokenizer just to test the pre-tokenizer
125
+ tokenizer = Tokenizer(BPE())
126
+ tokenizer.pre_tokenizer = get_pretokenizer()
127
+
128
+ test_cases = [
129
+ # Contractions
130
+ ("Contractions", "don't she'll they've"),
131
+ # Abbreviations
132
+ ("Abbreviations", "U.S.A has a Ph.D e.g. this"),
133
+ # Scientific notation
134
+ ("Scientific", "the value is 1.5e-3 and 2.0E+4"),
135
+ # Decimals
136
+ ("Decimals", "pi is 3.14159 and e is 2.718"),
137
+ # Integers
138
+ ("Integers", "there are 1000 students in 2024"),
139
+ # Operators
140
+ ("Operators", "if x==0 or y!=1 then z+=2"),
141
+ # Snake case
142
+ ("Snake case", "my_variable and snake_case_name"),
143
+ # Mixed real world
144
+ ("Real world", "The CO2 level is 415.2 ppm\n\nSee e.g. Smith et al."),
145
+ # Code like
146
+ ("Code-like", "def my_func(x):\n return x**2 + 1"),
147
+ ]
148
+
149
+ print(f"\n{'='*60}")
150
+ print(f" PRE-TOKENIZER TEST")
151
+ print(f"{'='*60}\n")
152
+
153
+ for label, text in test_cases:
154
+ tokens = tokenizer.pre_tokenizer.pre_tokenize_str(text)
155
+ token_strings = [t[0] for t in tokens] # tokens are (string, offset) tuples
156
+ print(f"[{label}]")
157
+ print(f" Input : {repr(text)}")
158
+ print(f" Tokens : {token_strings}")
159
+ print()
tokenizer/tempCodeRunnerFile.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ with open(os.path.join(save_dir, "special_tokens_map.json"), "w") as f:
3
+ json.dump(special_tokens_map, f, indent=2)
4
+
5
+ print("special_tokens_map.json written manually")
tokenizer/tokenize_dataset.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenize_dataset.py — Parallel tokenization pipeline
3
+
4
+ Architecture:
5
+ Main thread : stream HF dataset → filter → normalize → batch texts
6
+ Worker pool : N_WORKERS processes, each with own loaded tokenizer,
7
+ tokenize batches concurrently using ProcessPoolExecutor
8
+ Main thread : collect results IN ORDER → route train/val → flush shards
9
+
10
+ Why this is faster:
11
+ Old code: stream → [normalize] → [tokenize 1000 docs, 1 CPU] → write
12
+ New code: stream → [normalize] → [tokenize 1000 docs × N cores] → write
13
+
14
+ On 12-core machine: expect 6-10× speedup on tokenization step.
15
+ Bottleneck shifts to HF streaming bandwidth, not CPU.
16
+
17
+ Notes:
18
+ - Workers are initialized ONCE with the tokenizer loaded (no repeated disk reads)
19
+ - Results collected in SUBMISSION ORDER so train/val routing is deterministic
20
+ - Sliding window of MAX_PENDING futures keeps all cores busy without
21
+ unbounded memory growth
22
+ - Ctrl+C safe: flushes remaining buffers before exit
23
+ """
24
+
25
+ import os
26
+ import sys
27
+ import time
28
+ import warnings
29
+ import numpy as np
30
+ from collections import deque
31
+ from concurrent.futures import ProcessPoolExecutor
32
+ from datasets import load_dataset
33
+ from transformers import PreTrainedTokenizerFast, logging as hf_logging
34
+ from tqdm import tqdm
35
+
36
+ # Import normalizer from same directory
37
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
+ from normalizer import normalization
39
+
40
+ hf_logging.set_verbosity_error()
41
+ warnings.filterwarnings("ignore")
42
+
43
+
44
+ # ------------------------------------------------------------------ #
45
+ # CONSTANTS
46
+ # ------------------------------------------------------------------ #
47
+
48
+ DATASET_NAME = "HuggingFaceFW/fineweb-edu"
49
+ DATASET_SUBSET = "CC-MAIN-2014-49"
50
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
51
+ TOKENIZER_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
52
+ DATA_DIR = os.path.join(SCRIPT_DIR, "data")
53
+
54
+ MIN_QUALITY = 3
55
+ SHARD_SIZE = 100_000_000 # tokens per shard (~190 MB at uint16)
56
+ BATCH_SIZE = 2_000 # docs per tokenization task (↑ from 1000)
57
+ VAL_RATIO = 100 # every 100th accepted doc → val
58
+ SHUFFLE_BUFFER = 10_000
59
+ MIN_DOC_LENGTH = 100
60
+ DTYPE = np.uint16
61
+ MAX_TOKENS = 3_200_000_000
62
+
63
+ # Parallel workers: leave 2 cores for OS + HF streaming
64
+ N_WORKERS = max(1, os.cpu_count() - 2)
65
+
66
+ # How many tokenization futures to keep in-flight at once
67
+ # = N_WORKERS × 2 keeps the pipeline full without excess memory
68
+ MAX_PENDING = N_WORKERS * 2
69
+
70
+
71
+ # ------------------------------------------------------------------ #
72
+ # WORKER PROCESS — loaded once per process at startup
73
+ # ------------------------------------------------------------------ #
74
+
75
+ # Module-level tokenizer in each worker process
76
+ _worker_tokenizer = None
77
+
78
+
79
+ def _worker_init(tokenizer_dir: str):
80
+ """
81
+ Called ONCE per worker process at startup.
82
+ Loads the tokenizer into the worker's global state.
83
+ Subsequent calls to _tokenize_worker_fn reuse this loaded tokenizer.
84
+ """
85
+ global _worker_tokenizer
86
+ import warnings
87
+ from transformers import PreTrainedTokenizerFast, logging as hf_log
88
+ hf_log.set_verbosity_error()
89
+ warnings.filterwarnings("ignore")
90
+ _worker_tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
91
+
92
+
93
+ def _tokenize_worker_fn(texts: list) -> list:
94
+ """
95
+ Tokenizes a batch of pre-normalized texts in a worker process.
96
+ Returns a list of token-ID lists, one per document.
97
+ Each doc ends with <|endoftext|> (added by add_special_tokens=True).
98
+
99
+ Args:
100
+ texts : list of normalized strings (already filtered, normalized)
101
+
102
+ Returns:
103
+ list of list[int] — token IDs per document
104
+ """
105
+ global _worker_tokenizer
106
+ encoded = _worker_tokenizer(
107
+ texts,
108
+ add_special_tokens = True, # appends <|endoftext|>
109
+ truncation = False, # keep full document
110
+ padding = False, # no padding (we pack shards)
111
+ return_attention_mask= False, # not needed
112
+ )
113
+ return encoded["input_ids"]
114
+
115
+
116
+ # ------------------------------------------------------------------ #
117
+ # SHARD HELPERS
118
+ # ------------------------------------------------------------------ #
119
+
120
+ def get_shard_path(split: str, shard_idx: int) -> str:
121
+ return os.path.join(DATA_DIR, f"{split}_{shard_idx:03d}.bin")
122
+
123
+
124
+ def save_shard(tokens: list, split: str, shard_idx: int):
125
+ arr = np.array(tokens, dtype=DTYPE)
126
+ path = get_shard_path(split, shard_idx)
127
+ arr.tofile(path)
128
+ size_mb = arr.nbytes / 1024 / 1024
129
+ tqdm.write(f" saved {split}_{shard_idx:03d}.bin | {len(tokens):,} tokens | {size_mb:.1f} MB")
130
+
131
+
132
+ # ------------------------------------------------------------------ #
133
+ # ROUTE BATCH RESULTS → train / val buffers
134
+ # ------------------------------------------------------------------ #
135
+
136
+ def route_results(
137
+ all_ids : list,
138
+ doc_count_start: int,
139
+ train_buffer : list,
140
+ val_buffer : list,
141
+ train_tokens : int,
142
+ val_tokens : int,
143
+ total_tokens : int,
144
+ ) -> tuple:
145
+ """
146
+ Routes tokenized docs to train or val buffer by doc index.
147
+ Returns updated (train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count).
148
+ """
149
+ batch_tok_count = 0
150
+
151
+ for i, ids in enumerate(all_ids):
152
+ doc_num = doc_count_start + i
153
+
154
+ if doc_num % VAL_RATIO == 0: # every 100th doc → val
155
+ val_buffer.extend(ids)
156
+ val_tokens += len(ids)
157
+ else:
158
+ train_buffer.extend(ids)
159
+ train_tokens += len(ids)
160
+
161
+ total_tokens += len(ids)
162
+ batch_tok_count += len(ids)
163
+
164
+ return train_buffer, val_buffer, train_tokens, val_tokens, total_tokens, batch_tok_count
165
+
166
+
167
+ # ------------------------------------------------------------------ #
168
+ # MAIN PARALLEL TOKENIZATION PIPELINE
169
+ # ------------------------------------------------------------------ #
170
+
171
+ def tokenize_dataset():
172
+ os.makedirs(DATA_DIR, exist_ok=True)
173
+
174
+ print(f"Loading tokenizer from: {TOKENIZER_DIR}")
175
+ print(f" workers : {N_WORKERS} of {os.cpu_count()} CPUs")
176
+
177
+ print(f"\nLoading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
178
+ ds = load_dataset(
179
+ DATASET_NAME,
180
+ name = DATASET_SUBSET,
181
+ split = "train",
182
+ streaming = True,
183
+ ).shuffle(buffer_size=SHUFFLE_BUFFER, seed=42)
184
+
185
+ # ---- State ------------------------------------------------------ #
186
+ train_buffer = []
187
+ val_buffer = []
188
+ train_shard = 0
189
+ val_shard = 0
190
+ total_docs = 0
191
+ skipped_docs = 0
192
+ total_tokens = 0
193
+ train_tokens = 0
194
+ val_tokens = 0
195
+ batch_texts = [] # accumulating next batch to submit
196
+ batch_doc_start = 0 # doc index at start of current batch_texts
197
+
198
+ # pending: deque of (future, doc_count_start)
199
+ # We always pop from the LEFT (oldest submission) to preserve order
200
+ pending = deque()
201
+ cap_reached = False
202
+
203
+ # ---- Progress bars ----------------------------------------------- #
204
+ token_bar = tqdm(
205
+ total=MAX_TOKENS,
206
+ desc="tokens",
207
+ unit="tok",
208
+ unit_scale=True,
209
+ unit_divisor=1000,
210
+ colour="green",
211
+ position=0,
212
+ )
213
+ doc_bar = tqdm(
214
+ desc="docs ",
215
+ unit="doc",
216
+ unit_scale=True,
217
+ colour="blue",
218
+ position=1,
219
+ )
220
+
221
+ t_start = time.time()
222
+
223
+ # ------------------------------------------------------------------ #
224
+ # DRAIN HELPER — collect the oldest pending future and process it
225
+ # ------------------------------------------------------------------ #
226
+
227
+ def drain_one():
228
+ nonlocal train_buffer, val_buffer, train_shard, val_shard
229
+ nonlocal total_tokens, train_tokens, val_tokens
230
+
231
+ if not pending:
232
+ return False
233
+
234
+ future, doc_start = pending.popleft()
235
+ all_ids = future.result() # blocks until this task done
236
+
237
+ (train_buffer, val_buffer,
238
+ train_tokens, val_tokens,
239
+ total_tokens, batch_tok) = route_results(
240
+ all_ids, doc_start,
241
+ train_buffer, val_buffer,
242
+ train_tokens, val_tokens, total_tokens,
243
+ )
244
+
245
+ token_bar.update(batch_tok)
246
+ token_bar.set_postfix({
247
+ "train": f"{train_tokens/1e9:.2f}B",
248
+ "val" : f"{val_tokens/1e6:.0f}M",
249
+ "shards": train_shard,
250
+ })
251
+
252
+ # Flush train shards
253
+ while len(train_buffer) >= SHARD_SIZE:
254
+ save_shard(train_buffer[:SHARD_SIZE], "train", train_shard)
255
+ train_buffer = train_buffer[SHARD_SIZE:]
256
+ train_shard += 1
257
+
258
+ # Flush val shards
259
+ while len(val_buffer) >= SHARD_SIZE:
260
+ save_shard(val_buffer[:SHARD_SIZE], "val", val_shard)
261
+ val_buffer = val_buffer[SHARD_SIZE:]
262
+ val_shard += 1
263
+
264
+ return True
265
+
266
+ # ------------------------------------------------------------------ #
267
+ # MAIN LOOP with ProcessPoolExecutor
268
+ # ------------------------------------------------------------------ #
269
+
270
+ print(f"\nStarting tokenization...")
271
+ print(f" token target : {MAX_TOKENS:,}")
272
+ print(f" shard size : {SHARD_SIZE:,} tokens")
273
+ print(f" batch size : {BATCH_SIZE} docs")
274
+ print(f" val ratio : every {VAL_RATIO}th doc")
275
+ print(f" quality : int_score >= {MIN_QUALITY}\n")
276
+
277
+ with ProcessPoolExecutor(
278
+ max_workers = N_WORKERS,
279
+ initializer = _worker_init,
280
+ initargs = (TOKENIZER_DIR,),
281
+ ) as executor:
282
+
283
+ for doc in ds:
284
+
285
+ # ---- Quality filter ------------------------------------ #
286
+ if doc["int_score"] < MIN_QUALITY:
287
+ skipped_docs += 1
288
+ doc_bar.set_postfix({"skipped": skipped_docs})
289
+ continue
290
+
291
+ # ---- Length + normalize -------------------------------- #
292
+ text = doc["text"]
293
+ if len(text) < MIN_DOC_LENGTH:
294
+ skipped_docs += 1
295
+ doc_bar.set_postfix({"skipped": skipped_docs})
296
+ continue
297
+
298
+ text = normalization(text)
299
+ if len(text) < MIN_DOC_LENGTH:
300
+ skipped_docs += 1
301
+ doc_bar.set_postfix({"skipped": skipped_docs})
302
+ continue
303
+
304
+ batch_texts.append(text)
305
+ total_docs += 1
306
+ doc_bar.update(1)
307
+
308
+ # ---- Submit batch when full ---------------------------- #
309
+ if len(batch_texts) == BATCH_SIZE:
310
+ # Record which doc index this batch starts at
311
+ doc_start = total_docs - BATCH_SIZE
312
+
313
+ future = executor.submit(_tokenize_worker_fn, batch_texts)
314
+ pending.append((future, doc_start))
315
+ batch_texts = []
316
+
317
+ # ---- Backpressure: drain oldest if queue full ------- #
318
+ # This prevents unbounded memory accumulation
319
+ # while keeping all N_WORKERS busy
320
+ while len(pending) >= MAX_PENDING:
321
+ drain_one()
322
+
323
+ # ---- Check token cap -------------------------------- #
324
+ if total_tokens >= MAX_TOKENS:
325
+ tqdm.write(f"\nToken cap reached: {total_tokens:,} tokens from {total_docs:,} docs")
326
+ cap_reached = True
327
+ break
328
+
329
+ # ---- Submit any remaining partial batch -------------------- #
330
+ if batch_texts and not cap_reached:
331
+ doc_start = total_docs - len(batch_texts)
332
+ future = executor.submit(_tokenize_worker_fn, batch_texts)
333
+ pending.append((future, doc_start))
334
+
335
+ # ---- Drain all remaining pending futures ------------------- #
336
+ while pending:
337
+ drain_one()
338
+
339
+ # ---- Close progress bars --------------------------------------- #
340
+ token_bar.close()
341
+ doc_bar.close()
342
+
343
+ # ---- Save remaining partial shards ----------------------------- #
344
+ if train_buffer:
345
+ save_shard(train_buffer, "train", train_shard)
346
+ train_shard += 1
347
+
348
+ if val_buffer:
349
+ save_shard(val_buffer, "val", val_shard)
350
+ val_shard += 1
351
+
352
+ # ---- Final summary --------------------------------------------- #
353
+ print(f"\n{'='*60}")
354
+ print(f" TOKENIZATION COMPLETE")
355
+ print(f"{'='*60}")
356
+ print(f" total docs : {total_docs:,}")
357
+ print(f" skipped docs : {skipped_docs:,}")
358
+ print(f" total tokens : {total_tokens:,}")
359
+ print(f" train tokens : {train_tokens:,}")
360
+ print(f" val tokens : {val_tokens:,}")
361
+ print(f" train shards : {train_shard}")
362
+ print(f" val shards : {val_shard}")
363
+ print(f" data dir : {os.path.abspath(DATA_DIR)}")
364
+
365
+
366
+ # ------------------------------------------------------------------ #
367
+ # LOAD SHARDS DURING TRAINING (unchanged)
368
+ # ------------------------------------------------------------------ #
369
+
370
+ def load_shard(split: str, shard_idx: int) -> np.ndarray:
371
+ """
372
+ Loads a shard as a memory-mapped numpy array.
373
+ The full shard never loads into RAM at once.
374
+
375
+ Usage during training:
376
+ shard = load_shard("train", 0)
377
+ chunk = shard[i : i + 1024]
378
+ """
379
+ path = get_shard_path(split, shard_idx)
380
+ return np.memmap(path, dtype=DTYPE, mode="r")
381
+
382
+
383
+ # ------------------------------------------------------------------ #
384
+ # ENTRY POINT
385
+ # ------------------------------------------------------------------ #
386
+
387
+ if __name__ == "__main__":
388
+ # Windows requires this guard for multiprocessing with spawn start method
389
+ tokenize_dataset()
tokenizer/traintokenizer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tokenizers import Tokenizer
3
+
4
+ # Import our components
5
+ from normalizer import normalization # our normalize function
6
+ from bpe import build_tokenizer, build_trainer, get_special_token_ids
7
+
8
+ from post_processor import add_post_processor
9
+ # ------------------------------------------------------------------ #
10
+ # CONSTANTS
11
+ # ------------------------------------------------------------------ #
12
+
13
+ DATASET_NAME = "HuggingFaceFW/fineweb-edu"
14
+ DATASET_SUBSET = "CC-MAIN-2014-49"
15
+ MIN_QUALITY = 3 # int_score >= 3 only
16
+ MAX_TOKENS = 25_000_000 # ~100M characters worth, enough for BPE training
17
+ # FineWeb-Edu tokens avg 4-5 chars each
18
+ MIN_DOC_LENGTH = 100 # skip very short documents, likely boilerplate
19
+ import os
20
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ SAVE_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer")
22
+
23
+
24
+ # ------------------------------------------------------------------ #
25
+ # DATA GENERATOR
26
+ # ------------------------------------------------------------------ #
27
+
28
+ def fineweb_edu_iterator(
29
+ max_tokens: int = MAX_TOKENS,
30
+ min_quality: int = MIN_QUALITY,
31
+ min_length: int = MIN_DOC_LENGTH,
32
+ ):
33
+ """
34
+ Streams FineWeb-Edu documents, filters by quality,
35
+ normalizes text, and yields clean strings for BPE training.
36
+
37
+ Args:
38
+ max_tokens : stop after consuming this many tokens total
39
+ min_quality : only yield docs with int_score >= this value
40
+ min_length : skip docs shorter than this many characters
41
+
42
+ Yields:
43
+ str: normalized, clean document text
44
+ """
45
+
46
+ print(f"Loading dataset stream: {DATASET_NAME} / {DATASET_SUBSET}")
47
+ ds = load_dataset(
48
+ DATASET_NAME,
49
+ name=DATASET_SUBSET,
50
+ split="train",
51
+ streaming=True,
52
+ )
53
+
54
+ tokens_seen = 0 # running total of tokens consumed
55
+ docs_yielded = 0 # how many docs passed all filters
56
+ docs_skipped = 0 # how many docs were filtered out
57
+
58
+ for doc in ds:
59
+
60
+ # ---- Stop condition ----------------------------------------
61
+ if tokens_seen >= max_tokens:
62
+ break
63
+
64
+ # ---- Quality filter ----------------------------------------
65
+ # int_score is 0-5, we want educational quality >= 3
66
+ if doc["int_score"] < min_quality:
67
+ docs_skipped += 1
68
+ continue
69
+
70
+ # ---- Extract and normalize ---------------------------------
71
+ text = doc["text"]
72
+
73
+ # Skip very short documents before normalization
74
+ # (saves compute on boilerplate/empty docs)
75
+ if len(text) < min_length:
76
+ docs_skipped += 1
77
+ continue
78
+
79
+ # Run our normalization pipeline
80
+ text = normalization(text)
81
+
82
+ # Skip if normalization made it too short
83
+ # (e.g. doc was mostly HTML tags or control chars)
84
+ if len(text) < min_length:
85
+ docs_skipped += 1
86
+ continue
87
+
88
+ # ---- Track progress ----------------------------------------
89
+ tokens_seen += doc["token_count"]
90
+ docs_yielded += 1
91
+
92
+ # Log progress every 100k documents
93
+ if docs_yielded % 100_000 == 0:
94
+ print(
95
+ f" docs yielded: {docs_yielded:,} | "
96
+ f"docs skipped: {docs_skipped:,} | "
97
+ f"tokens seen: {tokens_seen:,} / {max_tokens:,} "
98
+ f"({100 * tokens_seen / max_tokens:.1f}%)"
99
+ )
100
+
101
+ yield text
102
+
103
+ # Final stats
104
+ print(f"\nStream complete:")
105
+ print(f" docs yielded : {docs_yielded:,}")
106
+ print(f" docs skipped : {docs_skipped:,}")
107
+ print(f" tokens seen : {tokens_seen:,}")
108
+
109
+
110
+ # ------------------------------------------------------------------ #
111
+ # TRAINING
112
+ # ------------------------------------------------------------------ #
113
+
114
+ def train_tokenizer() -> Tokenizer:
115
+ """
116
+ Builds, trains, and saves the tokenizer.
117
+
118
+ Returns:
119
+ Trained Tokenizer object
120
+ """
121
+
122
+ # Build untrained tokenizer and trainer
123
+ tokenizer = build_tokenizer()
124
+ trainer = build_trainer()
125
+
126
+ print("\nStarting BPE training...")
127
+ print(f" vocab size : {trainer.vocab_size:,}")
128
+ print(f" min frequency : {trainer.min_frequency}")
129
+ print(f" quality filter: int_score >= {MIN_QUALITY}")
130
+ print(f" max tokens : {MAX_TOKENS:,}\n")
131
+
132
+ # train_from_iterator expects an iterable of strings
133
+ # our generator yields one clean document string at a time
134
+ tokenizer.train_from_iterator(
135
+ iterator=fineweb_edu_iterator(),
136
+ trainer=trainer,
137
+ length=MAX_TOKENS, # optional hint for progress bar accuracy
138
+ )
139
+
140
+ print("\nTraining complete.")
141
+
142
+ tokenizer = add_post_processor(tokenizer)
143
+
144
+ # Print special token IDs
145
+ ids = get_special_token_ids(tokenizer)
146
+ print(f"\nSpecial token IDs:")
147
+ for token, token_id in ids.items():
148
+ print(f" {token} -> {token_id}")
149
+
150
+
151
+ # Save tokenizer to disk
152
+ tokenizer.save(f"{SAVE_PATH}.json")
153
+ print(f"\nTokenizer saved to: {SAVE_PATH}.json")
154
+
155
+ return tokenizer
156
+
157
+
158
+ # ------------------------------------------------------------------ #
159
+ # QUICK VERIFICATION after training
160
+ # ------------------------------------------------------------------ #
161
+
162
+ def verify_tokenizer(tokenizer: Tokenizer):
163
+ """
164
+ Runs a few quick checks after training to verify correctness.
165
+ """
166
+ print("\n" + "="*60)
167
+ print(" TOKENIZER VERIFICATION")
168
+ print("="*60 + "\n")
169
+
170
+ test_cases = [
171
+ "The mitochondria is the powerhouse of the cell.",
172
+ "CO2 levels rose by 1.5e-3 ppm in 2024.",
173
+ "def compute_loss(y_pred, y_true):\n return (y_pred - y_true)**2",
174
+ "U.S.A has a Ph.D program e.g. at MIT.",
175
+ "don't they've she'll",
176
+ "∇f(x) = 0 is a necessary condition.", # tests byte fallback
177
+ ]
178
+
179
+ for text in test_cases:
180
+ encoded = tokenizer.encode(text)
181
+ decoded = tokenizer.decode(encoded.ids)
182
+ n_tokens = len(encoded.ids)
183
+
184
+ print(f"Input : {repr(text)}")
185
+ print(f"Tokens : {encoded.tokens}")
186
+ print(f"IDs : {encoded.ids}")
187
+ print(f"N tokens: {n_tokens}")
188
+ print(f"Decoded : {repr(decoded)}")
189
+ print(f"Lossless: {text == decoded}")
190
+ print()
191
+
192
+ # Verify vocab size
193
+ vocab_size = tokenizer.get_vocab_size()
194
+ print(f"Final vocab size: {vocab_size:,}")
195
+
196
+ # Verify endoftext token exists
197
+ eot_id = tokenizer.token_to_id("<|endoftext|>")
198
+ print(f"<|endoftext|> ID: {eot_id}")
199
+
200
+
201
+ # ------------------------------------------------------------------ #
202
+ # ENTRY POINT
203
+ # ------------------------------------------------------------------ #
204
+
205
+ if __name__ == "__main__":
206
+ tokenizer = train_tokenizer()
207
+ verify_tokenizer(tokenizer)
tokenizer/wrap_tokenizer.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from transformers import PreTrainedTokenizerFast
3
+ import json
4
+ import os
5
+
6
+ # ------------------------------------------------------------------ #
7
+ # CONSTANTS
8
+ # ------------------------------------------------------------------ #
9
+
10
+ import os
11
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
12
+ TOKENIZER_PATH = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer.json")
13
+ SAVE_DIR = os.path.join(SCRIPT_DIR, "fineweb_edu_tokenizer") # output folder
14
+ MODEL_MAX_LENGTH = 1024 # context length
15
+ PADDING_SIDE = "right" # causal LM standard
16
+
17
+
18
+ # ------------------------------------------------------------------ #
19
+ # WRAP
20
+ # ------------------------------------------------------------------ #
21
+
22
+ def wrap_tokenizer(
23
+ tokenizer_path: str = TOKENIZER_PATH,
24
+ save_dir: str = SAVE_DIR,
25
+ ) -> PreTrainedTokenizerFast:
26
+ """
27
+ Wraps a trained HuggingFace Tokenizer as a PreTrainedTokenizerFast.
28
+
29
+ This gives us:
30
+ - datasets.map() compatibility for bulk tokenization
31
+ - HuggingFace Trainer + DataCollator compatibility
32
+ - Automatic padding, truncation, attention masks
33
+ - from_pretrained() loading support
34
+ - return_tensors="pt" for PyTorch tensors
35
+
36
+ Args:
37
+ tokenizer_path : path to trained tokenizer .json file
38
+ save_dir : folder to save the wrapped tokenizer
39
+
40
+ Returns:
41
+ PreTrainedTokenizerFast ready for training
42
+ """
43
+
44
+ print(f"Loading trained tokenizer from: {tokenizer_path}")
45
+ base_tokenizer = Tokenizer.from_file(tokenizer_path)
46
+
47
+ # ---- Wrap --------------------------------------------------------
48
+ # We map <|endoftext|> to all three roles:
49
+ #
50
+ # eos_token - end of sequence marker, used during generation
51
+ # to know when to stop
52
+ #
53
+ # bos_token - beginning of sequence, GPT-2 style uses eos
54
+ # for both since there is no separate BOS token
55
+ #
56
+ # pad_token - safe to reuse eos here because we are packing
57
+ # sequences and will never actually pad during
58
+ # pretraining. Defined so HuggingFace doesn't
59
+ # complain about missing pad token
60
+ #
61
+ # unk_token - None because byte-level means no unknowns ever
62
+
63
+ tokenizer = PreTrainedTokenizerFast(
64
+ tokenizer_object=base_tokenizer,
65
+
66
+ # Special token mappings
67
+ eos_token="<|endoftext|>",
68
+ bos_token="<|endoftext|>",
69
+ pad_token="<|endoftext|>",
70
+ unk_token=None,
71
+
72
+ # Context length
73
+ model_max_length=MODEL_MAX_LENGTH,
74
+
75
+ # Padding behavior
76
+ padding_side=PADDING_SIDE,
77
+
78
+ # Truncation side - truncate from the right
79
+ # (keep the beginning of the sequence, drop the end)
80
+ truncation_side="right",
81
+ )
82
+
83
+ tokenizer.add_special_tokens({
84
+ "eos_token": "<|endoftext|>",
85
+ "bos_token": "<|endoftext|>",
86
+ "pad_token": "<|endoftext|>",
87
+ })
88
+ special_tokens_map = {
89
+ "bos_token": "<|endoftext|>",
90
+ "eos_token": "<|endoftext|>",
91
+ "pad_token": "<|endoftext|>",
92
+ }
93
+ os.makedirs(save_dir, exist_ok=True)
94
+
95
+ with open(os.path.join(save_dir, "special_tokens_map.json"), "w") as f:
96
+ json.dump(special_tokens_map, f, indent=2)
97
+
98
+ print("special_tokens_map.json written manually")
99
+ # ---- Save --------------------------------------------------------
100
+ # Saves three files to save_dir/:
101
+ # tokenizer.json - the trained BPE tokenizer
102
+ # tokenizer_config.json - max length, pad token, special tokens
103
+ # special_tokens_map.json - maps eos/bos/pad to actual tokens
104
+ tokenizer.save_pretrained(save_dir)
105
+ print(f"Tokenizer saved to: {save_dir}/")
106
+ print(f" tokenizer.json")
107
+ print(f" tokenizer_config.json")
108
+ print(f" special_tokens_map.json")
109
+
110
+ return tokenizer
111
+
112
+
113
+ # ------------------------------------------------------------------ #
114
+ # VERIFICATION
115
+ # ------------------------------------------------------------------ #
116
+
117
+ def verify_wrapped_tokenizer(tokenizer: PreTrainedTokenizerFast):
118
+ """
119
+ Verifies the wrapped tokenizer behaves correctly.
120
+ Tests encoding, decoding, padding, truncation and batch encoding.
121
+ """
122
+
123
+ print("\n" + "="*60)
124
+ print(" WRAPPED TOKENIZER VERIFICATION")
125
+ print("="*60 + "\n")
126
+
127
+ eot_id = tokenizer.eos_token_id
128
+
129
+ # ---- 1. Basic config -----------------------------------------
130
+ print("Config:")
131
+ print(f" vocab size : {tokenizer.vocab_size:,}")
132
+ print(f" model_max_length : {tokenizer.model_max_length}")
133
+ print(f" padding_side : {tokenizer.padding_side}")
134
+ print(f" eos_token : {tokenizer.eos_token!r} (ID: {eot_id})")
135
+ print(f" bos_token : {tokenizer.bos_token!r}")
136
+ print(f" pad_token : {tokenizer.pad_token!r} (ID: {tokenizer.pad_token_id})")
137
+ print(f" unk_token : {tokenizer.unk_token!r}")
138
+ print()
139
+
140
+ # ---- 2. Basic encode/decode ----------------------------------
141
+ text = "The mitochondria is the powerhouse of the cell."
142
+ encoded = tokenizer(text)
143
+ decoded = tokenizer.decode(encoded["input_ids"])
144
+
145
+ print("Basic encode/decode:")
146
+ print(f" input : {repr(text)}")
147
+ print(f" input_ids: {encoded['input_ids']}")
148
+ print(f" decoded : {repr(decoded)}")
149
+ print()
150
+
151
+ # ---- 3. Padding ----------------------------------------------
152
+ # Batch of two sequences with different lengths
153
+ # shorter one should be right-padded to match the longer
154
+ batch = [
155
+ "Short sentence.",
156
+ "This is a much longer sentence that has more tokens in it.",
157
+ ]
158
+
159
+ encoded_batch = tokenizer(
160
+ batch,
161
+ padding=True, # pad to longest in batch
162
+ return_tensors="pt", # return PyTorch tensors
163
+ )
164
+
165
+ print("Batch padding (right padding):")
166
+ print(f" input_ids shape : {encoded_batch['input_ids'].shape}")
167
+ print(f" attention_mask shape : {encoded_batch['attention_mask'].shape}")
168
+ print(f" input_ids[0] : {encoded_batch['input_ids'][0].tolist()}")
169
+ print(f" input_ids[1] : {encoded_batch['input_ids'][1].tolist()}")
170
+ print(f" attention_mask[0] : {encoded_batch['attention_mask'][0].tolist()}")
171
+ print()
172
+
173
+ # ---- 4. Truncation -------------------------------------------
174
+ # Sequence longer than model_max_length should be truncated
175
+ long_text = "word " * 2000 # 2000 words >> 1024 tokens
176
+ encoded_long = tokenizer(
177
+ long_text,
178
+ truncation=True,
179
+ max_length=MODEL_MAX_LENGTH,
180
+ )
181
+
182
+ print("Truncation:")
183
+ print(f" input length : {len(long_text.split())} words")
184
+ print(f" token count : {len(encoded_long['input_ids'])} (max: {MODEL_MAX_LENGTH})")
185
+ print(f" truncated : {len(encoded_long['input_ids']) <= MODEL_MAX_LENGTH}")
186
+ print()
187
+
188
+ # ---- 5. Load from disk and verify ----------------------------
189
+ print("Loading from disk:")
190
+ reloaded = PreTrainedTokenizerFast.from_pretrained(SAVE_DIR)
191
+ reloaded_ids = reloaded(text)["input_ids"]
192
+ original_ids = encoded["input_ids"]
193
+ match = reloaded_ids == original_ids
194
+
195
+ print(f" from_pretrained() : OK")
196
+ print(f" IDs match original: {match}")
197
+
198
+
199
+ # ------------------------------------------------------------------ #
200
+ # ENTRY POINT
201
+ # ------------------------------------------------------------------ #
202
+
203
+ if __name__ == "__main__":
204
+ tokenizer = wrap_tokenizer()
205
+ verify_wrapped_tokenizer(tokenizer)
206
+
207
+ print("\n" + "="*60)
208
+ print(" USAGE EXAMPLES")
209
+ print("="*60)
210
+ print("""
211
+ # Load anywhere with one line
212
+ from transformers import PreTrainedTokenizerFast
213
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("fineweb_edu_tokenizer")
214
+
215
+ # Single encode
216
+ ids = tokenizer("Hello world")["input_ids"]
217
+
218
+ # Batch encode with padding and tensors
219
+ batch = tokenizer(
220
+ ["sentence one", "sentence two"],
221
+ padding=True,
222
+ truncation=True,
223
+ max_length=1024,
224
+ return_tensors="pt",
225
+ )
226
+
227
+ # Decode
228
+ text = tokenizer.decode(ids, skip_special_tokens=True)
229
+
230
+ # Get eos token id (use as document separator when packing)
231
+ eot_id = tokenizer.eos_token_id
232
+ """)
tokenizer_walkthrough.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Walkthrough: SLLM Custom BPE Tokenizer
2
+
3
+ This document explains the architecture, execution pipeline, and design choices of the custom **Byte-Pair Encoding (BPE)** tokenizer implemented in the `tokenizer/` directory of the `sllm` project.
4
+
5
+ ---
6
+
7
+ ## 🏗️ Overall Architecture & Pipeline
8
+
9
+ The SLLM tokenizer is a custom-built BPE tokenizer tailored for pre-training small language models on the educational subset of HuggingFace's **FineWeb-Edu** dataset. It integrates custom text normalization, a regex-based pre-tokenization strategy, standard BPE training with byte-level fallback, and packaging utility scripts for high-performance training.
10
+
11
+ ```mermaid
12
+ graph TD
13
+ A[Raw Text Stream] --> B[normalizer.py: Normalization]
14
+ B --> C[pretokenizer.py: Custom Regex Split]
15
+ C --> D[bpe.py: Byte-Level Encoding]
16
+ D --> E[traintokenizer.py: BPE Trainer]
17
+ E --> F[post_processor.py: Template Post-Processing]
18
+ F --> G[wrap_tokenizer.py: PreTrainedTokenizerFast Wrapper]
19
+ G --> H[tokenize_dataset.py: Packed binary .bin Shards]
20
+ ```
21
+
22
+ ---
23
+
24
+ ## 📁 Component-by-Component Breakdown
25
+
26
+ ### 1. `normalizer.py` (Text Normalization)
27
+ Before any splitting occurs, the raw input text is standardized and cleaned to eliminate noise while preserving syntax and code structure:
28
+ * **HTML Stripping & Decoding**: Removes HTML tags using regex and decodes HTML entities (e.g., `&amp;` $\rightarrow$ `&`).
29
+ * **Unicode Normalization**: Performs **NFC** normalization to ensure characters like accented letters are represented consistently.
30
+ * **Noise Removal**: Eliminates raw control characters, zero-width characters (e.g., zero-width spaces/joins), and the Unicode replacement character (`\ufffd`).
31
+ * **Whitespace Control**:
32
+ * Collapses multiple consecutive spaces into a single space (preserving leading tabs for code indentation).
33
+ * Cleans trailing whitespaces at the end of lines.
34
+ * Collapses 3 or more consecutive newlines into exactly two newlines (`\n\n`) to preserve paragraph structure.
35
+
36
+ ---
37
+
38
+ ### 2. `pretokenizer.py` (Custom Regex Segmentation)
39
+ Instead of relying on standard GPT-2/Llama pre-tokenization, this model implements a custom, ordered, priority-based regex pre-tokenizer:
40
+ 1. **Contractions**: `'s`, `'t`, `'re`, `'ve`, `'ll`, `'m`, `'d`.
41
+ 2. **Abbreviations**: Acronyms and shorthand (e.g., `U.S.A`, `e.g.`, `Ph.D`).
42
+ 3. **Scientific Notation**: E.g., `1.5e-3`, `3e10`, `2.0E+4` (evaluated *before* decimals to avoid splitting).
43
+ 4. **Decimal Numbers**: E.g., `3.14` (evaluated *before* integers).
44
+ 5. **Integers**: E.g., `42`, `1984`.
45
+ 6. **Multi-character Operators**: Common coding operators like `==`, `!=`, `->`, `<=`, `>=`, `**`, `//`, `+=`, `-=`, `*=`, `/=`.
46
+ 7. **Snake Case Identifiers**: E.g., `snake_case`, `_private` (evaluated *before* plain words for clean code representation).
47
+ 8. **Regular Unicode Words**: Alphanumeric words covering non-English languages.
48
+ 9. **Whitespace**: Preserves sequences of spaces/tabs separately from newlines to keep structural formatting.
49
+ 10. **Punctuation Catch-all**: Individual punctuation characters.
50
+
51
+ > [!NOTE]
52
+ > The pre-tokenizer uses HuggingFace's `Split` pre-tokenizer with `behavior="isolated"` and `invert=True`, meaning matched strings are isolated and kept as distinct, individual tokens instead of being discarded as delimiters.
53
+
54
+ ---
55
+
56
+ ### 3. `bpe.py` (BPE Model Configuration)
57
+ Defines the base tokenizer pipeline:
58
+ * **Byte Fallback**: Configures the BPE model with `unk_token=None` and `byte_fallback=True`. This guarantees that *every* character maps to at least one byte-level token, resulting in **zero out-of-vocabulary (OOV)** issues.
59
+ * **Pre-Tokenizer Chain**: Sequentially runs the custom Regex pre-tokenizer followed by `ByteLevel(add_prefix_space=False)` to translate character segments to their corresponding byte values.
60
+ * **Decoder**: Instantiates the standard `ByteLevelDecoder` to reverse byte conversions, allowing human-readable decoded strings.
61
+ * **Trainer Config**: Builds a `BpeTrainer` specifying a vocabulary of `32,000` tokens, minimum merge frequency of `3`, and initial alphabet containing all `256` bytes to enforce the fallback capability.
62
+
63
+ ---
64
+
65
+ ### 4. `post_processor.py` (Sequence Endings)
66
+ Once BPE rules have been learned and vocabulary IDs are assigned:
67
+ * Attaches `TemplateProcessing` to automatically append `<|endoftext|>` to every sequence.
68
+ * For single documents, it maps to `[tokens...] <|endoftext|>`.
69
+ * For sequence pairs (useful in downstream tasks like Question-Answering), it automatically maps to `[tokens_A...] <|endoftext|> [tokens_B...] <|endoftext|>`.
70
+
71
+ ---
72
+
73
+ ### 5. `traintokenizer.py` (BPE Training Loop)
74
+ * Streams the educational subset of `HuggingFaceFW/fineweb-edu` (`CC-MAIN-2014-49` split).
75
+ * Filters out low-quality documents (requires educational score `int_score >= 3`) and documents shorter than 100 characters.
76
+ * Feeds documents iteratively into BPE training via `train_from_iterator()`.
77
+ * Adds the post-processor and runs comprehensive verification checks against edge cases (equations, scientific numbers, code snippets, byte fallbacks, and contractions).
78
+
79
+ ---
80
+
81
+ ### 6. `wrap_tokenizer.py` (HuggingFace Integration)
82
+ Wraps the trained HuggingFace BPE model into `PreTrainedTokenizerFast` from `transformers`:
83
+ * Associates `<|endoftext|>` as the `bos_token`, `eos_token`, and `pad_token`.
84
+ * Enables compatibility with the `datasets.map()` bulk utility, the HuggingFace Trainer, and PyTorch dataloaders.
85
+ * Standardizes right-padding, right-truncation, and context length configurations (`model_max_length=1024`).
86
+
87
+ ---
88
+
89
+ ### 7. `tokenize_dataset.py` (Dataset Packing)
90
+ A highly optimized bulk-tokenization utility:
91
+ * Tokenizes the streamed FineWeb-Edu dataset up to a target cap (e.g., `3.2` Billion tokens).
92
+ * Performs a 99% train and 1% validation split (every 100th document is routed to the validation buffer).
93
+ * Concatenates/packs documents sequentially (using `<|endoftext|>` as the document boundary) and writes them to disk as high-performance flat binary shards (`.bin` files of `np.uint16` type).
94
+ * Standard shard size is `100,000,000` tokens.
95
+ * Provides a memory-mapped helper `load_shard(split, shard_idx)` using `np.memmap` so that models can stream training batches without loading multi-gigabyte files into RAM.
96
+
97
+ ---
98
+
99
+ ## 💡 Key Design Highlights
100
+
101
+ > [!TIP]
102
+ > **Why Byte Fallback is Critical**: By initializing the alphabet with 256 unique byte values and enabling fallback, characters like math symbols ($\nabla$) or emojis don't fail or return an `<unk>` token; instead, they represent themselves as their raw UTF-8 bytes (e.g., $\nabla$ is parsed perfectly as `<0xE2><0x88><0x87>`).
103
+
104
+ > [!TIP]
105
+ > **Code-Aware Features**: The combination of preserving leading tabs in `normalizer.py`, isolating multi-character operators (`==`, `!=`, etc.), and protecting `snake_case` variables guarantees high-fidelity, compact token representation when the language model is trained on code.
train.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py — SLLM Training Loop
3
+
4
+ Supports:
5
+ --max_steps N Run for exactly N steps then save checkpoint and exit.
6
+ Omit to train indefinitely (until Ctrl+C or data exhausted).
7
+ --resume Resume from the latest checkpoint in --run_dir.
8
+ --config 100M|150M Choose model config (default: 100M).
9
+ --synthetic Use synthetic data (for testing without real shards).
10
+
11
+ Features:
12
+ - bf16 mixed precision (autocast) + GradScaler for stable training
13
+ - Gradient accumulation: --grad_accum N steps per optimizer update
14
+ - Gradient checkpointing: --grad_checkpoint to save VRAM
15
+ - Cosine LR schedule with linear warmup
16
+ - Checkpoint save every --save_every steps (and on clean exit/Ctrl+C)
17
+ - Metric logging to <run_dir>/train_log.jsonl (one JSON line per log step)
18
+ - Real-time terminal progress with tqdm
19
+
20
+ Recommended for RTX 3050 4GB:
21
+ python train.py --config 100M --batch_size 4 --grad_accum 8 \\
22
+ --grad_checkpoint --max_steps 1000
23
+
24
+ Run for N steps, stop, then resume:
25
+ python train.py --max_steps 500 --run_dir runs/my_run
26
+ python train.py --max_steps 500 --run_dir runs/my_run --resume
27
+ """
28
+
29
+ import os
30
+ import sys
31
+ import json
32
+ import math
33
+ import time
34
+ import signal
35
+ import argparse
36
+
37
+ import torch
38
+ import torch.nn.functional as F
39
+ from torch.amp import autocast, GradScaler
40
+ from tqdm import tqdm
41
+
42
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
43
+ from model.config import SLLM_100M, SLLM_150M, ModelConfig
44
+ from model.model import SLLM
45
+ from data.dataloader import build_dataloader
46
+
47
+
48
+ # ------------------------------------------------------------------ #
49
+ # ARG PARSING
50
+ # ------------------------------------------------------------------ #
51
+
52
+ def parse_args():
53
+ p = argparse.ArgumentParser(description="SLLM Training Loop")
54
+
55
+ # Run management
56
+ p.add_argument("--run_dir", type=str, default="runs/run_001", help="Directory for checkpoints and logs")
57
+ p.add_argument("--run_name", type=str, default=None, help="Override run name (defaults to run_dir basename)")
58
+ p.add_argument("--resume", action="store_true", help="Resume from latest checkpoint in run_dir")
59
+ p.add_argument("--max_steps", type=int, default=None, help="Absolute step target — stop when step reaches this number.")
60
+ p.add_argument("--extra_steps", type=int, default=None, help="Run N MORE steps from current checkpoint (relative). Converted to --max_steps internally.")
61
+
62
+ # Model
63
+ p.add_argument("--config", type=str, default="100M", choices=["100M", "150M"])
64
+
65
+ # Data
66
+ p.add_argument("--data_dir", type=str, default="tokenizer/data")
67
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic random data (for testing)")
68
+ p.add_argument("--num_workers",type=int, default=2)
69
+
70
+ # Training
71
+ p.add_argument("--batch_size", type=int, default=4, help="Per-device batch size")
72
+ p.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps")
73
+ p.add_argument("--max_lr", type=float, default=3e-4)
74
+ p.add_argument("--min_lr", type=float, default=3e-5)
75
+ p.add_argument("--warmup_steps", type=int, default=100)
76
+ p.add_argument("--weight_decay", type=float, default=0.1)
77
+ p.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping norm (0 = disabled)")
78
+
79
+ # Memory
80
+ p.add_argument("--grad_checkpoint", action="store_true", help="Enable gradient checkpointing (saves VRAM, slower)")
81
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
82
+
83
+ # Logging / Saving
84
+ p.add_argument("--log_every", type=int, default=10, help="Log metrics every N optimizer steps")
85
+ p.add_argument("--save_every", type=int, default=500, help="Save checkpoint every N optimizer steps")
86
+ p.add_argument("--val_every", type=int, default=250, help="Run validation every N optimizer steps")
87
+ p.add_argument("--val_steps", type=int, default=20, help="Number of val batches to average")
88
+
89
+ return p.parse_args()
90
+
91
+
92
+ # ------------------------------------------------------------------ #
93
+ # LEARNING RATE SCHEDULE
94
+ # ------------------------------------------------------------------ #
95
+
96
+ def get_lr(step: int, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float) -> float:
97
+ """
98
+ Linear warmup then cosine decay.
99
+ If total_steps is None (training indefinitely), uses a fixed 10k step decay window.
100
+ """
101
+ # Linear warmup
102
+ if step < warmup_steps:
103
+ return max_lr * (step + 1) / warmup_steps
104
+
105
+ # After decay: hold at min_lr
106
+ decay_steps = total_steps if total_steps else 10_000
107
+ if step >= decay_steps:
108
+ return min_lr
109
+
110
+ # Cosine decay
111
+ progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps)
112
+ coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
113
+ return min_lr + coeff * (max_lr - min_lr)
114
+
115
+
116
+ # ------------------------------------------------------------------ #
117
+ # OPTIMIZER (AdamW with selective weight decay)
118
+ # ------------------------------------------------------------------ #
119
+
120
+ def build_optimizer(model: SLLM, lr: float, weight_decay: float) -> torch.optim.AdamW:
121
+ """
122
+ AdamW with weight decay applied only to 2D params (Linear weights).
123
+ Excludes: embeddings, norms (RMSNorm weight vectors), biases.
124
+
125
+ This is the standard approach from GPT-2/NanoGPT.
126
+ """
127
+ decay_params = []
128
+ no_decay_params = []
129
+
130
+ for name, param in model.named_parameters():
131
+ if not param.requires_grad:
132
+ continue
133
+ # 2D tensors (weight matrices) get weight decay
134
+ if param.dim() >= 2:
135
+ decay_params.append(param)
136
+ else:
137
+ # 1D: norm weights, biases, embeddings
138
+ no_decay_params.append(param)
139
+
140
+ optim_groups = [
141
+ {"params": decay_params, "weight_decay": weight_decay},
142
+ {"params": no_decay_params, "weight_decay": 0.0},
143
+ ]
144
+
145
+ n_decay = sum(p.numel() for p in decay_params)
146
+ n_no_decay = sum(p.numel() for p in no_decay_params)
147
+ print(f" Optimizer: {n_decay/1e6:.1f}M decay params | {n_no_decay/1e6:.1f}M no-decay params")
148
+
149
+ return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
150
+
151
+
152
+ # ------------------------------------------------------------------ #
153
+ # CHECKPOINT SAVE / LOAD
154
+ # ------------------------------------------------------------------ #
155
+
156
+ def save_checkpoint(path: str, model: SLLM, optimizer, step: int, args, loss: float):
157
+ os.makedirs(os.path.dirname(path), exist_ok=True)
158
+ torch.save({
159
+ "step": step,
160
+ "model_state_dict": model.state_dict(),
161
+ "optimizer_state_dict": optimizer.state_dict(),
162
+ "loss": loss,
163
+ "config_name": args.config,
164
+ }, path)
165
+ print(f"\n [CKPT] Saved checkpoint: {path} (step={step}, loss={loss:.4f})")
166
+
167
+
168
+ def load_checkpoint(run_dir: str, model: SLLM, optimizer, device):
169
+ """Loads the latest checkpoint from run_dir. Returns step number."""
170
+ ckpts = sorted([
171
+ f for f in os.listdir(run_dir)
172
+ if f.startswith("ckpt_") and f.endswith(".pt")
173
+ ])
174
+ if not ckpts:
175
+ raise FileNotFoundError(f"No checkpoints found in {run_dir}")
176
+
177
+ path = os.path.join(run_dir, ckpts[-1])
178
+ ckpt = torch.load(path, map_location=device, weights_only=False)
179
+
180
+ model.load_state_dict(ckpt["model_state_dict"])
181
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
182
+
183
+ step = ckpt["step"]
184
+ loss = ckpt.get("loss", float("nan"))
185
+ print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})")
186
+ return step
187
+
188
+
189
+ # ------------------------------------------------------------------ #
190
+ # VALIDATION
191
+ # ------------------------------------------------------------------ #
192
+
193
+ @torch.no_grad()
194
+ def estimate_val_loss(model, val_loader, val_steps: int, device, dtype_ctx) -> float:
195
+ model.eval()
196
+ losses = []
197
+ for i, (x, y) in enumerate(val_loader):
198
+ if i >= val_steps:
199
+ break
200
+ x, y = x.to(device), y.to(device)
201
+ with dtype_ctx:
202
+ _, loss = model(x, y)
203
+ losses.append(loss.item())
204
+ model.train()
205
+ return sum(losses) / len(losses) if losses else float("nan")
206
+
207
+
208
+ # ------------------------------------------------------------------ #
209
+ # METRIC LOGGING
210
+ # ------------------------------------------------------------------ #
211
+
212
+ class MetricLogger:
213
+ """Appends one JSON line per step to train_log.jsonl."""
214
+
215
+ def __init__(self, log_path: str):
216
+ self.log_path = log_path
217
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
218
+ # Don't clear existing log when resuming — append
219
+ print(f" [LOG] Logging to: {log_path}")
220
+
221
+ def log(self, **kwargs):
222
+ with open(self.log_path, "a") as f:
223
+ f.write(json.dumps(kwargs) + "\n")
224
+
225
+
226
+ # ------------------------------------------------------------------ #
227
+ # MAIN TRAINING LOOP
228
+ # ------------------------------------------------------------------ #
229
+
230
+ def train():
231
+ args = parse_args()
232
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
+ print(f"\nDevice : {device}")
234
+ if device.type == "cuda":
235
+ print(f"GPU : {torch.cuda.get_device_name(0)}")
236
+ print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
237
+
238
+ # ---- dtype context --------------------------------------------- #
239
+ if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
240
+ dtype_torch = torch.bfloat16
241
+ dtype_name = "bf16"
242
+ elif args.dtype == "fp16" and device.type == "cuda":
243
+ dtype_torch = torch.float16
244
+ dtype_name = "fp16"
245
+ else:
246
+ dtype_torch = torch.float32
247
+ dtype_name = "fp32"
248
+
249
+ print(f"dtype : {dtype_name}")
250
+ use_amp = dtype_torch in (torch.float16, torch.bfloat16)
251
+ dtype_ctx = autocast(device_type=device.type, dtype=dtype_torch) if use_amp else torch.no_grad().__class__()
252
+ scaler = GradScaler(enabled=(dtype_torch == torch.float16)) # bf16 doesn't need scaler
253
+
254
+ # ---- Auto-detect config on resume ------------------------------ #
255
+ if args.resume:
256
+ try:
257
+ ckpts = sorted([
258
+ f for f in os.listdir(args.run_dir)
259
+ if f.startswith("ckpt_") and f.endswith(".pt")
260
+ ])
261
+ if ckpts:
262
+ ckpt_path = os.path.join(args.run_dir, ckpts[-1])
263
+ _tmp_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
264
+ if "config_name" in _tmp_ckpt and _tmp_ckpt["config_name"] != args.config:
265
+ print(f" [CKPT] Auto-switching config from '{args.config}' to '{_tmp_ckpt['config_name']}' to match checkpoint.")
266
+ args.config = _tmp_ckpt["config_name"]
267
+ del _tmp_ckpt
268
+ except Exception:
269
+ pass
270
+
271
+ # ---- Model ----------------------------------------------------- #
272
+ cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M}
273
+ cfg = cfg_map[args.config]
274
+ model = SLLM(cfg).to(device)
275
+
276
+ if args.grad_checkpoint:
277
+ model.enable_gradient_checkpointing()
278
+ print(" Gradient checkpointing: ON")
279
+
280
+ print(f"\nModel : SLLM-{args.config} ({model.count_params()/1e6:.1f}M params)")
281
+ print(f"Config : {cfg}")
282
+
283
+ # ---- Optimizer ------------------------------------------------- #
284
+ optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay)
285
+
286
+ # ---- Data ------------------------------------------------------ #
287
+ train_loader = build_dataloader(
288
+ data_dir = args.data_dir,
289
+ split = "train",
290
+ context_length = cfg.context_length,
291
+ batch_size = args.batch_size,
292
+ num_workers = args.num_workers,
293
+ use_synthetic = args.synthetic,
294
+ vocab_size = cfg.vocab_size,
295
+ )
296
+ val_loader = build_dataloader(
297
+ data_dir = args.data_dir,
298
+ split = "val",
299
+ context_length = cfg.context_length,
300
+ batch_size = args.batch_size,
301
+ num_workers = 0,
302
+ use_synthetic = args.synthetic,
303
+ vocab_size = cfg.vocab_size,
304
+ )
305
+
306
+ # ---- Run directory --------------------------------------------- #
307
+ os.makedirs(args.run_dir, exist_ok=True)
308
+ log_path = os.path.join(args.run_dir, "train_log.jsonl")
309
+ logger = MetricLogger(log_path)
310
+
311
+ # ---- Resume ---------------------------------------------------- #
312
+ start_step = 0
313
+ if args.resume:
314
+ try:
315
+ start_step = load_checkpoint(args.run_dir, model, optimizer, device)
316
+ except FileNotFoundError as e:
317
+ print(f" [WARN] {e} — starting from scratch.")
318
+
319
+ # ---- Effective batch size info --------------------------------- #
320
+ eff_batch = args.batch_size * args.grad_accum
321
+ tokens_per_step = eff_batch * cfg.context_length
322
+ print(f"\nTraining:")
323
+ # ---- Resolve extra_steps -> max_steps -------------------------- #
324
+ if args.extra_steps is not None:
325
+ if args.max_steps is not None:
326
+ print(" [WARN] Both --extra_steps and --max_steps given. --extra_steps takes priority.")
327
+ args.max_steps = start_step + args.extra_steps
328
+ print(f" [INFO] --extra_steps {args.extra_steps} → running until step {args.max_steps}")
329
+
330
+ print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} -> effective={eff_batch})")
331
+ print(f" tokens/step : {tokens_per_step:,}")
332
+ print(f" max_steps : {args.max_steps or 'unlimited'} (absolute step target)")
333
+ print(f" start_step : {start_step}")
334
+ print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else 'unlimited'}")
335
+ print(f" save_every : {args.save_every}")
336
+ print(f" log_every : {args.log_every}")
337
+
338
+ # ---- Early exit if already past max_steps ---------------------- #
339
+ if args.max_steps is not None and start_step >= args.max_steps:
340
+ print(f"\n [WARN] start_step ({start_step}) >= max_steps ({args.max_steps}).")
341
+ print(f" Nothing to train. Use --extra_steps N to run N more steps.")
342
+ print(f"\nExample: python train.py --resume --run_dir {args.run_dir} --extra_steps 5000")
343
+ return
344
+
345
+ # ---- Graceful Ctrl+C handler ----------------------------------- #
346
+ stop_flag = {"stop": False}
347
+ def _signal_handler(sig, frame):
348
+ print("\n [SIGNAL] Ctrl+C received — will save checkpoint and exit after current step.")
349
+ stop_flag["stop"] = True
350
+ signal.signal(signal.SIGINT, _signal_handler)
351
+
352
+ # ---- Training loop --------------------------------------------- #
353
+ model.train()
354
+ step = start_step
355
+ micro_step = 0 # within grad_accum window
356
+ running_loss = 0.0 # accumulated for logging
357
+ t_start = time.time()
358
+ t_step_start = time.time()
359
+ data_iter = iter(train_loader)
360
+
361
+ print(f"\n{'='*60}")
362
+ print(f" TRAINING STARTED (step {step} -> {args.max_steps or '∞'})")
363
+ print(f"{'='*60}\n")
364
+
365
+ pbar = tqdm(
366
+ initial=step,
367
+ total=args.max_steps,
368
+ desc="Training",
369
+ unit="step",
370
+ dynamic_ncols=True,
371
+ )
372
+
373
+ while True:
374
+ # ---- Stop conditions --------------------------------------- #
375
+ if stop_flag["stop"]:
376
+ break
377
+ if args.max_steps is not None and step >= args.max_steps:
378
+ print(f"\n [DONE] Reached max_steps={args.max_steps}")
379
+ break
380
+
381
+ optimizer.zero_grad(set_to_none=True)
382
+ accum_loss = 0.0
383
+
384
+ # ---- Gradient accumulation micro-steps --------------------- #
385
+ for micro in range(args.grad_accum):
386
+ # Get next batch
387
+ try:
388
+ x, y = next(data_iter)
389
+ except StopIteration:
390
+ data_iter = iter(train_loader)
391
+ x, y = next(data_iter)
392
+
393
+ x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
394
+
395
+ # Forward + loss (inside AMP context)
396
+ with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
397
+ logits, loss = model(x, y)
398
+ # Scale loss by grad_accum so gradients average correctly
399
+ loss = loss / args.grad_accum
400
+
401
+ # Backward
402
+ scaler.scale(loss).backward()
403
+ accum_loss += loss.item()
404
+
405
+ # ---- Gradient clipping ------------------------------------- #
406
+ if args.grad_clip > 0:
407
+ scaler.unscale_(optimizer)
408
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
409
+ else:
410
+ grad_norm = float("nan")
411
+
412
+ # ---- LR update --------------------------------------------- #
413
+ lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr)
414
+ for pg in optimizer.param_groups:
415
+ pg["lr"] = lr
416
+
417
+ # ---- Optimizer step ---------------------------------------- #
418
+ scaler.step(optimizer)
419
+ scaler.update()
420
+
421
+ step += 1
422
+ running_loss = accum_loss # loss for this step
423
+
424
+ # ---- Tokens per second ------------------------------------- #
425
+ t_now = time.time()
426
+ elapsed = t_now - t_step_start
427
+ t_step_start = t_now
428
+ tok_per_sec = tokens_per_step / max(elapsed, 1e-6)
429
+
430
+ # ---- Progress bar update ----------------------------------- #
431
+ pbar.update(1)
432
+ pbar.set_postfix({
433
+ "loss": f"{running_loss:.4f}",
434
+ "lr": f"{lr:.2e}",
435
+ "tok/s": f"{tok_per_sec:.0f}",
436
+ })
437
+
438
+ # ---- Logging ----------------------------------------------- #
439
+ if step % args.log_every == 0:
440
+ log_entry = {
441
+ "step": step,
442
+ "loss": round(running_loss, 6),
443
+ "lr": lr,
444
+ "grad_norm": round(float(grad_norm), 4) if not math.isnan(float(grad_norm)) else None,
445
+ "tok_per_sec": round(tok_per_sec, 1),
446
+ "elapsed_s": round(t_now - t_start, 1),
447
+ }
448
+ if device.type == "cuda":
449
+ log_entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3)
450
+ logger.log(**log_entry)
451
+
452
+ # ---- Validation -------------------------------------------- #
453
+ if step % args.val_every == 0:
454
+ val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp))
455
+ tqdm.write(f" [STEP {step:6d}] train_loss={running_loss:.4f} val_loss={val_loss:.4f} lr={lr:.2e}")
456
+ logger.log(step=step, val_loss=round(val_loss, 6))
457
+
458
+ # ---- Checkpoint -------------------------------------------- #
459
+ if step % args.save_every == 0:
460
+ ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt")
461
+ save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss)
462
+
463
+ # ---- Final checkpoint on exit (only if we actually ran steps) -- #
464
+ pbar.close()
465
+ steps_done = step - start_step
466
+ if steps_done > 0:
467
+ ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt")
468
+ save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss)
469
+ else:
470
+ print("\n [SKIP] No steps were taken — skipping final checkpoint save.")
471
+
472
+ total_time = time.time() - t_start
473
+ print(f"\n{'='*60}")
474
+ print(f" TRAINING COMPLETE")
475
+ print(f"{'='*60}")
476
+ print(f" Steps completed : {step - start_step}")
477
+ print(f" Final loss : {running_loss:.4f}")
478
+ print(f" Total time : {total_time/60:.1f} min")
479
+ print(f" Run dir : {args.run_dir}")
480
+ print(f"\nTo resume: python train.py --resume --run_dir {args.run_dir} --max_steps <N>")
481
+ print(f"To plot : python plot_training.py --run_dir {args.run_dir}")
482
+
483
+
484
+ if __name__ == "__main__":
485
+ train()