""" MiniMind Training Utilities Standard training loop with mixed precision and gradient accumulation. """ import os import math import time from typing import Optional, Dict, Any from pathlib import Path from dataclasses import dataclass import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.cuda.amp import GradScaler, autocast import sys sys.path.insert(0, str(Path(__file__).parent.parent)) from configs.model_config import Mind2Config @dataclass class TrainingConfig: """Training configuration.""" # Optimization learning_rate: float = 3e-4 min_learning_rate: float = 3e-5 weight_decay: float = 0.1 beta1: float = 0.9 beta2: float = 0.95 grad_clip: float = 1.0 warmup_steps: int = 1000 # Training num_epochs: int = 3 batch_size: int = 8 gradient_accumulation_steps: int = 4 max_steps: Optional[int] = None # Mixed precision use_amp: bool = True amp_dtype: str = "float16" # float16 or bfloat16 # Checkpointing save_steps: int = 1000 eval_steps: int = 500 output_dir: str = "./outputs" resume_from: Optional[str] = None # Logging log_steps: int = 10 wandb_project: Optional[str] = None class Mind2Trainer: """Trainer for MiniMind models.""" def __init__( self, model: nn.Module, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, config: Optional[TrainingConfig] = None, ): self.model = model self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.config = config or TrainingConfig() self.device = next(model.parameters()).device self.global_step = 0 self.epoch = 0 # Setup optimizer self.optimizer = self._create_optimizer() self.scheduler = self._create_scheduler() # Mixed precision self.scaler = GradScaler() if self.config.use_amp else None self.amp_dtype = torch.float16 if self.config.amp_dtype == "float16" else torch.bfloat16 # Output directory Path(self.config.output_dir).mkdir(parents=True, exist_ok=True) def _create_optimizer(self) -> torch.optim.Optimizer: """Create AdamW optimizer with weight decay.""" decay_params = [] no_decay_params = [] for name, param in self.model.named_parameters(): if not param.requires_grad: continue if "bias" in name or "norm" in name or "layernorm" in name: no_decay_params.append(param) else: decay_params.append(param) optimizer_groups = [ {"params": decay_params, "weight_decay": self.config.weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] return torch.optim.AdamW( optimizer_groups, lr=self.config.learning_rate, betas=(self.config.beta1, self.config.beta2), ) def _create_scheduler(self): """Create cosine annealing scheduler with warmup.""" total_steps = self._get_total_steps() def lr_lambda(step): if step < self.config.warmup_steps: return step / max(1, self.config.warmup_steps) progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps) return max( self.config.min_learning_rate / self.config.learning_rate, 0.5 * (1.0 + math.cos(math.pi * progress)) ) return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) def _get_total_steps(self) -> int: if self.config.max_steps: return self.config.max_steps steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps return steps_per_epoch * self.config.num_epochs def train(self) -> Dict[str, float]: """Main training loop.""" self.model.train() total_steps = self._get_total_steps() print(f"Starting training for {total_steps} steps") print(f" Batch size: {self.config.batch_size}") print(f" Gradient accumulation: {self.config.gradient_accumulation_steps}") print(f" Effective batch size: {self.config.batch_size * self.config.gradient_accumulation_steps}") running_loss = 0.0 start_time = time.time() for epoch in range(self.config.num_epochs): self.epoch = epoch for step, batch in enumerate(self.train_dataloader): loss = self._training_step(batch) running_loss += loss if (step + 1) % self.config.gradient_accumulation_steps == 0: self._optimizer_step() self.global_step += 1 # Logging if self.global_step % self.config.log_steps == 0: avg_loss = running_loss / self.config.log_steps elapsed = time.time() - start_time tokens_per_sec = ( self.config.batch_size * self.config.gradient_accumulation_steps * batch["input_ids"].shape[1] * self.config.log_steps / elapsed ) print( f"Step {self.global_step}/{total_steps} | " f"Loss: {avg_loss:.4f} | " f"LR: {self.scheduler.get_last_lr()[0]:.2e} | " f"Tokens/s: {tokens_per_sec:.0f}" ) running_loss = 0.0 start_time = time.time() # Evaluation if self.eval_dataloader and self.global_step % self.config.eval_steps == 0: eval_loss = self.evaluate() print(f"Eval Loss: {eval_loss:.4f}") self.model.train() # Save checkpoint if self.global_step % self.config.save_steps == 0: self.save_checkpoint() if self.config.max_steps and self.global_step >= self.config.max_steps: break if self.config.max_steps and self.global_step >= self.config.max_steps: break self.save_checkpoint(final=True) return {"final_loss": running_loss} def _training_step(self, batch: Dict[str, torch.Tensor]) -> float: """Single training step.""" input_ids = batch["input_ids"].to(self.device) attention_mask = batch.get("attention_mask", None) if attention_mask is not None: attention_mask = attention_mask.to(self.device) labels = batch["labels"].to(self.device) if self.config.use_amp: with autocast(dtype=self.amp_dtype): loss, _, _, _ = self.model(input_ids, attention_mask, labels) loss = loss / self.config.gradient_accumulation_steps self.scaler.scale(loss).backward() else: loss, _, _, _ = self.model(input_ids, attention_mask, labels) loss = loss / self.config.gradient_accumulation_steps loss.backward() return loss.item() * self.config.gradient_accumulation_steps def _optimizer_step(self): """Optimizer step with gradient clipping.""" if self.config.use_amp: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) if self.config.use_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() @torch.no_grad() def evaluate(self) -> float: """Evaluate model on eval dataset.""" self.model.eval() total_loss = 0.0 num_batches = 0 for batch in self.eval_dataloader: input_ids = batch["input_ids"].to(self.device) attention_mask = batch.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(self.device) labels = batch["labels"].to(self.device) loss, _, _, _ = self.model(input_ids, attention_mask, labels) total_loss += loss.item() num_batches += 1 return total_loss / max(1, num_batches) def save_checkpoint(self, final: bool = False): """Save model checkpoint.""" checkpoint_name = "final" if final else f"step_{self.global_step}" checkpoint_path = Path(self.config.output_dir) / checkpoint_name checkpoint_path.mkdir(parents=True, exist_ok=True) torch.save(self.model.state_dict(), checkpoint_path / "model.pt") torch.save(self.optimizer.state_dict(), checkpoint_path / "optimizer.pt") torch.save({ "global_step": self.global_step, "epoch": self.epoch, "config": self.config, }, checkpoint_path / "trainer_state.pt") print(f"Checkpoint saved to {checkpoint_path}") def load_checkpoint(self, checkpoint_path: str): """Load model checkpoint.""" path = Path(checkpoint_path) self.model.load_state_dict(torch.load(path / "model.pt", map_location=self.device)) self.optimizer.load_state_dict(torch.load(path / "optimizer.pt", map_location=self.device)) state = torch.load(path / "trainer_state.pt", map_location=self.device) self.global_step = state["global_step"] self.epoch = state["epoch"] print(f"Checkpoint loaded from {checkpoint_path}")