from typing import Any, Optional, Union import torch from torch.optim.swa_utils import get_ema_avg_fn from lightning.pytorch.utilities.rank_zero import rank_zero_info from src.callbacks.lightning_weight_averaging import WeightAveraging class WarmupEMAWeightAveraging(WeightAveraging): def __init__( self, warmup_pct: float, enabled: bool = True, decay: Optional[float] = None, decay_numerator: float = 20.0, update_every_n_steps: int = 1, update_starting_at_step: Optional[int] = None, device: Optional[Union[torch.device, str, int]] = None, use_buffers: bool = True, ) -> None: super().__init__(device=device, use_buffers=use_buffers) self.enabled = enabled self.warmup_pct = warmup_pct self.decay = decay self.decay_numerator = decay_numerator self.update_every_n_steps = update_every_n_steps self.update_starting_at_step = update_starting_at_step self.resolved_decay: Optional[float] = None self.resolved_start_step: Optional[int] = None def setup(self, trainer, pl_module, stage: str) -> None: if not self.enabled: rank_zero_info("WarmupEMAWeightAveraging is disabled by config.") return if stage != "fit": return super().setup(trainer, pl_module, stage) if trainer.max_steps and trainer.max_steps > 0: total_steps = trainer.max_steps else: total_steps = trainer.estimated_stepping_batches if total_steps <= 0: total_steps = 100000 warmup_steps = int(total_steps * self.warmup_pct) if self.update_starting_at_step is None: self.resolved_start_step = warmup_steps else: self.resolved_start_step = self.update_starting_at_step if self.decay is None: active_steps = max(1, total_steps - self.resolved_start_step) computed_decay = 1.0 - (self.decay_numerator / active_steps) self.resolved_decay = min(0.99999, max(0.9, computed_decay)) else: self.resolved_decay = self.decay self._kwargs["avg_fn"] = get_ema_avg_fn(decay=self.resolved_decay) super().setup(trainer, pl_module, stage) rank_zero_info( "WarmupEMAWeightAveraging configured: " f"total_steps={total_steps}, warmup_steps={warmup_steps}, " f"start_step={self.resolved_start_step}, decay={self.resolved_decay:.8f}" ) def should_update( self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None ) -> bool: if step_idx is None: return False if not self.enabled: return False if self.resolved_start_step is None: return False if step_idx < self.resolved_start_step: return False if self.update_every_n_steps <= 0: return False return step_idx % self.update_every_n_steps == 0 def state_dict(self) -> dict[str, Any]: state = super().state_dict() state["resolved_decay"] = self.resolved_decay state["resolved_start_step"] = self.resolved_start_step return state def load_state_dict(self, state_dict: dict[str, Any]) -> None: super().load_state_dict(state_dict) self.resolved_decay = state_dict.get("resolved_decay") self.resolved_start_step = state_dict.get("resolved_start_step")