"""EMA target encoder with cosine-annealed tau schedule (0.996 -> 0.9999 over first 30%). Per Weimann & Conrad (T1-1) and I-JEPA (T1-2). """ from __future__ import annotations import copy import math import torch from torch import nn def ema_tau(step: int, total_steps: int, start: float = 0.996, end: float = 0.9999, warmup_frac: float = 0.30) -> float: warmup = max(1, int(total_steps * warmup_frac)) if step >= warmup: return end t = step / warmup return end - 0.5 * (end - start) * (1 + math.cos(math.pi * t)) class EMA(nn.Module): """Wraps an online encoder + a detached target copy updated in-place.""" def __init__(self, online: nn.Module): super().__init__() self.target = copy.deepcopy(online) for p in self.target.parameters(): p.requires_grad_(False) self.target.train(False) @torch.no_grad() def update(self, online: nn.Module, tau: float) -> None: for p_t, p_o in zip(self.target.parameters(), online.parameters()): p_t.data.mul_(tau).add_(p_o.data, alpha=1 - tau) for b_t, b_o in zip(self.target.buffers(), online.buffers()): b_t.data.copy_(b_o.data) def forward(self, *args, **kwargs): with torch.no_grad(): return self.target(*args, **kwargs)