| """ |
| utils.grokfast — accelerated grokking by amplifying slow-varying gradient |
| components (Lee et al. 2024, arXiv:2405.20233). |
| |
| Maintain an EMA of gradients across steps; the slow-EMA component |
| corresponds to the generalising circuit. Adding it back into the live |
| gradient (scaled by `lamb`) accelerates the grokking transition 20-100×. |
| """ |
|
|
| from __future__ import annotations |
|
|
|
|
| def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0): |
| """ |
| Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`. |
| |
| Args: |
| model: the network whose gradients we are filtering. |
| grads_ema: dict {param_name: ema_grad}, or None on the first call. |
| alpha: EMA decay (0.98 → very slow, emphasises persistent grads). |
| lamb: amplification factor for the slow component. |
| |
| Returns: |
| Updated `grads_ema` dict — pass it back in on the next step. |
| """ |
| if grads_ema is None: |
| grads_ema = {} |
|
|
| for name, p in model.named_parameters(): |
| if p.requires_grad and p.grad is not None: |
| if name not in grads_ema: |
| grads_ema[name] = p.grad.data.detach().clone() |
| else: |
| grads_ema[name] = ( |
| grads_ema[name] * alpha |
| + p.grad.data.detach() * (1 - alpha) |
| ) |
| p.grad.data = p.grad.data + grads_ema[name] * lamb |
|
|
| return grads_ema |
|
|