File size: 2,435 Bytes
82d098e
 
 
 
 
 
df559be
82d098e
 
 
df559be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d098e
 
df559be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d098e
df559be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Fused cross-entropy: streams over the VOCAB dimension (online-softmax) so the
[N x V] logit matrix is NEVER materialized -- only [N x vchunk]. Custom backward
recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the
DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
the output head instead of network depth."""
import torch

class FusedCE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, h, W, tgt, vchunk=16384):
        with torch.cuda.amp.autocast(enabled=False):
            hf = h.float()
            Wf = W.float()
            N, d = h.shape
            V = W.shape[0]
            m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32)
            s = torch.zeros(N, device=h.device, dtype=torch.float32)
            zt = torch.zeros(N, device=h.device, dtype=torch.float32)
            for c in range(0, V, vchunk):
                lg = hf @ Wf[c:c+vchunk].T                    # [N,vchunk] transient only
                cm = lg.max(1).values
                nm = torch.maximum(m, cm)
                s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1)
                m = nm
                ic = (tgt >= c) & (tgt < c+vchunk)
                if ic.any():
                    zt[ic] = lg[ic, tgt[ic] - c].float()
            lse = m + torch.log(s)
            ctx.save_for_backward(h, W, tgt, lse)
            ctx.vchunk = vchunk
            return (lse - zt).mean()

    @staticmethod
    def backward(ctx, go):
        h, W, tgt, lse = ctx.saved_tensors
        vc = ctx.vchunk
        N, d = h.shape
        V = W.shape[0]
        with torch.cuda.amp.autocast(enabled=False):
            hf = h.float()
            Wc_all = W.float()
            gh = torch.zeros_like(hf)
            gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32)
            sc = float(go) / N
            for c in range(0, V, vc):
                Wc = Wc_all[c:c+vc]
                p = torch.exp(hf @ Wc.T - lse[:, None])     # softmax chunk [N,vchunk]
                ic = (tgt >= c) & (tgt < c+vc)
                if ic.any():
                    p[ic, tgt[ic] - c] -= 1.0
                p *= sc
                gh += p @ Wc
                gW[c:c+vc] += p.T @ hf
            return gh.to(h.dtype), gW.to(W.dtype), None, None

def fused_ce(h, W, tgt, vchunk=16384):
    return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk)