| """EMA target encoder with cosine-annealed tau schedule (0.996 -> 0.9999 over first 30%). |
| |
| Per Weimann & Conrad (T1-1) and I-JEPA (T1-2). |
| """ |
| from __future__ import annotations |
|
|
| import copy |
| import math |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| def ema_tau(step: int, total_steps: int, start: float = 0.996, end: float = 0.9999, |
| warmup_frac: float = 0.30) -> float: |
| warmup = max(1, int(total_steps * warmup_frac)) |
| if step >= warmup: |
| return end |
| t = step / warmup |
| return end - 0.5 * (end - start) * (1 + math.cos(math.pi * t)) |
|
|
|
|
| class EMA(nn.Module): |
| """Wraps an online encoder + a detached target copy updated in-place.""" |
|
|
| def __init__(self, online: nn.Module): |
| super().__init__() |
| self.target = copy.deepcopy(online) |
| for p in self.target.parameters(): |
| p.requires_grad_(False) |
| self.target.train(False) |
|
|
| @torch.no_grad() |
| def update(self, online: nn.Module, tau: float) -> None: |
| for p_t, p_o in zip(self.target.parameters(), online.parameters()): |
| p_t.data.mul_(tau).add_(p_o.data, alpha=1 - tau) |
| for b_t, b_o in zip(self.target.buffers(), online.buffers()): |
| b_t.data.copy_(b_o.data) |
|
|
| def forward(self, *args, **kwargs): |
| with torch.no_grad(): |
| return self.target(*args, **kwargs) |
|
|