Spaces:
Running
Running
| """ | |
| src/continual/ewc.py — Elastic Weight Consolidation (EWC) | |
| Implements EWC penalty term to prevent catastrophic forgetting when | |
| fine-tuning on new generator classes. | |
| EWC loss: | |
| L_total = L_task_new + λ * Σ_i F_i * (θ_i - θ*_i)² | |
| where: | |
| F_i = diagonal Fisher information for parameter i on the *old* task | |
| θ*_i = parameter i's value immediately after old-task training | |
| λ = importance weight (default 400, tunable in configs/) | |
| Reference: | |
| Kirkpatrick et al. "Overcoming catastrophic forgetting in neural networks." | |
| PNAS 2017. https://arxiv.org/abs/1612.00796 | |
| Usage: | |
| ewc = EWC(model, dataloader_old_task) | |
| loss = ewc.penalty(model) # add to task-specific loss before backward() | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| from typing import Iterator, Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torch.utils.data import DataLoader | |
| class EWC: | |
| """ | |
| Elastic Weight Consolidation penalty tracker. | |
| Parameters | |
| ---------- | |
| model: | |
| PyTorch module whose important weights we want to preserve. | |
| dataloader: | |
| DataLoader over the *previous* task's training data (used to | |
| estimate the diagonal Fisher information). | |
| criterion: | |
| Loss function used to compute per-sample gradients. | |
| Defaults to CrossEntropyLoss. | |
| num_samples: | |
| Maximum number of samples used to estimate Fisher. Fewer samples | |
| = faster, noisier estimate. | |
| importance: | |
| λ — EWC penalty weight. Higher → less forgetting, slower adaptation. | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| dataloader: DataLoader, | |
| criterion: Optional[nn.Module] = None, | |
| num_samples: int = 200, | |
| importance: float = 400.0, | |
| ) -> None: | |
| self.model = model | |
| self.importance = importance | |
| self._criterion = criterion or nn.CrossEntropyLoss() | |
| # Snapshot θ* (old-task parameters) | |
| self._params_star: dict[str, Tensor] = { | |
| name: param.clone().detach() | |
| for name, param in model.named_parameters() | |
| if param.requires_grad | |
| } | |
| # Estimate diagonal Fisher | |
| self._fisher: dict[str, Tensor] = self._estimate_fisher( | |
| model, dataloader, num_samples | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Public | |
| # ------------------------------------------------------------------ | |
| def penalty(self, model: nn.Module) -> Tensor: | |
| """ | |
| Compute the EWC penalty for the *current* model parameters. | |
| Returns a scalar tensor that can be added to the task loss. | |
| """ | |
| device = next(model.parameters()).device | |
| loss = torch.tensor(0.0, device=device) | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if name not in self._fisher: | |
| continue | |
| fisher = self._fisher[name].to(device) | |
| star = self._params_star[name].to(device) | |
| loss += (fisher * (param - star).pow(2)).sum() | |
| return self.importance * loss | |
| # ------------------------------------------------------------------ | |
| # Private | |
| # ------------------------------------------------------------------ | |
| def _estimate_fisher( | |
| self, | |
| model: nn.Module, | |
| dataloader: DataLoader, | |
| num_samples: int, | |
| ) -> dict[str, Tensor]: | |
| """Estimate diagonal Fisher via squared gradients on old-task data.""" | |
| fisher: dict[str, Tensor] = { | |
| name: torch.zeros_like(param) | |
| for name, param in model.named_parameters() | |
| if param.requires_grad | |
| } | |
| model.eval() | |
| n_seen = 0 | |
| for batch in dataloader: | |
| if n_seen >= num_samples: | |
| break | |
| images, labels = batch | |
| device = next(model.parameters()).device | |
| images, labels = images.to(device), labels.to(device) | |
| model.zero_grad() | |
| output = model(images) | |
| loss = self._criterion(output, labels) | |
| loss.backward() | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad and param.grad is not None: | |
| fisher[name] += param.grad.detach().pow(2) | |
| n_seen += images.size(0) | |
| # Normalize | |
| if n_seen > 0: | |
| for name in fisher: | |
| fisher[name] /= n_seen | |
| return fisher | |