File size: 5,599 Bytes
f1294a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | #!/usr/bin/env python3
"""End-to-end training benchmark: Dense vs PyLoop vs Triton sparse backward."""
import math, os, time, urllib.request
import torch, torch.nn as nn, torch.nn.functional as F
import tiktoken
# Import our Triton kernels from the module
from triton_sparse import (
TritonChunkedSparseLinear, PythonLoopSparseLinear,
sparse_bwd_dW, sparse_bwd_dX, sparse_bwd_dbias
)
device = 'cuda'
BS, BLK = 8, 256
# Data
if not os.path.exists('input.txt'):
urllib.request.urlretrieve('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', 'input.txt')
enc = tiktoken.get_encoding('gpt2')
tokens = torch.tensor(enc.encode(open('input.txt').read()), dtype=torch.long)
train_data = tokens[:int(0.9*len(tokens))]
val_data = tokens[int(0.9*len(tokens)):]
V = enc.n_vocab
def get_batch(data, gen=None):
ix = torch.randint(len(data)-BLK-1, (BS,), generator=gen)
return (torch.stack([data[i:i+BLK] for i in ix]).to(device),
torch.stack([data[i+1:i+BLK+1] for i in ix]).to(device))
# Model
class SparseFFN(nn.Module):
def __init__(self, d, cs=64):
super().__init__()
self.fc = nn.Linear(d, 4*d)
self.proj = nn.Linear(4*d, d)
self.do = nn.Dropout(0.1)
self.cs = cs
self.mode = 'dense'
self.active_chunks = None
def forward(self, x):
h = F.gelu(self.fc(x))
if self.mode == 'dense' or self.active_chunks is None:
return self.do(self.proj(h))
elif self.mode == 'pyloop':
return self.do(PythonLoopSparseLinear.apply(
h, self.proj.weight, self.proj.bias, self.active_chunks, self.cs, False))
else: # triton
return self.do(TritonChunkedSparseLinear.apply(
h, self.proj.weight, self.proj.bias, self.active_chunks, self.cs, False))
class Attn(nn.Module):
def __init__(self, d, nh, bs):
super().__init__()
self.nh, self.hd = nh, d//nh
self.qkv = nn.Linear(d, 3*d)
self.proj = nn.Linear(d, d)
self.do = nn.Dropout(0.1)
self.register_buffer('mask', torch.tril(torch.ones(bs,bs)).view(1,1,bs,bs))
def forward(self, x):
B,T,C = x.shape
q,k,v = self.qkv(x).split(C,2)
q = q.view(B,T,self.nh,self.hd).transpose(1,2)
k = k.view(B,T,self.nh,self.hd).transpose(1,2)
v = v.view(B,T,self.nh,self.hd).transpose(1,2)
att = (q @ k.transpose(-2,-1)) / math.sqrt(self.hd)
att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
att = self.do(F.softmax(att, dim=-1))
return self.proj((att @ v).transpose(1,2).contiguous().view(B,T,C))
class Block(nn.Module):
def __init__(self, d, nh, bs):
super().__init__()
self.ln1=nn.LayerNorm(d); self.attn=Attn(d,nh,bs)
self.ln2=nn.LayerNorm(d); self.mlp=SparseFFN(d)
def forward(self, x):
x = x + self.attn(self.ln1(x))
return x + self.mlp(self.ln2(x))
class GPT(nn.Module):
def __init__(self, d, nl, nh, bs):
super().__init__()
self.te=nn.Embedding(V,d); self.pe=nn.Embedding(bs,d)
self.blocks=nn.ModuleList([Block(d,nh,bs) for _ in range(nl)])
self.ln=nn.LayerNorm(d); self.head=nn.Linear(d,V)
def forward(self, idx, tgt=None):
x = self.te(idx)+self.pe(torch.arange(idx.shape[1],device=idx.device))[None]
for b in self.blocks: x = b(x)
lo = self.head(self.ln(x))
loss = F.cross_entropy(lo.view(-1,lo.size(-1)), tgt.view(-1)) if tgt is not None else None
return lo, loss
def get_ffns(self):
return [b.mlp for b in self.blocks]
# Run
STEPS = 100
af = 0.10
cs = 64
print(f"End-to-end training: {STEPS} steps, B={BS}, T={BLK}, active_frac={af}")
print(f"{'d_model':>7} | {'Mode':>8} | {'ms/step':>10} | {'vs Dense':>10} | {'val_loss':>10}")
print("-"*60)
for d in [512, 1024, 2048]:
nh = 8; nl = 6
results = {}
for mode in ['dense', 'pyloop', 'triton']:
torch.manual_seed(42)
model = GPT(d, nl, nh, BLK).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
ffns = model.get_ffns()
torch.cuda.synchronize()
t0 = time.perf_counter()
for step in range(STEPS):
if mode != 'dense':
for ffn in ffns:
ffn.mode = mode
# proj: Linear(4d, d) -> weight shape (d, 4d), out_features=d
nc = ffn.proj.out_features // cs
k = max(1, int(af * nc))
ffn.active_chunks = torch.randperm(nc, device=device)[:k].sort().values
else:
for ffn in ffns:
ffn.mode = 'dense'; ffn.active_chunks = None
x, y = get_batch(train_data, torch.Generator().manual_seed(step))
opt.zero_grad()
_, loss = model(x, y)
loss.backward()
opt.step()
torch.cuda.synchronize()
ms = 1000 * (time.perf_counter() - t0) / STEPS
# Eval
model.eval()
for ffn in ffns: ffn.mode = 'dense'; ffn.active_chunks = None
with torch.no_grad():
vl = sum(model(*get_batch(val_data, torch.Generator().manual_seed(9999+i)))[1].item() for i in range(20))/20
results[mode] = (ms, vl)
del model; torch.cuda.empty_cache()
d_ms = results['dense'][0]
for mode in ['dense', 'pyloop', 'triton']:
ms, vl = results[mode]
sp = d_ms / ms
print(f"{d:>7} | {mode:>8} | {ms:>9.1f}ms | {sp:>9.2f}x | {vl:>9.4f}")
print()
|