A newer version of the Gradio SDK is available:
6.1.0
SmolLM2-135M Training Guide
This directory contains the training code for SmolLM2-135M model.
Files
model.py: Model definition with KV cache support for inferencetrain.py: Main training script (trains for 5000 steps)- Run with checkpoint path to Resume training for 50 additional steps
Setup
Install required packages:
pip install torch lightning transformers tensorboard
Training
Phase 1: Initial Training (5000 steps)
Run the main training script:
python train.py
This will:
- Train the model for 5000 steps
- Generate text predictions every 500 steps
- Save checkpoints every 500 steps
- Log training metrics to TensorBoard and text file
- Save the final checkpoint at step 5000
Phase 2: Resume Training (50 additional steps)
After Phase 1 completes, run:
python train.py
But this time set the checkpoint path, and set steps as 50 to resume training for 50 additional steps. just to showcase that training is started where it stopped.
This will:
- Load the checkpoint from Phase 1
- Train for 50 additional steps
- Save the final checkpoint
Training Configuration
The training uses the following hyperparameters (from the SmolLM2 paper):
- Optimizer: AdamW with (β₁, β₂) = (0.9, 0.95)
- Learning Rate Schedule: Warmup Stable Decay (WSD)
- Warmup: 2000 steps
- Peak LR: 5.0 × 10⁻⁴
- Stable phase: maintains peak LR
- Decay: reduces to zero over 10% of total steps
- Block size: 512 tokens
- Batch size: 4
- Precision: bfloat16 (if GPU available), float32 otherwise
Outputs
- Checkpoints: Saved in
./checkpoints/ - TensorBoard logs: Saved in
./logs/tensorboard/ - Text logs: Saved in
./logs/training_*.log
Model Features
The model includes:
- KV Cache: Efficient inference using key-value caching
- Generation: Text generation with top-k and top-p sampling
- Checkpointing: Full state saving for resuming training
Usage Example
from model import SmolLM2, SmolConfig
from transformers import AutoTokenizer, AutoConfig
# Load config
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
config = SmolConfig.from_hf(hf_config)
# Create model
model = SmolLM2(config)
# Load checkpoint
checkpoint = torch.load("checkpoints/smollm2-00500-*.ckpt")
model.load_state_dict(checkpoint['state_dict'])
# Generate text
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
prompt = "First Citizen:"
input_ids = tokenizer.encode(prompt, return_tensors='pt')
generated_ids = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.8,
top_k=50,
)
generated_text = tokenizer.decode(generated_ids[0])
print(generated_text)