akagtag's picture
Initial commit
4e75170
"""
src/continual/ewc.py — Elastic Weight Consolidation (EWC)
Implements EWC penalty term to prevent catastrophic forgetting when
fine-tuning on new generator classes.
EWC loss:
L_total = L_task_new + λ * Σ_i F_i * (θ_i - θ*_i)²
where:
F_i = diagonal Fisher information for parameter i on the *old* task
θ*_i = parameter i's value immediately after old-task training
λ = importance weight (default 400, tunable in configs/)
Reference:
Kirkpatrick et al. "Overcoming catastrophic forgetting in neural networks."
PNAS 2017. https://arxiv.org/abs/1612.00796
Usage:
ewc = EWC(model, dataloader_old_task)
loss = ewc.penalty(model) # add to task-specific loss before backward()
"""
from __future__ import annotations
import copy
from typing import Iterator, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
class EWC:
"""
Elastic Weight Consolidation penalty tracker.
Parameters
----------
model:
PyTorch module whose important weights we want to preserve.
dataloader:
DataLoader over the *previous* task's training data (used to
estimate the diagonal Fisher information).
criterion:
Loss function used to compute per-sample gradients.
Defaults to CrossEntropyLoss.
num_samples:
Maximum number of samples used to estimate Fisher. Fewer samples
= faster, noisier estimate.
importance:
λ — EWC penalty weight. Higher → less forgetting, slower adaptation.
"""
def __init__(
self,
model: nn.Module,
dataloader: DataLoader,
criterion: Optional[nn.Module] = None,
num_samples: int = 200,
importance: float = 400.0,
) -> None:
self.model = model
self.importance = importance
self._criterion = criterion or nn.CrossEntropyLoss()
# Snapshot θ* (old-task parameters)
self._params_star: dict[str, Tensor] = {
name: param.clone().detach()
for name, param in model.named_parameters()
if param.requires_grad
}
# Estimate diagonal Fisher
self._fisher: dict[str, Tensor] = self._estimate_fisher(
model, dataloader, num_samples
)
# ------------------------------------------------------------------
# Public
# ------------------------------------------------------------------
def penalty(self, model: nn.Module) -> Tensor:
"""
Compute the EWC penalty for the *current* model parameters.
Returns a scalar tensor that can be added to the task loss.
"""
device = next(model.parameters()).device
loss = torch.tensor(0.0, device=device)
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name not in self._fisher:
continue
fisher = self._fisher[name].to(device)
star = self._params_star[name].to(device)
loss += (fisher * (param - star).pow(2)).sum()
return self.importance * loss
# ------------------------------------------------------------------
# Private
# ------------------------------------------------------------------
def _estimate_fisher(
self,
model: nn.Module,
dataloader: DataLoader,
num_samples: int,
) -> dict[str, Tensor]:
"""Estimate diagonal Fisher via squared gradients on old-task data."""
fisher: dict[str, Tensor] = {
name: torch.zeros_like(param)
for name, param in model.named_parameters()
if param.requires_grad
}
model.eval()
n_seen = 0
for batch in dataloader:
if n_seen >= num_samples:
break
images, labels = batch
device = next(model.parameters()).device
images, labels = images.to(device), labels.to(device)
model.zero_grad()
output = model(images)
loss = self._criterion(output, labels)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
fisher[name] += param.grad.detach().pow(2)
n_seen += images.size(0)
# Normalize
if n_seen > 0:
for name in fisher:
fisher[name] /= n_seen
return fisher