Sualeh Qureshi
commited on
Commit
·
c175ce3
0
Parent(s):
Commited the training code and model file
Browse files- .gitignore +14 -0
- .python-version +1 -0
- README.md +0 -0
- README_TRAINING.md +110 -0
- logs/tensorboard/version_0/events.out.tfevents.1765268407.MAC-QNYQPC2R2T.88043.0 +0 -0
- logs/tensorboard/version_0/hparams.yaml +5 -0
- logs/tensorboard/version_1/events.out.tfevents.1765274926.MAC-QNYQPC2R2T.7268.0 +0 -0
- logs/tensorboard/version_2/events.out.tfevents.1765275552.MAC-QNYQPC2R2T.7768.0 +0 -0
- logs/tensorboard/version_2/hparams.yaml +5 -0
- logs/training_20251209_135005.log +54 -0
- logs/training_20251209_154910.log +35 -0
- main.py +6 -0
- model.py +589 -0
- pyproject.toml +17 -0
- test_model_implementation.py +187 -0
- train.py +360 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Checkpoints
|
| 13 |
+
checkpoints/
|
| 14 |
+
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
README.md
ADDED
|
File without changes
|
README_TRAINING.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SmolLM2-135M Training Guide
|
| 2 |
+
|
| 3 |
+
This directory contains the training code for SmolLM2-135M model.
|
| 4 |
+
|
| 5 |
+
## Files
|
| 6 |
+
|
| 7 |
+
- `model.py`: Model definition with KV cache support for inference
|
| 8 |
+
- `train.py`: Main training script (trains for 5000 steps)
|
| 9 |
+
- Run with checkpoint path to Resume training for 50 additional steps
|
| 10 |
+
|
| 11 |
+
## Setup
|
| 12 |
+
|
| 13 |
+
Install required packages:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
pip install torch lightning transformers tensorboard
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Training
|
| 20 |
+
|
| 21 |
+
### Phase 1: Initial Training (5000 steps)
|
| 22 |
+
|
| 23 |
+
Run the main training script:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python train.py
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
This will:
|
| 30 |
+
- Train the model for 5000 steps
|
| 31 |
+
- Generate text predictions every 500 steps
|
| 32 |
+
- Save checkpoints every 500 steps
|
| 33 |
+
- Log training metrics to TensorBoard and text file
|
| 34 |
+
- Save the final checkpoint at step 5000
|
| 35 |
+
|
| 36 |
+
### Phase 2: Resume Training (50 additional steps)
|
| 37 |
+
|
| 38 |
+
After Phase 1 completes, run:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python train.py
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
But this time set the checkpoint path, and set steps as 50 to resume training for 50 additional steps. just to showcase that training is started where it stopped.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
This will:
|
| 48 |
+
- Load the checkpoint from Phase 1
|
| 49 |
+
- Train for 50 additional steps
|
| 50 |
+
- Save the final checkpoint
|
| 51 |
+
|
| 52 |
+
## Training Configuration
|
| 53 |
+
|
| 54 |
+
The training uses the following hyperparameters (from the SmolLM2 paper):
|
| 55 |
+
|
| 56 |
+
- **Optimizer**: AdamW with (β₁, β₂) = (0.9, 0.95)
|
| 57 |
+
- **Learning Rate Schedule**: Warmup Stable Decay (WSD)
|
| 58 |
+
- Warmup: 2000 steps
|
| 59 |
+
- Peak LR: 5.0 × 10⁻⁴
|
| 60 |
+
- Stable phase: maintains peak LR
|
| 61 |
+
- Decay: reduces to zero over 10% of total steps
|
| 62 |
+
- **Block size**: 512 tokens
|
| 63 |
+
- **Batch size**: 4
|
| 64 |
+
- **Precision**: bfloat16 (if GPU available), float32 otherwise
|
| 65 |
+
|
| 66 |
+
## Outputs
|
| 67 |
+
|
| 68 |
+
- **Checkpoints**: Saved in `./checkpoints/`
|
| 69 |
+
- **TensorBoard logs**: Saved in `./logs/tensorboard/`
|
| 70 |
+
- **Text logs**: Saved in `./logs/training_*.log`
|
| 71 |
+
|
| 72 |
+
## Model Features
|
| 73 |
+
|
| 74 |
+
The model includes:
|
| 75 |
+
- **KV Cache**: Efficient inference using key-value caching
|
| 76 |
+
- **Generation**: Text generation with top-k and top-p sampling
|
| 77 |
+
- **Checkpointing**: Full state saving for resuming training
|
| 78 |
+
|
| 79 |
+
## Usage Example
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
from model import SmolLM2, SmolConfig
|
| 83 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 84 |
+
|
| 85 |
+
# Load config
|
| 86 |
+
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 87 |
+
config = SmolConfig.from_hf(hf_config)
|
| 88 |
+
|
| 89 |
+
# Create model
|
| 90 |
+
model = SmolLM2(config)
|
| 91 |
+
|
| 92 |
+
# Load checkpoint
|
| 93 |
+
checkpoint = torch.load("checkpoints/smollm2-00500-*.ckpt")
|
| 94 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 95 |
+
|
| 96 |
+
# Generate text
|
| 97 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 98 |
+
prompt = "First Citizen:"
|
| 99 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
| 100 |
+
|
| 101 |
+
generated_ids = model.generate(
|
| 102 |
+
input_ids,
|
| 103 |
+
max_new_tokens=100,
|
| 104 |
+
temperature=0.8,
|
| 105 |
+
top_k=50,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
generated_text = tokenizer.decode(generated_ids[0])
|
| 109 |
+
print(generated_text)
|
| 110 |
+
```
|
logs/tensorboard/version_0/events.out.tfevents.1765268407.MAC-QNYQPC2R2T.88043.0
ADDED
|
Binary file (5.59 kB). View file
|
|
|
logs/tensorboard/version_0/hparams.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
block_size: 512
|
| 2 |
+
peak_lr: 0.0005
|
| 3 |
+
predict_every: 500
|
| 4 |
+
total_steps: 5000
|
| 5 |
+
warmup_steps: 1000
|
logs/tensorboard/version_1/events.out.tfevents.1765274926.MAC-QNYQPC2R2T.7268.0
ADDED
|
Binary file (88 Bytes). View file
|
|
|
logs/tensorboard/version_2/events.out.tfevents.1765275552.MAC-QNYQPC2R2T.7768.0
ADDED
|
Binary file (2.8 kB). View file
|
|
|
logs/tensorboard/version_2/hparams.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
block_size: 512
|
| 2 |
+
peak_lr: 0.0005
|
| 3 |
+
predict_every: 500
|
| 4 |
+
total_steps: 3500
|
| 5 |
+
warmup_steps: 1000
|
logs/training_20251209_135005.log
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-12-09 13:50:05,106 - INFO - Logging to: logs/training_20251209_135005.log
|
| 2 |
+
2025-12-09 13:50:05,106 - INFO - Loading tokenizer...
|
| 3 |
+
2025-12-09 13:50:05,965 - INFO - Loading model config...
|
| 4 |
+
2025-12-09 13:50:06,205 - INFO - Loading dataset from: /Users/qureshsu/Learning/TSAI/ERAV4/session13/data/input.txt
|
| 5 |
+
2025-12-09 13:50:06,657 - INFO - Initializing model...
|
| 6 |
+
2025-12-09 13:50:07,391 - INFO - Starting training...
|
| 7 |
+
2025-12-09 13:50:24,556 - INFO -
|
| 8 |
+
================================================================================
|
| 9 |
+
2025-12-09 13:50:24,557 - INFO - MODEL SUMMARY
|
| 10 |
+
2025-12-09 13:50:24,557 - INFO - ================================================================================
|
| 11 |
+
2025-12-09 13:50:24,557 - INFO - Model: SmolLM2-135M
|
| 12 |
+
2025-12-09 13:50:24,557 - INFO - Total parameters: 134,515,008
|
| 13 |
+
2025-12-09 13:50:24,557 - INFO - Trainable parameters: 134,515,008
|
| 14 |
+
2025-12-09 13:50:24,557 - INFO - Block size: 512
|
| 15 |
+
2025-12-09 13:50:24,557 - INFO - Warmup steps: 1000
|
| 16 |
+
2025-12-09 13:50:24,557 - INFO - Peak learning rate: 0.0005
|
| 17 |
+
2025-12-09 13:50:24,557 - INFO - Total training steps: 5000
|
| 18 |
+
2025-12-09 13:50:24,557 - INFO - Predict every: 500 steps
|
| 19 |
+
2025-12-09 13:50:24,557 - INFO - ================================================================================
|
| 20 |
+
|
| 21 |
+
2025-12-09 14:05:59,075 - INFO -
|
| 22 |
+
================================================================================
|
| 23 |
+
2025-12-09 14:05:59,081 - INFO - Step 500 - Generated text:
|
| 24 |
+
2025-12-09 14:05:59,081 - INFO - First Citizen:
|
| 25 |
+
WhatONEONE:
|
| 26 |
+
DUKE VINCENTIO:
|
| 27 |
+
DUKE VINCENTIO:
|
| 28 |
+
Nay, thou art thou pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow pow
|
| 29 |
+
2025-12-09 14:05:59,081 - INFO - ================================================================================
|
| 30 |
+
|
| 31 |
+
2025-12-09 14:21:21,767 - INFO -
|
| 32 |
+
================================================================================
|
| 33 |
+
2025-12-09 14:21:21,771 - INFO - Step 1000 - Generated text:
|
| 34 |
+
2025-12-09 14:21:21,771 - INFO - First Citizen:
|
| 35 |
+
And then, like thee: thou hast thou dost in thy husband'st:
|
| 36 |
+
And in thy soldiers, not in thy master's name,
|
| 37 |
+
Which then in thy shame: I did thy shame,
|
| 38 |
+
Which thou doth know her
|
| 39 |
+
2025-12-09 14:21:21,771 - INFO - ================================================================================
|
| 40 |
+
|
| 41 |
+
2025-12-09 14:37:17,744 - INFO -
|
| 42 |
+
================================================================================
|
| 43 |
+
2025-12-09 14:37:17,748 - INFO - Step 1500 - Generated text:
|
| 44 |
+
2025-12-09 14:37:17,748 - INFO - First Citizen:
|
| 45 |
+
I have done a'rt too that, if the king had title to the
|
| 46 |
+
Where it shall be the is born to be in the tongue.
|
| 47 |
+
|
| 48 |
+
Second Citizen:
|
| 49 |
+
And so shall I.
|
| 50 |
+
|
| 51 |
+
ANTONIO:
|
| 52 |
+
I
|
| 53 |
+
2025-12-09 14:37:17,748 - INFO - ================================================================================
|
| 54 |
+
|
logs/training_20251209_154910.log
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-12-09 15:49:10,023 - INFO - Logging to: logs/training_20251209_154910.log
|
| 2 |
+
2025-12-09 15:49:10,023 - INFO - Loading tokenizer...
|
| 3 |
+
2025-12-09 15:49:10,936 - INFO - Loading model config...
|
| 4 |
+
2025-12-09 15:49:11,184 - INFO - Loading dataset from: /Users/qureshsu/Learning/TSAI/ERAV4/session13/data/input.txt
|
| 5 |
+
2025-12-09 15:49:11,623 - INFO - Initializing model...
|
| 6 |
+
2025-12-09 15:49:12,354 - INFO - Starting training...
|
| 7 |
+
2025-12-09 15:49:12,357 - INFO - Resuming from checkpoint: checkpoints/smollm2-step=01500-train_loss=3.6240.ckpt
|
| 8 |
+
2025-12-09 15:49:30,901 - INFO -
|
| 9 |
+
================================================================================
|
| 10 |
+
2025-12-09 15:49:30,901 - INFO - MODEL SUMMARY
|
| 11 |
+
2025-12-09 15:49:30,901 - INFO - ================================================================================
|
| 12 |
+
2025-12-09 15:49:30,901 - INFO - Model: SmolLM2-135M
|
| 13 |
+
2025-12-09 15:49:30,901 - INFO - Total parameters: 134,515,008
|
| 14 |
+
2025-12-09 15:49:30,901 - INFO - Trainable parameters: 134,515,008
|
| 15 |
+
2025-12-09 15:49:30,901 - INFO - Block size: 512
|
| 16 |
+
2025-12-09 15:49:30,901 - INFO - Warmup steps: 1000
|
| 17 |
+
2025-12-09 15:49:30,901 - INFO - Peak learning rate: 0.0005
|
| 18 |
+
2025-12-09 15:49:30,901 - INFO - Total training steps: 3500
|
| 19 |
+
2025-12-09 15:49:30,901 - INFO - Predict every: 500 steps
|
| 20 |
+
2025-12-09 15:49:30,901 - INFO - ================================================================================
|
| 21 |
+
|
| 22 |
+
2025-12-09 15:59:45,441 - INFO - Step 2000 | train_loss=0.9070
|
| 23 |
+
2025-12-09 15:59:47,487 - INFO -
|
| 24 |
+
================================================================================
|
| 25 |
+
2025-12-09 15:59:47,487 - INFO - Step 2000 - Generated text:
|
| 26 |
+
2025-12-09 15:59:47,488 - INFO - First Citizen:
|
| 27 |
+
Why, no; but the Hortenspur, and
|
| 28 |
+
To perricks. Thou art said so when a king
|
| 29 |
+
Hadst thouable to be ruled, and not to forget
|
| 30 |
+
At any man.
|
| 31 |
+
|
| 32 |
+
First Citizen:
|
| 33 |
+
None,
|
| 34 |
+
2025-12-09 15:59:47,488 - INFO - ================================================================================
|
| 35 |
+
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from smollm-135!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Minimal SmolLM2-135M style model implemented in PyTorch.
|
| 3 |
+
# Architecture: LLaMA-style decoder-only Transformer with:
|
| 4 |
+
# - RMSNorm
|
| 5 |
+
# - RoPE positional encoding
|
| 6 |
+
# - SwiGLU MLP
|
| 7 |
+
# - Grouped (GQA/MQA) attention: num_attention_heads != num_key_value_heads
|
| 8 |
+
#
|
| 9 |
+
# This file is self-contained (except PyTorch) and can be used as:
|
| 10 |
+
#
|
| 11 |
+
# from model import SmolConfig, SmolLM2
|
| 12 |
+
#
|
| 13 |
+
# cfg = SmolConfig.from_hf("HuggingFaceTB/SmolLM2-135M")
|
| 14 |
+
# model = SmolLM2(cfg)
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Optional, Tuple, List
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
# =========================
|
| 25 |
+
# 1. Config
|
| 26 |
+
|
| 27 |
+
# Got config from HuggingFace Using: transformers.AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 28 |
+
|
| 29 |
+
# Config: SmolLM2-135M
|
| 30 |
+
|
| 31 |
+
# LlamaConfig {
|
| 32 |
+
# "architectures": [
|
| 33 |
+
# "LlamaForCausalLM"
|
| 34 |
+
# ],
|
| 35 |
+
# "attention_bias": false,
|
| 36 |
+
# "attention_dropout": 0.0,
|
| 37 |
+
# "bos_token_id": 0,
|
| 38 |
+
# "dtype": "bfloat16",
|
| 39 |
+
# "eos_token_id": 0,
|
| 40 |
+
# "head_dim": 64,
|
| 41 |
+
# "hidden_act": "silu",
|
| 42 |
+
# "hidden_size": 576,
|
| 43 |
+
# "initializer_range": 0.041666666666666664,
|
| 44 |
+
# "intermediate_size": 1536,
|
| 45 |
+
# "is_llama_config": true,
|
| 46 |
+
# "max_position_embeddings": 8192,
|
| 47 |
+
# "mlp_bias": false,
|
| 48 |
+
# "model_type": "llama",
|
| 49 |
+
# "num_attention_heads": 9,
|
| 50 |
+
# "num_hidden_layers": 30,
|
| 51 |
+
# "num_key_value_heads": 3,
|
| 52 |
+
# "pretraining_tp": 1,
|
| 53 |
+
# "rms_norm_eps": 1e-05,
|
| 54 |
+
# "rope_interleaved": false,
|
| 55 |
+
# "rope_scaling": null,
|
| 56 |
+
# "rope_theta": 100000,
|
| 57 |
+
# "tie_word_embeddings": true,
|
| 58 |
+
# "transformers_version": "4.57.3",
|
| 59 |
+
# "use_cache": true,
|
| 60 |
+
# "vocab_size": 49152
|
| 61 |
+
# }
|
| 62 |
+
# =========================
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class SmolConfig:
|
| 66 |
+
# Core dimensions
|
| 67 |
+
vocab_size: int = 49152 # from HF config
|
| 68 |
+
hidden_size: int = 576 # "hidden_size"
|
| 69 |
+
intermediate_size: int = 1536 # "intermediate_size"
|
| 70 |
+
num_hidden_layers: int = 30 # "num_hidden_layers"
|
| 71 |
+
num_attention_heads: int = 9 # "num_attention_heads"
|
| 72 |
+
num_key_value_heads: int = 3 # "num_key_value_heads"
|
| 73 |
+
max_position_embeddings: int = 8192 # "max_position_embeddings"
|
| 74 |
+
|
| 75 |
+
# Positional / RoPE
|
| 76 |
+
rope_theta: float = 100000.0 # "rope_theta"
|
| 77 |
+
|
| 78 |
+
# Norm / numerical
|
| 79 |
+
rms_norm_eps: float = 1e-5 # "rms_norm_eps"
|
| 80 |
+
|
| 81 |
+
# Biases
|
| 82 |
+
attention_bias: bool = False # "attention_bias"
|
| 83 |
+
mlp_bias: bool = False # "mlp_bias"
|
| 84 |
+
|
| 85 |
+
# Misc
|
| 86 |
+
dtype: torch.dtype = torch.bfloat16
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def head_dim(self) -> int:
|
| 90 |
+
# Should be 64 for SmolLM2-135M (576 / 9).
|
| 91 |
+
return self.hidden_size // self.num_attention_heads # 576 / 9 = 64
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def from_hf(cls, hf_config) -> "SmolConfig":
|
| 95 |
+
"""
|
| 96 |
+
Helper to build this config from a transformers LlamaConfig (Which is the config for the HuggingFace SmolLM2-135M model).
|
| 97 |
+
Example:
|
| 98 |
+
from transformers import AutoConfig
|
| 99 |
+
hf = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 100 |
+
cfg = SmolConfig.from_hf(hf)
|
| 101 |
+
And then pass this config to this function call to set the config for the model.
|
| 102 |
+
"""
|
| 103 |
+
return cls(
|
| 104 |
+
vocab_size=hf_config.vocab_size,
|
| 105 |
+
hidden_size=hf_config.hidden_size,
|
| 106 |
+
intermediate_size=hf_config.intermediate_size,
|
| 107 |
+
num_hidden_layers=hf_config.num_hidden_layers,
|
| 108 |
+
num_attention_heads=hf_config.num_attention_heads,
|
| 109 |
+
num_key_value_heads=getattr(hf_config, "num_key_value_heads",
|
| 110 |
+
hf_config.num_attention_heads),
|
| 111 |
+
max_position_embeddings=hf_config.max_position_embeddings,
|
| 112 |
+
rope_theta=getattr(hf_config, "rope_theta", 10000.0),
|
| 113 |
+
rms_norm_eps=hf_config.rms_norm_eps,
|
| 114 |
+
attention_bias=getattr(hf_config, "attention_bias", False),
|
| 115 |
+
mlp_bias=getattr(hf_config, "mlp_bias", False),
|
| 116 |
+
dtype=torch.bfloat16, # SmolLM2 uses bfloat16
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# =========================
|
| 120 |
+
# 2. RMSNorm
|
| 121 |
+
# =========================
|
| 122 |
+
|
| 123 |
+
class RMSNorm(nn.Module):
|
| 124 |
+
"""
|
| 125 |
+
Root Mean Square Layer Normalization (RMSNorm)
|
| 126 |
+
Used in LLaMA / SmolLM2 instead of LayerNorm.
|
| 127 |
+
"""
|
| 128 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.eps = eps
|
| 131 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 132 |
+
|
| 133 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 134 |
+
# x: (..., dim)
|
| 135 |
+
# rms = sqrt(mean(x^2)), but we can use rsqrt for stability
|
| 136 |
+
norm = x.pow(2).mean(dim=-1, keepdim=True)
|
| 137 |
+
x = x * torch.rsqrt(norm + self.eps)
|
| 138 |
+
return self.weight * x
|
| 139 |
+
|
| 140 |
+
# =========================
|
| 141 |
+
# 3. RoPE (Rotary Positional Embeddings)
|
| 142 |
+
# =========================
|
| 143 |
+
|
| 144 |
+
def rope_freqs(head_dim: int, base: float, device, dtype):
|
| 145 |
+
"""
|
| 146 |
+
Compute inverse frequencies for RoPE.
|
| 147 |
+
"""
|
| 148 |
+
half_dim = head_dim // 2
|
| 149 |
+
# Equivalent to: base^{ -2i / d }
|
| 150 |
+
freq_seq = torch.arange(half_dim, device=device, dtype=dtype)
|
| 151 |
+
inv_freq = 1.0 / (base ** (freq_seq / half_dim))
|
| 152 |
+
return inv_freq # shape: (half_dim,)
|
| 153 |
+
|
| 154 |
+
def build_rope_cache(
|
| 155 |
+
seq_len: int,
|
| 156 |
+
head_dim: int,
|
| 157 |
+
base: float,
|
| 158 |
+
device,
|
| 159 |
+
dtype,
|
| 160 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
+
"""
|
| 162 |
+
Build cosine and sine caches for RoPE.
|
| 163 |
+
Returns:
|
| 164 |
+
cos: (1, 1, seq_len, head_dim/2)
|
| 165 |
+
sin: (1, 1, seq_len, head_dim/2)
|
| 166 |
+
"""
|
| 167 |
+
inv_freq = rope_freqs(head_dim, base, device, dtype) # (half_dim,)
|
| 168 |
+
# Positions
|
| 169 |
+
t = torch.arange(seq_len, device=device, dtype=dtype) # (seq_len,)
|
| 170 |
+
freqs = torch.outer(t, inv_freq) # (seq_len, half_dim)
|
| 171 |
+
cos = freqs.cos()[None, None, :, :] # (1,1,seq_len,half_dim)
|
| 172 |
+
sin = freqs.sin()[None, None, :, :] # (1,1,seq_len,half_dim)
|
| 173 |
+
return cos, sin
|
| 174 |
+
|
| 175 |
+
def apply_rope(
|
| 176 |
+
x: torch.Tensor, # (B, n_head, T, head_dim)
|
| 177 |
+
cos: torch.Tensor,
|
| 178 |
+
sin: torch.Tensor,
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
"""
|
| 181 |
+
Apply RoPE to last dimension of x.
|
| 182 |
+
cos, sin are broadcast to match (..., head_dim/2).
|
| 183 |
+
"""
|
| 184 |
+
b, h, t, d = x.shape
|
| 185 |
+
half = d // 2
|
| 186 |
+
|
| 187 |
+
x1 = x[..., :half] # (B, n_head, T, head_dim/2)
|
| 188 |
+
x2 = x[..., half:] # (B, n_head, T, head_dim/2)
|
| 189 |
+
|
| 190 |
+
# cos/sin: (1,1,T,half) -> broadcast over B,h
|
| 191 |
+
cos_t = cos[..., :t, :]
|
| 192 |
+
sin_t = sin[..., :t, :]
|
| 193 |
+
|
| 194 |
+
x1_rot = x1 * cos_t - x2 * sin_t
|
| 195 |
+
x2_rot = x1 * sin_t + x2 * cos_t
|
| 196 |
+
|
| 197 |
+
return torch.cat([x1_rot, x2_rot], dim=-1) # (B, n_head, T, head_dim)
|
| 198 |
+
|
| 199 |
+
# =========================
|
| 200 |
+
# 4. Attention
|
| 201 |
+
# =========================
|
| 202 |
+
|
| 203 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 204 |
+
"""
|
| 205 |
+
LLaMA / SmolLM2-style attention with:
|
| 206 |
+
- Q heads = num_attention_heads
|
| 207 |
+
- K/V heads = num_key_value_heads (GQA/MQA)
|
| 208 |
+
- RoPE on Q and K
|
| 209 |
+
- Causal masking
|
| 210 |
+
"""
|
| 211 |
+
def __init__(self, config: SmolConfig):
|
| 212 |
+
super().__init__()
|
| 213 |
+
|
| 214 |
+
self.config = config
|
| 215 |
+
self.n_heads = config.num_attention_heads # 9
|
| 216 |
+
self.n_kv_heads = config.num_key_value_heads # 3
|
| 217 |
+
self.head_dim = config.head_dim # 64
|
| 218 |
+
self.hidden_size = config.hidden_size # 576
|
| 219 |
+
|
| 220 |
+
assert self.hidden_size == self.n_heads * self.head_dim
|
| 221 |
+
|
| 222 |
+
# Projections
|
| 223 |
+
self.q_proj = nn.Linear(
|
| 224 |
+
self.hidden_size,
|
| 225 |
+
self.n_heads * self.head_dim,
|
| 226 |
+
bias=config.attention_bias,
|
| 227 |
+
)
|
| 228 |
+
self.k_proj = nn.Linear(
|
| 229 |
+
self.hidden_size,
|
| 230 |
+
self.n_kv_heads * self.head_dim,
|
| 231 |
+
bias=config.attention_bias,
|
| 232 |
+
)
|
| 233 |
+
self.v_proj = nn.Linear(
|
| 234 |
+
self.hidden_size,
|
| 235 |
+
self.n_kv_heads * self.head_dim,
|
| 236 |
+
bias=config.attention_bias,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.o_proj = nn.Linear(
|
| 240 |
+
self.n_heads * self.head_dim,
|
| 241 |
+
self.hidden_size,
|
| 242 |
+
bias=config.attention_bias,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
x: torch.Tensor, # (B, T, C) or (B, 1, C) for inference
|
| 248 |
+
cos: torch.Tensor, # (1,1,T,head_dim/2) or (1,1,1,head_dim/2) for inference
|
| 249 |
+
sin: torch.Tensor, # (1,1,T,head_dim/2) or (1,1,1,head_dim/2) for inference
|
| 250 |
+
attention_mask: Optional[torch.Tensor] = None, # (B, T) or (B,1,1,T)
|
| 251 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (k_cache, v_cache)
|
| 252 |
+
use_cache: bool = False,
|
| 253 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 254 |
+
B, T, C = x.shape
|
| 255 |
+
|
| 256 |
+
# Projections: (B,T,C) -> (B,T,h,d) -> (B,h,T,d)
|
| 257 |
+
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,h*d) -> (B,T,h,d) -> (B,h,T,d)
|
| 258 |
+
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,k*d) -> (B,T,k,d) -> (B,k,T,d)
|
| 259 |
+
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,T,C) -> (B,T,v*d) -> (B,T,v,d) -> (B,v,T,d)
|
| 260 |
+
|
| 261 |
+
# Apply RoPE to Q and K
|
| 262 |
+
q = apply_rope(q, cos, sin) # (B, h, T, d)
|
| 263 |
+
k = apply_rope(k, cos, sin) # (B, n_kv_heads, T, d)
|
| 264 |
+
# v doesn't need RoPE
|
| 265 |
+
|
| 266 |
+
# If using KV cache, concatenate with past keys/values
|
| 267 |
+
if past_key_value is not None:
|
| 268 |
+
past_k, past_v = past_key_value
|
| 269 |
+
# past_k, past_v: (B, n_kv_heads, past_len, head_dim)
|
| 270 |
+
k = torch.cat([past_k, k], dim=2) # (B, n_kv_heads, past_len + T, head_dim)
|
| 271 |
+
v = torch.cat([past_v, v], dim=2) # (B, n_kv_heads, past_len + T, head_dim)
|
| 272 |
+
seq_len = k.shape[2]
|
| 273 |
+
else:
|
| 274 |
+
seq_len = T
|
| 275 |
+
|
| 276 |
+
# Store k, v for cache (before GQA expansion)
|
| 277 |
+
k_cache = k # (B, n_kv_heads, seq_len, head_dim)
|
| 278 |
+
v_cache = v # (B, n_kv_heads, seq_len, head_dim)
|
| 279 |
+
|
| 280 |
+
# GQA: expand K/V if num_kv_heads < num_heads
|
| 281 |
+
if self.n_kv_heads != self.n_heads:
|
| 282 |
+
repeat_factor = self.n_heads // self.n_kv_heads
|
| 283 |
+
k = k.repeat_interleave(repeat_factor, dim=1) # (B, n_kv_heads, seq_len, d) -> (B, n_heads, seq_len, d)
|
| 284 |
+
v = v.repeat_interleave(repeat_factor, dim=1) # (B, n_kv_heads, seq_len, d) -> (B, n_heads, seq_len, d)
|
| 285 |
+
|
| 286 |
+
# Attention scores: (B,h,T,d) @ (B,h,d,seq_len) -> (B,h,T,seq_len)
|
| 287 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 288 |
+
|
| 289 |
+
# Causal mask: prevent attending to future tokens
|
| 290 |
+
# For inference with KV cache, we only need to mask the current position
|
| 291 |
+
if past_key_value is None:
|
| 292 |
+
# Full sequence: mask all future positions
|
| 293 |
+
causal_mask = torch.full(
|
| 294 |
+
(T, T), float("-inf"), device=x.device, dtype=x.dtype
|
| 295 |
+
).triu(1) # upper triangle (i < j)
|
| 296 |
+
scores = scores + causal_mask.unsqueeze(0).unsqueeze(0) # (B,h,T,T) + (1,1,T,T) -> (B,h,T,T)
|
| 297 |
+
else:
|
| 298 |
+
# With KV cache: only mask positions beyond current (shouldn't happen, but safety)
|
| 299 |
+
# Since we're generating one token at a time, T=1, and we attend to all past + current
|
| 300 |
+
pass
|
| 301 |
+
|
| 302 |
+
# Optional attention mask (e.g., padding). Should be additive (0 or -inf).
|
| 303 |
+
if attention_mask is not None:
|
| 304 |
+
# Expect attention_mask as (B, 1, 1, seq_len) or (B, seq_len)
|
| 305 |
+
if attention_mask.dim() == 2:
|
| 306 |
+
# (B, seq_len) -> (B,1,1,seq_len)
|
| 307 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 308 |
+
# Adjust mask shape if needed
|
| 309 |
+
if attention_mask.shape[-1] != seq_len:
|
| 310 |
+
# For inference, we might need to extend the mask
|
| 311 |
+
if past_key_value is not None:
|
| 312 |
+
# Extend mask to include past positions (all 0s for past, current mask for new token)
|
| 313 |
+
past_len = past_k.shape[2]
|
| 314 |
+
extended_mask = torch.zeros(B, 1, 1, seq_len, device=attention_mask.device, dtype=attention_mask.dtype)
|
| 315 |
+
extended_mask[..., past_len:] = attention_mask[..., -T:]
|
| 316 |
+
attention_mask = extended_mask
|
| 317 |
+
scores = scores + attention_mask
|
| 318 |
+
|
| 319 |
+
# Softmax over last dim (seq_len)
|
| 320 |
+
probs = F.softmax(scores, dim=-1) # (B,h,T,seq_len) -> (B,h,T,seq_len)
|
| 321 |
+
|
| 322 |
+
# Weighted sum of values
|
| 323 |
+
out = torch.matmul(probs, v) # (B,h,T,seq_len) @ (B,h,seq_len,d) -> (B,h,T,d)
|
| 324 |
+
|
| 325 |
+
# Reshape back: (B,T,C)
|
| 326 |
+
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,h,T,d) -> (B,T,h,d) -> (B,T,h*d) -> (B,T,C)
|
| 327 |
+
out = self.o_proj(out) # (B,T,C) -> (B,T,C)
|
| 328 |
+
|
| 329 |
+
# Return output and optionally the new KV cache
|
| 330 |
+
present_key_value = None
|
| 331 |
+
if use_cache:
|
| 332 |
+
# Return k_cache, v_cache (before GQA expansion, after RoPE)
|
| 333 |
+
present_key_value = (k_cache, v_cache)
|
| 334 |
+
|
| 335 |
+
return out, present_key_value
|
| 336 |
+
|
| 337 |
+
# =========================
|
| 338 |
+
# 5. MLP (SwiGLU)
|
| 339 |
+
# =========================
|
| 340 |
+
class SmolMLP(nn.Module):
|
| 341 |
+
"""
|
| 342 |
+
SwiGLU MLP:
|
| 343 |
+
z = W1(x) -> split -> (x1, x2)
|
| 344 |
+
out = W2( SiLU(x1) * x2 )
|
| 345 |
+
"""
|
| 346 |
+
def __init__(self, config: SmolConfig):
|
| 347 |
+
super().__init__()
|
| 348 |
+
|
| 349 |
+
self.fc1 = nn.Linear(
|
| 350 |
+
config.hidden_size,
|
| 351 |
+
2 * config.intermediate_size, # for SwiGLU split (2 x 1536 = 3072)
|
| 352 |
+
bias=config.mlp_bias,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self.fc2 = nn.Linear(
|
| 356 |
+
config.intermediate_size, # 1536
|
| 357 |
+
config.hidden_size, # 576
|
| 358 |
+
bias=config.mlp_bias,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 362 |
+
x = self.fc1(x)# (B,T,C) -> (B,T,2*intermediate_size) -> (B,T,1536*2) -> (B,T,3072)
|
| 363 |
+
x1, x2 = x.chunk(2, dim=-1) # (B,T,2*intermediate_size) = (B,T,3072) -> (B,T,intermediate), (B,T,intermediate) = (B,T,1536), (B,T,1536)
|
| 364 |
+
return self.fc2(F.silu(x1) * x2) # (B,T,intermediate) * (B,T,intermediate) -> (B,T,intermediate) -> (B,T,hidden_size) = (B,T,576)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# =========================
|
| 368 |
+
# 6. Transformer Block
|
| 369 |
+
# =========================
|
| 370 |
+
class SmolBlock(nn.Module):
|
| 371 |
+
def __init__(self, config: SmolConfig):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 374 |
+
self.attn = MultiHeadSelfAttention(config)
|
| 375 |
+
self.mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 376 |
+
self.mlp = SmolMLP(config)
|
| 377 |
+
|
| 378 |
+
def forward(
|
| 379 |
+
self,
|
| 380 |
+
x: torch.Tensor,
|
| 381 |
+
cos: torch.Tensor,
|
| 382 |
+
sin: torch.Tensor,
|
| 383 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 384 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 385 |
+
use_cache: bool = False,
|
| 386 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 387 |
+
# Pre-norm + residual for attention
|
| 388 |
+
attn_out, present_key_value = self.attn(
|
| 389 |
+
self.attn_norm(x), cos, sin, attention_mask, past_key_value, use_cache
|
| 390 |
+
)
|
| 391 |
+
x = x + attn_out
|
| 392 |
+
# Pre-norm + residual for MLP
|
| 393 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 394 |
+
return x, present_key_value
|
| 395 |
+
|
| 396 |
+
# =============================================
|
| 397 |
+
# 7. Top-level SmolLM2-135M Model Architecture
|
| 398 |
+
# SmolLM2 follows the LLaMA-style decoder-only Transformer architecture.
|
| 399 |
+
# =============================================
|
| 400 |
+
class SmolLM2(nn.Module):
|
| 401 |
+
"""
|
| 402 |
+
SmolLM2-135M-style LLaMA decoder-only language model.
|
| 403 |
+
|
| 404 |
+
Usage:
|
| 405 |
+
cfg = SmolConfig()
|
| 406 |
+
model = SmolLM2(cfg)
|
| 407 |
+
|
| 408 |
+
input_ids: LongTensor (B, T)
|
| 409 |
+
logits = model(input_ids)
|
| 410 |
+
"""
|
| 411 |
+
def __init__(self, config: SmolConfig):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.config = config
|
| 414 |
+
|
| 415 |
+
self.embed_tokens = nn.Embedding(
|
| 416 |
+
config.vocab_size,
|
| 417 |
+
config.hidden_size,
|
| 418 |
+
) # (Vocab_Size, Hidden_Size) (49152 x 576)
|
| 419 |
+
|
| 420 |
+
self.layers = nn.ModuleList(
|
| 421 |
+
[SmolBlock(config) for _ in range(config.num_hidden_layers)]
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 425 |
+
|
| 426 |
+
self.lm_head = nn.Linear(
|
| 427 |
+
config.hidden_size,
|
| 428 |
+
config.vocab_size,
|
| 429 |
+
bias=False,
|
| 430 |
+
) # (Hidden_Size, Vocab_Size) (576 x 49152)
|
| 431 |
+
|
| 432 |
+
# tie weights
|
| 433 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 434 |
+
|
| 435 |
+
def forward(
|
| 436 |
+
self,
|
| 437 |
+
input_ids: torch.Tensor, # (B, T)
|
| 438 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 439 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 440 |
+
use_cache: bool = False,
|
| 441 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 442 |
+
B, T = input_ids.shape
|
| 443 |
+
|
| 444 |
+
# For inference with KV cache, we might have T=1
|
| 445 |
+
if past_key_values is None:
|
| 446 |
+
assert T <= self.config.max_position_embeddings, (
|
| 447 |
+
f"Sequence length {T} exceeds max_position_embeddings "
|
| 448 |
+
f"{self.config.max_position_embeddings}"
|
| 449 |
+
)
|
| 450 |
+
seq_len = T
|
| 451 |
+
else:
|
| 452 |
+
# With KV cache, current sequence length is past_len + T
|
| 453 |
+
past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
|
| 454 |
+
seq_len = past_len + T
|
| 455 |
+
assert seq_len <= self.config.max_position_embeddings, (
|
| 456 |
+
f"Total sequence length {seq_len} exceeds max_position_embeddings "
|
| 457 |
+
f"{self.config.max_position_embeddings}"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Embedding
|
| 461 |
+
x = self.embed_tokens(input_ids) # (B,T) -> (B,T,C)
|
| 462 |
+
|
| 463 |
+
# RoPE cache - build for the full sequence length (past + current)
|
| 464 |
+
cos, sin = build_rope_cache(
|
| 465 |
+
seq_len=seq_len,
|
| 466 |
+
head_dim=self.config.head_dim,
|
| 467 |
+
base=self.config.rope_theta,
|
| 468 |
+
device=x.device,
|
| 469 |
+
dtype=x.dtype,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# If using KV cache, we only need cos/sin for current positions
|
| 473 |
+
if past_key_values is not None:
|
| 474 |
+
past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
|
| 475 |
+
# Slice to get only the current positions for RoPE
|
| 476 |
+
cos = cos[..., past_len:, :]
|
| 477 |
+
sin = sin[..., past_len:, :]
|
| 478 |
+
|
| 479 |
+
# Layers
|
| 480 |
+
present_key_values = [] if use_cache else None
|
| 481 |
+
for i, layer in enumerate(self.layers):
|
| 482 |
+
past_kv = past_key_values[i] if past_key_values is not None else None
|
| 483 |
+
x, present_kv = layer(x, cos, sin, attention_mask, past_kv, use_cache)
|
| 484 |
+
if use_cache:
|
| 485 |
+
present_key_values.append(present_kv)
|
| 486 |
+
|
| 487 |
+
# Final norm + lm head
|
| 488 |
+
x = self.norm(x)
|
| 489 |
+
logits = self.lm_head(x) # (B,T,C) -> (B,T,vocab_size)
|
| 490 |
+
return logits, present_key_values
|
| 491 |
+
|
| 492 |
+
@torch.no_grad()
|
| 493 |
+
def generate(
|
| 494 |
+
self,
|
| 495 |
+
input_ids: torch.Tensor,
|
| 496 |
+
max_new_tokens: int = 100,
|
| 497 |
+
temperature: float = 1.0,
|
| 498 |
+
top_k: Optional[int] = None,
|
| 499 |
+
top_p: Optional[float] = None,
|
| 500 |
+
eos_token_id: Optional[int] = None,
|
| 501 |
+
) -> torch.Tensor:
|
| 502 |
+
"""
|
| 503 |
+
Generate text using KV cache for efficient inference.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
input_ids: (B, T) input token ids
|
| 507 |
+
max_new_tokens: maximum number of new tokens to generate
|
| 508 |
+
temperature: sampling temperature
|
| 509 |
+
top_k: top-k sampling (keep top k tokens)
|
| 510 |
+
top_p: nucleus sampling (keep tokens with cumulative probability <= top_p)
|
| 511 |
+
eos_token_id: end-of-sequence token id (stop generation when encountered)
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
generated_ids: (B, T + max_new_tokens) generated token ids
|
| 515 |
+
"""
|
| 516 |
+
self.eval()
|
| 517 |
+
device = input_ids.device
|
| 518 |
+
B, T = input_ids.shape
|
| 519 |
+
|
| 520 |
+
# Start with input_ids
|
| 521 |
+
generated_ids = input_ids.clone()
|
| 522 |
+
past_key_values = None
|
| 523 |
+
|
| 524 |
+
for step in range(max_new_tokens):
|
| 525 |
+
# Forward pass with KV cache
|
| 526 |
+
# On first iteration, use full input_ids. On subsequent iterations, use only last token
|
| 527 |
+
if past_key_values is None:
|
| 528 |
+
# First iteration: process full sequence
|
| 529 |
+
current_input = generated_ids
|
| 530 |
+
else:
|
| 531 |
+
# Subsequent iterations: only process the last generated token
|
| 532 |
+
current_input = generated_ids[:, -1:]
|
| 533 |
+
|
| 534 |
+
logits, past_key_values = self.forward(
|
| 535 |
+
input_ids=current_input,
|
| 536 |
+
past_key_values=past_key_values,
|
| 537 |
+
use_cache=True,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# Get logits for the last token (always the last position in logits)
|
| 541 |
+
next_token_logits = logits[:, -1, :] / temperature
|
| 542 |
+
|
| 543 |
+
# Apply top-k filtering
|
| 544 |
+
if top_k is not None:
|
| 545 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
| 546 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 547 |
+
|
| 548 |
+
# Apply top-p (nucleus) filtering
|
| 549 |
+
if top_p is not None:
|
| 550 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 551 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 552 |
+
|
| 553 |
+
# Remove tokens with cumulative probability above the threshold
|
| 554 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 555 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 556 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 557 |
+
|
| 558 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 559 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 560 |
+
|
| 561 |
+
# Sample next token
|
| 562 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 563 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 564 |
+
|
| 565 |
+
# Append to generated sequence
|
| 566 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 567 |
+
|
| 568 |
+
# Check for EOS token
|
| 569 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 570 |
+
break
|
| 571 |
+
|
| 572 |
+
return generated_ids
|
| 573 |
+
|
| 574 |
+
# =========================
|
| 575 |
+
# 8. Quick self-test
|
| 576 |
+
# =========================
|
| 577 |
+
if __name__ == "__main__":
|
| 578 |
+
# Tiny sanity check: runs a forward pass on random input
|
| 579 |
+
cfg = SmolConfig()
|
| 580 |
+
model = SmolLM2(cfg)
|
| 581 |
+
|
| 582 |
+
B, T = 2, 16
|
| 583 |
+
x = torch.randint(0, cfg.vocab_size, (B, T))
|
| 584 |
+
|
| 585 |
+
with torch.no_grad():
|
| 586 |
+
logits, _ = model(x)
|
| 587 |
+
|
| 588 |
+
print("Input shape :", x.shape)
|
| 589 |
+
print("Logits shape:", logits.shape) # should be (2, 16, vocab_size)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "smollm-135"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"lightning>=2.6.0",
|
| 9 |
+
"tensorboard>=2.20.0",
|
| 10 |
+
"torch>=2.9.1",
|
| 11 |
+
"torchinfo>=1.8.0",
|
| 12 |
+
"torchmetrics>=1.8.2",
|
| 13 |
+
"torchsummary>=1.5.1",
|
| 14 |
+
"torchvision>=0.24.1",
|
| 15 |
+
"tqdm>=4.67.1",
|
| 16 |
+
"transformers>=4.57.3",
|
| 17 |
+
]
|
test_model_implementation.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
|
| 5 |
+
from model import SmolLM2, SmolConfig # your implementation
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
PRETRAINED_NAME = "HuggingFaceTB/SmolLM2-135M"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_custom_model():
|
| 12 |
+
"""Create our SmolLM2 using HF config to ensure identical hyperparams."""
|
| 13 |
+
hf_cfg = AutoConfig.from_pretrained(PRETRAINED_NAME)
|
| 14 |
+
cfg = SmolConfig.from_hf(hf_cfg)
|
| 15 |
+
model = SmolLM2(cfg)
|
| 16 |
+
return model, cfg
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_hf_model():
|
| 20 |
+
"""Load reference HF model."""
|
| 21 |
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
| 22 |
+
PRETRAINED_NAME,
|
| 23 |
+
torch_dtype=torch.float32, # use float32 for easier comparison
|
| 24 |
+
)
|
| 25 |
+
hf_model.eval()
|
| 26 |
+
return hf_model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_weights_from_hf(custom_model: SmolLM2, hf_model: AutoModelForCausalLM):
|
| 30 |
+
"""
|
| 31 |
+
Map HF LlamaForCausalLM weights into our SmolLM2 model.
|
| 32 |
+
|
| 33 |
+
- HF model structure: hf_model.model (LlamaModel) + hf_model.lm_head
|
| 34 |
+
- Our model: embed_tokens, layers, norm, lm_head
|
| 35 |
+
"""
|
| 36 |
+
hf_state = hf_model.state_dict()
|
| 37 |
+
custom_state = custom_model.state_dict()
|
| 38 |
+
|
| 39 |
+
# 1. Embeddings
|
| 40 |
+
custom_state["embed_tokens.weight"] = hf_state["model.embed_tokens.weight"]
|
| 41 |
+
|
| 42 |
+
# 2. Per-layer mappings
|
| 43 |
+
num_layers = custom_model.config.num_hidden_layers
|
| 44 |
+
|
| 45 |
+
for i in range(num_layers):
|
| 46 |
+
# Norms
|
| 47 |
+
custom_state[f"layers.{i}.attn_norm.weight"] = hf_state[
|
| 48 |
+
f"model.layers.{i}.input_layernorm.weight"
|
| 49 |
+
]
|
| 50 |
+
custom_state[f"layers.{i}.mlp_norm.weight"] = hf_state[
|
| 51 |
+
f"model.layers.{i}.post_attention_layernorm.weight"
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# Attention projections
|
| 55 |
+
custom_state[f"layers.{i}.attn.q_proj.weight"] = hf_state[
|
| 56 |
+
f"model.layers.{i}.self_attn.q_proj.weight"
|
| 57 |
+
]
|
| 58 |
+
custom_state[f"layers.{i}.attn.k_proj.weight"] = hf_state[
|
| 59 |
+
f"model.layers.{i}.self_attn.k_proj.weight"
|
| 60 |
+
]
|
| 61 |
+
custom_state[f"layers.{i}.attn.v_proj.weight"] = hf_state[
|
| 62 |
+
f"model.layers.{i}.self_attn.v_proj.weight"
|
| 63 |
+
]
|
| 64 |
+
custom_state[f"layers.{i}.attn.o_proj.weight"] = hf_state[
|
| 65 |
+
f"model.layers.{i}.self_attn.o_proj.weight"
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# MLP: HF has gate_proj, up_proj, down_proj
|
| 69 |
+
gate = hf_state[f"model.layers.{i}.mlp.gate_proj.weight"]
|
| 70 |
+
up = hf_state[f"model.layers.{i}.mlp.up_proj.weight"]
|
| 71 |
+
down = hf_state[f"model.layers.{i}.mlp.down_proj.weight"]
|
| 72 |
+
|
| 73 |
+
# Our fc1 is [gate; up] concatenated along output dim (dim=0)
|
| 74 |
+
custom_state[f"layers.{i}.mlp.fc1.weight"] = torch.cat([gate, up], dim=0)
|
| 75 |
+
# Our fc2 is down_proj
|
| 76 |
+
custom_state[f"layers.{i}.mlp.fc2.weight"] = down
|
| 77 |
+
|
| 78 |
+
# 3. Final norm
|
| 79 |
+
custom_state["norm.weight"] = hf_state["model.norm.weight"]
|
| 80 |
+
|
| 81 |
+
# 4. LM head (tied with embeddings, but we still load it)
|
| 82 |
+
custom_state["lm_head.weight"] = hf_state["lm_head.weight"]
|
| 83 |
+
|
| 84 |
+
# Now load into the model
|
| 85 |
+
missing, unexpected = custom_model.load_state_dict(custom_state, strict=False)
|
| 86 |
+
return missing, unexpected
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_weight_loading():
|
| 90 |
+
"""
|
| 91 |
+
1. Build custom SmolLM2 model (our implementation).
|
| 92 |
+
2. Build HF reference model.
|
| 93 |
+
3. Load HF weights into our model via mapping.
|
| 94 |
+
4. Run a small test prompt and compare logits.
|
| 95 |
+
"""
|
| 96 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 97 |
+
print(f"Using device: {device}")
|
| 98 |
+
|
| 99 |
+
print("🟦 Building custom model...")
|
| 100 |
+
custom_model, cfg = build_custom_model()
|
| 101 |
+
custom_model.to(device)
|
| 102 |
+
custom_model.eval()
|
| 103 |
+
|
| 104 |
+
print("🟦 Building HF reference model...")
|
| 105 |
+
hf_model = build_hf_model()
|
| 106 |
+
hf_model.to(device)
|
| 107 |
+
|
| 108 |
+
print("🟦 Mapping HF weights into custom model...")
|
| 109 |
+
missing, unexpected = load_weights_from_hf(custom_model, hf_model)
|
| 110 |
+
|
| 111 |
+
print(f"Missing keys : {len(missing)}")
|
| 112 |
+
print(f"Unexpected keys : {len(unexpected)}")
|
| 113 |
+
if missing:
|
| 114 |
+
print(" Missing examples:", missing[:5])
|
| 115 |
+
if unexpected:
|
| 116 |
+
print(" Unexpected examples:", unexpected[:5])
|
| 117 |
+
|
| 118 |
+
if len(missing) > 0:
|
| 119 |
+
print("⚠️ There are missing keys; mapping may be incomplete.")
|
| 120 |
+
else:
|
| 121 |
+
print("✅ All expected parameters were assigned from HF weights.")
|
| 122 |
+
|
| 123 |
+
# 5. Test with a dummy input
|
| 124 |
+
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_NAME)
|
| 125 |
+
prompt = "Hello, how are you?"
|
| 126 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 127 |
+
|
| 128 |
+
print("🟦 Running HF model forward...")
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
hf_logits = hf_model(**inputs).logits # (B, T, V)
|
| 131 |
+
|
| 132 |
+
print("🟦 Running custom model forward...")
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
custom_logits, _ = custom_model(inputs["input_ids"])
|
| 135 |
+
|
| 136 |
+
# 6. Compare logits
|
| 137 |
+
# align dtypes
|
| 138 |
+
hf_logits = hf_logits.to(torch.float32)
|
| 139 |
+
custom_logits = custom_logits.to(torch.float32)
|
| 140 |
+
|
| 141 |
+
diff = torch.abs(hf_logits - custom_logits).max().item()
|
| 142 |
+
print(f"🔍 Max absolute difference between logits: {diff:.6f}")
|
| 143 |
+
|
| 144 |
+
if diff < 1e-4:
|
| 145 |
+
print("✅ SUCCESS: Outputs match very closely. Implementation is correct.")
|
| 146 |
+
elif diff < 1e-2:
|
| 147 |
+
print("🟡 Outputs are close but not identical; check for small implementation differences (e.g., RoPE details).")
|
| 148 |
+
else:
|
| 149 |
+
print("❌ Outputs differ significantly. Some part of the implementation is likely off.")
|
| 150 |
+
|
| 151 |
+
# 7. Print predictions from both models
|
| 152 |
+
print("\n📝 Predictions:")
|
| 153 |
+
print(f"Prompt: '{prompt}'")
|
| 154 |
+
|
| 155 |
+
# Get predicted token IDs (argmax on vocabulary dimension)
|
| 156 |
+
hf_predicted_ids = hf_logits.argmax(dim=-1) # (B, T)
|
| 157 |
+
custom_predicted_ids = custom_logits.argmax(dim=-1) # (B, T)
|
| 158 |
+
|
| 159 |
+
# Get the next token prediction (last position)
|
| 160 |
+
hf_next_token_id = hf_predicted_ids[0, -1].item()
|
| 161 |
+
custom_next_token_id = custom_predicted_ids[0, -1].item()
|
| 162 |
+
|
| 163 |
+
# Decode the next token
|
| 164 |
+
hf_next_token = tokenizer.decode([hf_next_token_id])
|
| 165 |
+
custom_next_token = tokenizer.decode([custom_next_token_id])
|
| 166 |
+
|
| 167 |
+
print(f"HF Model prediction (next token): '{hf_next_token}' (token_id: {hf_next_token_id})")
|
| 168 |
+
print(f"Custom Model prediction (next token): '{custom_next_token}' (token_id: {custom_next_token_id})")
|
| 169 |
+
|
| 170 |
+
# Also show full sequence predictions for comparison
|
| 171 |
+
hf_full_prediction = tokenizer.decode(hf_predicted_ids[0])
|
| 172 |
+
custom_full_prediction = tokenizer.decode(custom_predicted_ids[0])
|
| 173 |
+
print(f"\nHF Model full sequence prediction: '{hf_full_prediction}'")
|
| 174 |
+
print(f"Custom Model full sequence prediction: '{custom_full_prediction}'")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
if len(sys.argv) < 2:
|
| 179 |
+
print("Usage: python test_model_implementation.py test_weight_loading")
|
| 180 |
+
sys.exit(1)
|
| 181 |
+
|
| 182 |
+
mode = sys.argv[1]
|
| 183 |
+
|
| 184 |
+
if mode == "test_weight_loading":
|
| 185 |
+
test_weight_loading()
|
| 186 |
+
else:
|
| 187 |
+
print(f"Unknown mode: {mode}")
|
train.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for SmolLM2-135M using PyTorch Lightning.
|
| 3 |
+
|
| 4 |
+
Training strategy from paper:
|
| 5 |
+
- AdamW optimizer with (β1, β2) = (0.9, 0.95)
|
| 6 |
+
- Warmup Stable Decay (WSD) learning rate schedule:
|
| 7 |
+
- 2,000-step warmup phase
|
| 8 |
+
- Peak learning rate: 5.0 × 10^-4 (stable phase)
|
| 9 |
+
- Decay phase: reduce LR to zero over 10% of total training steps
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.utils.data import Dataset, DataLoader
|
| 20 |
+
import lightning as L
|
| 21 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 22 |
+
from lightning.pytorch.loggers import TensorBoardLogger
|
| 23 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 24 |
+
|
| 25 |
+
from model import SmolLM2, SmolConfig
|
| 26 |
+
|
| 27 |
+
# Setup logging
|
| 28 |
+
def setup_logging(log_dir: Path):
|
| 29 |
+
"""Setup text file logging."""
|
| 30 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 36 |
+
handlers=[
|
| 37 |
+
logging.FileHandler(log_file),
|
| 38 |
+
logging.StreamHandler(sys.stdout)
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
return logging.getLogger(__name__), log_file
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TextDataset(Dataset):
|
| 45 |
+
"""Dataset for text data."""
|
| 46 |
+
def __init__(self, text_file: str, tokenizer, block_size: int = 512):
|
| 47 |
+
self.tokenizer = tokenizer
|
| 48 |
+
self.block_size = block_size
|
| 49 |
+
|
| 50 |
+
# Read and tokenize text
|
| 51 |
+
with open(text_file, 'r', encoding='utf-8') as f:
|
| 52 |
+
text = f.read()
|
| 53 |
+
|
| 54 |
+
# Tokenize
|
| 55 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 56 |
+
self.data = torch.tensor(tokens, dtype=torch.long)
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.data) - self.block_size
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx):
|
| 62 |
+
chunk = self.data[idx:idx + self.block_size + 1]
|
| 63 |
+
x = chunk[:-1]
|
| 64 |
+
y = chunk[1:]
|
| 65 |
+
return x, y
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class WarmupStableDecayLR(L.Callback):
|
| 69 |
+
"""
|
| 70 |
+
Warmup Stable Decay (WSD) learning rate schedule.
|
| 71 |
+
- Warmup: 2000 steps in paper, Since only training for 5000 steps, we will use 20% of total steps as warmup steps (1000 steps)
|
| 72 |
+
- Stable: maintain peak LR
|
| 73 |
+
- Decay: reduce to zero over 10% of total steps
|
| 74 |
+
"""
|
| 75 |
+
def __init__(self, warmup_steps: int = 2000, peak_lr: float = 5e-4, total_steps: int = 5000):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.warmup_steps = warmup_steps
|
| 78 |
+
self.peak_lr = peak_lr
|
| 79 |
+
self.total_steps = total_steps
|
| 80 |
+
self.decay_steps = int(0.1 * total_steps) # 10% of total steps
|
| 81 |
+
self.stable_steps = total_steps - warmup_steps - self.decay_steps
|
| 82 |
+
|
| 83 |
+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
| 84 |
+
current_step = trainer.global_step
|
| 85 |
+
|
| 86 |
+
if current_step < self.warmup_steps:
|
| 87 |
+
# Warmup phase: linear increase
|
| 88 |
+
lr = self.peak_lr * (current_step / self.warmup_steps)
|
| 89 |
+
elif current_step < self.warmup_steps + self.stable_steps:
|
| 90 |
+
# Stable phase: maintain peak LR
|
| 91 |
+
lr = self.peak_lr
|
| 92 |
+
else:
|
| 93 |
+
# Decay phase: linear decrease to zero
|
| 94 |
+
decay_start = self.warmup_steps + self.stable_steps
|
| 95 |
+
decay_progress = (current_step - decay_start) / self.decay_steps
|
| 96 |
+
lr = self.peak_lr * (1.0 - decay_progress)
|
| 97 |
+
|
| 98 |
+
# Update learning rate
|
| 99 |
+
optimizer = pl_module.optimizers()
|
| 100 |
+
if isinstance(optimizer, torch.optim.Optimizer):
|
| 101 |
+
for param_group in optimizer.param_groups:
|
| 102 |
+
param_group['lr'] = lr
|
| 103 |
+
else:
|
| 104 |
+
# If it's a list or other structure
|
| 105 |
+
for opt in optimizer:
|
| 106 |
+
for param_group in opt.param_groups:
|
| 107 |
+
param_group['lr'] = lr
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class SmolLM2Module(L.LightningModule):
|
| 111 |
+
"""PyTorch Lightning module for SmolLM2 training."""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
config: SmolConfig,
|
| 116 |
+
tokenizer,
|
| 117 |
+
block_size: int = 512,
|
| 118 |
+
warmup_steps: int = 2000,
|
| 119 |
+
peak_lr: float = 5e-4,
|
| 120 |
+
total_steps: int = 5000,
|
| 121 |
+
predict_every: int = 500,
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.save_hyperparameters(ignore=['tokenizer'])
|
| 125 |
+
self.config = config
|
| 126 |
+
self.tokenizer = tokenizer
|
| 127 |
+
self.block_size = block_size
|
| 128 |
+
self.warmup_steps = warmup_steps
|
| 129 |
+
self.peak_lr = peak_lr
|
| 130 |
+
self.total_steps = total_steps
|
| 131 |
+
self.predict_every = predict_every
|
| 132 |
+
|
| 133 |
+
# Initialize model
|
| 134 |
+
self.model = SmolLM2(config)
|
| 135 |
+
|
| 136 |
+
# Loss function
|
| 137 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 138 |
+
|
| 139 |
+
# For generation
|
| 140 |
+
self.example_prompt = "First Citizen:"
|
| 141 |
+
|
| 142 |
+
def forward(self, input_ids, attention_mask=None):
|
| 143 |
+
logits, present_key_values = self.model(input_ids, attention_mask=attention_mask, use_cache=False)
|
| 144 |
+
return logits
|
| 145 |
+
|
| 146 |
+
def training_step(self, batch, batch_idx):
|
| 147 |
+
x, y = batch
|
| 148 |
+
logits = self.forward(x)
|
| 149 |
+
|
| 150 |
+
# Reshape for loss calculation
|
| 151 |
+
loss = self.criterion(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 152 |
+
|
| 153 |
+
# Logging
|
| 154 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
| 155 |
+
|
| 156 |
+
# Generate text every predict_every steps
|
| 157 |
+
if (self.global_step + 1) % self.predict_every == 0:
|
| 158 |
+
# Log scalar loss to text log so it shows up with generations
|
| 159 |
+
logger.info(f"Step {self.global_step + 1} | train_loss={loss.item():.4f}")
|
| 160 |
+
self.generate_and_log()
|
| 161 |
+
|
| 162 |
+
return loss
|
| 163 |
+
|
| 164 |
+
def generate_and_log(self):
|
| 165 |
+
"""Generate text and log it."""
|
| 166 |
+
self.model.eval()
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
# Tokenize prompt
|
| 169 |
+
prompt_ids = self.tokenizer.encode(
|
| 170 |
+
self.example_prompt,
|
| 171 |
+
return_tensors='pt',
|
| 172 |
+
add_special_tokens=False
|
| 173 |
+
).to(self.device)
|
| 174 |
+
|
| 175 |
+
# Generate
|
| 176 |
+
generated_ids = self.model.generate(
|
| 177 |
+
prompt_ids,
|
| 178 |
+
max_new_tokens=50,
|
| 179 |
+
temperature=0.8,
|
| 180 |
+
top_k=50,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Decode
|
| 184 |
+
generated_text = self.tokenizer.decode(
|
| 185 |
+
generated_ids[0].cpu().tolist(),
|
| 186 |
+
skip_special_tokens=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Log to console and file
|
| 190 |
+
logger.info(f"\n{'='*80}")
|
| 191 |
+
logger.info(f"Step {self.global_step + 1} - Generated text:")
|
| 192 |
+
logger.info(f"{generated_text}")
|
| 193 |
+
logger.info(f"{'='*80}\n")
|
| 194 |
+
|
| 195 |
+
self.model.train()
|
| 196 |
+
|
| 197 |
+
def configure_optimizers(self):
|
| 198 |
+
"""Configure optimizer with AdamW."""
|
| 199 |
+
optimizer = torch.optim.AdamW(
|
| 200 |
+
self.parameters(),
|
| 201 |
+
lr=self.peak_lr, # Will be adjusted by scheduler
|
| 202 |
+
betas=(0.9, 0.95),
|
| 203 |
+
weight_decay=0.01,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# WSD scheduler (implemented as callback)
|
| 207 |
+
return optimizer
|
| 208 |
+
|
| 209 |
+
def on_train_start(self):
|
| 210 |
+
"""Log model summary at training start."""
|
| 211 |
+
# Count parameters
|
| 212 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 213 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 214 |
+
|
| 215 |
+
logger.info("\n" + "="*80)
|
| 216 |
+
logger.info("MODEL SUMMARY")
|
| 217 |
+
logger.info("="*80)
|
| 218 |
+
logger.info(f"Model: SmolLM2-135M")
|
| 219 |
+
logger.info(f"Total parameters: {total_params:,}")
|
| 220 |
+
logger.info(f"Trainable parameters: {trainable_params:,}")
|
| 221 |
+
logger.info(f"Block size: {self.block_size}")
|
| 222 |
+
logger.info(f"Warmup steps: {self.warmup_steps}")
|
| 223 |
+
logger.info(f"Peak learning rate: {self.peak_lr}")
|
| 224 |
+
logger.info(f"Total training steps: {self.total_steps}")
|
| 225 |
+
logger.info(f"Predict every: {self.predict_every} steps")
|
| 226 |
+
logger.info("="*80 + "\n")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def main():
|
| 230 |
+
# Configuration
|
| 231 |
+
data_file = Path("../data/input.txt").resolve()
|
| 232 |
+
output_dir = Path("./checkpoints")
|
| 233 |
+
log_dir = Path("./logs")
|
| 234 |
+
block_size = 512
|
| 235 |
+
batch_size = 4
|
| 236 |
+
num_workers = 8
|
| 237 |
+
max_steps = 3500
|
| 238 |
+
predict_every = 500
|
| 239 |
+
resume_from_checkpoint = "checkpoints/smollm2-step=01500-train_loss=3.6240.ckpt" # Set to checkpoint path to resume, or None for fresh training
|
| 240 |
+
|
| 241 |
+
# Training hyperparameters from paper
|
| 242 |
+
warmup_steps = 1000
|
| 243 |
+
peak_lr = 5e-4
|
| 244 |
+
total_steps = max_steps
|
| 245 |
+
|
| 246 |
+
# Setup logging
|
| 247 |
+
global logger
|
| 248 |
+
logger, log_file = setup_logging(log_dir)
|
| 249 |
+
logger.info(f"Logging to: {log_file}")
|
| 250 |
+
|
| 251 |
+
# Load tokenizer
|
| 252 |
+
logger.info("Loading tokenizer...")
|
| 253 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 254 |
+
if tokenizer.pad_token is None:
|
| 255 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 256 |
+
|
| 257 |
+
# Allow SmolConfig to be deserialized from Lightning checkpoints when torch.load
|
| 258 |
+
# uses weights_only=True default (torch>=2.6). This is safe because the class
|
| 259 |
+
# is defined locally in this file.
|
| 260 |
+
try:
|
| 261 |
+
torch.serialization.add_safe_globals([SmolConfig]) # type: ignore[attr-defined]
|
| 262 |
+
except Exception:
|
| 263 |
+
# Fallback for torch versions without add_safe_globals; Lightning will still
|
| 264 |
+
# load normally when weights_only=False.
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
# Load config and create model config
|
| 268 |
+
logger.info("Loading model config...")
|
| 269 |
+
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
| 270 |
+
config = SmolConfig.from_hf(hf_config)
|
| 271 |
+
|
| 272 |
+
# Create dataset
|
| 273 |
+
logger.info(f"Loading dataset from: {data_file}")
|
| 274 |
+
dataset = TextDataset(data_file, tokenizer, block_size=block_size)
|
| 275 |
+
dataloader = DataLoader(
|
| 276 |
+
dataset,
|
| 277 |
+
batch_size=batch_size,
|
| 278 |
+
shuffle=True,
|
| 279 |
+
num_workers=num_workers,
|
| 280 |
+
pin_memory=True,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Create Lightning module
|
| 284 |
+
logger.info("Initializing model...")
|
| 285 |
+
model = SmolLM2Module(
|
| 286 |
+
config=config,
|
| 287 |
+
tokenizer=tokenizer,
|
| 288 |
+
block_size=block_size,
|
| 289 |
+
warmup_steps=warmup_steps,
|
| 290 |
+
peak_lr=peak_lr,
|
| 291 |
+
total_steps=total_steps,
|
| 292 |
+
predict_every=predict_every,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Additional callback to ensure checkpoint at final step
|
| 296 |
+
class FinalCheckpointCallback(L.Callback):
|
| 297 |
+
def on_train_end(self, trainer, pl_module):
|
| 298 |
+
# Save final checkpoint
|
| 299 |
+
final_checkpoint_path = output_dir / f"smollm2-final-step-{trainer.global_step:05d}.ckpt"
|
| 300 |
+
trainer.save_checkpoint(str(final_checkpoint_path))
|
| 301 |
+
logger.info(f"Final checkpoint saved: {final_checkpoint_path}")
|
| 302 |
+
|
| 303 |
+
final_checkpoint_callback = FinalCheckpointCallback()
|
| 304 |
+
|
| 305 |
+
# Setup callbacks
|
| 306 |
+
checkpoint_callback = ModelCheckpoint(
|
| 307 |
+
dirpath=output_dir,
|
| 308 |
+
filename='smollm2-{step:05d}-{train_loss:.4f}',
|
| 309 |
+
monitor='train_loss',
|
| 310 |
+
save_top_k=3,
|
| 311 |
+
mode='min',
|
| 312 |
+
every_n_train_steps=predict_every,
|
| 313 |
+
save_last=True,
|
| 314 |
+
save_on_train_epoch_end=False, # Save based on steps, not epochs
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
| 318 |
+
|
| 319 |
+
wsd_scheduler = WarmupStableDecayLR(
|
| 320 |
+
warmup_steps=warmup_steps,
|
| 321 |
+
peak_lr=peak_lr,
|
| 322 |
+
total_steps=total_steps,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Setup TensorBoard logger
|
| 326 |
+
tb_logger = TensorBoardLogger(
|
| 327 |
+
save_dir=log_dir,
|
| 328 |
+
name='tensorboard',
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Create trainer
|
| 332 |
+
trainer = L.Trainer(
|
| 333 |
+
max_steps=max_steps,
|
| 334 |
+
callbacks=[checkpoint_callback, lr_monitor, wsd_scheduler, final_checkpoint_callback],
|
| 335 |
+
logger=tb_logger,
|
| 336 |
+
accelerator='auto',
|
| 337 |
+
devices='auto',
|
| 338 |
+
# Set precision depending on device capabilities.
|
| 339 |
+
# bf16-mixed: CUDA; 32-true: others; MPS supports only 32-true.
|
| 340 |
+
precision='bf16-mixed' if torch.cuda.is_available() else '32-true',
|
| 341 |
+
gradient_clip_val=1.0,
|
| 342 |
+
log_every_n_steps=50,
|
| 343 |
+
enable_checkpointing=True,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# Train
|
| 347 |
+
logger.info("Starting training...")
|
| 348 |
+
if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
|
| 349 |
+
logger.info(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| 350 |
+
trainer.fit(model, dataloader, ckpt_path=resume_from_checkpoint)
|
| 351 |
+
else:
|
| 352 |
+
trainer.fit(model, dataloader)
|
| 353 |
+
|
| 354 |
+
logger.info("Training completed!")
|
| 355 |
+
logger.info(f"Best checkpoint: {checkpoint_callback.best_model_path}")
|
| 356 |
+
logger.info(f"Last checkpoint: {checkpoint_callback.last_model_path}")
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|