| # SmolLM2-135M Training Guide | |
| This directory contains the training code for SmolLM2-135M model. | |
| ## Files | |
| - `model.py`: Model definition with KV cache support for inference | |
| - `train.py`: Main training script (trains for 5000 steps) | |
| - Run with checkpoint path to Resume training for 50 additional steps | |
| ## Setup | |
| Install required packages: | |
| ```bash | |
| pip install torch lightning transformers tensorboard | |
| ``` | |
| ## Training | |
| ### Phase 1: Initial Training (5000 steps) | |
| Run the main training script: | |
| ```bash | |
| python train.py | |
| ``` | |
| This will: | |
| - Train the model for 5000 steps | |
| - Generate text predictions every 500 steps | |
| - Save checkpoints every 500 steps | |
| - Log training metrics to TensorBoard and text file | |
| - Save the final checkpoint at step 5000 | |
| ### Phase 2: Resume Training (50 additional steps) | |
| After Phase 1 completes, run: | |
| ```bash | |
| python train.py | |
| ``` | |
| But this time set the checkpoint path, and set steps as 50 to resume training for 50 additional steps. just to showcase that training is started where it stopped. | |
| This will: | |
| - Load the checkpoint from Phase 1 | |
| - Train for 50 additional steps | |
| - Save the final checkpoint | |
| ## Training Configuration | |
| The training uses the following hyperparameters (from the SmolLM2 paper): | |
| - **Optimizer**: AdamW with (β₁, β₂) = (0.9, 0.95) | |
| - **Learning Rate Schedule**: Warmup Stable Decay (WSD) | |
| - Warmup: 2000 steps | |
| - Peak LR: 5.0 × 10⁻⁴ | |
| - Stable phase: maintains peak LR | |
| - Decay: reduces to zero over 10% of total steps | |
| - **Block size**: 512 tokens | |
| - **Batch size**: 4 | |
| - **Precision**: bfloat16 (if GPU available), float32 otherwise | |
| ## Outputs | |
| - **Checkpoints**: Saved in `./checkpoints/` | |
| - **TensorBoard logs**: Saved in `./logs/tensorboard/` | |
| - **Text logs**: Saved in `./logs/training_*.log` | |
| ## Model Features | |
| The model includes: | |
| - **KV Cache**: Efficient inference using key-value caching | |
| - **Generation**: Text generation with top-k and top-p sampling | |
| - **Checkpointing**: Full state saving for resuming training | |
| ## Usage Example | |
| ```python | |
| from model import SmolLM2, SmolConfig | |
| from transformers import AutoTokenizer, AutoConfig | |
| # Load config | |
| hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
| config = SmolConfig.from_hf(hf_config) | |
| # Create model | |
| model = SmolLM2(config) | |
| # Load checkpoint | |
| checkpoint = torch.load("checkpoints/smollm2-00500-*.ckpt") | |
| model.load_state_dict(checkpoint['state_dict']) | |
| # Generate text | |
| tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
| prompt = "First Citizen:" | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
| generated_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=100, | |
| temperature=0.8, | |
| top_k=50, | |
| ) | |
| generated_text = tokenizer.decode(generated_ids[0]) | |
| print(generated_text) | |
| ``` | |