Instructions to use Motif-Technologies/activation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/activation with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/activation") - Notebooks
- Google Colab
- Kaggle
fix: rename stale references and clean up Triton remnants
Browse files- Fix broken imports in benchmarks (grouped_fused_mul_poly_norm → fused_mul_grouped_poly_norm)
- Rename GroupedTritonModule → GroupedCUDAModule in benchmarks
- Rename _run_triton → _run_cuda in tests, update docstrings
- Remove stale Triton comments in tests
- Fix return type annotations in __init__.py (None → torch.Tensor)
- Update README with new function name
- Remove TRITON_PRINT_AUTOTUNING env var from profile_bwd.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- README.md +1 -1
- benchmarks/cases/grouped_mul_poly.py +21 -8
- benchmarks/profile_bwd.py +146 -0
- tests/test_fused_mul_grouped_poly_norm.py +9 -11
- torch-ext/activation/__init__.py +4 -4
README.md
CHANGED
|
@@ -58,7 +58,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
|
|
| 58 |
- Fused as:
|
| 59 |
|
| 60 |
```python
|
| 61 |
-
out =
|
| 62 |
```
|
| 63 |
|
| 64 |
## Usage
|
|
|
|
| 58 |
- Fused as:
|
| 59 |
|
| 60 |
```python
|
| 61 |
+
out = fused_mul_grouped_poly_norm(x, mul, weight, bias, offsets, eps)
|
| 62 |
```
|
| 63 |
|
| 64 |
## Usage
|
benchmarks/cases/grouped_mul_poly.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
| 1 |
import torch
|
| 2 |
import torch._functorch.config
|
|
|
|
| 3 |
from common.diff_engine import DiffCase
|
| 4 |
|
| 5 |
torch._functorch.config.donated_buffer = False
|
| 6 |
|
| 7 |
from grouped_poly_norm import (
|
| 8 |
-
|
| 9 |
-
|
| 10 |
)
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class GroupedRefModule(torch.nn.Module):
|
|
@@ -24,13 +28,13 @@ class GroupedRefModule(torch.nn.Module):
|
|
| 24 |
self.expert_offset = expert_offset
|
| 25 |
|
| 26 |
def forward(self, x, mul):
|
| 27 |
-
return
|
| 28 |
self.offsets, self.eps,
|
| 29 |
expert_offset=self.expert_offset)
|
| 30 |
|
| 31 |
|
| 32 |
-
class
|
| 33 |
-
"""Wraps the
|
| 34 |
|
| 35 |
def __init__(self, weight, bias, offsets, eps, expert_offset=0):
|
| 36 |
super().__init__()
|
|
@@ -41,7 +45,7 @@ class GroupedTritonModule(torch.nn.Module):
|
|
| 41 |
self.expert_offset = expert_offset
|
| 42 |
|
| 43 |
def forward(self, x, mul):
|
| 44 |
-
return
|
| 45 |
self.offsets, self.eps,
|
| 46 |
expert_offset=self.expert_offset)
|
| 47 |
|
|
@@ -105,13 +109,22 @@ class GroupedMulPoly(DiffCase):
|
|
| 105 |
return torch.compile(m)
|
| 106 |
|
| 107 |
def make_cuda(self, I):
|
| 108 |
-
return
|
| 109 |
I["weight"].detach().clone(),
|
| 110 |
I["bias"].detach().clone(),
|
| 111 |
I["offsets"],
|
| 112 |
I["eps"],
|
| 113 |
)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def forward(self, obj, I):
|
| 116 |
return obj(I["x"], I["mul"])
|
| 117 |
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch._functorch.config
|
| 3 |
+
import torch._inductor
|
| 4 |
from common.diff_engine import DiffCase
|
| 5 |
|
| 6 |
torch._functorch.config.donated_buffer = False
|
| 7 |
|
| 8 |
from grouped_poly_norm import (
|
| 9 |
+
fused_mul_grouped_poly_norm,
|
| 10 |
+
fused_mul_grouped_poly_norm_ref,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
# 384 / 8 (EP) = 48 experts per rank
|
| 14 |
+
# total_tokens = bs * sl, which equals per-rank tokens
|
| 15 |
+
# since top_k=8 and EP=8, each rank sees all tokens once
|
| 16 |
+
NUM_EXPERTS = 48
|
| 17 |
|
| 18 |
|
| 19 |
class GroupedRefModule(torch.nn.Module):
|
|
|
|
| 28 |
self.expert_offset = expert_offset
|
| 29 |
|
| 30 |
def forward(self, x, mul):
|
| 31 |
+
return fused_mul_grouped_poly_norm_ref(x, mul, self.weight, self.bias,
|
| 32 |
self.offsets, self.eps,
|
| 33 |
expert_offset=self.expert_offset)
|
| 34 |
|
| 35 |
|
| 36 |
+
class GroupedCUDAModule(torch.nn.Module):
|
| 37 |
+
"""Wraps the CUDA kernel for grouped FusedMulPolyNorm."""
|
| 38 |
|
| 39 |
def __init__(self, weight, bias, offsets, eps, expert_offset=0):
|
| 40 |
super().__init__()
|
|
|
|
| 45 |
self.expert_offset = expert_offset
|
| 46 |
|
| 47 |
def forward(self, x, mul):
|
| 48 |
+
return fused_mul_grouped_poly_norm(x, mul, self.weight, self.bias,
|
| 49 |
self.offsets, self.eps,
|
| 50 |
expert_offset=self.expert_offset)
|
| 51 |
|
|
|
|
| 109 |
return torch.compile(m)
|
| 110 |
|
| 111 |
def make_cuda(self, I):
|
| 112 |
+
return GroupedCUDAModule(
|
| 113 |
I["weight"].detach().clone(),
|
| 114 |
I["bias"].detach().clone(),
|
| 115 |
I["offsets"],
|
| 116 |
I["eps"],
|
| 117 |
)
|
| 118 |
|
| 119 |
+
def make_compiled_cuda(self, I):
|
| 120 |
+
m = GroupedCUDAModule(
|
| 121 |
+
I["weight"].detach().clone(),
|
| 122 |
+
I["bias"].detach().clone(),
|
| 123 |
+
I["offsets"],
|
| 124 |
+
I["eps"],
|
| 125 |
+
)
|
| 126 |
+
return torch.compile(m)
|
| 127 |
+
|
| 128 |
def forward(self, obj, I):
|
| 129 |
return obj(I["x"], I["mul"])
|
| 130 |
|
benchmarks/profile_bwd.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Profiling script for grouped polynorm backward kernel using torch.profiler."""
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import torch.cuda
|
| 5 |
+
from torch.profiler import profile, ProfilerActivity
|
| 6 |
+
from grouped_poly_norm import fused_mul_grouped_poly_norm
|
| 7 |
+
|
| 8 |
+
torch.set_default_device("cuda")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_inputs(N, D, num_experts):
|
| 12 |
+
torch.manual_seed(42)
|
| 13 |
+
probs = torch.ones(num_experts) / num_experts
|
| 14 |
+
assignments = torch.multinomial(probs, N, replacement=True)
|
| 15 |
+
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 16 |
+
offsets = torch.cumsum(
|
| 17 |
+
torch.tensor(counts, dtype=torch.int32), dim=0)
|
| 18 |
+
|
| 19 |
+
x = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
|
| 20 |
+
m = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
|
| 21 |
+
w = (torch.ones(num_experts, 3, dtype=torch.bfloat16) / 3
|
| 22 |
+
).requires_grad_(True)
|
| 23 |
+
b = (torch.randn(num_experts, 1, dtype=torch.bfloat16) * 0.01
|
| 24 |
+
).requires_grad_(True)
|
| 25 |
+
return x, m, w, b, offsets
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument("--tokens", type=int, default=4096)
|
| 31 |
+
parser.add_argument("--dim", type=int, default=1280)
|
| 32 |
+
parser.add_argument("--experts", type=int, default=48)
|
| 33 |
+
parser.add_argument("--output", type=str, default="/tmp/profile")
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
N, D, num_experts = args.tokens, args.dim, args.experts
|
| 37 |
+
|
| 38 |
+
# Warmup (fresh inputs each time to avoid graph reuse issues)
|
| 39 |
+
for _ in range(3):
|
| 40 |
+
x, m, w, b, offsets = make_inputs(N, D, num_experts)
|
| 41 |
+
out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
|
| 42 |
+
out.sum().backward()
|
| 43 |
+
torch.cuda.synchronize()
|
| 44 |
+
|
| 45 |
+
# Profiled: mimic do_bench — forward once, backward multiple times with retain_graph
|
| 46 |
+
x, m, w, b, offsets = make_inputs(N, D, num_experts)
|
| 47 |
+
out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
|
| 48 |
+
gin = [x, m] + [w, b]
|
| 49 |
+
g = [torch.randn_like(out)]
|
| 50 |
+
|
| 51 |
+
# Warmup backward
|
| 52 |
+
for _ in range(5):
|
| 53 |
+
torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
|
| 54 |
+
torch.cuda.synchronize()
|
| 55 |
+
|
| 56 |
+
with profile(
|
| 57 |
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
| 58 |
+
record_shapes=True,
|
| 59 |
+
with_stack=True,
|
| 60 |
+
) as prof:
|
| 61 |
+
for _ in range(100):
|
| 62 |
+
torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
|
| 63 |
+
torch.cuda.synchronize()
|
| 64 |
+
|
| 65 |
+
# Print kernel-level stats
|
| 66 |
+
print(f"\n=== Kernel Table (N={N}, D={D}) ===")
|
| 67 |
+
print(prof.key_averages().table(
|
| 68 |
+
sort_by="cuda_time_total", row_limit=20))
|
| 69 |
+
|
| 70 |
+
# Export chrome trace
|
| 71 |
+
trace_path = f"{args.output}_trace_N{N}.json"
|
| 72 |
+
prof.export_chrome_trace(trace_path)
|
| 73 |
+
print(f"\nTrace exported to {trace_path}")
|
| 74 |
+
|
| 75 |
+
# === Occupancy analysis from Triton kernel metadata ===
|
| 76 |
+
print(f"\n=== Occupancy Analysis ===")
|
| 77 |
+
|
| 78 |
+
props = torch.cuda.get_device_properties(0)
|
| 79 |
+
print(f"GPU: {props.name}")
|
| 80 |
+
print(f"SMs: {props.multi_processor_count}")
|
| 81 |
+
print(f"Max threads/SM: {props.max_threads_per_multi_processor}")
|
| 82 |
+
print(f"Regs/SM: {props.regs_per_multiprocessor}")
|
| 83 |
+
print(f"Shared mem/block: {props.shared_memory_per_block} bytes")
|
| 84 |
+
|
| 85 |
+
# Get register info from Triton compiled cubins
|
| 86 |
+
try:
|
| 87 |
+
import glob
|
| 88 |
+
import json
|
| 89 |
+
import subprocess
|
| 90 |
+
cache_dir = os.path.expanduser("~/.triton/cache")
|
| 91 |
+
|
| 92 |
+
# Find metadata JSON files
|
| 93 |
+
json_files = sorted(glob.glob(f"{cache_dir}/**/*.json", recursive=True),
|
| 94 |
+
key=os.path.getmtime, reverse=True)
|
| 95 |
+
print(f"\nFound {len(json_files)} compiled kernel metadata files")
|
| 96 |
+
for jf in json_files[:10]:
|
| 97 |
+
try:
|
| 98 |
+
with open(jf) as f:
|
| 99 |
+
meta = json.load(f)
|
| 100 |
+
if isinstance(meta, dict):
|
| 101 |
+
n_regs = meta.get('num_regs', meta.get('n_regs', None))
|
| 102 |
+
n_spills = meta.get('num_spills', meta.get('n_spills', None))
|
| 103 |
+
name = meta.get('name', os.path.basename(jf))
|
| 104 |
+
shared = meta.get('shared', None)
|
| 105 |
+
if n_regs is not None:
|
| 106 |
+
print(f" {name}: regs={n_regs}, spills={n_spills}, shared={shared}")
|
| 107 |
+
except Exception:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
# Also try cuobjdump on recent cubins
|
| 111 |
+
cubin_files = sorted(glob.glob(f"{cache_dir}/**/*.cubin", recursive=True),
|
| 112 |
+
key=os.path.getmtime, reverse=True)
|
| 113 |
+
print(f"\nFound {len(cubin_files)} cubins, inspecting latest:")
|
| 114 |
+
for cb in cubin_files[:5]:
|
| 115 |
+
try:
|
| 116 |
+
result = subprocess.run(
|
| 117 |
+
["cuobjdump", "-res-usage", cb],
|
| 118 |
+
capture_output=True, text=True, timeout=5)
|
| 119 |
+
if result.returncode == 0 and result.stdout.strip():
|
| 120 |
+
print(f"\n {os.path.basename(cb)}:")
|
| 121 |
+
for line in result.stdout.strip().split('\n'):
|
| 122 |
+
print(f" {line}")
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f" cuobjdump failed: {e}")
|
| 125 |
+
break
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Cache inspection error: {e}")
|
| 128 |
+
|
| 129 |
+
# Calculate theoretical occupancy for different register counts
|
| 130 |
+
print("\n=== Theoretical Occupancy (num_warps=4, 128 threads/block) ===")
|
| 131 |
+
threads_per_block = 128
|
| 132 |
+
max_threads = props.max_threads_per_multi_processor
|
| 133 |
+
total_regs = props.regs_per_multiprocessor
|
| 134 |
+
for n_regs in [64, 96, 128, 160, 192, 224, 256]:
|
| 135 |
+
regs_per_block = n_regs * threads_per_block
|
| 136 |
+
max_blocks_by_regs = total_regs // regs_per_block
|
| 137 |
+
max_blocks_by_threads = max_threads // threads_per_block
|
| 138 |
+
blocks = min(max_blocks_by_regs, max_blocks_by_threads, 32)
|
| 139 |
+
active_threads = blocks * threads_per_block
|
| 140 |
+
occupancy = active_threads / max_threads * 100
|
| 141 |
+
print(f" {n_regs:3d} regs/thread -> {blocks:2d} blocks/SM -> "
|
| 142 |
+
f"{active_threads:4d} threads -> {occupancy:.1f}% occupancy")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
tests/test_fused_mul_grouped_poly_norm.py
CHANGED
|
@@ -17,8 +17,6 @@ D = [256, 1280]
|
|
| 17 |
NUM_EXPERTS_LIST = [8, 384]
|
| 18 |
EXPERT_OFFSETS = [0, 4]
|
| 19 |
SEEDS = [0]
|
| 20 |
-
# Triton kernels launch on the current CUDA device and do not
|
| 21 |
-
# auto-dispatch to the tensor's device like CUDA extensions.
|
| 22 |
# Only test on cuda:0 to avoid cross-device issues.
|
| 23 |
CUDA_DEVICES = ["cuda:0"]
|
| 24 |
|
|
@@ -77,9 +75,9 @@ def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0,
|
|
| 77 |
return grads + (s.grad,) if s is not None else grads + (None,)
|
| 78 |
|
| 79 |
|
| 80 |
-
def
|
| 81 |
scores=None, hidden_clamp=None):
|
| 82 |
-
"""Run
|
| 83 |
inp = input_t.clone().detach().requires_grad_(True)
|
| 84 |
m = mul_t.clone().detach().requires_grad_(True)
|
| 85 |
w = weight.clone().detach().requires_grad_(True)
|
|
@@ -112,7 +110,7 @@ def test_fused_mul_grouped_poly_norm_forward(
|
|
| 112 |
seed: int,
|
| 113 |
device: str,
|
| 114 |
) -> None:
|
| 115 |
-
"""
|
| 116 |
torch.set_default_device(device)
|
| 117 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 118 |
num_tokens, d, num_experts, dtype, device, seed,
|
|
@@ -151,7 +149,7 @@ def test_fused_mul_grouped_poly_norm_backward(
|
|
| 151 |
seed: int,
|
| 152 |
device: str,
|
| 153 |
) -> None:
|
| 154 |
-
"""
|
| 155 |
torch.set_default_device(device)
|
| 156 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 157 |
num_tokens, d, num_experts, dtype, device, seed,
|
|
@@ -159,7 +157,7 @@ def test_fused_mul_grouped_poly_norm_backward(
|
|
| 159 |
|
| 160 |
_, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
|
| 161 |
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
|
| 162 |
-
_, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ =
|
| 163 |
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
|
| 164 |
|
| 165 |
if dtype == torch.float32:
|
|
@@ -213,7 +211,7 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
|
|
| 213 |
_, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias,
|
| 214 |
offsets,
|
| 215 |
expert_offset=expert_offset)
|
| 216 |
-
_, _, _, w_grad_tri, b_grad_tri, _ =
|
| 217 |
offsets,
|
| 218 |
expert_offset=expert_offset)
|
| 219 |
|
|
@@ -250,7 +248,7 @@ def test_fused_mul_grouped_poly_norm_no_nan_inf(
|
|
| 250 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 251 |
4096, 256, 8, dtype, device, expert_offset=expert_offset)
|
| 252 |
|
| 253 |
-
out, inp_grad, mul_grad, w_grad, b_grad, _ =
|
| 254 |
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
|
| 255 |
|
| 256 |
assert not out.isnan().any(), "Output contains NaN"
|
|
@@ -306,7 +304,7 @@ def test_fused_mul_grouped_poly_norm_scores_backward(
|
|
| 306 |
|
| 307 |
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 308 |
input_t, mul_t, weight, bias, offsets, scores=scores)
|
| 309 |
-
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri =
|
| 310 |
input_t, mul_t, weight, bias, offsets, scores=scores)
|
| 311 |
|
| 312 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
|
@@ -372,7 +370,7 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
|
|
| 372 |
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 373 |
input_t, mul_t, weight, bias, offsets,
|
| 374 |
scores=scores, hidden_clamp=hidden_clamp)
|
| 375 |
-
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri =
|
| 376 |
input_t, mul_t, weight, bias, offsets,
|
| 377 |
scores=scores, hidden_clamp=hidden_clamp)
|
| 378 |
|
|
|
|
| 17 |
NUM_EXPERTS_LIST = [8, 384]
|
| 18 |
EXPERT_OFFSETS = [0, 4]
|
| 19 |
SEEDS = [0]
|
|
|
|
|
|
|
| 20 |
# Only test on cuda:0 to avoid cross-device issues.
|
| 21 |
CUDA_DEVICES = ["cuda:0"]
|
| 22 |
|
|
|
|
| 75 |
return grads + (s.grad,) if s is not None else grads + (None,)
|
| 76 |
|
| 77 |
|
| 78 |
+
def _run_cuda(input_t, mul_t, weight, bias, offsets, expert_offset=0,
|
| 79 |
scores=None, hidden_clamp=None):
|
| 80 |
+
"""Run CUDA forward + backward, return output and grads."""
|
| 81 |
inp = input_t.clone().detach().requires_grad_(True)
|
| 82 |
m = mul_t.clone().detach().requires_grad_(True)
|
| 83 |
w = weight.clone().detach().requires_grad_(True)
|
|
|
|
| 110 |
seed: int,
|
| 111 |
device: str,
|
| 112 |
) -> None:
|
| 113 |
+
"""CUDA forward output should match PyTorch reference."""
|
| 114 |
torch.set_default_device(device)
|
| 115 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 116 |
num_tokens, d, num_experts, dtype, device, seed,
|
|
|
|
| 149 |
seed: int,
|
| 150 |
device: str,
|
| 151 |
) -> None:
|
| 152 |
+
"""CUDA backward gradients should match PyTorch reference."""
|
| 153 |
torch.set_default_device(device)
|
| 154 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 155 |
num_tokens, d, num_experts, dtype, device, seed,
|
|
|
|
| 157 |
|
| 158 |
_, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
|
| 159 |
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
|
| 160 |
+
_, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_cuda(
|
| 161 |
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
|
| 162 |
|
| 163 |
if dtype == torch.float32:
|
|
|
|
| 211 |
_, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias,
|
| 212 |
offsets,
|
| 213 |
expert_offset=expert_offset)
|
| 214 |
+
_, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t, mul_t, weight, bias,
|
| 215 |
offsets,
|
| 216 |
expert_offset=expert_offset)
|
| 217 |
|
|
|
|
| 248 |
input_t, mul_t, weight, bias, offsets = _make_inputs(
|
| 249 |
4096, 256, 8, dtype, device, expert_offset=expert_offset)
|
| 250 |
|
| 251 |
+
out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_cuda(
|
| 252 |
input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
|
| 253 |
|
| 254 |
assert not out.isnan().any(), "Output contains NaN"
|
|
|
|
| 304 |
|
| 305 |
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 306 |
input_t, mul_t, weight, bias, offsets, scores=scores)
|
| 307 |
+
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
|
| 308 |
input_t, mul_t, weight, bias, offsets, scores=scores)
|
| 309 |
|
| 310 |
atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
|
|
|
|
| 370 |
out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
|
| 371 |
input_t, mul_t, weight, bias, offsets,
|
| 372 |
scores=scores, hidden_clamp=hidden_clamp)
|
| 373 |
+
out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
|
| 374 |
input_t, mul_t, weight, bias, offsets,
|
| 375 |
scores=scores, hidden_clamp=hidden_clamp)
|
| 376 |
|
torch-ext/activation/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@ def poly_norm(
|
|
| 12 |
weight: torch.Tensor,
|
| 13 |
bias: torch.Tensor,
|
| 14 |
eps: float = 1e-6,
|
| 15 |
-
) ->
|
| 16 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 17 |
|
| 18 |
|
|
@@ -22,7 +22,7 @@ def fused_mul_poly_norm(
|
|
| 22 |
weight: torch.Tensor,
|
| 23 |
bias: torch.Tensor,
|
| 24 |
eps: float = 1e-6,
|
| 25 |
-
) ->
|
| 26 |
return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
|
| 27 |
|
| 28 |
|
|
@@ -30,7 +30,7 @@ def rms_norm(
|
|
| 30 |
x: torch.Tensor,
|
| 31 |
weight: torch.Tensor,
|
| 32 |
eps: float = 1e-6,
|
| 33 |
-
) ->
|
| 34 |
return RMSNormFunction.apply(x, weight, eps)
|
| 35 |
|
| 36 |
|
|
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
|
|
| 39 |
residual: torch.Tensor,
|
| 40 |
weight: torch.Tensor,
|
| 41 |
eps: float = 1e-6,
|
| 42 |
-
) ->
|
| 43 |
return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
|
| 44 |
|
| 45 |
|
|
|
|
| 12 |
weight: torch.Tensor,
|
| 13 |
bias: torch.Tensor,
|
| 14 |
eps: float = 1e-6,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 17 |
|
| 18 |
|
|
|
|
| 22 |
weight: torch.Tensor,
|
| 23 |
bias: torch.Tensor,
|
| 24 |
eps: float = 1e-6,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
|
| 27 |
|
| 28 |
|
|
|
|
| 30 |
x: torch.Tensor,
|
| 31 |
weight: torch.Tensor,
|
| 32 |
eps: float = 1e-6,
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
return RMSNormFunction.apply(x, weight, eps)
|
| 35 |
|
| 36 |
|
|
|
|
| 39 |
residual: torch.Tensor,
|
| 40 |
weight: torch.Tensor,
|
| 41 |
eps: float = 1e-6,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
|
| 44 |
|
| 45 |
|