File size: 1,429 Bytes
9d2fc01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
utils.grokfast — accelerated grokking by amplifying slow-varying gradient
components (Lee et al. 2024, arXiv:2405.20233).

Maintain an EMA of gradients across steps; the slow-EMA component
corresponds to the generalising circuit. Adding it back into the live
gradient (scaled by `lamb`) accelerates the grokking transition 20-100×.
"""

from __future__ import annotations


def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0):
    """
    Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`.

    Args:
        model:     the network whose gradients we are filtering.
        grads_ema: dict {param_name: ema_grad}, or None on the first call.
        alpha:     EMA decay (0.98 → very slow, emphasises persistent grads).
        lamb:      amplification factor for the slow component.

    Returns:
        Updated `grads_ema` dict — pass it back in on the next step.
    """
    if grads_ema is None:
        grads_ema = {}

    for name, p in model.named_parameters():
        if p.requires_grad and p.grad is not None:
            if name not in grads_ema:
                grads_ema[name] = p.grad.data.detach().clone()
            else:
                grads_ema[name] = (
                    grads_ema[name] * alpha
                    + p.grad.data.detach() * (1 - alpha)
                )
            p.grad.data = p.grad.data + grads_ema[name] * lamb

    return grads_ema