| """ |
| Exponential Moving Average (EMA) for model weights. |
| |
| EMA provides smoother training dynamics and often better final performance |
| by maintaining a moving average of model parameters. |
| """ |
|
|
| import logging |
| from typing import Dict |
| import torch |
| import torch.nn as nn |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EMA: |
| """ |
| Exponential Moving Average for model parameters. |
| |
| Maintains a shadow copy of model weights that is updated with |
| exponential moving average after each training step. |
| """ |
|
|
| def __init__(self, model: nn.Module, decay: float = 0.9999, device: str = "cuda"): |
| """ |
| Args: |
| model: Model to create EMA for |
| decay: EMA decay factor (higher = slower update, more stable) |
| device: Device to store shadow weights on |
| """ |
| self.model = model |
| self.decay = decay |
| self.device = device |
| self.shadow: Dict[str, torch.Tensor] = {} |
| self.backup: Dict[str, torch.Tensor] = {} |
| self.register() |
|
|
| def register(self): |
| """Register all trainable parameters for EMA.""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad: |
| self.shadow[name] = param.data.clone().to(self.device) |
| logger.debug(f"Registered {len(self.shadow)} parameters for EMA") |
|
|
| def update(self): |
| """Update shadow weights with exponential moving average.""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad and name in self.shadow: |
| new_average = (1.0 - self.decay) * param.data.to( |
| self.device |
| ) + self.decay * self.shadow[name] |
| self.shadow[name] = new_average.clone() |
|
|
| def apply_shadow(self): |
| """Apply shadow weights to model (for evaluation).""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad and name in self.shadow: |
| self.backup[name] = param.data.clone() |
| param.data.copy_(self.shadow[name].to(param.device)) |
|
|
| def restore(self): |
| """Restore original model weights.""" |
| for name, param in self.model.named_parameters(): |
| if param.requires_grad and name in self.backup: |
| param.data.copy_(self.backup[name]) |
| self.backup = {} |
|
|
| def state_dict(self) -> Dict: |
| """Get EMA state for checkpointing.""" |
| return { |
| "shadow": {k: v.cpu() for k, v in self.shadow.items()}, |
| "decay": self.decay, |
| } |
|
|
| def load_state_dict(self, state_dict: Dict): |
| """Load EMA state from checkpoint.""" |
| self.decay = state_dict.get("decay", self.decay) |
| shadow = state_dict.get("shadow", {}) |
| for name in self.shadow: |
| if name in shadow: |
| self.shadow[name] = shadow[name].to(self.device) |
| logger.info("Loaded EMA state from checkpoint") |
|
|