Scott/Codex
Upgrade VRAM-first DiffusionBlocks trainer
df559be
Raw
History Blame Contribute Delete
2.44 kB
"""Fused cross-entropy: streams over the VOCAB dimension (online-softmax) so the
[N x V] logit matrix is NEVER materialized -- only [N x vchunk]. Custom backward
recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the
DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
the output head instead of network depth."""
import torch
class FusedCE(torch.autograd.Function):
@staticmethod
def forward(ctx, h, W, tgt, vchunk=16384):
with torch.cuda.amp.autocast(enabled=False):
hf = h.float()
Wf = W.float()
N, d = h.shape
V = W.shape[0]
m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32)
s = torch.zeros(N, device=h.device, dtype=torch.float32)
zt = torch.zeros(N, device=h.device, dtype=torch.float32)
for c in range(0, V, vchunk):
lg = hf @ Wf[c:c+vchunk].T # [N,vchunk] transient only
cm = lg.max(1).values
nm = torch.maximum(m, cm)
s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1)
m = nm
ic = (tgt >= c) & (tgt < c+vchunk)
if ic.any():
zt[ic] = lg[ic, tgt[ic] - c].float()
lse = m + torch.log(s)
ctx.save_for_backward(h, W, tgt, lse)
ctx.vchunk = vchunk
return (lse - zt).mean()
@staticmethod
def backward(ctx, go):
h, W, tgt, lse = ctx.saved_tensors
vc = ctx.vchunk
N, d = h.shape
V = W.shape[0]
with torch.cuda.amp.autocast(enabled=False):
hf = h.float()
Wc_all = W.float()
gh = torch.zeros_like(hf)
gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32)
sc = float(go) / N
for c in range(0, V, vc):
Wc = Wc_all[c:c+vc]
p = torch.exp(hf @ Wc.T - lse[:, None]) # softmax chunk [N,vchunk]
ic = (tgt >= c) & (tgt < c+vc)
if ic.any():
p[ic, tgt[ic] - c] -= 1.0
p *= sc
gh += p @ Wc
gW[c:c+vc] += p.T @ hf
return gh.to(h.dtype), gW.to(W.dtype), None, None
def fused_ce(h, W, tgt, vchunk=16384):
return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk)