Sualeh Qureshi
Added Gradio app for HF space
58ae689
raw
history blame
12.5 kB
"""
Training script for SmolLM2-135M using PyTorch Lightning.
Training strategy from paper:
- AdamW optimizer with (β1, β2) = (0.9, 0.95)
- Warmup Stable Decay (WSD) learning rate schedule:
- 2,000-step warmup phase
- Peak learning rate: 5.0 × 10^-4 (stable phase)
- Decay phase: reduce LR to zero over 10% of total training steps
"""
import sys
import logging
from pathlib import Path
from datetime import datetime
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from transformers import AutoTokenizer, AutoConfig
from model import SmolLM2, SmolConfig
# Setup logging
def setup_logging(log_dir: Path):
"""Setup text file logging."""
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger(__name__), log_file
class TextDataset(Dataset):
"""Dataset for text data."""
def __init__(self, text_file: str, tokenizer, block_size: int = 512):
self.tokenizer = tokenizer
self.block_size = block_size
# Read and tokenize text
with open(text_file, 'r', encoding='utf-8') as f:
text = f.read()
# Tokenize
tokens = tokenizer.encode(text, add_special_tokens=False)
self.data = torch.tensor(tokens, dtype=torch.long)
def __len__(self):
return len(self.data) - self.block_size
def __getitem__(self, idx):
chunk = self.data[idx:idx + self.block_size + 1]
x = chunk[:-1]
y = chunk[1:]
return x, y
class WarmupStableDecayLR(L.Callback):
"""
Warmup Stable Decay (WSD) learning rate schedule.
- Warmup: 2000 steps in paper, Since only training for 5000 steps, we will use 20% of total steps as warmup steps (1000 steps)
- Stable: maintain peak LR
- Decay: reduce to zero over 10% of total steps
"""
def __init__(self, warmup_steps: int = 2000, peak_lr: float = 5e-4, total_steps: int = 5000):
super().__init__()
self.warmup_steps = warmup_steps
self.peak_lr = peak_lr
self.total_steps = total_steps
self.decay_steps = int(0.1 * total_steps) # 10% of total steps
self.stable_steps = total_steps - warmup_steps - self.decay_steps
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
current_step = trainer.global_step
if current_step < self.warmup_steps:
# Warmup phase: linear increase
lr = self.peak_lr * (current_step / self.warmup_steps)
elif current_step < self.warmup_steps + self.stable_steps:
# Stable phase: maintain peak LR
lr = self.peak_lr
else:
# Decay phase: linear decrease to zero
decay_start = self.warmup_steps + self.stable_steps
decay_progress = (current_step - decay_start) / self.decay_steps
lr = self.peak_lr * (1.0 - decay_progress)
# Update learning rate
optimizer = pl_module.optimizers()
if isinstance(optimizer, torch.optim.Optimizer):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
# If it's a list or other structure
for opt in optimizer:
for param_group in opt.param_groups:
param_group['lr'] = lr
class SmolLM2Module(L.LightningModule):
"""PyTorch Lightning module for SmolLM2 training."""
def __init__(
self,
config: SmolConfig,
tokenizer,
block_size: int = 512,
warmup_steps: int = 2000,
peak_lr: float = 5e-4,
total_steps: int = 5000,
predict_every: int = 500,
):
super().__init__()
self.save_hyperparameters(ignore=['tokenizer'])
self.config = config
self.tokenizer = tokenizer
self.block_size = block_size
self.warmup_steps = warmup_steps
self.peak_lr = peak_lr
self.total_steps = total_steps
self.predict_every = predict_every
# Initialize model
self.model = SmolLM2(config)
# Loss function
self.criterion = nn.CrossEntropyLoss()
# For generation
self.example_prompt = "First Citizen:"
def forward(self, input_ids, attention_mask=None):
logits, present_key_values = self.model(input_ids, attention_mask=attention_mask, use_cache=False)
return logits
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
# Reshape for loss calculation
loss = self.criterion(logits.view(-1, logits.size(-1)), y.view(-1))
# Logging
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# Generate text every predict_every steps
if (self.global_step + 1) % self.predict_every == 0:
# Log scalar loss to text log so it shows up with generations
logger.info(f"Step {self.global_step + 1} | train_loss={loss.item():.4f}")
self.generate_and_log()
return loss
def generate_and_log(self):
"""Generate text and log it."""
self.model.eval()
with torch.no_grad():
# Tokenize prompt
prompt_ids = self.tokenizer.encode(
self.example_prompt,
return_tensors='pt',
add_special_tokens=False
).to(self.device)
# Generate
generated_ids = self.model.generate(
prompt_ids,
max_new_tokens=50,
temperature=0.8,
top_k=50,
)
# Decode
generated_text = self.tokenizer.decode(
generated_ids[0].cpu().tolist(),
skip_special_tokens=True
)
# Log to console and file
logger.info(f"\n{'='*80}")
logger.info(f"Step {self.global_step + 1} - Generated text:")
logger.info(f"{generated_text}")
logger.info(f"{'='*80}\n")
self.model.train()
def configure_optimizers(self):
"""Configure optimizer with AdamW."""
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.peak_lr, # Will be adjusted by scheduler
betas=(0.9, 0.95),
weight_decay=0.01,
)
# WSD scheduler (implemented as callback)
return optimizer
def on_train_start(self):
"""Log model summary at training start."""
# Count parameters
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info("\n" + "="*80)
logger.info("MODEL SUMMARY")
logger.info("="*80)
logger.info(f"Model: SmolLM2-135M")
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Block size: {self.block_size}")
logger.info(f"Warmup steps: {self.warmup_steps}")
logger.info(f"Peak learning rate: {self.peak_lr}")
logger.info(f"Total training steps: {self.total_steps}")
logger.info(f"Predict every: {self.predict_every} steps")
logger.info("="*80 + "\n")
def main():
# Configuration
data_file = Path("../data/input.txt").resolve()
output_dir = Path("./checkpoints")
log_dir = Path("./logs")
block_size = 512
batch_size = 4
num_workers = 8
max_steps = 5000
predict_every = 500
resume_from_checkpoint = "checkpoints/smollm2-step=03500-train_loss=0.1352.ckpt" # Set to checkpoint path to resume, or None for fresh training
# Training hyperparameters from paper
warmup_steps = 1000
peak_lr = 5e-4
total_steps = max_steps
# Setup logging
global logger
logger, log_file = setup_logging(log_dir)
logger.info(f"Logging to: {log_file}")
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Allow SmolConfig to be deserialized from Lightning checkpoints when torch.load
# uses weights_only=True default (torch>=2.6). This is safe because the class
# is defined locally in this file.
try:
torch.serialization.add_safe_globals([SmolConfig]) # type: ignore[attr-defined]
except Exception:
# Fallback for torch versions without add_safe_globals; Lightning will still
# load normally when weights_only=False.
pass
# Load config and create model config
logger.info("Loading model config...")
hf_config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
config = SmolConfig.from_hf(hf_config)
# Create dataset
logger.info(f"Loading dataset from: {data_file}")
dataset = TextDataset(data_file, tokenizer, block_size=block_size)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
# Create Lightning module
logger.info("Initializing model...")
model = SmolLM2Module(
config=config,
tokenizer=tokenizer,
block_size=block_size,
warmup_steps=warmup_steps,
peak_lr=peak_lr,
total_steps=total_steps,
predict_every=predict_every,
)
# Additional callback to ensure checkpoint at final step
class FinalCheckpointCallback(L.Callback):
def on_train_end(self, trainer, pl_module):
# Save final checkpoint
final_checkpoint_path = output_dir / f"smollm2-final-step-{trainer.global_step:05d}.ckpt"
trainer.save_checkpoint(str(final_checkpoint_path))
logger.info(f"Final checkpoint saved: {final_checkpoint_path}")
final_checkpoint_callback = FinalCheckpointCallback()
# Setup callbacks
checkpoint_callback = ModelCheckpoint(
dirpath=output_dir,
filename='smollm2-{step:05d}-{train_loss:.4f}',
monitor='train_loss',
save_top_k=3,
mode='min',
every_n_train_steps=predict_every,
save_last=True,
save_on_train_epoch_end=False, # Save based on steps, not epochs
)
lr_monitor = LearningRateMonitor(logging_interval='step')
wsd_scheduler = WarmupStableDecayLR(
warmup_steps=warmup_steps,
peak_lr=peak_lr,
total_steps=total_steps,
)
# Setup TensorBoard logger
tb_logger = TensorBoardLogger(
save_dir=log_dir,
name='tensorboard',
)
# Create trainer
trainer = L.Trainer(
max_steps=max_steps,
callbacks=[checkpoint_callback, lr_monitor, wsd_scheduler, final_checkpoint_callback],
logger=tb_logger,
accelerator='auto',
devices='auto',
# Set precision depending on device capabilities.
# bf16-mixed: CUDA; 32-true: others; MPS supports only 32-true.
precision='bf16-mixed' if torch.cuda.is_available() else '32-true',
gradient_clip_val=1.0,
log_every_n_steps=50,
enable_checkpointing=True,
)
# Train
logger.info("Starting training...")
if resume_from_checkpoint and Path(resume_from_checkpoint).exists():
logger.info(f"Resuming from checkpoint: {resume_from_checkpoint}")
trainer.fit(model, dataloader, ckpt_path=resume_from_checkpoint)
else:
trainer.fit(model, dataloader)
logger.info("Training completed!")
logger.info(f"Best checkpoint: {checkpoint_callback.best_model_path}")
logger.info(f"Last checkpoint: {checkpoint_callback.last_model_path}")
if __name__ == "__main__":
main()