theapemachine commited on
Commit
a9b4621
Β·
verified Β·
1 Parent(s): ae901d9

Upload e2e_full.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. e2e_full.py +262 -0
e2e_full.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Self-contained E2E training benchmark: Dense vs PyLoop-sparse vs Triton-sparse.
4
+ Includes all Triton kernels inline. Runs d_model ∈ {512, 1024, 2048}.
5
+ """
6
+
7
+ import math, os, time, urllib.request
8
+ import torch, torch.nn as nn, torch.nn.functional as F
9
+ import triton, triton.language as tl
10
+ import tiktoken
11
+
12
+ device = 'cuda'
13
+ BS, BLK = 8, 256
14
+
15
+ # ═══════════ DATA ═══════════
16
+
17
+ if not os.path.exists('input.txt'):
18
+ urllib.request.urlretrieve('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', 'input.txt')
19
+ enc = tiktoken.get_encoding('gpt2')
20
+ tokens = torch.tensor(enc.encode(open('input.txt').read()), dtype=torch.long)
21
+ train_data = tokens[:int(0.9*len(tokens))]
22
+ val_data = tokens[int(0.9*len(tokens)):]
23
+ V = enc.n_vocab
24
+
25
+ def get_batch(data, gen=None):
26
+ ix = torch.randint(len(data)-BLK-1, (BS,), generator=gen)
27
+ return (torch.stack([data[i:i+BLK] for i in ix]).to(device),
28
+ torch.stack([data[i+1:i+BLK+1] for i in ix]).to(device))
29
+
30
+ # ═══════════ TRITON KERNELS ═══════════
31
+
32
+ @triton.autotune(
33
+ configs=[
34
+ triton.Config({'BN': 32, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4),
35
+ triton.Config({'BN': 64, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4),
36
+ triton.Config({'BN': 64, 'BK': 128, 'BM': 32}, num_stages=3, num_warps=4),
37
+ triton.Config({'BN': 32, 'BK': 128, 'BM': 64}, num_stages=3, num_warps=4),
38
+ triton.Config({'BN': 64, 'BK': 64, 'BM': 64}, num_stages=4, num_warps=4),
39
+ ],
40
+ key=['M', 'd_in', 'CS'],
41
+ )
42
+ @triton.jit
43
+ def _sparse_bwd_dW_kernel(
44
+ X_ptr, dY_ptr, dW_ptr, chunk_ids_ptr,
45
+ M, d_in, d_out, num_active,
46
+ stride_xm, stride_xk, stride_dym, stride_dyn, stride_dwn, stride_dwk,
47
+ CS: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr, BM: tl.constexpr,
48
+ ):
49
+ pid0 = tl.program_id(0); pid1 = tl.program_id(1)
50
+ N_BLOCKS = tl.cdiv(CS, BN)
51
+ cli = pid0 // N_BLOCKS; nbi = pid0 % N_BLOCKS; kbi = pid1
52
+ if cli >= num_active: return
53
+ cidx = tl.load(chunk_ids_ptr + cli); cs0 = cidx * CS
54
+ rn = nbi * BN + tl.arange(0, BN); rk = kbi * BK + tl.arange(0, BK)
55
+ na = cs0 + rn; nm = rn < CS; km = rk < d_in
56
+ acc = tl.zeros((BN, BK), dtype=tl.float32)
57
+ for ms in range(0, M, BM):
58
+ rm = ms + tl.arange(0, BM); mm = rm < M
59
+ x = tl.load(X_ptr + rm[:, None]*stride_xm + rk[None, :]*stride_xk, mask=mm[:, None] & km[None, :], other=0.0)
60
+ dy = tl.load(dY_ptr + rm[:, None]*stride_dym + na[None, :]*stride_dyn, mask=mm[:, None] & nm[None, :], other=0.0)
61
+ acc = tl.dot(tl.trans(dy), x, acc=acc)
62
+ tl.store(dW_ptr + na[:, None]*stride_dwn + rk[None, :]*stride_dwk, acc.to(dW_ptr.dtype.element_ty), mask=nm[:, None] & km[None, :])
63
+
64
+ def sparse_bwd_dW(X, dY, active, cs, d_out):
65
+ M, d_in = X.shape; na = active.shape[0]
66
+ dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype)
67
+ if na == 0: return dW
68
+ cids = active.to(torch.int32).contiguous()
69
+ grid = lambda META: (na * triton.cdiv(cs, META['BN']), triton.cdiv(d_in, META['BK']))
70
+ _sparse_bwd_dW_kernel[grid](X, dY, dW, cids, M, d_in, d_out, na,
71
+ X.stride(0), X.stride(1), dY.stride(0), dY.stride(1), dW.stride(0), dW.stride(1), CS=cs)
72
+ return dW
73
+
74
+ @triton.jit
75
+ def _sparse_bwd_dbias_kernel(
76
+ dY_ptr, dB_ptr, chunk_ids_ptr, M, d_out, num_active,
77
+ stride_dym, stride_dyn, CS: tl.constexpr, BM: tl.constexpr,
78
+ ):
79
+ pid = tl.program_id(0)
80
+ cl = pid // CS; ci = pid % CS
81
+ if cl >= num_active: return
82
+ cidx = tl.load(chunk_ids_ptr + cl); ca = cidx * CS + ci
83
+ acc = 0.0
84
+ for ms in range(0, M, BM):
85
+ rm = ms + tl.arange(0, BM); mm = rm < M
86
+ acc += tl.sum(tl.load(dY_ptr + rm*stride_dym + ca*stride_dyn, mask=mm, other=0.0))
87
+ tl.store(dB_ptr + ca, acc.to(dB_ptr.dtype.element_ty))
88
+
89
+ def sparse_bwd_dbias(dY, active, cs, d_out):
90
+ M = dY.shape[0]; na = active.shape[0]
91
+ dB = torch.zeros(d_out, device=dY.device, dtype=dY.dtype)
92
+ if na == 0: return dB
93
+ cids = active.to(torch.int32).contiguous()
94
+ _sparse_bwd_dbias_kernel[(na * cs,)](dY, dB, cids, M, d_out, na, dY.stride(0), dY.stride(1), CS=cs, BM=128)
95
+ return dB
96
+
97
+ # ═══════════ AUTOGRAD ═══════════
98
+
99
+ class TritonSparse(torch.autograd.Function):
100
+ @staticmethod
101
+ def forward(ctx, x, w, b, active, cs, sdx):
102
+ ctx.save_for_backward(x, w, active); ctx.has_bias = b is not None; ctx.sdx = sdx; ctx.cs = cs
103
+ return F.linear(x, w, b)
104
+ @staticmethod
105
+ def backward(ctx, gy):
106
+ x, w, active = ctx.saved_tensors; cs = ctx.cs; do, di = w.shape
107
+ xf = x.reshape(-1, di); gf = gy.reshape(-1, do)
108
+ gw = sparse_bwd_dW(xf, gf, active, cs, do)
109
+ gb = sparse_bwd_dbias(gf, active, cs, do) if ctx.has_bias else None
110
+ gx = gf @ w # dense dX
111
+ return gx.reshape(x.shape), gw, gb, None, None, None
112
+
113
+ class PyLoopSparse(torch.autograd.Function):
114
+ @staticmethod
115
+ def forward(ctx, x, w, b, active, cs, sdx):
116
+ ctx.save_for_backward(x, w, active); ctx.has_bias = b is not None; ctx.sdx = sdx; ctx.cs = cs
117
+ return F.linear(x, w, b)
118
+ @staticmethod
119
+ def backward(ctx, gy):
120
+ x, w, active = ctx.saved_tensors; cs = ctx.cs
121
+ xf = x.reshape(-1, x.shape[-1]); gf = gy.reshape(-1, gy.shape[-1])
122
+ gw = torch.zeros_like(w)
123
+ gb = torch.zeros(w.shape[0], device=w.device, dtype=w.dtype) if ctx.has_bias else None
124
+ gx = gf @ w
125
+ for c in active.tolist():
126
+ s, e = c*cs, (c+1)*cs
127
+ gw[s:e] = gf[:, s:e].t() @ xf
128
+ if ctx.has_bias: gb[s:e] = gf[:, s:e].sum(0)
129
+ return gx.reshape(x.shape), gw, gb, None, None, None
130
+
131
+ # ═══════════ MODEL ═══════════
132
+
133
+ class SparseFFN(nn.Module):
134
+ def __init__(self, d, cs=64):
135
+ super().__init__()
136
+ self.fc = nn.Linear(d, 4*d); self.proj = nn.Linear(4*d, d)
137
+ self.do = nn.Dropout(0.1); self.cs = cs; self.mode = 'dense'; self.active_chunks = None
138
+ def forward(self, x):
139
+ h = F.gelu(self.fc(x))
140
+ if self.mode == 'dense' or self.active_chunks is None:
141
+ return self.do(self.proj(h))
142
+ elif self.mode == 'pyloop':
143
+ return self.do(PyLoopSparse.apply(h, self.proj.weight, self.proj.bias, self.active_chunks, self.cs, False))
144
+ else:
145
+ return self.do(TritonSparse.apply(h, self.proj.weight, self.proj.bias, self.active_chunks, self.cs, False))
146
+
147
+ class Attn(nn.Module):
148
+ def __init__(self, d, nh, bs):
149
+ super().__init__()
150
+ self.nh, self.hd = nh, d//nh
151
+ self.qkv = nn.Linear(d, 3*d); self.proj = nn.Linear(d, d)
152
+ self.do = nn.Dropout(0.1)
153
+ self.register_buffer('mask', torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs))
154
+ def forward(self, x):
155
+ B,T,C = x.shape
156
+ q,k,v = self.qkv(x).split(C,2)
157
+ q=q.view(B,T,self.nh,self.hd).transpose(1,2); k=k.view(B,T,self.nh,self.hd).transpose(1,2); v=v.view(B,T,self.nh,self.hd).transpose(1,2)
158
+ a = self.do(F.softmax((q@k.transpose(-2,-1))/math.sqrt(self.hd)+self.mask[:,:,:T,:T].log(), dim=-1))
159
+ return self.proj((a@v).transpose(1,2).contiguous().view(B,T,C))
160
+
161
+ class Block(nn.Module):
162
+ def __init__(self, d, nh, bs):
163
+ super().__init__()
164
+ self.ln1=nn.LayerNorm(d); self.attn=Attn(d,nh,bs); self.ln2=nn.LayerNorm(d); self.mlp=SparseFFN(d)
165
+ def forward(self, x):
166
+ x = x + self.attn(self.ln1(x)); return x + self.mlp(self.ln2(x))
167
+
168
+ class GPT(nn.Module):
169
+ def __init__(self, d, nl, nh, bs):
170
+ super().__init__()
171
+ self.te=nn.Embedding(V,d); self.pe=nn.Embedding(bs,d)
172
+ self.blocks=nn.ModuleList([Block(d,nh,bs) for _ in range(nl)]); self.ln=nn.LayerNorm(d); self.head=nn.Linear(d,V)
173
+ def forward(self, idx, tgt=None):
174
+ x = self.te(idx)+self.pe(torch.arange(idx.shape[1],device=idx.device))[None]
175
+ for b in self.blocks: x = b(x)
176
+ lo = self.head(self.ln(x))
177
+ return lo, F.cross_entropy(lo.view(-1,lo.size(-1)), tgt.view(-1)) if tgt is not None else None
178
+ def get_ffns(self): return [b.mlp for b in self.blocks]
179
+ def nparams(self): return sum(p.numel() for p in self.parameters())
180
+
181
+ # ═══════════ RUN ═══════════
182
+
183
+ STEPS = 500
184
+ af = 0.10
185
+ cs = 64
186
+
187
+ if torch.cuda.is_available():
188
+ print(f"GPU: {torch.cuda.get_device_name()} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
189
+
190
+ print(f"E2E training: {STEPS} steps, B={BS}, T={BLK}, active_frac={af}, chunk_size={cs}")
191
+ print(f"{'d_model':>7} | {'Mode':>8} | {'Params':>8} | {'ms/step':>10} | {'vs Dense':>10} | {'val_loss':>10} | {'train_loss':>10}")
192
+ print("-"*80)
193
+
194
+ for d in [512, 1024, 2048]:
195
+ nh = 8; nl = 6
196
+ results = {}
197
+
198
+ for mode in ['dense', 'pyloop', 'triton']:
199
+ torch.manual_seed(42)
200
+ model = GPT(d, nl, nh, BLK).to(device)
201
+ npar = model.nparams()
202
+ opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
203
+ ffns = model.get_ffns()
204
+
205
+ # Triton warmup (compile kernels before timing)
206
+ if mode == 'triton':
207
+ for ffn in ffns:
208
+ ffn.mode = mode
209
+ nc = ffn.proj.out_features // cs
210
+ k = max(1, int(af * nc))
211
+ ffn.active_chunks = torch.randperm(nc, device=device)[:k].sort().values
212
+ x, y = get_batch(train_data, torch.Generator().manual_seed(99999))
213
+ opt.zero_grad(); _, loss = model(x, y); loss.backward(); opt.step()
214
+ # Reset model
215
+ torch.manual_seed(42)
216
+ model = GPT(d, nl, nh, BLK).to(device)
217
+ opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
218
+ ffns = model.get_ffns()
219
+
220
+ torch.cuda.synchronize()
221
+ t0 = time.perf_counter()
222
+ last_loss = 0.0
223
+
224
+ for step in range(STEPS):
225
+ if mode != 'dense':
226
+ for ffn in ffns:
227
+ ffn.mode = mode
228
+ nc = ffn.proj.out_features // cs
229
+ k = max(1, int(af * nc))
230
+ ffn.active_chunks = torch.randperm(nc, device=device)[:k].sort().values
231
+ else:
232
+ for ffn in ffns:
233
+ ffn.mode = 'dense'; ffn.active_chunks = None
234
+
235
+ x, y = get_batch(train_data, torch.Generator().manual_seed(step))
236
+ opt.zero_grad()
237
+ _, loss = model(x, y)
238
+ loss.backward()
239
+ opt.step()
240
+ last_loss = loss.item()
241
+
242
+ if step % 100 == 0:
243
+ print(f" [{mode}] d={d} step {step}/{STEPS} loss={last_loss:.4f}")
244
+
245
+ torch.cuda.synchronize()
246
+ ms = 1000 * (time.perf_counter() - t0) / STEPS
247
+
248
+ # Eval
249
+ model.eval()
250
+ for ffn in ffns: ffn.mode = 'dense'; ffn.active_chunks = None
251
+ with torch.no_grad():
252
+ vl = sum(model(*get_batch(val_data, torch.Generator().manual_seed(9999+i)))[1].item() for i in range(20))/20
253
+
254
+ results[mode] = (ms, vl, last_loss, npar)
255
+ del model, opt; torch.cuda.empty_cache()
256
+
257
+ d_ms = results['dense'][0]
258
+ for mode in ['dense', 'pyloop', 'triton']:
259
+ ms, vl, tl_, np_ = results[mode]
260
+ sp = d_ms / ms
261
+ print(f"{d:>7} | {mode:>8} | {np_/1e6:>7.1f}M | {ms:>9.1f}ms | {sp:>9.2f}x | {vl:>9.4f} | {tl_:>9.4f}")
262
+ print()