MiniMind / training /distillation.py
fariasultana's picture
MiniMind Max2 - Efficient MoE Language Model
8b187bb verified
"""
Knowledge Distillation for MiniMind
Train smaller models using larger teacher models.
"""
import math
from typing import Optional, Dict, Any, Callable
from pathlib import Path
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
@dataclass
class DistillationConfig:
"""Configuration for knowledge distillation."""
# Distillation parameters
temperature: float = 2.0
alpha_ce: float = 0.5 # Weight for hard label loss
alpha_kd: float = 0.5 # Weight for distillation loss
alpha_hidden: float = 0.0 # Weight for hidden state matching
# Optimization
learning_rate: float = 1e-4
min_learning_rate: float = 1e-5
weight_decay: float = 0.1
warmup_steps: int = 500
grad_clip: float = 1.0
# Training
num_epochs: int = 5
batch_size: int = 4
gradient_accumulation_steps: int = 8
max_steps: Optional[int] = None
# Mixed precision
use_amp: bool = True
# Checkpointing
save_steps: int = 500
output_dir: str = "./distill_outputs"
log_steps: int = 10
class DistillationTrainer:
"""
Knowledge Distillation Trainer.
Supports:
- Soft label distillation (KL divergence)
- Hard label training (CE loss)
- Hidden state matching (optional)
- Online and offline distillation
"""
def __init__(
self,
student_model: nn.Module,
teacher_model: Optional[nn.Module] = None,
train_dataloader: DataLoader = None,
config: Optional[DistillationConfig] = None,
projection_layer: Optional[nn.Module] = None,
):
self.student = student_model
self.teacher = teacher_model
self.train_dataloader = train_dataloader
self.config = config or DistillationConfig()
self.projection_layer = projection_layer # For hidden state matching
self.device = next(student_model.parameters()).device
if self.teacher is not None:
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False
self.optimizer = self._create_optimizer()
self.scheduler = self._create_scheduler()
self.scaler = GradScaler() if self.config.use_amp else None
self.global_step = 0
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
def _create_optimizer(self) -> torch.optim.Optimizer:
params = list(self.student.parameters())
if self.projection_layer is not None:
params += list(self.projection_layer.parameters())
return torch.optim.AdamW(
params,
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
)
def _create_scheduler(self):
total_steps = self._get_total_steps()
def lr_lambda(step):
if step < self.config.warmup_steps:
return step / max(1, self.config.warmup_steps)
progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps)
return max(
self.config.min_learning_rate / self.config.learning_rate,
0.5 * (1.0 + math.cos(math.pi * progress))
)
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def _get_total_steps(self) -> int:
if self.config.max_steps:
return self.config.max_steps
steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps
return steps_per_epoch * self.config.num_epochs
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
student_hidden: Optional[torch.Tensor] = None,
teacher_hidden: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Compute combined distillation loss.
Args:
student_logits: Student model output logits [B, T, V]
teacher_logits: Teacher model output logits [B, T, V]
labels: Ground truth labels [B, T]
student_hidden: Student hidden states (optional)
teacher_hidden: Teacher hidden states (optional)
Returns:
Dictionary with loss components and total loss
"""
# Temperature-scaled soft labels
T = self.config.temperature
# Soft label loss (KL divergence)
student_log_probs = F.log_softmax(student_logits / T, dim=-1)
teacher_probs = F.softmax(teacher_logits / T, dim=-1)
kd_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="batchmean"
) * (T ** 2)
# Hard label loss (Cross entropy)
shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
ce_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# Hidden state matching (optional)
hidden_loss = torch.tensor(0.0, device=self.device)
if student_hidden is not None and teacher_hidden is not None and self.projection_layer is not None:
projected_student = self.projection_layer(student_hidden)
hidden_loss = F.mse_loss(projected_student, teacher_hidden)
# Combined loss
total_loss = (
self.config.alpha_ce * ce_loss +
self.config.alpha_kd * kd_loss +
self.config.alpha_hidden * hidden_loss
)
return {
"total_loss": total_loss,
"ce_loss": ce_loss,
"kd_loss": kd_loss,
"hidden_loss": hidden_loss,
}
def train(self) -> Dict[str, float]:
"""Main distillation training loop."""
self.student.train()
total_steps = self._get_total_steps()
print(f"Starting knowledge distillation for {total_steps} steps")
print(f" Temperature: {self.config.temperature}")
print(f" Alpha CE: {self.config.alpha_ce}, Alpha KD: {self.config.alpha_kd}")
running_losses = {"total": 0.0, "ce": 0.0, "kd": 0.0}
for epoch in range(self.config.num_epochs):
for step, batch in enumerate(self.train_dataloader):
losses = self._training_step(batch)
for key in running_losses:
running_losses[key] += losses.get(f"{key}_loss", losses.get("total_loss", 0.0)).item() if isinstance(losses.get(f"{key}_loss", losses.get("total_loss")), torch.Tensor) else 0.0
if (step + 1) % self.config.gradient_accumulation_steps == 0:
self._optimizer_step()
self.global_step += 1
if self.global_step % self.config.log_steps == 0:
avg_losses = {k: v / self.config.log_steps for k, v in running_losses.items()}
print(
f"Step {self.global_step}/{total_steps} | "
f"Total: {avg_losses['total']:.4f} | "
f"CE: {avg_losses['ce']:.4f} | "
f"KD: {avg_losses['kd']:.4f}"
)
running_losses = {k: 0.0 for k in running_losses}
if self.global_step % self.config.save_steps == 0:
self._save_checkpoint()
if self.config.max_steps and self.global_step >= self.config.max_steps:
break
if self.config.max_steps and self.global_step >= self.config.max_steps:
break
self._save_checkpoint(final=True)
return {"final_step": self.global_step}
def _training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Single distillation training step."""
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
labels = batch["labels"].to(self.device)
# Check for pre-computed teacher logits
teacher_logits = batch.get("teacher_logits")
if teacher_logits is not None:
teacher_logits = teacher_logits.to(self.device)
elif self.teacher is not None:
with torch.no_grad():
_, teacher_logits, _, _ = self.teacher(input_ids, attention_mask)
if self.config.use_amp:
with autocast(dtype=torch.float16):
_, student_logits, _, _ = self.student(input_ids, attention_mask)
losses = self.distillation_loss(student_logits, teacher_logits, labels)
loss = losses["total_loss"] / self.config.gradient_accumulation_steps
self.scaler.scale(loss).backward()
else:
_, student_logits, _, _ = self.student(input_ids, attention_mask)
losses = self.distillation_loss(student_logits, teacher_logits, labels)
loss = losses["total_loss"] / self.config.gradient_accumulation_steps
loss.backward()
return losses
def _optimizer_step(self):
if self.config.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.config.grad_clip)
if self.config.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
def _save_checkpoint(self, final: bool = False):
name = "final" if final else f"step_{self.global_step}"
path = Path(self.config.output_dir) / name
path.mkdir(parents=True, exist_ok=True)
torch.save(self.student.state_dict(), path / "student_model.pt")
if self.projection_layer is not None:
torch.save(self.projection_layer.state_dict(), path / "projection.pt")
print(f"Checkpoint saved to {path}")
def generate_teacher_logits(
teacher_model: nn.Module,
dataloader: DataLoader,
output_path: str,
device: str = "cuda",
top_k: int = 100, # Only save top-k logits to reduce storage
):
"""
Pre-generate teacher logits for offline distillation.
Saves storage by only keeping top-k logits per position.
"""
teacher_model.eval()
teacher_model.to(device)
all_logits = []
print(f"Generating teacher logits for {len(dataloader)} batches...")
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device)
_, logits, _, _ = teacher_model(input_ids, attention_mask)
# Keep only top-k logits
if top_k > 0 and top_k < logits.shape[-1]:
topk_values, topk_indices = torch.topk(logits, k=top_k, dim=-1)
sparse_logits = {
"values": topk_values.cpu(),
"indices": topk_indices.cpu(),
}
all_logits.append(sparse_logits)
else:
all_logits.append(logits.cpu())
torch.save(all_logits, output_path)
print(f"Teacher logits saved to {output_path}")