""" Knowledge Distillation for MiniMind Train smaller models using larger teacher models. """ import math from typing import Optional, Dict, Any, Callable from pathlib import Path from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.cuda.amp import GradScaler, autocast @dataclass class DistillationConfig: """Configuration for knowledge distillation.""" # Distillation parameters temperature: float = 2.0 alpha_ce: float = 0.5 # Weight for hard label loss alpha_kd: float = 0.5 # Weight for distillation loss alpha_hidden: float = 0.0 # Weight for hidden state matching # Optimization learning_rate: float = 1e-4 min_learning_rate: float = 1e-5 weight_decay: float = 0.1 warmup_steps: int = 500 grad_clip: float = 1.0 # Training num_epochs: int = 5 batch_size: int = 4 gradient_accumulation_steps: int = 8 max_steps: Optional[int] = None # Mixed precision use_amp: bool = True # Checkpointing save_steps: int = 500 output_dir: str = "./distill_outputs" log_steps: int = 10 class DistillationTrainer: """ Knowledge Distillation Trainer. Supports: - Soft label distillation (KL divergence) - Hard label training (CE loss) - Hidden state matching (optional) - Online and offline distillation """ def __init__( self, student_model: nn.Module, teacher_model: Optional[nn.Module] = None, train_dataloader: DataLoader = None, config: Optional[DistillationConfig] = None, projection_layer: Optional[nn.Module] = None, ): self.student = student_model self.teacher = teacher_model self.train_dataloader = train_dataloader self.config = config or DistillationConfig() self.projection_layer = projection_layer # For hidden state matching self.device = next(student_model.parameters()).device if self.teacher is not None: self.teacher.eval() for param in self.teacher.parameters(): param.requires_grad = False self.optimizer = self._create_optimizer() self.scheduler = self._create_scheduler() self.scaler = GradScaler() if self.config.use_amp else None self.global_step = 0 Path(self.config.output_dir).mkdir(parents=True, exist_ok=True) def _create_optimizer(self) -> torch.optim.Optimizer: params = list(self.student.parameters()) if self.projection_layer is not None: params += list(self.projection_layer.parameters()) return torch.optim.AdamW( params, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, ) def _create_scheduler(self): total_steps = self._get_total_steps() def lr_lambda(step): if step < self.config.warmup_steps: return step / max(1, self.config.warmup_steps) progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps) return max( self.config.min_learning_rate / self.config.learning_rate, 0.5 * (1.0 + math.cos(math.pi * progress)) ) return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) def _get_total_steps(self) -> int: if self.config.max_steps: return self.config.max_steps steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps return steps_per_epoch * self.config.num_epochs def distillation_loss( self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor, student_hidden: Optional[torch.Tensor] = None, teacher_hidden: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Compute combined distillation loss. Args: student_logits: Student model output logits [B, T, V] teacher_logits: Teacher model output logits [B, T, V] labels: Ground truth labels [B, T] student_hidden: Student hidden states (optional) teacher_hidden: Teacher hidden states (optional) Returns: Dictionary with loss components and total loss """ # Temperature-scaled soft labels T = self.config.temperature # Soft label loss (KL divergence) student_log_probs = F.log_softmax(student_logits / T, dim=-1) teacher_probs = F.softmax(teacher_logits / T, dim=-1) kd_loss = F.kl_div( student_log_probs, teacher_probs, reduction="batchmean" ) * (T ** 2) # Hard label loss (Cross entropy) shift_logits = student_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() ce_loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) # Hidden state matching (optional) hidden_loss = torch.tensor(0.0, device=self.device) if student_hidden is not None and teacher_hidden is not None and self.projection_layer is not None: projected_student = self.projection_layer(student_hidden) hidden_loss = F.mse_loss(projected_student, teacher_hidden) # Combined loss total_loss = ( self.config.alpha_ce * ce_loss + self.config.alpha_kd * kd_loss + self.config.alpha_hidden * hidden_loss ) return { "total_loss": total_loss, "ce_loss": ce_loss, "kd_loss": kd_loss, "hidden_loss": hidden_loss, } def train(self) -> Dict[str, float]: """Main distillation training loop.""" self.student.train() total_steps = self._get_total_steps() print(f"Starting knowledge distillation for {total_steps} steps") print(f" Temperature: {self.config.temperature}") print(f" Alpha CE: {self.config.alpha_ce}, Alpha KD: {self.config.alpha_kd}") running_losses = {"total": 0.0, "ce": 0.0, "kd": 0.0} for epoch in range(self.config.num_epochs): for step, batch in enumerate(self.train_dataloader): losses = self._training_step(batch) for key in running_losses: running_losses[key] += losses.get(f"{key}_loss", losses.get("total_loss", 0.0)).item() if isinstance(losses.get(f"{key}_loss", losses.get("total_loss")), torch.Tensor) else 0.0 if (step + 1) % self.config.gradient_accumulation_steps == 0: self._optimizer_step() self.global_step += 1 if self.global_step % self.config.log_steps == 0: avg_losses = {k: v / self.config.log_steps for k, v in running_losses.items()} print( f"Step {self.global_step}/{total_steps} | " f"Total: {avg_losses['total']:.4f} | " f"CE: {avg_losses['ce']:.4f} | " f"KD: {avg_losses['kd']:.4f}" ) running_losses = {k: 0.0 for k in running_losses} if self.global_step % self.config.save_steps == 0: self._save_checkpoint() if self.config.max_steps and self.global_step >= self.config.max_steps: break if self.config.max_steps and self.global_step >= self.config.max_steps: break self._save_checkpoint(final=True) return {"final_step": self.global_step} def _training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Single distillation training step.""" input_ids = batch["input_ids"].to(self.device) attention_mask = batch.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(self.device) labels = batch["labels"].to(self.device) # Check for pre-computed teacher logits teacher_logits = batch.get("teacher_logits") if teacher_logits is not None: teacher_logits = teacher_logits.to(self.device) elif self.teacher is not None: with torch.no_grad(): _, teacher_logits, _, _ = self.teacher(input_ids, attention_mask) if self.config.use_amp: with autocast(dtype=torch.float16): _, student_logits, _, _ = self.student(input_ids, attention_mask) losses = self.distillation_loss(student_logits, teacher_logits, labels) loss = losses["total_loss"] / self.config.gradient_accumulation_steps self.scaler.scale(loss).backward() else: _, student_logits, _, _ = self.student(input_ids, attention_mask) losses = self.distillation_loss(student_logits, teacher_logits, labels) loss = losses["total_loss"] / self.config.gradient_accumulation_steps loss.backward() return losses def _optimizer_step(self): if self.config.use_amp: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.config.grad_clip) if self.config.use_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() def _save_checkpoint(self, final: bool = False): name = "final" if final else f"step_{self.global_step}" path = Path(self.config.output_dir) / name path.mkdir(parents=True, exist_ok=True) torch.save(self.student.state_dict(), path / "student_model.pt") if self.projection_layer is not None: torch.save(self.projection_layer.state_dict(), path / "projection.pt") print(f"Checkpoint saved to {path}") def generate_teacher_logits( teacher_model: nn.Module, dataloader: DataLoader, output_path: str, device: str = "cuda", top_k: int = 100, # Only save top-k logits to reduce storage ): """ Pre-generate teacher logits for offline distillation. Saves storage by only keeping top-k logits per position. """ teacher_model.eval() teacher_model.to(device) all_logits = [] print(f"Generating teacher logits for {len(dataloader)} batches...") with torch.no_grad(): for batch in dataloader: input_ids = batch["input_ids"].to(device) attention_mask = batch.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(device) _, logits, _, _ = teacher_model(input_ids, attention_mask) # Keep only top-k logits if top_k > 0 and top_k < logits.shape[-1]: topk_values, topk_indices = torch.topk(logits, k=top_k, dim=-1) sparse_logits = { "values": topk_values.cpu(), "indices": topk_indices.cpu(), } all_logits.append(sparse_logits) else: all_logits.append(logits.cpu()) torch.save(all_logits, output_path) print(f"Teacher logits saved to {output_path}")