""" src/continual/ewc.py — Elastic Weight Consolidation (EWC) Implements EWC penalty term to prevent catastrophic forgetting when fine-tuning on new generator classes. EWC loss: L_total = L_task_new + λ * Σ_i F_i * (θ_i - θ*_i)² where: F_i = diagonal Fisher information for parameter i on the *old* task θ*_i = parameter i's value immediately after old-task training λ = importance weight (default 400, tunable in configs/) Reference: Kirkpatrick et al. "Overcoming catastrophic forgetting in neural networks." PNAS 2017. https://arxiv.org/abs/1612.00796 Usage: ewc = EWC(model, dataloader_old_task) loss = ewc.penalty(model) # add to task-specific loss before backward() """ from __future__ import annotations import copy from typing import Iterator, Optional import torch import torch.nn as nn from torch import Tensor from torch.utils.data import DataLoader class EWC: """ Elastic Weight Consolidation penalty tracker. Parameters ---------- model: PyTorch module whose important weights we want to preserve. dataloader: DataLoader over the *previous* task's training data (used to estimate the diagonal Fisher information). criterion: Loss function used to compute per-sample gradients. Defaults to CrossEntropyLoss. num_samples: Maximum number of samples used to estimate Fisher. Fewer samples = faster, noisier estimate. importance: λ — EWC penalty weight. Higher → less forgetting, slower adaptation. """ def __init__( self, model: nn.Module, dataloader: DataLoader, criterion: Optional[nn.Module] = None, num_samples: int = 200, importance: float = 400.0, ) -> None: self.model = model self.importance = importance self._criterion = criterion or nn.CrossEntropyLoss() # Snapshot θ* (old-task parameters) self._params_star: dict[str, Tensor] = { name: param.clone().detach() for name, param in model.named_parameters() if param.requires_grad } # Estimate diagonal Fisher self._fisher: dict[str, Tensor] = self._estimate_fisher( model, dataloader, num_samples ) # ------------------------------------------------------------------ # Public # ------------------------------------------------------------------ def penalty(self, model: nn.Module) -> Tensor: """ Compute the EWC penalty for the *current* model parameters. Returns a scalar tensor that can be added to the task loss. """ device = next(model.parameters()).device loss = torch.tensor(0.0, device=device) for name, param in model.named_parameters(): if not param.requires_grad: continue if name not in self._fisher: continue fisher = self._fisher[name].to(device) star = self._params_star[name].to(device) loss += (fisher * (param - star).pow(2)).sum() return self.importance * loss # ------------------------------------------------------------------ # Private # ------------------------------------------------------------------ def _estimate_fisher( self, model: nn.Module, dataloader: DataLoader, num_samples: int, ) -> dict[str, Tensor]: """Estimate diagonal Fisher via squared gradients on old-task data.""" fisher: dict[str, Tensor] = { name: torch.zeros_like(param) for name, param in model.named_parameters() if param.requires_grad } model.eval() n_seen = 0 for batch in dataloader: if n_seen >= num_samples: break images, labels = batch device = next(model.parameters()).device images, labels = images.to(device), labels.to(device) model.zero_grad() output = model(images) loss = self._criterion(output, labels) loss.backward() for name, param in model.named_parameters(): if param.requires_grad and param.grad is not None: fisher[name] += param.grad.detach().pow(2) n_seen += images.size(0) # Normalize if n_seen > 0: for name in fisher: fisher[name] /= n_seen return fisher