guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""EMA target encoder with cosine-annealed tau schedule (0.996 -> 0.9999 over first 30%).
Per Weimann & Conrad (T1-1) and I-JEPA (T1-2).
"""
from __future__ import annotations
import copy
import math
import torch
from torch import nn
def ema_tau(step: int, total_steps: int, start: float = 0.996, end: float = 0.9999,
warmup_frac: float = 0.30) -> float:
warmup = max(1, int(total_steps * warmup_frac))
if step >= warmup:
return end
t = step / warmup
return end - 0.5 * (end - start) * (1 + math.cos(math.pi * t))
class EMA(nn.Module):
"""Wraps an online encoder + a detached target copy updated in-place."""
def __init__(self, online: nn.Module):
super().__init__()
self.target = copy.deepcopy(online)
for p in self.target.parameters():
p.requires_grad_(False)
self.target.train(False)
@torch.no_grad()
def update(self, online: nn.Module, tau: float) -> None:
for p_t, p_o in zip(self.target.parameters(), online.parameters()):
p_t.data.mul_(tau).add_(p_o.data, alpha=1 - tau)
for b_t, b_o in zip(self.target.buffers(), online.buffers()):
b_t.data.copy_(b_o.data)
def forward(self, *args, **kwargs):
with torch.no_grad():
return self.target(*args, **kwargs)