CausalGrok / code /utils /grokfast.py
nileshsarkar-ai's picture
Upload code/utils
9d2fc01 verified
"""
utils.grokfast — accelerated grokking by amplifying slow-varying gradient
components (Lee et al. 2024, arXiv:2405.20233).
Maintain an EMA of gradients across steps; the slow-EMA component
corresponds to the generalising circuit. Adding it back into the live
gradient (scaled by `lamb`) accelerates the grokking transition 20-100×.
"""
from __future__ import annotations
def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0):
"""
Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`.
Args:
model: the network whose gradients we are filtering.
grads_ema: dict {param_name: ema_grad}, or None on the first call.
alpha: EMA decay (0.98 → very slow, emphasises persistent grads).
lamb: amplification factor for the slow component.
Returns:
Updated `grads_ema` dict — pass it back in on the next step.
"""
if grads_ema is None:
grads_ema = {}
for name, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
if name not in grads_ema:
grads_ema[name] = p.grad.data.detach().clone()
else:
grads_ema[name] = (
grads_ema[name] * alpha
+ p.grad.data.detach() * (1 - alpha)
)
p.grad.data = p.grad.data + grads_ema[name] * lamb
return grads_ema