Sualeh Qureshi commited on
Commit
c175ce3
·
0 Parent(s):

Commited the training code and model file

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Checkpoints
13
+ checkpoints/
14
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md ADDED
File without changes
README_TRAINING.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM2-135M Training Guide
2
+
3
+ This directory contains the training code for SmolLM2-135M model.
4
+
5
+ ## Files
6
+
7
+ - `model.py`: Model definition with KV cache support for inference
8
+ - `train.py`: Main training script (trains for 5000 steps)
9
+ - Run with checkpoint path to Resume training for 50 additional steps
10
+
11
+ ## Setup
12
+
13
+ Install required packages:
14
+
15
+ ```bash
16
+ pip install torch lightning transformers tensorboard
17
+ ```
18
+
19
+ ## Training
20
+
21
+ ### Phase 1: Initial Training (5000 steps)
22
+
23
+ Run the main training script:
24
+
25
+ ```bash
26
+ python train.py
27
+ ```
28
+
29
+ This will:
30
+ - Train the model for 5000 steps
31
+ - Generate text predictions every 500 steps
32
+ - Save checkpoints every 500 steps
33
+ - Log training metrics to TensorBoard and text file
34
+ - Save the final checkpoint at step 5000
35
+
36
+ ### Phase 2: Resume Training (50 additional steps)
37
+
38
+ After Phase 1 completes, run:
39
+
40
+ ```bash
41
+ python train.py
42
+ ```
43
+
44
+ But this time set the checkpoint path, and set steps as 50 to resume training for 50 additional steps. just to showcase that training is started where it stopped.
45
+
46
+
47
+ This will:
48
+ - Load the checkpoint from Phase 1
49
+ - Train for 50 additional steps
50
+ - Save the final checkpoint
51
+
52
+ ## Training Configuration
53
+
54
+ The training uses the following hyperparameters (from the SmolLM2 paper):
55
+
56
+ - **Optimizer**: AdamW with (β₁, β₂) = (0.9, 0.95)
57
+ - **Learning Rate Schedule**: Warmup Stable Decay (WSD)
58
+ - Warmup: 2000 steps
59
+ - Peak LR: 5.0 × 10⁻⁴
60
+ - Stable phase: maintains peak LR
61
+ - Decay: reduces to zero over 10% of total steps
62
+ - **Block size**: 512 tokens
63
+ - **Batch size**: 4
64
+ - **Precision**: bfloat16 (if GPU available), float32 otherwise
65
+
66
+ ## Outputs
67
+
68
+ - **Checkpoints**: Saved in `./checkpoints/`
69
+ - **TensorBoard logs**: Saved in `./logs/tensorboard/`
70
+ - **Text logs**: Saved in `./logs/training_*.log`
71
+
72
+ ## Model Features
73
+
74
+ The model includes:
75
+ - **KV Cache**: Efficient inference using key-value caching
76
+ - **Generation**: Text generation with top-k and top-p sampling
77
+ - **Checkpointing**: Full state saving for resuming training
78
+
79
+ ## Usage Example
80
+
81
+ ```python
82
+ from model import SmolLM2, SmolConfig
83
+ from transformers import AutoTokenizer, AutoConfig
84
+
85
+ # Load config
86
+ hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
87
+ config = SmolConfig.from_hf(hf_config)
88
+
89
+ # Create model
90
+ model = SmolLM2(config)
91
+
92
+ # Load checkpoint
93
+ checkpoint = torch.load("checkpoints/smollm2-00500-*.ckpt")
94
+ model.load_state_dict(checkpoint['state_dict'])
95
+
96
+ # Generate text
97
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
98
+ prompt = "First Citizen:"
99
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
100
+
101
+ generated_ids = model.generate(
102
+ input_ids,
103
+ max_new_tokens=100,
104
+ temperature=0.8,
105
+ top_k=50,
106
+ )
107
+
108
+ generated_text = tokenizer.decode(generated_ids[0])
109
+ print(generated_text)
110
+ ```
logs/tensorboard/version_0/events.out.tfevents.1765268407.MAC-QNYQPC2R2T.88043.0 ADDED
Binary file (5.59 kB). View file
 
logs/tensorboard/version_0/hparams.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ block_size: 512
2
+ peak_lr: 0.0005
3
+ predict_every: 500
4
+ total_steps: 5000
5
+ warmup_steps: 1000
logs/tensorboard/version_1/events.out.tfevents.1765274926.MAC-QNYQPC2R2T.7268.0 ADDED
Binary file (88 Bytes). View file
 
logs/tensorboard/version_2/events.out.tfevents.1765275552.MAC-QNYQPC2R2T.7768.0 ADDED
Binary file (2.8 kB). View file
 
logs/tensorboard/version_2/hparams.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ block_size: 512
2
+ peak_lr: 0.0005
3
+ predict_every: 500
4
+ total_steps: 3500
5
+ warmup_steps: 1000
logs/training_20251209_135005.log ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-12-09 13:50:05,106 - INFO - Logging to: logs/training_20251209_135005.log
2
+ 2025-12-09 13:50:05,106 - INFO - Loading tokenizer...
3
+ 2025-12-09 13:50:05,965 - INFO - Loading model config...
4
+ 2025-12-09 13:50:06,205 - INFO - Loading dataset from: /Users/qureshsu/Learning/TSAI/ERAV4/session13/data/input.txt
5
+ 2025-12-09 13:50:06,657 - INFO - Initializing model...
6
+ 2025-12-09 13:50:07,391 - INFO - Starting training...
7
+ 2025-12-09 13:50:24,556 - INFO -
8
+ ================================================================================
9
+ 2025-12-09 13:50:24,557 - INFO - MODEL SUMMARY
10
+ 2025-12-09 13:50:24,557 - INFO - ================================================================================
11
+ 2025-12-09 13:50:24,557 - INFO - Model: SmolLM2-135M
12
+ 2025-12-09 13:50:24,557 - INFO - Total parameters: 134,515,008
13
+ 2025-12-09 13:50:24,557 - INFO - Trainable parameters: 134,515,008
14
+ 2025-12-09 13:50:24,557 - INFO - Block size: 512
15
+ 2025-12-09 13:50:24,557 - INFO - Warmup steps: 1000
16
+ 2025-12-09 13:50:24,557 - INFO - Peak learning rate: 0.0005
17
+ 2025-12-09 13:50:24,557 - INFO - Total training steps: 5000
18
+ 2025-12-09 13:50:24,557 - INFO - Predict every: 500 steps
19
+ 2025-12-09 13:50:24,557 - INFO - ================================================================================
20
+
21
+ 2025-12-09 14:05:59,075 - INFO -
22
+ ================================================================================
23
+ 2025-12-09 14:05:59,081 - INFO - Step 500 - Generated text:
24
+ 2025-12-09 14:05:59,081 - INFO - First Citizen:
25
+ WhatONEONE:
26
+ DUKE VINCENTIO:
27
+ DUKE VINCENTIO:
28
+ Nay, thou art thou pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow
29
+ 2025-12-09 14:05:59,081 - INFO - ================================================================================
30
+
31
+ 2025-12-09 14:21:21,767 - INFO -
32
+ ================================================================================
33
+ 2025-12-09 14:21:21,771 - INFO - Step 1000 - Generated text:
34
+ 2025-12-09 14:21:21,771 - INFO - First Citizen:
35
+ And then, like thee: thou hast thou dost in thy husband'st:
36
+ And in thy soldiers, not in thy master's name,
37
+ Which then in thy shame: I did thy shame,
38
+ Which thou doth know her
39
+ 2025-12-09 14:21:21,771 - INFO - ================================================================================
40
+
41
+ 2025-12-09 14:37:17,744 - INFO -
42
+ ================================================================================
43
+ 2025-12-09 14:37:17,748 - INFO - Step 1500 - Generated text:
44
+ 2025-12-09 14:37:17,748 - INFO - First Citizen:
45
+ I have done a'rt too that, if the king had title to the
46
+ Where it shall be the is born to be in the tongue.
47
+
48
+ Second Citizen:
49
+ And so shall I.
50
+
51
+ ANTONIO:
52
+ I
53
+ 2025-12-09 14:37:17,748 - INFO - ================================================================================
54
+
logs/training_20251209_154910.log ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-12-09 15:49:10,023 - INFO - Logging to: logs/training_20251209_154910.log
2
+ 2025-12-09 15:49:10,023 - INFO - Loading tokenizer...
3
+ 2025-12-09 15:49:10,936 - INFO - Loading model config...
4
+ 2025-12-09 15:49:11,184 - INFO - Loading dataset from: /Users/qureshsu/Learning/TSAI/ERAV4/session13/data/input.txt
5
+ 2025-12-09 15:49:11,623 - INFO - Initializing model...
6
+ 2025-12-09 15:49:12,354 - INFO - Starting training...
7
+ 2025-12-09 15:49:12,357 - INFO - Resuming from checkpoint: checkpoints/smollm2-step=01500-train_loss=3.6240.ckpt
8
+ 2025-12-09 15:49:30,901 - INFO -
9
+ ================================================================================
10
+ 2025-12-09 15:49:30,901 - INFO - MODEL SUMMARY
11
+ 2025-12-09 15:49:30,901 - INFO - ================================================================================
12
+ 2025-12-09 15:49:30,901 - INFO - Model: SmolLM2-135M
13
+ 2025-12-09 15:49:30,901 - INFO - Total parameters: 134,515,008
14
+ 2025-12-09 15:49:30,901 - INFO - Trainable parameters: 134,515,008
15
+ 2025-12-09 15:49:30,901 - INFO - Block size: 512
16
+ 2025-12-09 15:49:30,901 - INFO - Warmup steps: 1000
17
+ 2025-12-09 15:49:30,901 - INFO - Peak learning rate: 0.0005
18
+ 2025-12-09 15:49:30,901 - INFO - Total training steps: 3500
19
+ 2025-12-09 15:49:30,901 - INFO - Predict every: 500 steps
20
+ 2025-12-09 15:49:30,901 - INFO - ================================================================================
21
+
22
+ 2025-12-09 15:59:45,441 - INFO - Step 2000 | train_loss=0.9070
23
+ 2025-12-09 15:59:47,487 - INFO -
24
+ ================================================================================
25
+ 2025-12-09 15:59:47,487 - INFO - Step 2000 - Generated text:
26
+ 2025-12-09 15:59:47,488 - INFO - First Citizen:
27
+ Why, no; but the Hortenspur, and
28
+ To perricks. Thou art said so when a king
29
+ Hadst thouable to be ruled, and not to forget
30
+ At any man.
31
+
32
+ First Citizen:
33
+ None,
34
+ 2025-12-09 15:59:47,488 - INFO - ================================================================================
35
+
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from smollm-135!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
model.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Minimal SmolLM2-135M style model implemented in PyTorch.
3
+ # Architecture: LLaMA-style decoder-only Transformer with:
4
+ # - RMSNorm
5
+ # - RoPE positional encoding
6
+ # - SwiGLU MLP
7
+ # - Grouped (GQA/MQA) attention: num_attention_heads != num_key_value_heads
8
+ #
9
+ # This file is self-contained (except PyTorch) and can be used as:
10
+ #
11
+ # from model import SmolConfig, SmolLM2
12
+ #
13
+ # cfg = SmolConfig.from_hf("HuggingFaceTB/SmolLM2-135M")
14
+ # model = SmolLM2(cfg)
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, List
18
+
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ # =========================
25
+ # 1. Config
26
+
27
+ # Got config from HuggingFace Using: transformers.AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
28
+
29
+ # Config: SmolLM2-135M
30
+
31
+ # LlamaConfig {
32
+ # "architectures": [
33
+ # "LlamaForCausalLM"
34
+ # ],
35
+ # "attention_bias": false,
36
+ # "attention_dropout": 0.0,
37
+ # "bos_token_id": 0,
38
+ # "dtype": "bfloat16",
39
+ # "eos_token_id": 0,
40
+ # "head_dim": 64,
41
+ # "hidden_act": "silu",
42
+ # "hidden_size": 576,
43
+ # "initializer_range": 0.041666666666666664,
44
+ # "intermediate_size": 1536,
45
+ # "is_llama_config": true,
46
+ # "max_position_embeddings": 8192,
47
+ # "mlp_bias": false,
48
+ # "model_type": "llama",
49
+ # "num_attention_heads": 9,
50
+ # "num_hidden_layers": 30,
51
+ # "num_key_value_heads": 3,
52
+ # "pretraining_tp": 1,
53
+ # "rms_norm_eps": 1e-05,
54
+ # "rope_interleaved": false,
55
+ # "rope_scaling": null,
56
+ # "rope_theta": 100000,
57
+ # "tie_word_embeddings": true,
58
+ # "transformers_version": "4.57.3",
59
+ # "use_cache": true,
60
+ # "vocab_size": 49152
61
+ # }
62
+ # =========================
63
+
64
+ @dataclass
65
+ class SmolConfig:
66
+ # Core dimensions
67
+ vocab_size: int = 49152 # from HF config
68
+ hidden_size: int = 576 # "hidden_size"
69
+ intermediate_size: int = 1536 # "intermediate_size"
70
+ num_hidden_layers: int = 30 # "num_hidden_layers"
71
+ num_attention_heads: int = 9 # "num_attention_heads"
72
+ num_key_value_heads: int = 3 # "num_key_value_heads"
73
+ max_position_embeddings: int = 8192 # "max_position_embeddings"
74
+
75
+ # Positional / RoPE
76
+ rope_theta: float = 100000.0 # "rope_theta"
77
+
78
+ # Norm / numerical
79
+ rms_norm_eps: float = 1e-5 # "rms_norm_eps"
80
+
81
+ # Biases
82
+ attention_bias: bool = False # "attention_bias"
83
+ mlp_bias: bool = False # "mlp_bias"
84
+
85
+ # Misc
86
+ dtype: torch.dtype = torch.bfloat16
87
+
88
+ @property
89
+ def head_dim(self) -> int:
90
+ # Should be 64 for SmolLM2-135M (576 / 9).
91
+ return self.hidden_size // self.num_attention_heads # 576 / 9 = 64
92
+
93
+ @classmethod
94
+ def from_hf(cls, hf_config) -> "SmolConfig":
95
+ """
96
+ Helper to build this config from a transformers LlamaConfig (Which is the config for the HuggingFace SmolLM2-135M model).
97
+ Example:
98
+ from transformers import AutoConfig
99
+ hf = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
100
+ cfg = SmolConfig.from_hf(hf)
101
+ And then pass this config to this function call to set the config for the model.
102
+ """
103
+ return cls(
104
+ vocab_size=hf_config.vocab_size,
105
+ hidden_size=hf_config.hidden_size,
106
+ intermediate_size=hf_config.intermediate_size,
107
+ num_hidden_layers=hf_config.num_hidden_layers,
108
+ num_attention_heads=hf_config.num_attention_heads,
109
+ num_key_value_heads=getattr(hf_config, "num_key_value_heads",
110
+ hf_config.num_attention_heads),
111
+ max_position_embeddings=hf_config.max_position_embeddings,
112
+ rope_theta=getattr(hf_config, "rope_theta", 10000.0),
113
+ rms_norm_eps=hf_config.rms_norm_eps,
114
+ attention_bias=getattr(hf_config, "attention_bias", False),
115
+ mlp_bias=getattr(hf_config, "mlp_bias", False),
116
+ dtype=torch.bfloat16, # SmolLM2 uses bfloat16
117
+ )
118
+
119
+ # =========================
120
+ # 2. RMSNorm
121
+ # =========================
122
+
123
+ class RMSNorm(nn.Module):
124
+ """
125
+ Root Mean Square Layer Normalization (RMSNorm)
126
+ Used in LLaMA / SmolLM2 instead of LayerNorm.
127
+ """
128
+ def __init__(self, dim: int, eps: float = 1e-5):
129
+ super().__init__()
130
+ self.eps = eps
131
+ self.weight = nn.Parameter(torch.ones(dim))
132
+
133
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
134
+ # x: (..., dim)
135
+ # rms = sqrt(mean(x^2)), but we can use rsqrt for stability
136
+ norm = x.pow(2).mean(dim=-1, keepdim=True)
137
+ x = x * torch.rsqrt(norm + self.eps)
138
+ return self.weight * x
139
+
140
+ # =========================
141
+ # 3. RoPE (Rotary Positional Embeddings)
142
+ # =========================
143
+
144
+ def rope_freqs(head_dim: int, base: float, device, dtype):
145
+ """
146
+ Compute inverse frequencies for RoPE.
147
+ """
148
+ half_dim = head_dim // 2
149
+ # Equivalent to: base^{ -2i / d }
150
+ freq_seq = torch.arange(half_dim, device=device, dtype=dtype)
151
+ inv_freq = 1.0 / (base ** (freq_seq / half_dim))
152
+ return inv_freq # shape: (half_dim,)
153
+
154
+ def build_rope_cache(
155
+ seq_len: int,
156
+ head_dim: int,
157
+ base: float,
158
+ device,
159
+ dtype,
160
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
161
+ """
162
+ Build cosine and sine caches for RoPE.
163
+ Returns:
164
+ cos: (1, 1, seq_len, head_dim/2)
165
+ sin: (1, 1, seq_len, head_dim/2)
166
+ """
167
+ inv_freq = rope_freqs(head_dim, base, device, dtype) # (half_dim,)
168
+ # Positions
169
+ t = torch.arange(seq_len, device=device, dtype=dtype) # (seq_len,)
170
+ freqs = torch.outer(t, inv_freq) # (seq_len, half_dim)
171
+ cos = freqs.cos()[None, None, :, :] # (1,1,seq_len,half_dim)
172
+ sin = freqs.sin()[None, None, :, :] # (1,1,seq_len,half_dim)
173
+ return cos, sin
174
+
175
+ def apply_rope(
176
+ x: torch.Tensor, # (B, n_head, T, head_dim)
177
+ cos: torch.Tensor,
178
+ sin: torch.Tensor,
179
+ ) -> torch.Tensor:
180
+ """
181
+ Apply RoPE to last dimension of x.
182
+ cos, sin are broadcast to match (..., head_dim/2).
183
+ """
184
+ b, h, t, d = x.shape
185
+ half = d // 2
186
+
187
+ x1 = x[..., :half] # (B, n_head, T, head_dim/2)
188
+ x2 = x[..., half:] # (B, n_head, T, head_dim/2)
189
+
190
+ # cos/sin: (1,1,T,half) -> broadcast over B,h
191
+ cos_t = cos[..., :t, :]
192
+ sin_t = sin[..., :t, :]
193
+
194
+ x1_rot = x1 * cos_t - x2 * sin_t
195
+ x2_rot = x1 * sin_t + x2 * cos_t
196
+
197
+ return torch.cat([x1_rot, x2_rot], dim=-1) # (B, n_head, T, head_dim)
198
+
199
+ # =========================
200
+ # 4. Attention
201
+ # =========================
202
+
203
+ class MultiHeadSelfAttention(nn.Module):
204
+ """
205
+ LLaMA / SmolLM2-style attention with:
206
+ - Q heads = num_attention_heads
207
+ - K/V heads = num_key_value_heads (GQA/MQA)
208
+ - RoPE on Q and K
209
+ - Causal masking
210
+ """
211
+ def __init__(self, config: SmolConfig):
212
+ super().__init__()
213
+
214
+ self.config = config
215
+ self.n_heads = config.num_attention_heads # 9
216
+ self.n_kv_heads = config.num_key_value_heads # 3
217
+ self.head_dim = config.head_dim # 64
218
+ self.hidden_size = config.hidden_size # 576
219
+
220
+ assert self.hidden_size == self.n_heads * self.head_dim
221
+
222
+ # Projections
223
+ self.q_proj = nn.Linear(
224
+ self.hidden_size,
225
+ self.n_heads * self.head_dim,
226
+ bias=config.attention_bias,
227
+ )
228
+ self.k_proj = nn.Linear(
229
+ self.hidden_size,
230
+ self.n_kv_heads * self.head_dim,
231
+ bias=config.attention_bias,
232
+ )
233
+ self.v_proj = nn.Linear(
234
+ self.hidden_size,
235
+ self.n_kv_heads * self.head_dim,
236
+ bias=config.attention_bias,
237
+ )
238
+
239
+ self.o_proj = nn.Linear(
240
+ self.n_heads * self.head_dim,
241
+ self.hidden_size,
242
+ bias=config.attention_bias,
243
+ )
244
+
245
+ def forward(
246
+ self,
247
+ x: torch.Tensor, # (B, T, C) or (B, 1, C) for inference
248
+ cos: torch.Tensor, # (1,1,T,head_dim/2) or (1,1,1,head_dim/2) for inference
249
+ sin: torch.Tensor, # (1,1,T,head_dim/2) or (1,1,1,head_dim/2) for inference
250
+ attention_mask: Optional[torch.Tensor] = None, # (B, T) or (B,1,1,T)
251
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (k_cache, v_cache)
252
+ use_cache: bool = False,
253
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
254
+ B, T, C = x.shape
255
+
256
+ # Projections: (B,T,C) -> (B,T,h,d) -> (B,h,T,d)
257
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,h*d) -> (B,T,h,d) -> (B,h,T,d)
258
+ k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,k*d) -> (B,T,k,d) -> (B,k,T,d)
259
+ v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,v*d) -> (B,T,v,d) -> (B,v,T,d)
260
+
261
+ # Apply RoPE to Q and K
262
+ q = apply_rope(q, cos, sin) # (B, h, T, d)
263
+ k = apply_rope(k, cos, sin) # (B, n_kv_heads, T, d)
264
+ # v doesn't need RoPE
265
+
266
+ # If using KV cache, concatenate with past keys/values
267
+ if past_key_value is not None:
268
+ past_k, past_v = past_key_value
269
+ # past_k, past_v: (B, n_kv_heads, past_len, head_dim)
270
+ k = torch.cat([past_k, k], dim=2) # (B, n_kv_heads, past_len + T, head_dim)
271
+ v = torch.cat([past_v, v], dim=2) # (B, n_kv_heads, past_len + T, head_dim)
272
+ seq_len = k.shape[2]
273
+ else:
274
+ seq_len = T
275
+
276
+ # Store k, v for cache (before GQA expansion)
277
+ k_cache = k # (B, n_kv_heads, seq_len, head_dim)
278
+ v_cache = v # (B, n_kv_heads, seq_len, head_dim)
279
+
280
+ # GQA: expand K/V if num_kv_heads < num_heads
281
+ if self.n_kv_heads != self.n_heads:
282
+ repeat_factor = self.n_heads // self.n_kv_heads
283
+ k = k.repeat_interleave(repeat_factor, dim=1) # (B, n_kv_heads, seq_len, d) -> (B, n_heads, seq_len, d)
284
+ v = v.repeat_interleave(repeat_factor, dim=1) # (B, n_kv_heads, seq_len, d) -> (B, n_heads, seq_len, d)
285
+
286
+ # Attention scores: (B,h,T,d) @ (B,h,d,seq_len) -> (B,h,T,seq_len)
287
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
288
+
289
+ # Causal mask: prevent attending to future tokens
290
+ # For inference with KV cache, we only need to mask the current position
291
+ if past_key_value is None:
292
+ # Full sequence: mask all future positions
293
+ causal_mask = torch.full(
294
+ (T, T), float("-inf"), device=x.device, dtype=x.dtype
295
+ ).triu(1) # upper triangle (i < j)
296
+ scores = scores + causal_mask.unsqueeze(0).unsqueeze(0) # (B,h,T,T) + (1,1,T,T) -> (B,h,T,T)
297
+ else:
298
+ # With KV cache: only mask positions beyond current (shouldn't happen, but safety)
299
+ # Since we're generating one token at a time, T=1, and we attend to all past + current
300
+ pass
301
+
302
+ # Optional attention mask (e.g., padding). Should be additive (0 or -inf).
303
+ if attention_mask is not None:
304
+ # Expect attention_mask as (B, 1, 1, seq_len) or (B, seq_len)
305
+ if attention_mask.dim() == 2:
306
+ # (B, seq_len) -> (B,1,1,seq_len)
307
+ attention_mask = attention_mask[:, None, None, :]
308
+ # Adjust mask shape if needed
309
+ if attention_mask.shape[-1] != seq_len:
310
+ # For inference, we might need to extend the mask
311
+ if past_key_value is not None:
312
+ # Extend mask to include past positions (all 0s for past, current mask for new token)
313
+ past_len = past_k.shape[2]
314
+ extended_mask = torch.zeros(B, 1, 1, seq_len, device=attention_mask.device, dtype=attention_mask.dtype)
315
+ extended_mask[..., past_len:] = attention_mask[..., -T:]
316
+ attention_mask = extended_mask
317
+ scores = scores + attention_mask
318
+
319
+ # Softmax over last dim (seq_len)
320
+ probs = F.softmax(scores, dim=-1) # (B,h,T,seq_len) -> (B,h,T,seq_len)
321
+
322
+ # Weighted sum of values
323
+ out = torch.matmul(probs, v) # (B,h,T,seq_len) @ (B,h,seq_len,d) -> (B,h,T,d)
324
+
325
+ # Reshape back: (B,T,C)
326
+ out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,h,T,d) -> (B,T,h,d) -> (B,T,h*d) -> (B,T,C)
327
+ out = self.o_proj(out) # (B,T,C) -> (B,T,C)
328
+
329
+ # Return output and optionally the new KV cache
330
+ present_key_value = None
331
+ if use_cache:
332
+ # Return k_cache, v_cache (before GQA expansion, after RoPE)
333
+ present_key_value = (k_cache, v_cache)
334
+
335
+ return out, present_key_value
336
+
337
+ # =========================
338
+ # 5. MLP (SwiGLU)
339
+ # =========================
340
+ class SmolMLP(nn.Module):
341
+ """
342
+ SwiGLU MLP:
343
+ z = W1(x) -> split -> (x1, x2)
344
+ out = W2( SiLU(x1) * x2 )
345
+ """
346
+ def __init__(self, config: SmolConfig):
347
+ super().__init__()
348
+
349
+ self.fc1 = nn.Linear(
350
+ config.hidden_size,
351
+ 2 * config.intermediate_size, # for SwiGLU split (2 x 1536 = 3072)
352
+ bias=config.mlp_bias,
353
+ )
354
+
355
+ self.fc2 = nn.Linear(
356
+ config.intermediate_size, # 1536
357
+ config.hidden_size, # 576
358
+ bias=config.mlp_bias,
359
+ )
360
+
361
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
362
+ x = self.fc1(x)# (B,T,C) -> (B,T,2*intermediate_size) -> (B,T,1536*2) -> (B,T,3072)
363
+ x1, x2 = x.chunk(2, dim=-1) # (B,T,2*intermediate_size) = (B,T,3072) -> (B,T,intermediate), (B,T,intermediate) = (B,T,1536), (B,T,1536)
364
+ return self.fc2(F.silu(x1) * x2) # (B,T,intermediate) * (B,T,intermediate) -> (B,T,intermediate) -> (B,T,hidden_size) = (B,T,576)
365
+
366
+
367
+ # =========================
368
+ # 6. Transformer Block
369
+ # =========================
370
+ class SmolBlock(nn.Module):
371
+ def __init__(self, config: SmolConfig):
372
+ super().__init__()
373
+ self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
374
+ self.attn = MultiHeadSelfAttention(config)
375
+ self.mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.mlp = SmolMLP(config)
377
+
378
+ def forward(
379
+ self,
380
+ x: torch.Tensor,
381
+ cos: torch.Tensor,
382
+ sin: torch.Tensor,
383
+ attention_mask: Optional[torch.Tensor] = None,
384
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
385
+ use_cache: bool = False,
386
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
387
+ # Pre-norm + residual for attention
388
+ attn_out, present_key_value = self.attn(
389
+ self.attn_norm(x), cos, sin, attention_mask, past_key_value, use_cache
390
+ )
391
+ x = x + attn_out
392
+ # Pre-norm + residual for MLP
393
+ x = x + self.mlp(self.mlp_norm(x))
394
+ return x, present_key_value
395
+
396
+ # =============================================
397
+ # 7. Top-level SmolLM2-135M Model Architecture
398
+ # SmolLM2 follows the LLaMA-style decoder-only Transformer architecture.
399
+ # =============================================
400
+ class SmolLM2(nn.Module):
401
+ """
402
+ SmolLM2-135M-style LLaMA decoder-only language model.
403
+
404
+ Usage:
405
+ cfg = SmolConfig()
406
+ model = SmolLM2(cfg)
407
+
408
+ input_ids: LongTensor (B, T)
409
+ logits = model(input_ids)
410
+ """
411
+ def __init__(self, config: SmolConfig):
412
+ super().__init__()
413
+ self.config = config
414
+
415
+ self.embed_tokens = nn.Embedding(
416
+ config.vocab_size,
417
+ config.hidden_size,
418
+ ) # (Vocab_Size, Hidden_Size) (49152 x 576)
419
+
420
+ self.layers = nn.ModuleList(
421
+ [SmolBlock(config) for _ in range(config.num_hidden_layers)]
422
+ )
423
+
424
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
425
+
426
+ self.lm_head = nn.Linear(
427
+ config.hidden_size,
428
+ config.vocab_size,
429
+ bias=False,
430
+ ) # (Hidden_Size, Vocab_Size) (576 x 49152)
431
+
432
+ # tie weights
433
+ self.lm_head.weight = self.embed_tokens.weight
434
+
435
+ def forward(
436
+ self,
437
+ input_ids: torch.Tensor, # (B, T)
438
+ attention_mask: Optional[torch.Tensor] = None,
439
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
440
+ use_cache: bool = False,
441
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
442
+ B, T = input_ids.shape
443
+
444
+ # For inference with KV cache, we might have T=1
445
+ if past_key_values is None:
446
+ assert T <= self.config.max_position_embeddings, (
447
+ f"Sequence length {T} exceeds max_position_embeddings "
448
+ f"{self.config.max_position_embeddings}"
449
+ )
450
+ seq_len = T
451
+ else:
452
+ # With KV cache, current sequence length is past_len + T
453
+ past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
454
+ seq_len = past_len + T
455
+ assert seq_len <= self.config.max_position_embeddings, (
456
+ f"Total sequence length {seq_len} exceeds max_position_embeddings "
457
+ f"{self.config.max_position_embeddings}"
458
+ )
459
+
460
+ # Embedding
461
+ x = self.embed_tokens(input_ids) # (B,T) -> (B,T,C)
462
+
463
+ # RoPE cache - build for the full sequence length (past + current)
464
+ cos, sin = build_rope_cache(
465
+ seq_len=seq_len,
466
+ head_dim=self.config.head_dim,
467
+ base=self.config.rope_theta,
468
+ device=x.device,
469
+ dtype=x.dtype,
470
+ )
471
+
472
+ # If using KV cache, we only need cos/sin for current positions
473
+ if past_key_values is not None:
474
+ past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
475
+ # Slice to get only the current positions for RoPE
476
+ cos = cos[..., past_len:, :]
477
+ sin = sin[..., past_len:, :]
478
+
479
+ # Layers
480
+ present_key_values = [] if use_cache else None
481
+ for i, layer in enumerate(self.layers):
482
+ past_kv = past_key_values[i] if past_key_values is not None else None
483
+ x, present_kv = layer(x, cos, sin, attention_mask, past_kv, use_cache)
484
+ if use_cache:
485
+ present_key_values.append(present_kv)
486
+
487
+ # Final norm + lm head
488
+ x = self.norm(x)
489
+ logits = self.lm_head(x) # (B,T,C) -> (B,T,vocab_size)
490
+ return logits, present_key_values
491
+
492
+ @torch.no_grad()
493
+ def generate(
494
+ self,
495
+ input_ids: torch.Tensor,
496
+ max_new_tokens: int = 100,
497
+ temperature: float = 1.0,
498
+ top_k: Optional[int] = None,
499
+ top_p: Optional[float] = None,
500
+ eos_token_id: Optional[int] = None,
501
+ ) -> torch.Tensor:
502
+ """
503
+ Generate text using KV cache for efficient inference.
504
+
505
+ Args:
506
+ input_ids: (B, T) input token ids
507
+ max_new_tokens: maximum number of new tokens to generate
508
+ temperature: sampling temperature
509
+ top_k: top-k sampling (keep top k tokens)
510
+ top_p: nucleus sampling (keep tokens with cumulative probability <= top_p)
511
+ eos_token_id: end-of-sequence token id (stop generation when encountered)
512
+
513
+ Returns:
514
+ generated_ids: (B, T + max_new_tokens) generated token ids
515
+ """
516
+ self.eval()
517
+ device = input_ids.device
518
+ B, T = input_ids.shape
519
+
520
+ # Start with input_ids
521
+ generated_ids = input_ids.clone()
522
+ past_key_values = None
523
+
524
+ for step in range(max_new_tokens):
525
+ # Forward pass with KV cache
526
+ # On first iteration, use full input_ids. On subsequent iterations, use only last token
527
+ if past_key_values is None:
528
+ # First iteration: process full sequence
529
+ current_input = generated_ids
530
+ else:
531
+ # Subsequent iterations: only process the last generated token
532
+ current_input = generated_ids[:, -1:]
533
+
534
+ logits, past_key_values = self.forward(
535
+ input_ids=current_input,
536
+ past_key_values=past_key_values,
537
+ use_cache=True,
538
+ )
539
+
540
+ # Get logits for the last token (always the last position in logits)
541
+ next_token_logits = logits[:, -1, :] / temperature
542
+
543
+ # Apply top-k filtering
544
+ if top_k is not None:
545
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
546
+ next_token_logits[indices_to_remove] = float('-inf')
547
+
548
+ # Apply top-p (nucleus) filtering
549
+ if top_p is not None:
550
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
551
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
552
+
553
+ # Remove tokens with cumulative probability above the threshold
554
+ sorted_indices_to_remove = cumulative_probs > top_p
555
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
556
+ sorted_indices_to_remove[..., 0] = 0
557
+
558
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
559
+ next_token_logits[indices_to_remove] = float('-inf')
560
+
561
+ # Sample next token
562
+ probs = F.softmax(next_token_logits, dim=-1)
563
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
564
+
565
+ # Append to generated sequence
566
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
567
+
568
+ # Check for EOS token
569
+ if eos_token_id is not None and (next_token == eos_token_id).all():
570
+ break
571
+
572
+ return generated_ids
573
+
574
+ # =========================
575
+ # 8. Quick self-test
576
+ # =========================
577
+ if __name__ == "__main__":
578
+ # Tiny sanity check: runs a forward pass on random input
579
+ cfg = SmolConfig()
580
+ model = SmolLM2(cfg)
581
+
582
+ B, T = 2, 16
583
+ x = torch.randint(0, cfg.vocab_size, (B, T))
584
+
585
+ with torch.no_grad():
586
+ logits, _ = model(x)
587
+
588
+ print("Input shape :", x.shape)
589
+ print("Logits shape:", logits.shape) # should be (2, 16, vocab_size)
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "smollm-135"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "lightning>=2.6.0",
9
+ "tensorboard>=2.20.0",
10
+ "torch>=2.9.1",
11
+ "torchinfo>=1.8.0",
12
+ "torchmetrics>=1.8.2",
13
+ "torchsummary>=1.5.1",
14
+ "torchvision>=0.24.1",
15
+ "tqdm>=4.67.1",
16
+ "transformers>=4.57.3",
17
+ ]
test_model_implementation.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
4
+
5
+ from model import SmolLM2, SmolConfig # your implementation
6
+
7
+
8
+ PRETRAINED_NAME = "HuggingFaceTB/SmolLM2-135M"
9
+
10
+
11
+ def build_custom_model():
12
+ """Create our SmolLM2 using HF config to ensure identical hyperparams."""
13
+ hf_cfg = AutoConfig.from_pretrained(PRETRAINED_NAME)
14
+ cfg = SmolConfig.from_hf(hf_cfg)
15
+ model = SmolLM2(cfg)
16
+ return model, cfg
17
+
18
+
19
+ def build_hf_model():
20
+ """Load reference HF model."""
21
+ hf_model = AutoModelForCausalLM.from_pretrained(
22
+ PRETRAINED_NAME,
23
+ torch_dtype=torch.float32, # use float32 for easier comparison
24
+ )
25
+ hf_model.eval()
26
+ return hf_model
27
+
28
+
29
+ def load_weights_from_hf(custom_model: SmolLM2, hf_model: AutoModelForCausalLM):
30
+ """
31
+ Map HF LlamaForCausalLM weights into our SmolLM2 model.
32
+
33
+ - HF model structure: hf_model.model (LlamaModel) + hf_model.lm_head
34
+ - Our model: embed_tokens, layers, norm, lm_head
35
+ """
36
+ hf_state = hf_model.state_dict()
37
+ custom_state = custom_model.state_dict()
38
+
39
+ # 1. Embeddings
40
+ custom_state["embed_tokens.weight"] = hf_state["model.embed_tokens.weight"]
41
+
42
+ # 2. Per-layer mappings
43
+ num_layers = custom_model.config.num_hidden_layers
44
+
45
+ for i in range(num_layers):
46
+ # Norms
47
+ custom_state[f"layers.{i}.attn_norm.weight"] = hf_state[
48
+ f"model.layers.{i}.input_layernorm.weight"
49
+ ]
50
+ custom_state[f"layers.{i}.mlp_norm.weight"] = hf_state[
51
+ f"model.layers.{i}.post_attention_layernorm.weight"
52
+ ]
53
+
54
+ # Attention projections
55
+ custom_state[f"layers.{i}.attn.q_proj.weight"] = hf_state[
56
+ f"model.layers.{i}.self_attn.q_proj.weight"
57
+ ]
58
+ custom_state[f"layers.{i}.attn.k_proj.weight"] = hf_state[
59
+ f"model.layers.{i}.self_attn.k_proj.weight"
60
+ ]
61
+ custom_state[f"layers.{i}.attn.v_proj.weight"] = hf_state[
62
+ f"model.layers.{i}.self_attn.v_proj.weight"
63
+ ]
64
+ custom_state[f"layers.{i}.attn.o_proj.weight"] = hf_state[
65
+ f"model.layers.{i}.self_attn.o_proj.weight"
66
+ ]
67
+
68
+ # MLP: HF has gate_proj, up_proj, down_proj
69
+ gate = hf_state[f"model.layers.{i}.mlp.gate_proj.weight"]
70
+ up = hf_state[f"model.layers.{i}.mlp.up_proj.weight"]
71
+ down = hf_state[f"model.layers.{i}.mlp.down_proj.weight"]
72
+
73
+ # Our fc1 is [gate; up] concatenated along output dim (dim=0)
74
+ custom_state[f"layers.{i}.mlp.fc1.weight"] = torch.cat([gate, up], dim=0)
75
+ # Our fc2 is down_proj
76
+ custom_state[f"layers.{i}.mlp.fc2.weight"] = down
77
+
78
+ # 3. Final norm
79
+ custom_state["norm.weight"] = hf_state["model.norm.weight"]
80
+
81
+ # 4. LM head (tied with embeddings, but we still load it)
82
+ custom_state["lm_head.weight"] = hf_state["lm_head.weight"]
83
+
84
+ # Now load into the model
85
+ missing, unexpected = custom_model.load_state_dict(custom_state, strict=False)
86
+ return missing, unexpected
87
+
88
+
89
+ def test_weight_loading():
90
+ """
91
+ 1. Build custom SmolLM2 model (our implementation).
92
+ 2. Build HF reference model.
93
+ 3. Load HF weights into our model via mapping.
94
+ 4. Run a small test prompt and compare logits.
95
+ """
96
+ device = "cuda" if torch.cuda.is_available() else "cpu"
97
+ print(f"Using device: {device}")
98
+
99
+ print("🟦 Building custom model...")
100
+ custom_model, cfg = build_custom_model()
101
+ custom_model.to(device)
102
+ custom_model.eval()
103
+
104
+ print("🟦 Building HF reference model...")
105
+ hf_model = build_hf_model()
106
+ hf_model.to(device)
107
+
108
+ print("🟦 Mapping HF weights into custom model...")
109
+ missing, unexpected = load_weights_from_hf(custom_model, hf_model)
110
+
111
+ print(f"Missing keys : {len(missing)}")
112
+ print(f"Unexpected keys : {len(unexpected)}")
113
+ if missing:
114
+ print(" Missing examples:", missing[:5])
115
+ if unexpected:
116
+ print(" Unexpected examples:", unexpected[:5])
117
+
118
+ if len(missing) > 0:
119
+ print("⚠️ There are missing keys; mapping may be incomplete.")
120
+ else:
121
+ print("✅ All expected parameters were assigned from HF weights.")
122
+
123
+ # 5. Test with a dummy input
124
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_NAME)
125
+ prompt = "Hello, how are you?"
126
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
127
+
128
+ print("🟦 Running HF model forward...")
129
+ with torch.no_grad():
130
+ hf_logits = hf_model(**inputs).logits # (B, T, V)
131
+
132
+ print("🟦 Running custom model forward...")
133
+ with torch.no_grad():
134
+ custom_logits, _ = custom_model(inputs["input_ids"])
135
+
136
+ # 6. Compare logits
137
+ # align dtypes
138
+ hf_logits = hf_logits.to(torch.float32)
139
+ custom_logits = custom_logits.to(torch.float32)
140
+
141
+ diff = torch.abs(hf_logits - custom_logits).max().item()
142
+ print(f"🔍 Max absolute difference between logits: {diff:.6f}")
143
+
144
+ if diff < 1e-4:
145
+ print("✅ SUCCESS: Outputs match very closely. Implementation is correct.")
146
+ elif diff < 1e-2:
147
+ print("🟡 Outputs are close but not identical; check for small implementation differences (e.g., RoPE details).")
148
+ else:
149
+ print("❌ Outputs differ significantly. Some part of the implementation is likely off.")
150
+
151
+ # 7. Print predictions from both models
152
+ print("\n📝 Predictions:")
153
+ print(f"Prompt: '{prompt}'")
154
+
155
+ # Get predicted token IDs (argmax on vocabulary dimension)
156
+ hf_predicted_ids = hf_logits.argmax(dim=-1) # (B, T)
157
+ custom_predicted_ids = custom_logits.argmax(dim=-1) # (B, T)
158
+
159
+ # Get the next token prediction (last position)
160
+ hf_next_token_id = hf_predicted_ids[0, -1].item()
161
+ custom_next_token_id = custom_predicted_ids[0, -1].item()
162
+
163
+ # Decode the next token
164
+ hf_next_token = tokenizer.decode([hf_next_token_id])
165
+ custom_next_token = tokenizer.decode([custom_next_token_id])
166
+
167
+ print(f"HF Model prediction (next token): '{hf_next_token}' (token_id: {hf_next_token_id})")
168
+ print(f"Custom Model prediction (next token): '{custom_next_token}' (token_id: {custom_next_token_id})")
169
+
170
+ # Also show full sequence predictions for comparison
171
+ hf_full_prediction = tokenizer.decode(hf_predicted_ids[0])
172
+ custom_full_prediction = tokenizer.decode(custom_predicted_ids[0])
173
+ print(f"\nHF Model full sequence prediction: '{hf_full_prediction}'")
174
+ print(f"Custom Model full sequence prediction: '{custom_full_prediction}'")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ if len(sys.argv) < 2:
179
+ print("Usage: python test_model_implementation.py test_weight_loading")
180
+ sys.exit(1)
181
+
182
+ mode = sys.argv[1]
183
+
184
+ if mode == "test_weight_loading":
185
+ test_weight_loading()
186
+ else:
187
+ print(f"Unknown mode: {mode}")
train.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for SmolLM2-135M using PyTorch Lightning.
3
+
4
+ Training strategy from paper:
5
+ - AdamW optimizer with (β1, β2) = (0.9, 0.95)
6
+ - Warmup Stable Decay (WSD) learning rate schedule:
7
+ - 2,000-step warmup phase
8
+ - Peak learning rate: 5.0 × 10^-4 (stable phase)
9
+ - Decay phase: reduce LR to zero over 10% of total training steps
10
+ """
11
+
12
+ import sys
13
+ import logging
14
+ from pathlib import Path
15
+ from datetime import datetime
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.utils.data import Dataset, DataLoader
20
+ import lightning as L
21
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
22
+ from lightning.pytorch.loggers import TensorBoardLogger
23
+ from transformers import AutoTokenizer, AutoConfig
24
+
25
+ from model import SmolLM2, SmolConfig
26
+
27
+ # Setup logging
28
+ def setup_logging(log_dir: Path):
29
+ """Setup text file logging."""
30
+ log_dir.mkdir(parents=True, exist_ok=True)
31
+ log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(levelname)s - %(message)s',
36
+ handlers=[
37
+ logging.FileHandler(log_file),
38
+ logging.StreamHandler(sys.stdout)
39
+ ]
40
+ )
41
+ return logging.getLogger(__name__), log_file
42
+
43
+
44
+ class TextDataset(Dataset):
45
+ """Dataset for text data."""
46
+ def __init__(self, text_file: str, tokenizer, block_size: int = 512):
47
+ self.tokenizer = tokenizer
48
+ self.block_size = block_size
49
+
50
+ # Read and tokenize text
51
+ with open(text_file, 'r', encoding='utf-8') as f:
52
+ text = f.read()
53
+
54
+ # Tokenize
55
+ tokens = tokenizer.encode(text, add_special_tokens=False)
56
+ self.data = torch.tensor(tokens, dtype=torch.long)
57
+
58
+ def __len__(self):
59
+ return len(self.data) - self.block_size
60
+
61
+ def __getitem__(self, idx):
62
+ chunk = self.data[idx:idx + self.block_size + 1]
63
+ x = chunk[:-1]
64
+ y = chunk[1:]
65
+ return x, y
66
+
67
+
68
+ class WarmupStableDecayLR(L.Callback):
69
+ """
70
+ Warmup Stable Decay (WSD) learning rate schedule.
71
+ - Warmup: 2000 steps in paper, Since only training for 5000 steps, we will use 20% of total steps as warmup steps (1000 steps)
72
+ - Stable: maintain peak LR
73
+ - Decay: reduce to zero over 10% of total steps
74
+ """
75
+ def __init__(self, warmup_steps: int = 2000, peak_lr: float = 5e-4, total_steps: int = 5000):
76
+ super().__init__()
77
+ self.warmup_steps = warmup_steps
78
+ self.peak_lr = peak_lr
79
+ self.total_steps = total_steps
80
+ self.decay_steps = int(0.1 * total_steps) # 10% of total steps
81
+ self.stable_steps = total_steps - warmup_steps - self.decay_steps
82
+
83
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
84
+ current_step = trainer.global_step
85
+
86
+ if current_step < self.warmup_steps:
87
+ # Warmup phase: linear increase
88
+ lr = self.peak_lr * (current_step / self.warmup_steps)
89
+ elif current_step < self.warmup_steps + self.stable_steps:
90
+ # Stable phase: maintain peak LR
91
+ lr = self.peak_lr
92
+ else:
93
+ # Decay phase: linear decrease to zero
94
+ decay_start = self.warmup_steps + self.stable_steps
95
+ decay_progress = (current_step - decay_start) / self.decay_steps
96
+ lr = self.peak_lr * (1.0 - decay_progress)
97
+
98
+ # Update learning rate
99
+ optimizer = pl_module.optimizers()
100
+ if isinstance(optimizer, torch.optim.Optimizer):
101
+ for param_group in optimizer.param_groups:
102
+ param_group['lr'] = lr
103
+ else:
104
+ # If it's a list or other structure
105
+ for opt in optimizer:
106
+ for param_group in opt.param_groups:
107
+ param_group['lr'] = lr
108
+
109
+
110
+ class SmolLM2Module(L.LightningModule):
111
+ """PyTorch Lightning module for SmolLM2 training."""
112
+
113
+ def __init__(
114
+ self,
115
+ config: SmolConfig,
116
+ tokenizer,
117
+ block_size: int = 512,
118
+ warmup_steps: int = 2000,
119
+ peak_lr: float = 5e-4,
120
+ total_steps: int = 5000,
121
+ predict_every: int = 500,
122
+ ):
123
+ super().__init__()
124
+ self.save_hyperparameters(ignore=['tokenizer'])
125
+ self.config = config
126
+ self.tokenizer = tokenizer
127
+ self.block_size = block_size
128
+ self.warmup_steps = warmup_steps
129
+ self.peak_lr = peak_lr
130
+ self.total_steps = total_steps
131
+ self.predict_every = predict_every
132
+
133
+ # Initialize model
134
+ self.model = SmolLM2(config)
135
+
136
+ # Loss function
137
+ self.criterion = nn.CrossEntropyLoss()
138
+
139
+ # For generation
140
+ self.example_prompt = "First Citizen:"
141
+
142
+ def forward(self, input_ids, attention_mask=None):
143
+ logits, present_key_values = self.model(input_ids, attention_mask=attention_mask, use_cache=False)
144
+ return logits
145
+
146
+ def training_step(self, batch, batch_idx):
147
+ x, y = batch
148
+ logits = self.forward(x)
149
+
150
+ # Reshape for loss calculation
151
+ loss = self.criterion(logits.view(-1, logits.size(-1)), y.view(-1))
152
+
153
+ # Logging
154
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
155
+
156
+ # Generate text every predict_every steps
157
+ if (self.global_step + 1) % self.predict_every == 0:
158
+ # Log scalar loss to text log so it shows up with generations
159
+ logger.info(f"Step {self.global_step + 1} | train_loss={loss.item():.4f}")
160
+ self.generate_and_log()
161
+
162
+ return loss
163
+
164
+ def generate_and_log(self):
165
+ """Generate text and log it."""
166
+ self.model.eval()
167
+ with torch.no_grad():
168
+ # Tokenize prompt
169
+ prompt_ids = self.tokenizer.encode(
170
+ self.example_prompt,
171
+ return_tensors='pt',
172
+ add_special_tokens=False
173
+ ).to(self.device)
174
+
175
+ # Generate
176
+ generated_ids = self.model.generate(
177
+ prompt_ids,
178
+ max_new_tokens=50,
179
+ temperature=0.8,
180
+ top_k=50,
181
+ )
182
+
183
+ # Decode
184
+ generated_text = self.tokenizer.decode(
185
+ generated_ids[0].cpu().tolist(),
186
+ skip_special_tokens=True
187
+ )
188
+
189
+ # Log to console and file
190
+ logger.info(f"\n{'='*80}")
191
+ logger.info(f"Step {self.global_step + 1} - Generated text:")
192
+ logger.info(f"{generated_text}")
193
+ logger.info(f"{'='*80}\n")
194
+
195
+ self.model.train()
196
+
197
+ def configure_optimizers(self):
198
+ """Configure optimizer with AdamW."""
199
+ optimizer = torch.optim.AdamW(
200
+ self.parameters(),
201
+ lr=self.peak_lr, # Will be adjusted by scheduler
202
+ betas=(0.9, 0.95),
203
+ weight_decay=0.01,
204
+ )
205
+
206
+ # WSD scheduler (implemented as callback)
207
+ return optimizer
208
+
209
+ def on_train_start(self):
210
+ """Log model summary at training start."""
211
+ # Count parameters
212
+ total_params = sum(p.numel() for p in self.model.parameters())
213
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
214
+
215
+ logger.info("\n" + "="*80)
216
+ logger.info("MODEL SUMMARY")
217
+ logger.info("="*80)
218
+ logger.info(f"Model: SmolLM2-135M")
219
+ logger.info(f"Total parameters: {total_params:,}")
220
+ logger.info(f"Trainable parameters: {trainable_params:,}")
221
+ logger.info(f"Block size: {self.block_size}")
222
+ logger.info(f"Warmup steps: {self.warmup_steps}")
223
+ logger.info(f"Peak learning rate: {self.peak_lr}")
224
+ logger.info(f"Total training steps: {self.total_steps}")
225
+ logger.info(f"Predict every: {self.predict_every} steps")
226
+ logger.info("="*80 + "\n")
227
+
228
+
229
+ def main():
230
+ # Configuration
231
+ data_file = Path("../data/input.txt").resolve()
232
+ output_dir = Path("./checkpoints")
233
+ log_dir = Path("./logs")
234
+ block_size = 512
235
+ batch_size = 4
236
+ num_workers = 8
237
+ max_steps = 3500
238
+ predict_every = 500
239
+ resume_from_checkpoint = "checkpoints/smollm2-step=01500-train_loss=3.6240.ckpt" # Set to checkpoint path to resume, or None for fresh training
240
+
241
+ # Training hyperparameters from paper
242
+ warmup_steps = 1000
243
+ peak_lr = 5e-4
244
+ total_steps = max_steps
245
+
246
+ # Setup logging
247
+ global logger
248
+ logger, log_file = setup_logging(log_dir)
249
+ logger.info(f"Logging to: {log_file}")
250
+
251
+ # Load tokenizer
252
+ logger.info("Loading tokenizer...")
253
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
254
+ if tokenizer.pad_token is None:
255
+ tokenizer.pad_token = tokenizer.eos_token
256
+
257
+ # Allow SmolConfig to be deserialized from Lightning checkpoints when torch.load
258
+ # uses weights_only=True default (torch>=2.6). This is safe because the class
259
+ # is defined locally in this file.
260
+ try:
261
+ torch.serialization.add_safe_globals([SmolConfig]) # type: ignore[attr-defined]
262
+ except Exception:
263
+ # Fallback for torch versions without add_safe_globals; Lightning will still
264
+ # load normally when weights_only=False.
265
+ pass
266
+
267
+ # Load config and create model config
268
+ logger.info("Loading model config...")
269
+ hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
270
+ config = SmolConfig.from_hf(hf_config)
271
+
272
+ # Create dataset
273
+ logger.info(f"Loading dataset from: {data_file}")
274
+ dataset = TextDataset(data_file, tokenizer, block_size=block_size)
275
+ dataloader = DataLoader(
276
+ dataset,
277
+ batch_size=batch_size,
278
+ shuffle=True,
279
+ num_workers=num_workers,
280
+ pin_memory=True,
281
+ )
282
+
283
+ # Create Lightning module
284
+ logger.info("Initializing model...")
285
+ model = SmolLM2Module(
286
+ config=config,
287
+ tokenizer=tokenizer,
288
+ block_size=block_size,
289
+ warmup_steps=warmup_steps,
290
+ peak_lr=peak_lr,
291
+ total_steps=total_steps,
292
+ predict_every=predict_every,
293
+ )
294
+
295
+ # Additional callback to ensure checkpoint at final step
296
+ class FinalCheckpointCallback(L.Callback):
297
+ def on_train_end(self, trainer, pl_module):
298
+ # Save final checkpoint
299
+ final_checkpoint_path = output_dir / f"smollm2-final-step-{trainer.global_step:05d}.ckpt"
300
+ trainer.save_checkpoint(str(final_checkpoint_path))
301
+ logger.info(f"Final checkpoint saved: {final_checkpoint_path}")
302
+
303
+ final_checkpoint_callback = FinalCheckpointCallback()
304
+
305
+ # Setup callbacks
306
+ checkpoint_callback = ModelCheckpoint(
307
+ dirpath=output_dir,
308
+ filename='smollm2-{step:05d}-{train_loss:.4f}',
309
+ monitor='train_loss',
310
+ save_top_k=3,
311
+ mode='min',
312
+ every_n_train_steps=predict_every,
313
+ save_last=True,
314
+ save_on_train_epoch_end=False, # Save based on steps, not epochs
315
+ )
316
+
317
+ lr_monitor = LearningRateMonitor(logging_interval='step')
318
+
319
+ wsd_scheduler = WarmupStableDecayLR(
320
+ warmup_steps=warmup_steps,
321
+ peak_lr=peak_lr,
322
+ total_steps=total_steps,
323
+ )
324
+
325
+ # Setup TensorBoard logger
326
+ tb_logger = TensorBoardLogger(
327
+ save_dir=log_dir,
328
+ name='tensorboard',
329
+ )
330
+
331
+ # Create trainer
332
+ trainer = L.Trainer(
333
+ max_steps=max_steps,
334
+ callbacks=[checkpoint_callback, lr_monitor, wsd_scheduler, final_checkpoint_callback],
335
+ logger=tb_logger,
336
+ accelerator='auto',
337
+ devices='auto',
338
+ # Set precision depending on device capabilities.
339
+ # bf16-mixed: CUDA; 32-true: others; MPS supports only 32-true.
340
+ precision='bf16-mixed' if torch.cuda.is_available() else '32-true',
341
+ gradient_clip_val=1.0,
342
+ log_every_n_steps=50,
343
+ enable_checkpointing=True,
344
+ )
345
+
346
+ # Train
347
+ logger.info("Starting training...")
348
+ if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
349
+ logger.info(f"Resuming from checkpoint: {resume_from_checkpoint}")
350
+ trainer.fit(model, dataloader, ckpt_path=resume_from_checkpoint)
351
+ else:
352
+ trainer.fit(model, dataloader)
353
+
354
+ logger.info("Training completed!")
355
+ logger.info(f"Best checkpoint: {checkpoint_callback.best_model_path}")
356
+ logger.info(f"Last checkpoint: {checkpoint_callback.last_model_path}")
357
+
358
+
359
+ if __name__ == "__main__":
360
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff