File size: 1,329 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""EMA target encoder with cosine-annealed tau schedule (0.996 -> 0.9999 over first 30%).

Per Weimann & Conrad (T1-1) and I-JEPA (T1-2).
"""
from __future__ import annotations

import copy
import math

import torch
from torch import nn


def ema_tau(step: int, total_steps: int, start: float = 0.996, end: float = 0.9999,
            warmup_frac: float = 0.30) -> float:
    warmup = max(1, int(total_steps * warmup_frac))
    if step >= warmup:
        return end
    t = step / warmup
    return end - 0.5 * (end - start) * (1 + math.cos(math.pi * t))


class EMA(nn.Module):
    """Wraps an online encoder + a detached target copy updated in-place."""

    def __init__(self, online: nn.Module):
        super().__init__()
        self.target = copy.deepcopy(online)
        for p in self.target.parameters():
            p.requires_grad_(False)
        self.target.train(False)

    @torch.no_grad()
    def update(self, online: nn.Module, tau: float) -> None:
        for p_t, p_o in zip(self.target.parameters(), online.parameters()):
            p_t.data.mul_(tau).add_(p_o.data, alpha=1 - tau)
        for b_t, b_o in zip(self.target.buffers(), online.buffers()):
            b_t.data.copy_(b_o.data)

    def forward(self, *args, **kwargs):
        with torch.no_grad():
            return self.target(*args, **kwargs)