|
|
""" |
|
|
Training loop for SLM. |
|
|
|
|
|
Handles the complete training process including: |
|
|
- Mixed precision training |
|
|
- Gradient accumulation |
|
|
- Checkpointing |
|
|
- Logging |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import json |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import Optional, Dict, Any, Callable |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .loss import LanguageModelingLoss, compute_perplexity, compute_accuracy |
|
|
from .optimizer import create_optimizer, create_scheduler, clip_grad_norm |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Configuration for training.""" |
|
|
|
|
|
|
|
|
learning_rate: float = 3e-4 |
|
|
weight_decay: float = 0.1 |
|
|
warmup_ratio: float = 0.1 |
|
|
min_lr_ratio: float = 0.1 |
|
|
max_grad_norm: float = 1.0 |
|
|
label_smoothing: float = 0.0 |
|
|
|
|
|
|
|
|
num_epochs: int = 5 |
|
|
gradient_accumulation_steps: int = 4 |
|
|
fp16: bool = True |
|
|
|
|
|
|
|
|
checkpoint_dir: str = "checkpoints" |
|
|
save_steps: int = 1000 |
|
|
save_total_limit: int = 3 |
|
|
|
|
|
|
|
|
eval_steps: int = 500 |
|
|
logging_steps: int = 10 |
|
|
|
|
|
|
|
|
early_stopping_patience: int = 5 |
|
|
early_stopping_threshold: float = 0.01 |
|
|
|
|
|
|
|
|
device: str = "auto" |
|
|
|
|
|
|
|
|
compile_model: bool = False |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
"""Training loop for SLM model.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
config: TrainingConfig, |
|
|
train_dataloader: DataLoader, |
|
|
val_dataloader: Optional[DataLoader] = None, |
|
|
wandb_project: Optional[str] = None, |
|
|
): |
|
|
"""Initialize trainer. |
|
|
|
|
|
Args: |
|
|
model: The model to train |
|
|
config: Training configuration |
|
|
train_dataloader: Training data loader |
|
|
val_dataloader: Optional validation data loader |
|
|
wandb_project: Optional W&B project name for logging |
|
|
""" |
|
|
self.config = config |
|
|
self.train_dataloader = train_dataloader |
|
|
self.val_dataloader = val_dataloader |
|
|
|
|
|
|
|
|
if config.device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device("cuda") |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
self.device = torch.device("mps") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
else: |
|
|
self.device = torch.device(config.device) |
|
|
|
|
|
print(f"Training on device: {self.device}") |
|
|
|
|
|
|
|
|
self.model = model.to(self.device) |
|
|
|
|
|
|
|
|
if hasattr(model, "config"): |
|
|
self.vocab_size = model.config.vocab_size |
|
|
else: |
|
|
self.vocab_size = model.embed_tokens.num_embeddings |
|
|
|
|
|
|
|
|
self.loss_fn = LanguageModelingLoss( |
|
|
vocab_size=self.vocab_size, |
|
|
label_smoothing=config.label_smoothing, |
|
|
) |
|
|
|
|
|
|
|
|
self.steps_per_epoch = len(train_dataloader) |
|
|
self.total_steps = self.steps_per_epoch * config.num_epochs |
|
|
self.total_steps = self.total_steps // config.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
self.optimizer = create_optimizer( |
|
|
model, |
|
|
learning_rate=config.learning_rate, |
|
|
weight_decay=config.weight_decay, |
|
|
) |
|
|
|
|
|
self.scheduler = create_scheduler( |
|
|
self.optimizer, |
|
|
num_training_steps=self.total_steps, |
|
|
warmup_ratio=config.warmup_ratio, |
|
|
min_lr_ratio=config.min_lr_ratio, |
|
|
) |
|
|
|
|
|
|
|
|
self.use_amp = config.fp16 and self.device.type == "cuda" |
|
|
self.scaler = GradScaler() if self.use_amp else None |
|
|
|
|
|
|
|
|
self.global_step = 0 |
|
|
self.epoch = 0 |
|
|
self.best_val_loss = float("inf") |
|
|
|
|
|
|
|
|
self.early_stopping_counter = 0 |
|
|
self.should_stop = False |
|
|
|
|
|
|
|
|
os.makedirs(config.checkpoint_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.wandb = None |
|
|
if wandb_project: |
|
|
try: |
|
|
import wandb |
|
|
wandb.init(project=wandb_project, config=config.to_dict()) |
|
|
self.wandb = wandb |
|
|
except ImportError: |
|
|
print("wandb not installed, skipping logging") |
|
|
|
|
|
def train(self) -> Dict[str, Any]: |
|
|
"""Run the full training loop. |
|
|
|
|
|
Returns: |
|
|
Dictionary with training results |
|
|
""" |
|
|
print(f"\n{'='*60}") |
|
|
print("STARTING TRAINING") |
|
|
print(f"{'='*60}") |
|
|
print(f"Total epochs: {self.config.num_epochs}") |
|
|
print(f"Steps per epoch: {self.steps_per_epoch}") |
|
|
print(f"Total optimization steps: {self.total_steps}") |
|
|
print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}") |
|
|
print(f"Mixed precision: {self.use_amp}") |
|
|
if self.config.early_stopping_patience > 0: |
|
|
print(f"Early stopping: patience={self.config.early_stopping_patience}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
training_start = time.time() |
|
|
|
|
|
|
|
|
start_epoch = self.epoch |
|
|
if start_epoch > 0: |
|
|
print(f"Resuming from epoch {start_epoch + 1}") |
|
|
|
|
|
for epoch in range(start_epoch, self.config.num_epochs): |
|
|
self.epoch = epoch |
|
|
epoch_loss = self._train_epoch() |
|
|
|
|
|
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs} - Loss: {epoch_loss:.4f}") |
|
|
|
|
|
|
|
|
if self.val_dataloader is not None: |
|
|
val_metrics = self.evaluate() |
|
|
print(f"Validation - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}") |
|
|
|
|
|
|
|
|
if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold: |
|
|
self.best_val_loss = val_metrics["loss"] |
|
|
self.early_stopping_counter = 0 |
|
|
self.save_checkpoint("best") |
|
|
print(f" New best model saved!") |
|
|
else: |
|
|
self.early_stopping_counter += 1 |
|
|
print(f" No improvement. Early stopping: {self.early_stopping_counter}/{self.config.early_stopping_patience}") |
|
|
|
|
|
if self.config.early_stopping_patience > 0 and self.early_stopping_counter >= self.config.early_stopping_patience: |
|
|
print(f"\nEarly stopping triggered after {self.early_stopping_counter} evaluations without improvement.") |
|
|
self.should_stop = True |
|
|
|
|
|
|
|
|
self.save_checkpoint(f"epoch_{epoch + 1}") |
|
|
|
|
|
|
|
|
if self.should_stop: |
|
|
print("Stopping training early.") |
|
|
break |
|
|
|
|
|
training_time = time.time() - training_start |
|
|
print(f"\n{'='*60}") |
|
|
print(f"TRAINING COMPLETE") |
|
|
print(f"Total time: {training_time / 3600:.2f} hours") |
|
|
print(f"Best validation loss: {self.best_val_loss:.4f}") |
|
|
if self.should_stop: |
|
|
print(f"Stopped early at epoch {self.epoch + 1}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
return { |
|
|
"total_steps": self.global_step, |
|
|
"training_time": training_time, |
|
|
"best_val_loss": self.best_val_loss, |
|
|
} |
|
|
|
|
|
def _train_epoch(self) -> float: |
|
|
"""Train for one epoch. |
|
|
|
|
|
Returns: |
|
|
Average training loss for the epoch |
|
|
""" |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
accumulated_loss = 0.0 |
|
|
num_accumulated_batches = 0 |
|
|
|
|
|
|
|
|
pbar = tqdm( |
|
|
enumerate(self.train_dataloader), |
|
|
total=len(self.train_dataloader), |
|
|
desc=f"Epoch {self.epoch + 1}", |
|
|
ncols=100, |
|
|
) |
|
|
|
|
|
for step, batch in pbar: |
|
|
|
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with autocast(enabled=self.use_amp): |
|
|
outputs = self.model(input_ids) |
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
logits = outputs |
|
|
elif hasattr(outputs, 'logits'): |
|
|
logits = outputs.logits |
|
|
else: |
|
|
logits = outputs[0] |
|
|
loss = self.loss_fn(logits, labels) |
|
|
loss = loss / self.config.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
if self.use_amp: |
|
|
self.scaler.scale(loss).backward() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
unscaled_loss = loss.item() * self.config.gradient_accumulation_steps |
|
|
accumulated_loss += unscaled_loss |
|
|
num_accumulated_batches += 1 |
|
|
total_loss += unscaled_loss |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0: |
|
|
|
|
|
if self.use_amp: |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
|
|
|
grad_norm = clip_grad_norm(self.model, self.config.max_grad_norm) |
|
|
|
|
|
|
|
|
if self.use_amp: |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
self.optimizer.step() |
|
|
|
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
self.global_step += 1 |
|
|
|
|
|
|
|
|
if self.global_step % self.config.logging_steps == 0: |
|
|
|
|
|
avg_loss = accumulated_loss / max(num_accumulated_batches, 1) |
|
|
lr = self.scheduler.get_last_lr()[0] |
|
|
|
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f'{avg_loss:.4f}', |
|
|
'lr': f'{lr:.2e}', |
|
|
'step': f'{self.global_step}/{self.total_steps}' |
|
|
}) |
|
|
|
|
|
tqdm.write( |
|
|
f"Step {self.global_step}/{self.total_steps} | " |
|
|
f"Loss: {avg_loss:.4f} | " |
|
|
f"LR: {lr:.2e} | " |
|
|
f"Grad: {grad_norm:.2f}" |
|
|
) |
|
|
|
|
|
if self.wandb: |
|
|
self.wandb.log({ |
|
|
"train/loss": avg_loss, |
|
|
"train/learning_rate": lr, |
|
|
"train/grad_norm": grad_norm, |
|
|
"train/epoch": self.epoch, |
|
|
}, step=self.global_step) |
|
|
|
|
|
|
|
|
accumulated_loss = 0.0 |
|
|
num_accumulated_batches = 0 |
|
|
|
|
|
|
|
|
if self.config.eval_steps > 0 and self.global_step % self.config.eval_steps == 0: |
|
|
if self.val_dataloader is not None: |
|
|
val_metrics = self.evaluate() |
|
|
print(f" Eval - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}") |
|
|
|
|
|
if self.wandb: |
|
|
self.wandb.log({ |
|
|
"eval/loss": val_metrics["loss"], |
|
|
"eval/perplexity": val_metrics["perplexity"], |
|
|
}, step=self.global_step) |
|
|
|
|
|
|
|
|
if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold: |
|
|
self.best_val_loss = val_metrics["loss"] |
|
|
self.early_stopping_counter = 0 |
|
|
self.save_checkpoint("best") |
|
|
print(f" New best model! Loss: {self.best_val_loss:.4f}") |
|
|
else: |
|
|
self.early_stopping_counter += 1 |
|
|
if self.config.early_stopping_patience > 0: |
|
|
print(f" No improvement ({self.early_stopping_counter}/{self.config.early_stopping_patience})") |
|
|
if self.early_stopping_counter >= self.config.early_stopping_patience: |
|
|
print(f"\n Early stopping triggered!") |
|
|
self.should_stop = True |
|
|
break |
|
|
|
|
|
|
|
|
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0: |
|
|
self.save_checkpoint(f"step_{self.global_step}") |
|
|
|
|
|
|
|
|
if self.should_stop: |
|
|
break |
|
|
|
|
|
return total_loss / max(num_batches, 1) |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self) -> Dict[str, float]: |
|
|
"""Evaluate the model on validation data. |
|
|
|
|
|
Returns: |
|
|
Dictionary with evaluation metrics |
|
|
""" |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
total_accuracy = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch in self.val_dataloader: |
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
with autocast(enabled=self.use_amp): |
|
|
outputs = self.model(input_ids) |
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
logits = outputs |
|
|
elif hasattr(outputs, 'logits'): |
|
|
logits = outputs.logits |
|
|
else: |
|
|
logits = outputs[0] |
|
|
loss = self.loss_fn(logits, labels) |
|
|
|
|
|
total_loss += loss.item() |
|
|
total_accuracy += compute_accuracy(logits, labels).item() |
|
|
num_batches += 1 |
|
|
|
|
|
self.model.train() |
|
|
|
|
|
avg_loss = total_loss / max(num_batches, 1) |
|
|
avg_accuracy = total_accuracy / max(num_batches, 1) |
|
|
|
|
|
return { |
|
|
"loss": avg_loss, |
|
|
"perplexity": compute_perplexity(torch.tensor(avg_loss)).item(), |
|
|
"accuracy": avg_accuracy, |
|
|
} |
|
|
|
|
|
def save_checkpoint(self, name: str): |
|
|
"""Save a checkpoint. |
|
|
|
|
|
Args: |
|
|
name: Checkpoint name (e.g., "best", "epoch_1", "step_1000") |
|
|
""" |
|
|
checkpoint_path = os.path.join(self.config.checkpoint_dir, name) |
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
|
|
|
|
|
model_path = os.path.join(checkpoint_path, "model.pt") |
|
|
torch.save(self.model.state_dict(), model_path) |
|
|
|
|
|
|
|
|
optimizer_path = os.path.join(checkpoint_path, "optimizer.pt") |
|
|
torch.save({ |
|
|
"optimizer": self.optimizer.state_dict(), |
|
|
"scheduler": self.scheduler.state_dict(), |
|
|
"global_step": self.global_step, |
|
|
"epoch": self.epoch, |
|
|
"best_val_loss": self.best_val_loss, |
|
|
"early_stopping_counter": self.early_stopping_counter, |
|
|
}, optimizer_path) |
|
|
|
|
|
|
|
|
config_path = os.path.join(checkpoint_path, "config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
print(f"Saved checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
self._cleanup_checkpoints() |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: str): |
|
|
"""Load a checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint directory |
|
|
""" |
|
|
|
|
|
model_path = os.path.join(checkpoint_path, "model.pt") |
|
|
state_dict = torch.load(model_path, map_location=self.device) |
|
|
|
|
|
|
|
|
if any(k.startswith("_orig_mod.") for k in state_dict.keys()): |
|
|
print(" Detected compiled model checkpoint, removing _orig_mod. prefix...") |
|
|
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
|
|
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
optimizer_path = os.path.join(checkpoint_path, "optimizer.pt") |
|
|
if os.path.exists(optimizer_path): |
|
|
state = torch.load(optimizer_path, map_location=self.device) |
|
|
self.optimizer.load_state_dict(state["optimizer"]) |
|
|
self.scheduler.load_state_dict(state["scheduler"]) |
|
|
self.global_step = state["global_step"] |
|
|
self.epoch = state["epoch"] |
|
|
self.best_val_loss = state.get("best_val_loss", float("inf")) |
|
|
self.early_stopping_counter = state.get("early_stopping_counter", 0) |
|
|
|
|
|
|
|
|
|
|
|
if "epoch_" in checkpoint_path: |
|
|
self.epoch += 1 |
|
|
print(f" Checkpoint was end-of-epoch, will start from epoch {self.epoch + 1}") |
|
|
|
|
|
print(f"Loaded checkpoint: {checkpoint_path}") |
|
|
print(f" Resuming from step {self.global_step}, epoch {self.epoch}") |
|
|
print(f" Best val loss so far: {self.best_val_loss:.4f}") |
|
|
|
|
|
def _cleanup_checkpoints(self): |
|
|
"""Remove old checkpoints to save disk space.""" |
|
|
if self.config.save_total_limit <= 0: |
|
|
return |
|
|
|
|
|
checkpoint_dir = Path(self.config.checkpoint_dir) |
|
|
checkpoints = sorted( |
|
|
[d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("step_")], |
|
|
key=lambda x: int(x.name.split("_")[1]), |
|
|
) |
|
|
|
|
|
|
|
|
while len(checkpoints) > self.config.save_total_limit: |
|
|
old_checkpoint = checkpoints.pop(0) |
|
|
print(f"Removing old checkpoint: {old_checkpoint}") |
|
|
import shutil |
|
|
shutil.rmtree(old_checkpoint) |
|
|
|