3d_model / ylff /utils /ema.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Exponential Moving Average (EMA) for model weights.
EMA provides smoother training dynamics and often better final performance
by maintaining a moving average of model parameters.
"""
import logging
from typing import Dict
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
class EMA:
"""
Exponential Moving Average for model parameters.
Maintains a shadow copy of model weights that is updated with
exponential moving average after each training step.
"""
def __init__(self, model: nn.Module, decay: float = 0.9999, device: str = "cuda"):
"""
Args:
model: Model to create EMA for
decay: EMA decay factor (higher = slower update, more stable)
device: Device to store shadow weights on
"""
self.model = model
self.decay = decay
self.device = device
self.shadow: Dict[str, torch.Tensor] = {}
self.backup: Dict[str, torch.Tensor] = {}
self.register()
def register(self):
"""Register all trainable parameters for EMA."""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone().to(self.device)
logger.debug(f"Registered {len(self.shadow)} parameters for EMA")
def update(self):
"""Update shadow weights with exponential moving average."""
for name, param in self.model.named_parameters():
if param.requires_grad and name in self.shadow:
new_average = (1.0 - self.decay) * param.data.to(
self.device
) + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
"""Apply shadow weights to model (for evaluation)."""
for name, param in self.model.named_parameters():
if param.requires_grad and name in self.shadow:
self.backup[name] = param.data.clone()
param.data.copy_(self.shadow[name].to(param.device))
def restore(self):
"""Restore original model weights."""
for name, param in self.model.named_parameters():
if param.requires_grad and name in self.backup:
param.data.copy_(self.backup[name])
self.backup = {}
def state_dict(self) -> Dict:
"""Get EMA state for checkpointing."""
return {
"shadow": {k: v.cpu() for k, v in self.shadow.items()},
"decay": self.decay,
}
def load_state_dict(self, state_dict: Dict):
"""Load EMA state from checkpoint."""
self.decay = state_dict.get("decay", self.decay)
shadow = state_dict.get("shadow", {})
for name in self.shadow:
if name in shadow:
self.shadow[name] = shadow[name].to(self.device)
logger.info("Loaded EMA state from checkpoint")