| """Fused cross-entropy: streams over the VOCAB dimension (online-softmax) so the |
| [N x V] logit matrix is NEVER materialized -- only [N x vchunk]. Custom backward |
| recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the |
| DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to |
| the output head instead of network depth.""" |
| import torch |
|
|
| class FusedCE(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, h, W, tgt, vchunk=16384): |
| with torch.cuda.amp.autocast(enabled=False): |
| hf = h.float() |
| Wf = W.float() |
| N, d = h.shape |
| V = W.shape[0] |
| m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32) |
| s = torch.zeros(N, device=h.device, dtype=torch.float32) |
| zt = torch.zeros(N, device=h.device, dtype=torch.float32) |
| for c in range(0, V, vchunk): |
| lg = hf @ Wf[c:c+vchunk].T |
| cm = lg.max(1).values |
| nm = torch.maximum(m, cm) |
| s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1) |
| m = nm |
| ic = (tgt >= c) & (tgt < c+vchunk) |
| if ic.any(): |
| zt[ic] = lg[ic, tgt[ic] - c].float() |
| lse = m + torch.log(s) |
| ctx.save_for_backward(h, W, tgt, lse) |
| ctx.vchunk = vchunk |
| return (lse - zt).mean() |
|
|
| @staticmethod |
| def backward(ctx, go): |
| h, W, tgt, lse = ctx.saved_tensors |
| vc = ctx.vchunk |
| N, d = h.shape |
| V = W.shape[0] |
| with torch.cuda.amp.autocast(enabled=False): |
| hf = h.float() |
| Wc_all = W.float() |
| gh = torch.zeros_like(hf) |
| gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32) |
| sc = float(go) / N |
| for c in range(0, V, vc): |
| Wc = Wc_all[c:c+vc] |
| p = torch.exp(hf @ Wc.T - lse[:, None]) |
| ic = (tgt >= c) & (tgt < c+vc) |
| if ic.any(): |
| p[ic, tgt[ic] - c] -= 1.0 |
| p *= sc |
| gh += p @ Wc |
| gW[c:c+vc] += p.T @ hf |
| return gh.to(h.dtype), gW.to(W.dtype), None, None |
|
|
| def fused_ce(h, W, tgt, vchunk=16384): |
| return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk) |
|
|