Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
5a9d09d
·
1 Parent(s): 906e125

fix: rename stale references and clean up Triton remnants

Browse files

- Fix broken imports in benchmarks (grouped_fused_mul_poly_norm → fused_mul_grouped_poly_norm)
- Rename GroupedTritonModule → GroupedCUDAModule in benchmarks
- Rename _run_triton → _run_cuda in tests, update docstrings
- Remove stale Triton comments in tests
- Fix return type annotations in __init__.py (None → torch.Tensor)
- Update README with new function name
- Remove TRITON_PRINT_AUTOTUNING env var from profile_bwd.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

README.md CHANGED
@@ -58,7 +58,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
58
  - Fused as:
59
 
60
  ```python
61
- out = grouped_fused_mul_poly_norm(x, mul, weight, bias, offsets, eps)
62
  ```
63
 
64
  ## Usage
 
58
  - Fused as:
59
 
60
  ```python
61
+ out = fused_mul_grouped_poly_norm(x, mul, weight, bias, offsets, eps)
62
  ```
63
 
64
  ## Usage
benchmarks/cases/grouped_mul_poly.py CHANGED
@@ -1,15 +1,19 @@
1
  import torch
2
  import torch._functorch.config
 
3
  from common.diff_engine import DiffCase
4
 
5
  torch._functorch.config.donated_buffer = False
6
 
7
  from grouped_poly_norm import (
8
- grouped_fused_mul_poly_norm,
9
- grouped_fused_mul_poly_norm_ref,
10
  )
11
 
12
- NUM_EXPERTS = 384
 
 
 
13
 
14
 
15
  class GroupedRefModule(torch.nn.Module):
@@ -24,13 +28,13 @@ class GroupedRefModule(torch.nn.Module):
24
  self.expert_offset = expert_offset
25
 
26
  def forward(self, x, mul):
27
- return grouped_fused_mul_poly_norm_ref(x, mul, self.weight, self.bias,
28
  self.offsets, self.eps,
29
  expert_offset=self.expert_offset)
30
 
31
 
32
- class GroupedTritonModule(torch.nn.Module):
33
- """Wraps the Triton kernel for grouped FusedMulPolyNorm."""
34
 
35
  def __init__(self, weight, bias, offsets, eps, expert_offset=0):
36
  super().__init__()
@@ -41,7 +45,7 @@ class GroupedTritonModule(torch.nn.Module):
41
  self.expert_offset = expert_offset
42
 
43
  def forward(self, x, mul):
44
- return grouped_fused_mul_poly_norm(x, mul, self.weight, self.bias,
45
  self.offsets, self.eps,
46
  expert_offset=self.expert_offset)
47
 
@@ -105,13 +109,22 @@ class GroupedMulPoly(DiffCase):
105
  return torch.compile(m)
106
 
107
  def make_cuda(self, I):
108
- return GroupedTritonModule(
109
  I["weight"].detach().clone(),
110
  I["bias"].detach().clone(),
111
  I["offsets"],
112
  I["eps"],
113
  )
114
 
 
 
 
 
 
 
 
 
 
115
  def forward(self, obj, I):
116
  return obj(I["x"], I["mul"])
117
 
 
1
  import torch
2
  import torch._functorch.config
3
+ import torch._inductor
4
  from common.diff_engine import DiffCase
5
 
6
  torch._functorch.config.donated_buffer = False
7
 
8
  from grouped_poly_norm import (
9
+ fused_mul_grouped_poly_norm,
10
+ fused_mul_grouped_poly_norm_ref,
11
  )
12
 
13
+ # 384 / 8 (EP) = 48 experts per rank
14
+ # total_tokens = bs * sl, which equals per-rank tokens
15
+ # since top_k=8 and EP=8, each rank sees all tokens once
16
+ NUM_EXPERTS = 48
17
 
18
 
19
  class GroupedRefModule(torch.nn.Module):
 
28
  self.expert_offset = expert_offset
29
 
30
  def forward(self, x, mul):
31
+ return fused_mul_grouped_poly_norm_ref(x, mul, self.weight, self.bias,
32
  self.offsets, self.eps,
33
  expert_offset=self.expert_offset)
34
 
35
 
36
+ class GroupedCUDAModule(torch.nn.Module):
37
+ """Wraps the CUDA kernel for grouped FusedMulPolyNorm."""
38
 
39
  def __init__(self, weight, bias, offsets, eps, expert_offset=0):
40
  super().__init__()
 
45
  self.expert_offset = expert_offset
46
 
47
  def forward(self, x, mul):
48
+ return fused_mul_grouped_poly_norm(x, mul, self.weight, self.bias,
49
  self.offsets, self.eps,
50
  expert_offset=self.expert_offset)
51
 
 
109
  return torch.compile(m)
110
 
111
  def make_cuda(self, I):
112
+ return GroupedCUDAModule(
113
  I["weight"].detach().clone(),
114
  I["bias"].detach().clone(),
115
  I["offsets"],
116
  I["eps"],
117
  )
118
 
119
+ def make_compiled_cuda(self, I):
120
+ m = GroupedCUDAModule(
121
+ I["weight"].detach().clone(),
122
+ I["bias"].detach().clone(),
123
+ I["offsets"],
124
+ I["eps"],
125
+ )
126
+ return torch.compile(m)
127
+
128
  def forward(self, obj, I):
129
  return obj(I["x"], I["mul"])
130
 
benchmarks/profile_bwd.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Profiling script for grouped polynorm backward kernel using torch.profiler."""
2
+ import argparse
3
+ import torch
4
+ import torch.cuda
5
+ from torch.profiler import profile, ProfilerActivity
6
+ from grouped_poly_norm import fused_mul_grouped_poly_norm
7
+
8
+ torch.set_default_device("cuda")
9
+
10
+
11
+ def make_inputs(N, D, num_experts):
12
+ torch.manual_seed(42)
13
+ probs = torch.ones(num_experts) / num_experts
14
+ assignments = torch.multinomial(probs, N, replacement=True)
15
+ counts = torch.bincount(assignments, minlength=num_experts).tolist()
16
+ offsets = torch.cumsum(
17
+ torch.tensor(counts, dtype=torch.int32), dim=0)
18
+
19
+ x = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
20
+ m = torch.randn(N, D, dtype=torch.bfloat16, requires_grad=True) * 0.5
21
+ w = (torch.ones(num_experts, 3, dtype=torch.bfloat16) / 3
22
+ ).requires_grad_(True)
23
+ b = (torch.randn(num_experts, 1, dtype=torch.bfloat16) * 0.01
24
+ ).requires_grad_(True)
25
+ return x, m, w, b, offsets
26
+
27
+
28
+ def main():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--tokens", type=int, default=4096)
31
+ parser.add_argument("--dim", type=int, default=1280)
32
+ parser.add_argument("--experts", type=int, default=48)
33
+ parser.add_argument("--output", type=str, default="/tmp/profile")
34
+ args = parser.parse_args()
35
+
36
+ N, D, num_experts = args.tokens, args.dim, args.experts
37
+
38
+ # Warmup (fresh inputs each time to avoid graph reuse issues)
39
+ for _ in range(3):
40
+ x, m, w, b, offsets = make_inputs(N, D, num_experts)
41
+ out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
42
+ out.sum().backward()
43
+ torch.cuda.synchronize()
44
+
45
+ # Profiled: mimic do_bench — forward once, backward multiple times with retain_graph
46
+ x, m, w, b, offsets = make_inputs(N, D, num_experts)
47
+ out = fused_mul_grouped_poly_norm(x, m, w, b, offsets)
48
+ gin = [x, m] + [w, b]
49
+ g = [torch.randn_like(out)]
50
+
51
+ # Warmup backward
52
+ for _ in range(5):
53
+ torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
54
+ torch.cuda.synchronize()
55
+
56
+ with profile(
57
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
58
+ record_shapes=True,
59
+ with_stack=True,
60
+ ) as prof:
61
+ for _ in range(100):
62
+ torch.autograd.grad(out, gin, g, retain_graph=True, allow_unused=True)
63
+ torch.cuda.synchronize()
64
+
65
+ # Print kernel-level stats
66
+ print(f"\n=== Kernel Table (N={N}, D={D}) ===")
67
+ print(prof.key_averages().table(
68
+ sort_by="cuda_time_total", row_limit=20))
69
+
70
+ # Export chrome trace
71
+ trace_path = f"{args.output}_trace_N{N}.json"
72
+ prof.export_chrome_trace(trace_path)
73
+ print(f"\nTrace exported to {trace_path}")
74
+
75
+ # === Occupancy analysis from Triton kernel metadata ===
76
+ print(f"\n=== Occupancy Analysis ===")
77
+
78
+ props = torch.cuda.get_device_properties(0)
79
+ print(f"GPU: {props.name}")
80
+ print(f"SMs: {props.multi_processor_count}")
81
+ print(f"Max threads/SM: {props.max_threads_per_multi_processor}")
82
+ print(f"Regs/SM: {props.regs_per_multiprocessor}")
83
+ print(f"Shared mem/block: {props.shared_memory_per_block} bytes")
84
+
85
+ # Get register info from Triton compiled cubins
86
+ try:
87
+ import glob
88
+ import json
89
+ import subprocess
90
+ cache_dir = os.path.expanduser("~/.triton/cache")
91
+
92
+ # Find metadata JSON files
93
+ json_files = sorted(glob.glob(f"{cache_dir}/**/*.json", recursive=True),
94
+ key=os.path.getmtime, reverse=True)
95
+ print(f"\nFound {len(json_files)} compiled kernel metadata files")
96
+ for jf in json_files[:10]:
97
+ try:
98
+ with open(jf) as f:
99
+ meta = json.load(f)
100
+ if isinstance(meta, dict):
101
+ n_regs = meta.get('num_regs', meta.get('n_regs', None))
102
+ n_spills = meta.get('num_spills', meta.get('n_spills', None))
103
+ name = meta.get('name', os.path.basename(jf))
104
+ shared = meta.get('shared', None)
105
+ if n_regs is not None:
106
+ print(f" {name}: regs={n_regs}, spills={n_spills}, shared={shared}")
107
+ except Exception:
108
+ pass
109
+
110
+ # Also try cuobjdump on recent cubins
111
+ cubin_files = sorted(glob.glob(f"{cache_dir}/**/*.cubin", recursive=True),
112
+ key=os.path.getmtime, reverse=True)
113
+ print(f"\nFound {len(cubin_files)} cubins, inspecting latest:")
114
+ for cb in cubin_files[:5]:
115
+ try:
116
+ result = subprocess.run(
117
+ ["cuobjdump", "-res-usage", cb],
118
+ capture_output=True, text=True, timeout=5)
119
+ if result.returncode == 0 and result.stdout.strip():
120
+ print(f"\n {os.path.basename(cb)}:")
121
+ for line in result.stdout.strip().split('\n'):
122
+ print(f" {line}")
123
+ except Exception as e:
124
+ print(f" cuobjdump failed: {e}")
125
+ break
126
+ except Exception as e:
127
+ print(f"Cache inspection error: {e}")
128
+
129
+ # Calculate theoretical occupancy for different register counts
130
+ print("\n=== Theoretical Occupancy (num_warps=4, 128 threads/block) ===")
131
+ threads_per_block = 128
132
+ max_threads = props.max_threads_per_multi_processor
133
+ total_regs = props.regs_per_multiprocessor
134
+ for n_regs in [64, 96, 128, 160, 192, 224, 256]:
135
+ regs_per_block = n_regs * threads_per_block
136
+ max_blocks_by_regs = total_regs // regs_per_block
137
+ max_blocks_by_threads = max_threads // threads_per_block
138
+ blocks = min(max_blocks_by_regs, max_blocks_by_threads, 32)
139
+ active_threads = blocks * threads_per_block
140
+ occupancy = active_threads / max_threads * 100
141
+ print(f" {n_regs:3d} regs/thread -> {blocks:2d} blocks/SM -> "
142
+ f"{active_threads:4d} threads -> {occupancy:.1f}% occupancy")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
tests/test_fused_mul_grouped_poly_norm.py CHANGED
@@ -17,8 +17,6 @@ D = [256, 1280]
17
  NUM_EXPERTS_LIST = [8, 384]
18
  EXPERT_OFFSETS = [0, 4]
19
  SEEDS = [0]
20
- # Triton kernels launch on the current CUDA device and do not
21
- # auto-dispatch to the tensor's device like CUDA extensions.
22
  # Only test on cuda:0 to avoid cross-device issues.
23
  CUDA_DEVICES = ["cuda:0"]
24
 
@@ -77,9 +75,9 @@ def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0,
77
  return grads + (s.grad,) if s is not None else grads + (None,)
78
 
79
 
80
- def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0,
81
  scores=None, hidden_clamp=None):
82
- """Run Triton/CUDA forward + backward, return output and grads."""
83
  inp = input_t.clone().detach().requires_grad_(True)
84
  m = mul_t.clone().detach().requires_grad_(True)
85
  w = weight.clone().detach().requires_grad_(True)
@@ -112,7 +110,7 @@ def test_fused_mul_grouped_poly_norm_forward(
112
  seed: int,
113
  device: str,
114
  ) -> None:
115
- """Triton forward output should match PyTorch reference."""
116
  torch.set_default_device(device)
117
  input_t, mul_t, weight, bias, offsets = _make_inputs(
118
  num_tokens, d, num_experts, dtype, device, seed,
@@ -151,7 +149,7 @@ def test_fused_mul_grouped_poly_norm_backward(
151
  seed: int,
152
  device: str,
153
  ) -> None:
154
- """Triton backward gradients should match PyTorch reference."""
155
  torch.set_default_device(device)
156
  input_t, mul_t, weight, bias, offsets = _make_inputs(
157
  num_tokens, d, num_experts, dtype, device, seed,
@@ -159,7 +157,7 @@ def test_fused_mul_grouped_poly_norm_backward(
159
 
160
  _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
161
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
162
- _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_triton(
163
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
164
 
165
  if dtype == torch.float32:
@@ -213,7 +211,7 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
213
  _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias,
214
  offsets,
215
  expert_offset=expert_offset)
216
- _, _, _, w_grad_tri, b_grad_tri, _ = _run_triton(input_t, mul_t, weight, bias,
217
  offsets,
218
  expert_offset=expert_offset)
219
 
@@ -250,7 +248,7 @@ def test_fused_mul_grouped_poly_norm_no_nan_inf(
250
  input_t, mul_t, weight, bias, offsets = _make_inputs(
251
  4096, 256, 8, dtype, device, expert_offset=expert_offset)
252
 
253
- out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_triton(
254
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
255
 
256
  assert not out.isnan().any(), "Output contains NaN"
@@ -306,7 +304,7 @@ def test_fused_mul_grouped_poly_norm_scores_backward(
306
 
307
  out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
308
  input_t, mul_t, weight, bias, offsets, scores=scores)
309
- out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_triton(
310
  input_t, mul_t, weight, bias, offsets, scores=scores)
311
 
312
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
@@ -372,7 +370,7 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
372
  out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
373
  input_t, mul_t, weight, bias, offsets,
374
  scores=scores, hidden_clamp=hidden_clamp)
375
- out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_triton(
376
  input_t, mul_t, weight, bias, offsets,
377
  scores=scores, hidden_clamp=hidden_clamp)
378
 
 
17
  NUM_EXPERTS_LIST = [8, 384]
18
  EXPERT_OFFSETS = [0, 4]
19
  SEEDS = [0]
 
 
20
  # Only test on cuda:0 to avoid cross-device issues.
21
  CUDA_DEVICES = ["cuda:0"]
22
 
 
75
  return grads + (s.grad,) if s is not None else grads + (None,)
76
 
77
 
78
+ def _run_cuda(input_t, mul_t, weight, bias, offsets, expert_offset=0,
79
  scores=None, hidden_clamp=None):
80
+ """Run CUDA forward + backward, return output and grads."""
81
  inp = input_t.clone().detach().requires_grad_(True)
82
  m = mul_t.clone().detach().requires_grad_(True)
83
  w = weight.clone().detach().requires_grad_(True)
 
110
  seed: int,
111
  device: str,
112
  ) -> None:
113
+ """CUDA forward output should match PyTorch reference."""
114
  torch.set_default_device(device)
115
  input_t, mul_t, weight, bias, offsets = _make_inputs(
116
  num_tokens, d, num_experts, dtype, device, seed,
 
149
  seed: int,
150
  device: str,
151
  ) -> None:
152
+ """CUDA backward gradients should match PyTorch reference."""
153
  torch.set_default_device(device)
154
  input_t, mul_t, weight, bias, offsets = _make_inputs(
155
  num_tokens, d, num_experts, dtype, device, seed,
 
157
 
158
  _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
159
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
160
+ _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_cuda(
161
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
162
 
163
  if dtype == torch.float32:
 
211
  _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias,
212
  offsets,
213
  expert_offset=expert_offset)
214
+ _, _, _, w_grad_tri, b_grad_tri, _ = _run_cuda(input_t, mul_t, weight, bias,
215
  offsets,
216
  expert_offset=expert_offset)
217
 
 
248
  input_t, mul_t, weight, bias, offsets = _make_inputs(
249
  4096, 256, 8, dtype, device, expert_offset=expert_offset)
250
 
251
+ out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_cuda(
252
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
253
 
254
  assert not out.isnan().any(), "Output contains NaN"
 
304
 
305
  out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
306
  input_t, mul_t, weight, bias, offsets, scores=scores)
307
+ out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
308
  input_t, mul_t, weight, bias, offsets, scores=scores)
309
 
310
  atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
 
370
  out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
371
  input_t, mul_t, weight, bias, offsets,
372
  scores=scores, hidden_clamp=hidden_clamp)
373
+ out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_cuda(
374
  input_t, mul_t, weight, bias, offsets,
375
  scores=scores, hidden_clamp=hidden_clamp)
376
 
torch-ext/activation/__init__.py CHANGED
@@ -12,7 +12,7 @@ def poly_norm(
12
  weight: torch.Tensor,
13
  bias: torch.Tensor,
14
  eps: float = 1e-6,
15
- ) -> None:
16
  return PolyNormFunction.apply(x, weight, bias, eps)
17
 
18
 
@@ -22,7 +22,7 @@ def fused_mul_poly_norm(
22
  weight: torch.Tensor,
23
  bias: torch.Tensor,
24
  eps: float = 1e-6,
25
- ) -> None:
26
  return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
27
 
28
 
@@ -30,7 +30,7 @@ def rms_norm(
30
  x: torch.Tensor,
31
  weight: torch.Tensor,
32
  eps: float = 1e-6,
33
- ) -> None:
34
  return RMSNormFunction.apply(x, weight, eps)
35
 
36
 
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  residual: torch.Tensor,
40
  weight: torch.Tensor,
41
  eps: float = 1e-6,
42
- ) -> None:
43
  return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
44
 
45
 
 
12
  weight: torch.Tensor,
13
  bias: torch.Tensor,
14
  eps: float = 1e-6,
15
+ ) -> torch.Tensor:
16
  return PolyNormFunction.apply(x, weight, bias, eps)
17
 
18
 
 
22
  weight: torch.Tensor,
23
  bias: torch.Tensor,
24
  eps: float = 1e-6,
25
+ ) -> torch.Tensor:
26
  return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
27
 
28
 
 
30
  x: torch.Tensor,
31
  weight: torch.Tensor,
32
  eps: float = 1e-6,
33
+ ) -> torch.Tensor:
34
  return RMSNormFunction.apply(x, weight, eps)
35
 
36
 
 
39
  residual: torch.Tensor,
40
  weight: torch.Tensor,
41
  eps: float = 1e-6,
42
+ ) -> torch.Tensor:
43
  return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
44
 
45