|
|
""" |
|
|
Knowledge Distillation for MiniMind |
|
|
Train smaller models using larger teacher models. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Dict, Any, Callable |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DistillationConfig: |
|
|
"""Configuration for knowledge distillation.""" |
|
|
|
|
|
temperature: float = 2.0 |
|
|
alpha_ce: float = 0.5 |
|
|
alpha_kd: float = 0.5 |
|
|
alpha_hidden: float = 0.0 |
|
|
|
|
|
|
|
|
learning_rate: float = 1e-4 |
|
|
min_learning_rate: float = 1e-5 |
|
|
weight_decay: float = 0.1 |
|
|
warmup_steps: int = 500 |
|
|
grad_clip: float = 1.0 |
|
|
|
|
|
|
|
|
num_epochs: int = 5 |
|
|
batch_size: int = 4 |
|
|
gradient_accumulation_steps: int = 8 |
|
|
max_steps: Optional[int] = None |
|
|
|
|
|
|
|
|
use_amp: bool = True |
|
|
|
|
|
|
|
|
save_steps: int = 500 |
|
|
output_dir: str = "./distill_outputs" |
|
|
log_steps: int = 10 |
|
|
|
|
|
|
|
|
class DistillationTrainer: |
|
|
""" |
|
|
Knowledge Distillation Trainer. |
|
|
Supports: |
|
|
- Soft label distillation (KL divergence) |
|
|
- Hard label training (CE loss) |
|
|
- Hidden state matching (optional) |
|
|
- Online and offline distillation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
student_model: nn.Module, |
|
|
teacher_model: Optional[nn.Module] = None, |
|
|
train_dataloader: DataLoader = None, |
|
|
config: Optional[DistillationConfig] = None, |
|
|
projection_layer: Optional[nn.Module] = None, |
|
|
): |
|
|
self.student = student_model |
|
|
self.teacher = teacher_model |
|
|
self.train_dataloader = train_dataloader |
|
|
self.config = config or DistillationConfig() |
|
|
self.projection_layer = projection_layer |
|
|
|
|
|
self.device = next(student_model.parameters()).device |
|
|
|
|
|
if self.teacher is not None: |
|
|
self.teacher.eval() |
|
|
for param in self.teacher.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.optimizer = self._create_optimizer() |
|
|
self.scheduler = self._create_scheduler() |
|
|
self.scaler = GradScaler() if self.config.use_amp else None |
|
|
|
|
|
self.global_step = 0 |
|
|
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def _create_optimizer(self) -> torch.optim.Optimizer: |
|
|
params = list(self.student.parameters()) |
|
|
if self.projection_layer is not None: |
|
|
params += list(self.projection_layer.parameters()) |
|
|
|
|
|
return torch.optim.AdamW( |
|
|
params, |
|
|
lr=self.config.learning_rate, |
|
|
weight_decay=self.config.weight_decay, |
|
|
) |
|
|
|
|
|
def _create_scheduler(self): |
|
|
total_steps = self._get_total_steps() |
|
|
|
|
|
def lr_lambda(step): |
|
|
if step < self.config.warmup_steps: |
|
|
return step / max(1, self.config.warmup_steps) |
|
|
progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps) |
|
|
return max( |
|
|
self.config.min_learning_rate / self.config.learning_rate, |
|
|
0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
) |
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) |
|
|
|
|
|
def _get_total_steps(self) -> int: |
|
|
if self.config.max_steps: |
|
|
return self.config.max_steps |
|
|
steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps |
|
|
return steps_per_epoch * self.config.num_epochs |
|
|
|
|
|
def distillation_loss( |
|
|
self, |
|
|
student_logits: torch.Tensor, |
|
|
teacher_logits: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
student_hidden: Optional[torch.Tensor] = None, |
|
|
teacher_hidden: Optional[torch.Tensor] = None, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute combined distillation loss. |
|
|
|
|
|
Args: |
|
|
student_logits: Student model output logits [B, T, V] |
|
|
teacher_logits: Teacher model output logits [B, T, V] |
|
|
labels: Ground truth labels [B, T] |
|
|
student_hidden: Student hidden states (optional) |
|
|
teacher_hidden: Teacher hidden states (optional) |
|
|
|
|
|
Returns: |
|
|
Dictionary with loss components and total loss |
|
|
""" |
|
|
|
|
|
T = self.config.temperature |
|
|
|
|
|
|
|
|
student_log_probs = F.log_softmax(student_logits / T, dim=-1) |
|
|
teacher_probs = F.softmax(teacher_logits / T, dim=-1) |
|
|
kd_loss = F.kl_div( |
|
|
student_log_probs, |
|
|
teacher_probs, |
|
|
reduction="batchmean" |
|
|
) * (T ** 2) |
|
|
|
|
|
|
|
|
shift_logits = student_logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
ce_loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_loss = torch.tensor(0.0, device=self.device) |
|
|
if student_hidden is not None and teacher_hidden is not None and self.projection_layer is not None: |
|
|
projected_student = self.projection_layer(student_hidden) |
|
|
hidden_loss = F.mse_loss(projected_student, teacher_hidden) |
|
|
|
|
|
|
|
|
total_loss = ( |
|
|
self.config.alpha_ce * ce_loss + |
|
|
self.config.alpha_kd * kd_loss + |
|
|
self.config.alpha_hidden * hidden_loss |
|
|
) |
|
|
|
|
|
return { |
|
|
"total_loss": total_loss, |
|
|
"ce_loss": ce_loss, |
|
|
"kd_loss": kd_loss, |
|
|
"hidden_loss": hidden_loss, |
|
|
} |
|
|
|
|
|
def train(self) -> Dict[str, float]: |
|
|
"""Main distillation training loop.""" |
|
|
self.student.train() |
|
|
total_steps = self._get_total_steps() |
|
|
|
|
|
print(f"Starting knowledge distillation for {total_steps} steps") |
|
|
print(f" Temperature: {self.config.temperature}") |
|
|
print(f" Alpha CE: {self.config.alpha_ce}, Alpha KD: {self.config.alpha_kd}") |
|
|
|
|
|
running_losses = {"total": 0.0, "ce": 0.0, "kd": 0.0} |
|
|
|
|
|
for epoch in range(self.config.num_epochs): |
|
|
for step, batch in enumerate(self.train_dataloader): |
|
|
losses = self._training_step(batch) |
|
|
|
|
|
for key in running_losses: |
|
|
running_losses[key] += losses.get(f"{key}_loss", losses.get("total_loss", 0.0)).item() if isinstance(losses.get(f"{key}_loss", losses.get("total_loss")), torch.Tensor) else 0.0 |
|
|
|
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0: |
|
|
self._optimizer_step() |
|
|
self.global_step += 1 |
|
|
|
|
|
if self.global_step % self.config.log_steps == 0: |
|
|
avg_losses = {k: v / self.config.log_steps for k, v in running_losses.items()} |
|
|
print( |
|
|
f"Step {self.global_step}/{total_steps} | " |
|
|
f"Total: {avg_losses['total']:.4f} | " |
|
|
f"CE: {avg_losses['ce']:.4f} | " |
|
|
f"KD: {avg_losses['kd']:.4f}" |
|
|
) |
|
|
running_losses = {k: 0.0 for k in running_losses} |
|
|
|
|
|
if self.global_step % self.config.save_steps == 0: |
|
|
self._save_checkpoint() |
|
|
|
|
|
if self.config.max_steps and self.global_step >= self.config.max_steps: |
|
|
break |
|
|
|
|
|
if self.config.max_steps and self.global_step >= self.config.max_steps: |
|
|
break |
|
|
|
|
|
self._save_checkpoint(final=True) |
|
|
return {"final_step": self.global_step} |
|
|
|
|
|
def _training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
"""Single distillation training step.""" |
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch.get("attention_mask") |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
|
|
|
teacher_logits = batch.get("teacher_logits") |
|
|
if teacher_logits is not None: |
|
|
teacher_logits = teacher_logits.to(self.device) |
|
|
elif self.teacher is not None: |
|
|
with torch.no_grad(): |
|
|
_, teacher_logits, _, _ = self.teacher(input_ids, attention_mask) |
|
|
|
|
|
if self.config.use_amp: |
|
|
with autocast(dtype=torch.float16): |
|
|
_, student_logits, _, _ = self.student(input_ids, attention_mask) |
|
|
losses = self.distillation_loss(student_logits, teacher_logits, labels) |
|
|
loss = losses["total_loss"] / self.config.gradient_accumulation_steps |
|
|
self.scaler.scale(loss).backward() |
|
|
else: |
|
|
_, student_logits, _, _ = self.student(input_ids, attention_mask) |
|
|
losses = self.distillation_loss(student_logits, teacher_logits, labels) |
|
|
loss = losses["total_loss"] / self.config.gradient_accumulation_steps |
|
|
loss.backward() |
|
|
|
|
|
return losses |
|
|
|
|
|
def _optimizer_step(self): |
|
|
if self.config.use_amp: |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.config.grad_clip) |
|
|
|
|
|
if self.config.use_amp: |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
self.optimizer.step() |
|
|
|
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
def _save_checkpoint(self, final: bool = False): |
|
|
name = "final" if final else f"step_{self.global_step}" |
|
|
path = Path(self.config.output_dir) / name |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
torch.save(self.student.state_dict(), path / "student_model.pt") |
|
|
if self.projection_layer is not None: |
|
|
torch.save(self.projection_layer.state_dict(), path / "projection.pt") |
|
|
|
|
|
print(f"Checkpoint saved to {path}") |
|
|
|
|
|
|
|
|
def generate_teacher_logits( |
|
|
teacher_model: nn.Module, |
|
|
dataloader: DataLoader, |
|
|
output_path: str, |
|
|
device: str = "cuda", |
|
|
top_k: int = 100, |
|
|
): |
|
|
""" |
|
|
Pre-generate teacher logits for offline distillation. |
|
|
Saves storage by only keeping top-k logits per position. |
|
|
""" |
|
|
teacher_model.eval() |
|
|
teacher_model.to(device) |
|
|
|
|
|
all_logits = [] |
|
|
|
|
|
print(f"Generating teacher logits for {len(dataloader)} batches...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
input_ids = batch["input_ids"].to(device) |
|
|
attention_mask = batch.get("attention_mask") |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
_, logits, _, _ = teacher_model(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
if top_k > 0 and top_k < logits.shape[-1]: |
|
|
topk_values, topk_indices = torch.topk(logits, k=top_k, dim=-1) |
|
|
sparse_logits = { |
|
|
"values": topk_values.cpu(), |
|
|
"indices": topk_indices.cpu(), |
|
|
} |
|
|
all_logits.append(sparse_logits) |
|
|
else: |
|
|
all_logits.append(logits.cpu()) |
|
|
|
|
|
torch.save(all_logits, output_path) |
|
|
print(f"Teacher logits saved to {output_path}") |
|
|
|