Sualeh Qureshi
Commited the training code and model file
c175ce3
# SmolLM2-135M Training Guide
This directory contains the training code for SmolLM2-135M model.
## Files
- `model.py`: Model definition with KV cache support for inference
- `train.py`: Main training script (trains for 5000 steps)
- Run with checkpoint path to Resume training for 50 additional steps
## Setup
Install required packages:
```bash
pip install torch lightning transformers tensorboard
```
## Training
### Phase 1: Initial Training (5000 steps)
Run the main training script:
```bash
python train.py
```
This will:
- Train the model for 5000 steps
- Generate text predictions every 500 steps
- Save checkpoints every 500 steps
- Log training metrics to TensorBoard and text file
- Save the final checkpoint at step 5000
### Phase 2: Resume Training (50 additional steps)
After Phase 1 completes, run:
```bash
python train.py
```
But this time set the checkpoint path, and set steps as 50 to resume training for 50 additional steps. just to showcase that training is started where it stopped.
This will:
- Load the checkpoint from Phase 1
- Train for 50 additional steps
- Save the final checkpoint
## Training Configuration
The training uses the following hyperparameters (from the SmolLM2 paper):
- **Optimizer**: AdamW with (β₁, β₂) = (0.9, 0.95)
- **Learning Rate Schedule**: Warmup Stable Decay (WSD)
- Warmup: 2000 steps
- Peak LR: 5.0 × 10⁻⁴
- Stable phase: maintains peak LR
- Decay: reduces to zero over 10% of total steps
- **Block size**: 512 tokens
- **Batch size**: 4
- **Precision**: bfloat16 (if GPU available), float32 otherwise
## Outputs
- **Checkpoints**: Saved in `./checkpoints/`
- **TensorBoard logs**: Saved in `./logs/tensorboard/`
- **Text logs**: Saved in `./logs/training_*.log`
## Model Features
The model includes:
- **KV Cache**: Efficient inference using key-value caching
- **Generation**: Text generation with top-k and top-p sampling
- **Checkpointing**: Full state saving for resuming training
## Usage Example
```python
from model import SmolLM2, SmolConfig
from transformers import AutoTokenizer, AutoConfig
# Load config
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
config = SmolConfig.from_hf(hf_config)
# Create model
model = SmolLM2(config)
# Load checkpoint
checkpoint = torch.load("checkpoints/smollm2-00500-*.ckpt")
model.load_state_dict(checkpoint['state_dict'])
# Generate text
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
prompt = "First Citizen:"
input_ids = tokenizer.encode(prompt, return_tensors='pt')
generated_ids = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.8,
top_k=50,
)
generated_text = tokenizer.decode(generated_ids[0])
print(generated_text)
```