File size: 11,322 Bytes
a9b4621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
#!/usr/bin/env python3
"""
Self-contained E2E training benchmark: Dense vs PyLoop-sparse vs Triton-sparse.
Includes all Triton kernels inline. Runs d_model ∈ {512, 1024, 2048}.
"""

import math, os, time, urllib.request
import torch, torch.nn as nn, torch.nn.functional as F
import triton, triton.language as tl
import tiktoken

device = 'cuda'
BS, BLK = 8, 256

# ═══════════ DATA ═══════════

if not os.path.exists('input.txt'):
    urllib.request.urlretrieve('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', 'input.txt')
enc = tiktoken.get_encoding('gpt2')
tokens = torch.tensor(enc.encode(open('input.txt').read()), dtype=torch.long)
train_data = tokens[:int(0.9*len(tokens))]
val_data = tokens[int(0.9*len(tokens)):]
V = enc.n_vocab

def get_batch(data, gen=None):
    ix = torch.randint(len(data)-BLK-1, (BS,), generator=gen)
    return (torch.stack([data[i:i+BLK] for i in ix]).to(device),
            torch.stack([data[i+1:i+BLK+1] for i in ix]).to(device))

# ═══════════ TRITON KERNELS ═══════════

@triton.autotune(
    configs=[
        triton.Config({'BN': 32, 'BK': 64,  'BM': 32}, num_stages=3, num_warps=4),
        triton.Config({'BN': 64, 'BK': 64,  'BM': 32}, num_stages=3, num_warps=4),
        triton.Config({'BN': 64, 'BK': 128, 'BM': 32}, num_stages=3, num_warps=4),
        triton.Config({'BN': 32, 'BK': 128, 'BM': 64}, num_stages=3, num_warps=4),
        triton.Config({'BN': 64, 'BK': 64,  'BM': 64}, num_stages=4, num_warps=4),
    ],
    key=['M', 'd_in', 'CS'],
)
@triton.jit
def _sparse_bwd_dW_kernel(
    X_ptr, dY_ptr, dW_ptr, chunk_ids_ptr,
    M, d_in, d_out, num_active,
    stride_xm, stride_xk, stride_dym, stride_dyn, stride_dwn, stride_dwk,
    CS: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
):
    pid0 = tl.program_id(0); pid1 = tl.program_id(1)
    N_BLOCKS = tl.cdiv(CS, BN)
    cli = pid0 // N_BLOCKS; nbi = pid0 % N_BLOCKS; kbi = pid1
    if cli >= num_active: return
    cidx = tl.load(chunk_ids_ptr + cli); cs0 = cidx * CS
    rn = nbi * BN + tl.arange(0, BN); rk = kbi * BK + tl.arange(0, BK)
    na = cs0 + rn; nm = rn < CS; km = rk < d_in
    acc = tl.zeros((BN, BK), dtype=tl.float32)
    for ms in range(0, M, BM):
        rm = ms + tl.arange(0, BM); mm = rm < M
        x = tl.load(X_ptr + rm[:, None]*stride_xm + rk[None, :]*stride_xk, mask=mm[:, None] & km[None, :], other=0.0)
        dy = tl.load(dY_ptr + rm[:, None]*stride_dym + na[None, :]*stride_dyn, mask=mm[:, None] & nm[None, :], other=0.0)
        acc = tl.dot(tl.trans(dy), x, acc=acc)
    tl.store(dW_ptr + na[:, None]*stride_dwn + rk[None, :]*stride_dwk, acc.to(dW_ptr.dtype.element_ty), mask=nm[:, None] & km[None, :])

def sparse_bwd_dW(X, dY, active, cs, d_out):
    M, d_in = X.shape; na = active.shape[0]
    dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype)
    if na == 0: return dW
    cids = active.to(torch.int32).contiguous()
    grid = lambda META: (na * triton.cdiv(cs, META['BN']), triton.cdiv(d_in, META['BK']))
    _sparse_bwd_dW_kernel[grid](X, dY, dW, cids, M, d_in, d_out, na,
        X.stride(0), X.stride(1), dY.stride(0), dY.stride(1), dW.stride(0), dW.stride(1), CS=cs)
    return dW

@triton.jit
def _sparse_bwd_dbias_kernel(
    dY_ptr, dB_ptr, chunk_ids_ptr, M, d_out, num_active,
    stride_dym, stride_dyn, CS: tl.constexpr, BM: tl.constexpr,
):
    pid = tl.program_id(0)
    cl = pid // CS; ci = pid % CS
    if cl >= num_active: return
    cidx = tl.load(chunk_ids_ptr + cl); ca = cidx * CS + ci
    acc = 0.0
    for ms in range(0, M, BM):
        rm = ms + tl.arange(0, BM); mm = rm < M
        acc += tl.sum(tl.load(dY_ptr + rm*stride_dym + ca*stride_dyn, mask=mm, other=0.0))
    tl.store(dB_ptr + ca, acc.to(dB_ptr.dtype.element_ty))

def sparse_bwd_dbias(dY, active, cs, d_out):
    M = dY.shape[0]; na = active.shape[0]
    dB = torch.zeros(d_out, device=dY.device, dtype=dY.dtype)
    if na == 0: return dB
    cids = active.to(torch.int32).contiguous()
    _sparse_bwd_dbias_kernel[(na * cs,)](dY, dB, cids, M, d_out, na, dY.stride(0), dY.stride(1), CS=cs, BM=128)
    return dB

# ═══════════ AUTOGRAD ═══════════

class TritonSparse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, b, active, cs, sdx):
        ctx.save_for_backward(x, w, active); ctx.has_bias = b is not None; ctx.sdx = sdx; ctx.cs = cs
        return F.linear(x, w, b)
    @staticmethod
    def backward(ctx, gy):
        x, w, active = ctx.saved_tensors; cs = ctx.cs; do, di = w.shape
        xf = x.reshape(-1, di); gf = gy.reshape(-1, do)
        gw = sparse_bwd_dW(xf, gf, active, cs, do)
        gb = sparse_bwd_dbias(gf, active, cs, do) if ctx.has_bias else None
        gx = gf @ w  # dense dX
        return gx.reshape(x.shape), gw, gb, None, None, None

class PyLoopSparse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, b, active, cs, sdx):
        ctx.save_for_backward(x, w, active); ctx.has_bias = b is not None; ctx.sdx = sdx; ctx.cs = cs
        return F.linear(x, w, b)
    @staticmethod
    def backward(ctx, gy):
        x, w, active = ctx.saved_tensors; cs = ctx.cs
        xf = x.reshape(-1, x.shape[-1]); gf = gy.reshape(-1, gy.shape[-1])
        gw = torch.zeros_like(w)
        gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None
        gx = gf @ w
        for c in active.tolist():
            s, e = c*cs, (c+1)*cs
            gw[s:e] = gf[:, s:e].t() @ xf
            if ctx.has_bias: gb[s:e] = gf[:, s:e].sum(0)
        return gx.reshape(x.shape), gw, gb, None, None, None

# ═══════════ MODEL ═══════════

class SparseFFN(nn.Module):
    def __init__(self, d, cs=64):
        super().__init__()
        self.fc = nn.Linear(d, 4*d); self.proj = nn.Linear(4*d, d)
        self.do = nn.Dropout(0.1); self.cs = cs; self.mode = 'dense'; self.active_chunks = None
    def forward(self, x):
        h = F.gelu(self.fc(x))
        if self.mode == 'dense' or self.active_chunks is None:
            return self.do(self.proj(h))
        elif self.mode == 'pyloop':
            return self.do(PyLoopSparse.apply(h, self.proj.weight, self.proj.bias, self.active_chunks, self.cs, False))
        else:
            return self.do(TritonSparse.apply(h, self.proj.weight, self.proj.bias, self.active_chunks, self.cs, False))

class Attn(nn.Module):
    def __init__(self, d, nh, bs):
        super().__init__()
        self.nh, self.hd = nh, d//nh
        self.qkv = nn.Linear(d, 3*d); self.proj = nn.Linear(d, d)
        self.do = nn.Dropout(0.1)
        self.register_buffer('mask', torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs))
    def forward(self, x):
        B,T,C = x.shape
        q,k,v = self.qkv(x).split(C,2)
        q=q.view(B,T,self.nh,self.hd).transpose(1,2); k=k.view(B,T,self.nh,self.hd).transpose(1,2); v=v.view(B,T,self.nh,self.hd).transpose(1,2)
        a = self.do(F.softmax((q@k.transpose(-2,-1))/math.sqrt(self.hd)+self.mask[:,:,:T,:T].log(), dim=-1))
        return self.proj((a@v).transpose(1,2).contiguous().view(B,T,C))

class Block(nn.Module):
    def __init__(self, d, nh, bs):
        super().__init__()
        self.ln1=nn.LayerNorm(d); self.attn=Attn(d,nh,bs); self.ln2=nn.LayerNorm(d); self.mlp=SparseFFN(d)
    def forward(self, x):
        x = x + self.attn(self.ln1(x)); return x + self.mlp(self.ln2(x))

class GPT(nn.Module):
    def __init__(self, d, nl, nh, bs):
        super().__init__()
        self.te=nn.Embedding(V,d); self.pe=nn.Embedding(bs,d)
        self.blocks=nn.ModuleList([Block(d,nh,bs) for _ in range(nl)]); self.ln=nn.LayerNorm(d); self.head=nn.Linear(d,V)
    def forward(self, idx, tgt=None):
        x = self.te(idx)+self.pe(torch.arange(idx.shape[1],device=idx.device))[None]
        for b in self.blocks: x = b(x)
        lo = self.head(self.ln(x))
        return lo, F.cross_entropy(lo.view(-1,lo.size(-1)), tgt.view(-1)) if tgt is not None else None
    def get_ffns(self): return [b.mlp for b in self.blocks]
    def nparams(self): return sum(p.numel() for p in self.parameters())

# ═══════════ RUN ═══════════

STEPS = 500
af = 0.10
cs = 64

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

print(f"E2E training: {STEPS} steps, B={BS}, T={BLK}, active_frac={af}, chunk_size={cs}")
print(f"{'d_model':>7} | {'Mode':>8} | {'Params':>8} | {'ms/step':>10} | {'vs Dense':>10} | {'val_loss':>10} | {'train_loss':>10}")
print("-"*80)

for d in [512, 1024, 2048]:
    nh = 8; nl = 6
    results = {}

    for mode in ['dense', 'pyloop', 'triton']:
        torch.manual_seed(42)
        model = GPT(d, nl, nh, BLK).to(device)
        npar = model.nparams()
        opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
        ffns = model.get_ffns()

        # Triton warmup (compile kernels before timing)
        if mode == 'triton':
            for ffn in ffns:
                ffn.mode = mode
                nc = ffn.proj.out_features // cs
                k = max(1, int(af * nc))
                ffn.active_chunks = torch.randperm(nc, device=device)[:k].sort().values
            x, y = get_batch(train_data, torch.Generator().manual_seed(99999))
            opt.zero_grad(); _, loss = model(x, y); loss.backward(); opt.step()
            # Reset model
            torch.manual_seed(42)
            model = GPT(d, nl, nh, BLK).to(device)
            opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
            ffns = model.get_ffns()

        torch.cuda.synchronize()
        t0 = time.perf_counter()
        last_loss = 0.0

        for step in range(STEPS):
            if mode != 'dense':
                for ffn in ffns:
                    ffn.mode = mode
                    nc = ffn.proj.out_features // cs
                    k = max(1, int(af * nc))
                    ffn.active_chunks = torch.randperm(nc, device=device)[:k].sort().values
            else:
                for ffn in ffns:
                    ffn.mode = 'dense'; ffn.active_chunks = None

            x, y = get_batch(train_data, torch.Generator().manual_seed(step))
            opt.zero_grad()
            _, loss = model(x, y)
            loss.backward()
            opt.step()
            last_loss = loss.item()

            if step % 100 == 0:
                print(f"  [{mode}] d={d} step {step}/{STEPS} loss={last_loss:.4f}")

        torch.cuda.synchronize()
        ms = 1000 * (time.perf_counter() - t0) / STEPS

        # Eval
        model.eval()
        for ffn in ffns: ffn.mode = 'dense'; ffn.active_chunks = None
        with torch.no_grad():
            vl = sum(model(*get_batch(val_data, torch.Generator().manual_seed(9999+i)))[1].item() for i in range(20))/20

        results[mode] = (ms, vl, last_loss, npar)
        del model, opt; torch.cuda.empty_cache()

    d_ms = results['dense'][0]
    for mode in ['dense', 'pyloop', 'triton']:
        ms, vl, tl_, np_ = results[mode]
        sp = d_ms / ms
        print(f"{d:>7} | {mode:>8} | {np_/1e6:>7.1f}M | {ms:>9.1f}ms | {sp:>9.2f}x | {vl:>9.4f} | {tl_:>9.4f}")
    print()