Spaces:
Running on Zero
Running on Zero
| """Generative fusion-chain grammar: [+residual] -> {rms|layer}norm -> ×w(+b) -> epilogue. | |
| The widest sweep of the reduction->epilogue region (where the compiler under-fuses and the 2B | |
| wins). References are COMPOSED; teacher kernels are TEMPLATE-GENERATED (scalar-reduce + | |
| whole-row variants). Everything is harness-filtered downstream — nothing here is trusted. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| _C = 0.7978845608028654 # sqrt(2/pi) | |
| # epilogue: (torch fn on fp32 tensor, triton expression in fp32 var `n`) | |
| ACTS = { | |
| "gelu": (lambda t: 0.5 * t * (1.0 + torch.tanh(_C * (t + 0.044715 * t * t * t))), | |
| "(0.5 * n * (1.0 + (2.0 * tl.sigmoid(2.0 * (0.7978845608028654 * (n + 0.044715 * n * n * n))) - 1.0)))"), | |
| "silu": (lambda t: t * torch.sigmoid(t), "(n * tl.sigmoid(n))"), | |
| "relu2": (lambda t: torch.relu(t) * torch.relu(t), "(tl.maximum(n, 0.0) * tl.maximum(n, 0.0))"), | |
| # --- expanded grammar (each torch fn EXACTLY matches its triton expr; no approximation) --- | |
| "tanh": (lambda t: torch.tanh(t), "(2.0 * tl.sigmoid(2.0 * n) - 1.0)"), # identity tanh(x)=2σ(2x)-1 | |
| "sigmoid": (lambda t: torch.sigmoid(t), "tl.sigmoid(n)"), | |
| "relu": (lambda t: torch.relu(t), "tl.maximum(n, 0.0)"), | |
| "square": (lambda t: t * t, "(n * n)"), | |
| # --- 2c round 2: more real activations (each torch fn EXACTLY matches its triton expr) --- | |
| "abs": (lambda t: torch.abs(t), "tl.abs(n)"), | |
| "softsign": (lambda t: t / (1.0 + torch.abs(t)), "(n / (1.0 + tl.abs(n)))"), | |
| "hardsigmoid": (lambda t: torch.clamp(t + 3.0, 0.0, 6.0) / 6.0, | |
| "(tl.minimum(tl.maximum(n + 3.0, 0.0), 6.0) / 6.0)"), # F.hardsigmoid | |
| "hardswish": (lambda t: t * torch.clamp(t + 3.0, 0.0, 6.0) / 6.0, | |
| "(n * tl.minimum(tl.maximum(n + 3.0, 0.0), 6.0) / 6.0)"), # F.hardswish | |
| # --- V2 round 3: 8 more real, numerically-safe activations. Same exactness rule: the | |
| # torch lambda IS the triton expression (tanh via the 2*sigmoid(2x)-1 identity; | |
| # softplus uses F.softplus's threshold=20 guard so exp never overflows). ----------- | |
| "leaky_relu": (lambda t: torch.where(t > 0, t, 0.01 * t), | |
| "tl.where(n > 0.0, n, 0.01 * n)"), | |
| "relu6": (lambda t: torch.clamp(t, 0.0, 6.0), | |
| "tl.minimum(tl.maximum(n, 0.0), 6.0)"), | |
| "hardtanh": (lambda t: torch.clamp(t, -1.0, 1.0), | |
| "tl.minimum(tl.maximum(n, -1.0), 1.0)"), | |
| "elu": (lambda t: torch.where(t > 0, t, torch.exp(torch.clamp(t, max=0.0)) - 1.0), | |
| "tl.where(n > 0.0, n, tl.exp(tl.minimum(n, 0.0)) - 1.0)"), | |
| "selu": (lambda t: 1.0507009873554805 * torch.where( | |
| t > 0, t, 1.6732632423543772 * (torch.exp(torch.clamp(t, max=0.0)) - 1.0)), | |
| "(1.0507009873554805 * tl.where(n > 0.0, n, " | |
| "1.6732632423543772 * (tl.exp(tl.minimum(n, 0.0)) - 1.0)))"), | |
| "softplus": (lambda t: torch.where(t > 20.0, t, torch.log(1.0 + torch.exp(torch.clamp(t, max=20.0)))), | |
| "tl.where(n > 20.0, n, tl.log(1.0 + tl.exp(tl.minimum(n, 20.0))))"), | |
| "mish": (lambda t: t * torch.tanh(torch.where( | |
| t > 20.0, t, torch.log(1.0 + torch.exp(torch.clamp(t, max=20.0))))), | |
| "(n * (2.0 * tl.sigmoid(2.0 * tl.where(n > 20.0, n, " | |
| "tl.log(1.0 + tl.exp(tl.minimum(n, 20.0))))) - 1.0))"), | |
| "gelu_erf": (lambda t: 0.5 * t * (1.0 + torch.erf(t * 0.7071067811865476)), | |
| "(0.5 * n * (1.0 + tl.erf(n * 0.7071067811865476)))"), # EXACT gelu | |
| } | |
| NORMS = ["rms", "layer"] | |
| RESID = [False, True] | |
| ACTNAMES = ["gelu", "silu", "relu2", "tanh", "sigmoid", "relu", "square", | |
| "abs", "softsign", "hardsigmoid", "hardswish", | |
| "leaky_relu", "relu6", "hardtanh", "elu", "selu", "softplus", "mish", "gelu_erf"] | |
| def chain_name(norm, residual, act): | |
| return ("add_" if residual else "") + ("rmsnorm" if norm == "rms" else "layernorm") + "_" + act | |
| def chain_kind(norm, residual): | |
| return ("add_" if residual else "") + ("rms" if norm == "rms" else "ln") # -> input signature | |
| def chain_reference(norm, residual, act, eps=None): | |
| eps = eps if eps is not None else (1e-6 if norm == "rms" else 1e-5) | |
| fn = ACTS[act][0] | |
| def ref(*args): | |
| if residual and norm == "rms": | |
| x, r, w = args; h = x.float() + r.float(); b = None | |
| elif residual: | |
| x, r, w, b = args; h = x.float() + r.float() | |
| elif norm == "rms": | |
| x, w = args; h = x.float(); b = None | |
| else: | |
| x, w, b = args; h = x.float() | |
| if norm == "rms": | |
| n = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + eps) * w.float() | |
| else: | |
| mu = h.mean(-1, keepdim=True); hc = h - mu | |
| n = hc * torch.rsqrt((hc * hc).mean(-1, keepdim=True) + eps) * w.float() + b.float() | |
| return fn(n).to(args[0].dtype) | |
| return ref | |
| # ---- teacher-kernel templates ----------------------------------------------------------- | |
| def _kernel(norm, residual, act_expr, eps, variant): | |
| """variant: 'scalar' (loop+scalar accumulator) or 'whole' (single block per row).""" | |
| ptrs = "x_ptr, " + ("r_ptr, " if residual else "") + "w_ptr, " + ("b_ptr, " if norm == "layer" else "") + "y_ptr" | |
| sig = "x, " + ("residual, " if residual else "") + "w" + (", b" if norm == "layer" else "") | |
| launch = "x, " + ("residual, " if residual else "") + "w" + (", b" if norm == "layer" else "") + ", y" | |
| radv = " r_ptr += row * stride;" if residual else "" | |
| hload = ("tl.load(x_ptr + cols, mask=MM, other=0.0).to(tl.float32)" | |
| + (" + tl.load(r_ptr + cols, mask=MM, other=0.0).to(tl.float32)" if residual else "")) | |
| # bias load indent differs: scalar variant loads it INSIDE the apply for-loop (8 spaces), | |
| # whole-row loads it flat (4 spaces). Wrong indent -> IndentationError. | |
| bload8 = " b = tl.load(b_ptr + cols, mask=MM, other=0.0).to(tl.float32)\n" if norm == "layer" else "" | |
| bload4 = " b = tl.load(b_ptr + cols, mask=MM, other=0.0).to(tl.float32)\n" if norm == "layer" else "" | |
| if norm == "rms": | |
| normed = "h * rr * w" | |
| else: | |
| normed = "(h - mu) * rr * w + b" | |
| if variant == "scalar": | |
| if norm == "rms": | |
| reduce_block = f''' s = 0.0 | |
| for off in range(0, N, BLOCK): | |
| cols = off + tl.arange(0, BLOCK); MM = cols < N | |
| h = {hload} | |
| s += tl.sum(h * h) | |
| rr = tl.rsqrt(s / N + eps)''' | |
| else: | |
| reduce_block = f''' s = 0.0 | |
| for off in range(0, N, BLOCK): | |
| cols = off + tl.arange(0, BLOCK); MM = cols < N | |
| s += tl.sum({hload}) | |
| mu = s / N | |
| v = 0.0 | |
| for off in range(0, N, BLOCK): | |
| cols = off + tl.arange(0, BLOCK); MM = cols < N | |
| d = tl.where(MM, ({hload}) - mu, 0.0); v += tl.sum(d * d) | |
| rr = tl.rsqrt(v / N + eps)''' | |
| body = f'''@triton.jit | |
| def _k({ptrs}, stride, N, eps, BLOCK: tl.constexpr): | |
| row = tl.program_id(0); x_ptr += row * stride;{radv} y_ptr += row * stride | |
| {reduce_block} | |
| for off in range(0, N, BLOCK): | |
| cols = off + tl.arange(0, BLOCK); MM = cols < N | |
| h = {hload} | |
| w = tl.load(w_ptr + cols, mask=MM, other=0.0).to(tl.float32) | |
| {bload8} n = {normed} | |
| tl.store(y_ptr + cols, {act_expr}, mask=MM) | |
| def run({sig}): | |
| M, N = x.shape; y = torch.empty_like(x) | |
| _k[(M,)]({launch}, x.stride(0), N, {eps}, BLOCK=1024) | |
| return y | |
| ''' | |
| else: # whole-row single block | |
| if norm == "rms": | |
| stat = " rr = tl.rsqrt(tl.sum(h * h) / N + eps)" | |
| else: | |
| stat = (" mu = tl.sum(h) / N\n hc = tl.where(MM, h - mu, 0.0)\n" | |
| " rr = tl.rsqrt(tl.sum(hc * hc) / N + eps)") | |
| normed = "hc * rr * w + b" | |
| body = f'''@triton.jit | |
| def _k({ptrs}, stride, N, eps, BLOCK: tl.constexpr): | |
| row = tl.program_id(0); x_ptr += row * stride;{radv} y_ptr += row * stride | |
| cols = tl.arange(0, BLOCK); MM = cols < N | |
| h = {hload} | |
| {stat} | |
| w = tl.load(w_ptr + cols, mask=MM, other=0.0).to(tl.float32) | |
| {bload4} n = {normed} | |
| tl.store(y_ptr + cols, {act_expr}, mask=MM) | |
| def run({sig}): | |
| M, N = x.shape; y = torch.empty_like(x) | |
| _k[(M,)]({launch}, x.stride(0), N, {eps}, BLOCK=triton.next_power_of_2(N)) | |
| return y | |
| ''' | |
| return body | |
| def chain_structures(norm, residual, act): | |
| eps = 1e-6 if norm == "rms" else 1e-5 | |
| expr = ACTS[act][1] | |
| return [_kernel(norm, residual, expr, eps, "scalar"), _kernel(norm, residual, expr, eps, "whole")] | |
| def all_chains(): | |
| """[(name, kind, reference_fn, [kernel_src, ...]), ...] for the full grammar.""" | |
| out = [] | |
| for norm in NORMS: | |
| for residual in RESID: | |
| for act in ACTNAMES: | |
| name = chain_name(norm, residual, act) | |
| out.append((name, chain_kind(norm, residual), chain_reference(norm, residual, act), | |
| chain_structures(norm, residual, act))) | |
| return out | |