YMRohit's picture
Add in-Space offgrid/local mode (ZeroGPU H200 + in-process referee)
2f6d104 verified
Raw
History Blame Contribute Delete
9.04 kB
"""Generative fusion-chain grammar: [+residual] -> {rms|layer}norm -> ×w(+b) -> epilogue.
The widest sweep of the reduction->epilogue region (where the compiler under-fuses and the 2B
wins). References are COMPOSED; teacher kernels are TEMPLATE-GENERATED (scalar-reduce +
whole-row variants). Everything is harness-filtered downstream — nothing here is trusted.
"""
from __future__ import annotations
import torch
_C = 0.7978845608028654 # sqrt(2/pi)
# epilogue: (torch fn on fp32 tensor, triton expression in fp32 var `n`)
ACTS = {
"gelu": (lambda t: 0.5 * t * (1.0 + torch.tanh(_C * (t + 0.044715 * t * t * t))),
"(0.5 * n * (1.0 + (2.0 * tl.sigmoid(2.0 * (0.7978845608028654 * (n + 0.044715 * n * n * n))) - 1.0)))"),
"silu": (lambda t: t * torch.sigmoid(t), "(n * tl.sigmoid(n))"),
"relu2": (lambda t: torch.relu(t) * torch.relu(t), "(tl.maximum(n, 0.0) * tl.maximum(n, 0.0))"),
# --- expanded grammar (each torch fn EXACTLY matches its triton expr; no approximation) ---
"tanh": (lambda t: torch.tanh(t), "(2.0 * tl.sigmoid(2.0 * n) - 1.0)"), # identity tanh(x)=2σ(2x)-1
"sigmoid": (lambda t: torch.sigmoid(t), "tl.sigmoid(n)"),
"relu": (lambda t: torch.relu(t), "tl.maximum(n, 0.0)"),
"square": (lambda t: t * t, "(n * n)"),
# --- 2c round 2: more real activations (each torch fn EXACTLY matches its triton expr) ---
"abs": (lambda t: torch.abs(t), "tl.abs(n)"),
"softsign": (lambda t: t / (1.0 + torch.abs(t)), "(n / (1.0 + tl.abs(n)))"),
"hardsigmoid": (lambda t: torch.clamp(t + 3.0, 0.0, 6.0) / 6.0,
"(tl.minimum(tl.maximum(n + 3.0, 0.0), 6.0) / 6.0)"), # F.hardsigmoid
"hardswish": (lambda t: t * torch.clamp(t + 3.0, 0.0, 6.0) / 6.0,
"(n * tl.minimum(tl.maximum(n + 3.0, 0.0), 6.0) / 6.0)"), # F.hardswish
# --- V2 round 3: 8 more real, numerically-safe activations. Same exactness rule: the
# torch lambda IS the triton expression (tanh via the 2*sigmoid(2x)-1 identity;
# softplus uses F.softplus's threshold=20 guard so exp never overflows). -----------
"leaky_relu": (lambda t: torch.where(t > 0, t, 0.01 * t),
"tl.where(n > 0.0, n, 0.01 * n)"),
"relu6": (lambda t: torch.clamp(t, 0.0, 6.0),
"tl.minimum(tl.maximum(n, 0.0), 6.0)"),
"hardtanh": (lambda t: torch.clamp(t, -1.0, 1.0),
"tl.minimum(tl.maximum(n, -1.0), 1.0)"),
"elu": (lambda t: torch.where(t > 0, t, torch.exp(torch.clamp(t, max=0.0)) - 1.0),
"tl.where(n > 0.0, n, tl.exp(tl.minimum(n, 0.0)) - 1.0)"),
"selu": (lambda t: 1.0507009873554805 * torch.where(
t > 0, t, 1.6732632423543772 * (torch.exp(torch.clamp(t, max=0.0)) - 1.0)),
"(1.0507009873554805 * tl.where(n > 0.0, n, "
"1.6732632423543772 * (tl.exp(tl.minimum(n, 0.0)) - 1.0)))"),
"softplus": (lambda t: torch.where(t > 20.0, t, torch.log(1.0 + torch.exp(torch.clamp(t, max=20.0)))),
"tl.where(n > 20.0, n, tl.log(1.0 + tl.exp(tl.minimum(n, 20.0))))"),
"mish": (lambda t: t * torch.tanh(torch.where(
t > 20.0, t, torch.log(1.0 + torch.exp(torch.clamp(t, max=20.0))))),
"(n * (2.0 * tl.sigmoid(2.0 * tl.where(n > 20.0, n, "
"tl.log(1.0 + tl.exp(tl.minimum(n, 20.0))))) - 1.0))"),
"gelu_erf": (lambda t: 0.5 * t * (1.0 + torch.erf(t * 0.7071067811865476)),
"(0.5 * n * (1.0 + tl.erf(n * 0.7071067811865476)))"), # EXACT gelu
}
NORMS = ["rms", "layer"]
RESID = [False, True]
ACTNAMES = ["gelu", "silu", "relu2", "tanh", "sigmoid", "relu", "square",
"abs", "softsign", "hardsigmoid", "hardswish",
"leaky_relu", "relu6", "hardtanh", "elu", "selu", "softplus", "mish", "gelu_erf"]
def chain_name(norm, residual, act):
return ("add_" if residual else "") + ("rmsnorm" if norm == "rms" else "layernorm") + "_" + act
def chain_kind(norm, residual):
return ("add_" if residual else "") + ("rms" if norm == "rms" else "ln") # -> input signature
def chain_reference(norm, residual, act, eps=None):
eps = eps if eps is not None else (1e-6 if norm == "rms" else 1e-5)
fn = ACTS[act][0]
def ref(*args):
if residual and norm == "rms":
x, r, w = args; h = x.float() + r.float(); b = None
elif residual:
x, r, w, b = args; h = x.float() + r.float()
elif norm == "rms":
x, w = args; h = x.float(); b = None
else:
x, w, b = args; h = x.float()
if norm == "rms":
n = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + eps) * w.float()
else:
mu = h.mean(-1, keepdim=True); hc = h - mu
n = hc * torch.rsqrt((hc * hc).mean(-1, keepdim=True) + eps) * w.float() + b.float()
return fn(n).to(args[0].dtype)
return ref
# ---- teacher-kernel templates -----------------------------------------------------------
def _kernel(norm, residual, act_expr, eps, variant):
"""variant: 'scalar' (loop+scalar accumulator) or 'whole' (single block per row)."""
ptrs = "x_ptr, " + ("r_ptr, " if residual else "") + "w_ptr, " + ("b_ptr, " if norm == "layer" else "") + "y_ptr"
sig = "x, " + ("residual, " if residual else "") + "w" + (", b" if norm == "layer" else "")
launch = "x, " + ("residual, " if residual else "") + "w" + (", b" if norm == "layer" else "") + ", y"
radv = " r_ptr += row * stride;" if residual else ""
hload = ("tl.load(x_ptr + cols, mask=MM, other=0.0).to(tl.float32)"
+ (" + tl.load(r_ptr + cols, mask=MM, other=0.0).to(tl.float32)" if residual else ""))
# bias load indent differs: scalar variant loads it INSIDE the apply for-loop (8 spaces),
# whole-row loads it flat (4 spaces). Wrong indent -> IndentationError.
bload8 = " b = tl.load(b_ptr + cols, mask=MM, other=0.0).to(tl.float32)\n" if norm == "layer" else ""
bload4 = " b = tl.load(b_ptr + cols, mask=MM, other=0.0).to(tl.float32)\n" if norm == "layer" else ""
if norm == "rms":
normed = "h * rr * w"
else:
normed = "(h - mu) * rr * w + b"
if variant == "scalar":
if norm == "rms":
reduce_block = f''' s = 0.0
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
h = {hload}
s += tl.sum(h * h)
rr = tl.rsqrt(s / N + eps)'''
else:
reduce_block = f''' s = 0.0
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
s += tl.sum({hload})
mu = s / N
v = 0.0
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
d = tl.where(MM, ({hload}) - mu, 0.0); v += tl.sum(d * d)
rr = tl.rsqrt(v / N + eps)'''
body = f'''@triton.jit
def _k({ptrs}, stride, N, eps, BLOCK: tl.constexpr):
row = tl.program_id(0); x_ptr += row * stride;{radv} y_ptr += row * stride
{reduce_block}
for off in range(0, N, BLOCK):
cols = off + tl.arange(0, BLOCK); MM = cols < N
h = {hload}
w = tl.load(w_ptr + cols, mask=MM, other=0.0).to(tl.float32)
{bload8} n = {normed}
tl.store(y_ptr + cols, {act_expr}, mask=MM)
def run({sig}):
M, N = x.shape; y = torch.empty_like(x)
_k[(M,)]({launch}, x.stride(0), N, {eps}, BLOCK=1024)
return y
'''
else: # whole-row single block
if norm == "rms":
stat = " rr = tl.rsqrt(tl.sum(h * h) / N + eps)"
else:
stat = (" mu = tl.sum(h) / N\n hc = tl.where(MM, h - mu, 0.0)\n"
" rr = tl.rsqrt(tl.sum(hc * hc) / N + eps)")
normed = "hc * rr * w + b"
body = f'''@triton.jit
def _k({ptrs}, stride, N, eps, BLOCK: tl.constexpr):
row = tl.program_id(0); x_ptr += row * stride;{radv} y_ptr += row * stride
cols = tl.arange(0, BLOCK); MM = cols < N
h = {hload}
{stat}
w = tl.load(w_ptr + cols, mask=MM, other=0.0).to(tl.float32)
{bload4} n = {normed}
tl.store(y_ptr + cols, {act_expr}, mask=MM)
def run({sig}):
M, N = x.shape; y = torch.empty_like(x)
_k[(M,)]({launch}, x.stride(0), N, {eps}, BLOCK=triton.next_power_of_2(N))
return y
'''
return body
def chain_structures(norm, residual, act):
eps = 1e-6 if norm == "rms" else 1e-5
expr = ACTS[act][1]
return [_kernel(norm, residual, expr, eps, "scalar"), _kernel(norm, residual, expr, eps, "whole")]
def all_chains():
"""[(name, kind, reference_fn, [kernel_src, ...]), ...] for the full grammar."""
out = []
for norm in NORMS:
for residual in RESID:
for act in ACTNAMES:
name = chain_name(norm, residual, act)
out.append((name, chain_kind(norm, residual), chain_reference(norm, residual, act),
chain_structures(norm, residual, act)))
return out