File size: 1,329 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | """EMA target encoder with cosine-annealed tau schedule (0.996 -> 0.9999 over first 30%).
Per Weimann & Conrad (T1-1) and I-JEPA (T1-2).
"""
from __future__ import annotations
import copy
import math
import torch
from torch import nn
def ema_tau(step: int, total_steps: int, start: float = 0.996, end: float = 0.9999,
warmup_frac: float = 0.30) -> float:
warmup = max(1, int(total_steps * warmup_frac))
if step >= warmup:
return end
t = step / warmup
return end - 0.5 * (end - start) * (1 + math.cos(math.pi * t))
class EMA(nn.Module):
"""Wraps an online encoder + a detached target copy updated in-place."""
def __init__(self, online: nn.Module):
super().__init__()
self.target = copy.deepcopy(online)
for p in self.target.parameters():
p.requires_grad_(False)
self.target.train(False)
@torch.no_grad()
def update(self, online: nn.Module, tau: float) -> None:
for p_t, p_o in zip(self.target.parameters(), online.parameters()):
p_t.data.mul_(tau).add_(p_o.data, alpha=1 - tau)
for b_t, b_o in zip(self.target.buffers(), online.buffers()):
b_t.data.copy_(b_o.data)
def forward(self, *args, **kwargs):
with torch.no_grad():
return self.target(*args, **kwargs)
|