Spaces:
Running on Zero
Running on Zero
File size: 9,044 Bytes
2f6d104 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """Generative fusion-chain grammar: [+residual] -> {rms|layer}norm -> ×w(+b) -> epilogue.
The widest sweep of the reduction->epilogue region (where the compiler under-fuses and the 2B
wins). References are COMPOSED; teacher kernels are TEMPLATE-GENERATED (scalar-reduce +
whole-row variants). Everything is harness-filtered downstream — nothing here is trusted.
"""
from __future__ import annotations
import torch
_C = 0.7978845608028654 # sqrt(2/pi)
# epilogue: (torch fn on fp32 tensor, triton expression in fp32 var `n`)
ACTS = {
"gelu": (lambda t: 0.5 * t * (1.0 + torch.tanh(_C * (t + 0.044715 * t * t * t))),
"(0.5 * n * (1.0 + (2.0 * tl.sigmoid(2.0 * (0.7978845608028654 * (n + 0.044715 * n * n * n))) - 1.0)))"),
"silu": (lambda t: t * torch.sigmoid(t), "(n * tl.sigmoid(n))"),
"relu2": (lambda t: torch.relu(t) * torch.relu(t), "(tl.maximum(n, 0.0) * tl.maximum(n, 0.0))"),
# --- expanded grammar (each torch fn EXACTLY matches its triton expr; no approximation) ---
"tanh": (lambda t: torch.tanh(t), "(2.0 * tl.sigmoid(2.0 * n) - 1.0)"), # identity tanh(x)=2σ(2x)-1
"sigmoid": (lambda t: torch.sigmoid(t), "tl.sigmoid(n)"),
"relu": (lambda t: torch.relu(t), "tl.maximum(n, 0.0)"),
"square": (lambda t: t * t, "(n * n)"),
# --- 2c round 2: more real activations (each torch fn EXACTLY matches its triton expr) ---
"abs": (lambda t: torch.abs(t), "tl.abs(n)"),
"softsign": (lambda t: t / (1.0 + torch.abs(t)), "(n / (1.0 + tl.abs(n)))"),
"hardsigmoid": (lambda t: torch.clamp(t + 3.0, 0.0, 6.0) / 6.0,
"(tl.minimum(tl.maximum(n + 3.0, 0.0), 6.0) / 6.0)"), # F.hardsigmoid
"hardswish": (lambda t: t * torch.clamp(t + 3.0, 0.0, 6.0) / 6.0,
"(n * tl.minimum(tl.maximum(n + 3.0, 0.0), 6.0) / 6.0)"), # F.hardswish
# --- V2 round 3: 8 more real, numerically-safe activations. Same exactness rule: the
# torch lambda IS the triton expression (tanh via the 2*sigmoid(2x)-1 identity;
# softplus uses F.softplus's threshold=20 guard so exp never overflows). -----------
"leaky_relu": (lambda t: torch.where(t > 0, t, 0.01 * t),
"tl.where(n > 0.0, n, 0.01 * n)"),
"relu6": (lambda t: torch.clamp(t, 0.0, 6.0),
"tl.minimum(tl.maximum(n, 0.0), 6.0)"),
"hardtanh": (lambda t: torch.clamp(t, -1.0, 1.0),
"tl.minimum(tl.maximum(n, -1.0), 1.0)"),
"elu": (lambda t: torch.where(t > 0, t, torch.exp(torch.clamp(t, max=0.0)) - 1.0),
"tl.where(n > 0.0, n, tl.exp(tl.minimum(n, 0.0)) - 1.0)"),
"selu": (lambda t: 1.0507009873554805 * torch.where(
t > 0, t, 1.6732632423543772 * (torch.exp(torch.clamp(t, max=0.0)) - 1.0)),
"(1.0507009873554805 * tl.where(n > 0.0, n, "
"1.6732632423543772 * (tl.exp(tl.minimum(n, 0.0)) - 1.0)))"),
"softplus": (lambda t: torch.where(t > 20.0, t, torch.log(1.0 + torch.exp(torch.clamp(t, max=20.0)))),
"tl.where(n > 20.0, n, tl.log(1.0 + tl.exp(tl.minimum(n, 20.0))))"),
"mish": (lambda t: t * torch.tanh(torch.where(
t > 20.0, t, torch.log(1.0 + torch.exp(torch.clamp(t, max=20.0))))),
"(n * (2.0 * tl.sigmoid(2.0 * tl.where(n > 20.0, n, "
"tl.log(1.0 + tl.exp(tl.minimum(n, 20.0))))) - 1.0))"),
"gelu_erf": (lambda t: 0.5 * t * (1.0 + torch.erf(t * 0.7071067811865476)),
"(0.5 * n * (1.0 + tl.erf(n * 0.7071067811865476)))"), # EXACT gelu
}
NORMS = ["rms", "layer"]
RESID = [False, True]
ACTNAMES = ["gelu", "silu", "relu2", "tanh", "sigmoid", "relu", "square",
"abs", "softsign", "hardsigmoid", "hardswish",
"leaky_relu", "relu6", "hardtanh", "elu", "selu", "softplus", "mish", "gelu_erf"]
def chain_name(norm, residual, act):
return ("add_" if residual else "") + ("rmsnorm" if norm == "rms" else "layernorm") + "_" + act
def chain_kind(norm, residual):
return ("add_" if residual else "") + ("rms" if norm == "rms" else "ln") # -> input signature
def chain_reference(norm, residual, act, eps=None):
eps = eps if eps is not None else (1e-6 if norm == "rms" else 1e-5)
fn = ACTS[act][0]
def ref(*args):
if residual and norm == "rms":
x, r, w = args; h = x.float() + r.float(); b = None
elif residual:
x, r, w, b = args; h = x.float() + r.float()
elif norm == "rms":
x, w = args; h = x.float(); b = None
else:
x, w, b = args; h = x.float()
if norm == "rms":
n = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + eps) * w.float()
else:
mu = h.mean(-1, keepdim=True); hc = h - mu
n = hc * torch.rsqrt((hc * hc).mean(-1, keepdim=True) + eps) * w.float() + b.float()
return fn(n).to(args[0].dtype)
return ref
# ---- teacher-kernel templates -----------------------------------------------------------
def _kernel(norm, residual, act_expr, eps, variant):
"""variant: 'scalar' (loop+scalar accumulator) or 'whole' (single block per row)."""
ptrs = "x_ptr, " + ("r_ptr, " if residual else "") + "w_ptr, " + ("b_ptr, " if norm == "layer" else "") + "y_ptr"
sig = "x, " + ("residual, " if residual else "") + "w" + (", b" if norm == "layer" else "")
launch = "x, " + ("residual, " if residual else "") + "w" + (", b" if norm == "layer" else "") + ", y"
radv = " r_ptr += row * stride;" if residual else ""
hload = ("tl.load(x_ptr + cols, mask=MM, other=0.0).to(tl.float32)"
+ (" + tl.load(r_ptr + cols, mask=MM, other=0.0).to(tl.float32)" if residual else ""))
# bias load indent differs: scalar variant loads it INSIDE the apply for-loop (8 spaces),
# whole-row loads it flat (4 spaces). Wrong indent -> IndentationError.
bload8 = " b = tl.load(b_ptr + cols, mask=MM, other=0.0).to(tl.float32)\n" if norm == "layer" else ""
bload4 = " b = tl.load(b_ptr + cols, mask=MM, other=0.0).to(tl.float32)\n" if norm == "layer" else ""
if norm == "rms":
normed = "h * rr * w"
else:
normed = "(h - mu) * rr * w + b"
if variant == "scalar":
if norm == "rms":
reduce_block = f''' s = 0.0
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
h = {hload}
s += tl.sum(h * h)
rr = tl.rsqrt(s / N + eps)'''
else:
reduce_block = f''' s = 0.0
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
s += tl.sum({hload})
mu = s / N
v = 0.0
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
d = tl.where(MM, ({hload}) - mu, 0.0); v += tl.sum(d * d)
rr = tl.rsqrt(v / N + eps)'''
body = f'''@triton.jit
def _k({ptrs}, stride, N, eps, BLOCK: tl.constexpr):
row = tl.program_id(0); x_ptr += row * stride;{radv} y_ptr += row * stride
{reduce_block}
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
h = {hload}
w = tl.load(w_ptr + cols, mask=MM, other=0.0).to(tl.float32)
{bload8} n = {normed}
tl.store(y_ptr + cols, {act_expr}, mask=MM)
def run({sig}):
M, N = x.shape; y = torch.empty_like(x)
_k[(M,)]({launch}, x.stride(0), N, {eps}, BLOCK=1024)
return y
'''
else: # whole-row single block
if norm == "rms":
stat = " rr = tl.rsqrt(tl.sum(h * h) / N + eps)"
else:
stat = (" mu = tl.sum(h) / N\n hc = tl.where(MM, h - mu, 0.0)\n"
" rr = tl.rsqrt(tl.sum(hc * hc) / N + eps)")
normed = "hc * rr * w + b"
body = f'''@triton.jit
def _k({ptrs}, stride, N, eps, BLOCK: tl.constexpr):
row = tl.program_id(0); x_ptr += row * stride;{radv} y_ptr += row * stride
cols = tl.arange(0, BLOCK); MM = cols < N
h = {hload}
{stat}
w = tl.load(w_ptr + cols, mask=MM, other=0.0).to(tl.float32)
{bload4} n = {normed}
tl.store(y_ptr + cols, {act_expr}, mask=MM)
def run({sig}):
M, N = x.shape; y = torch.empty_like(x)
_k[(M,)]({launch}, x.stride(0), N, {eps}, BLOCK=triton.next_power_of_2(N))
return y
'''
return body
def chain_structures(norm, residual, act):
eps = 1e-6 if norm == "rms" else 1e-5
expr = ACTS[act][1]
return [_kernel(norm, residual, expr, eps, "scalar"), _kernel(norm, residual, expr, eps, "whole")]
def all_chains():
"""[(name, kind, reference_fn, [kernel_src, ...]), ...] for the full grammar."""
out = []
for norm in NORMS:
for residual in RESID:
for act in ACTNAMES:
name = chain_name(norm, residual, act)
out.append((name, chain_kind(norm, residual), chain_reference(norm, residual, act),
chain_structures(norm, residual, act)))
return out
|