danieldk HF Staff commited on
Commit
b1e6bf1
·
verified ·
1 Parent(s): e6a53d5

Build uploaded using `kernels`.

Browse files
build/torch-rocm/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._finegrained_fp8_2836236
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_finegrained_fp8_2836236::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._finegrained_fp8_e62632c
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_finegrained_fp8_e62632c::{op_name}"
build/torch-rocm/batched.py CHANGED
@@ -19,6 +19,14 @@ from .act_quant import fp8_act_quant
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
 
 
 
 
 
 
 
 
22
  @triton.jit
23
  def w8a8_block_fp8_matmul_batched_kernel(
24
  A, # (S, K) raw BF16/FP16 activations
@@ -103,6 +111,14 @@ def w8a8_block_fp8_matmul_batched_kernel(
103
  tl.store(c_ptrs, c)
104
 
105
 
 
 
 
 
 
 
 
 
106
  @triton.jit
107
  def w8a8_tensor_fp8_matmul_batched_kernel(
108
  A, # (S, K) pre-quantized FP8 activations
 
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=w, num_stages=s)
25
+ for w in [2, 4, 8, 16]
26
+ for s in [2, 3, 4, 5]
27
+ ],
28
+ key=["N", "K"],
29
+ )
30
  @triton.jit
31
  def w8a8_block_fp8_matmul_batched_kernel(
32
  A, # (S, K) raw BF16/FP16 activations
 
111
  tl.store(c_ptrs, c)
112
 
113
 
114
+ @triton.autotune(
115
+ configs=[
116
+ triton.Config({}, num_warps=w, num_stages=s)
117
+ for w in [2, 4, 8, 16]
118
+ for s in [2, 3, 4, 5]
119
+ ],
120
+ key=["N", "K"],
121
+ )
122
  @triton.jit
123
  def w8a8_tensor_fp8_matmul_batched_kernel(
124
  A, # (S, K) pre-quantized FP8 activations
build/torch-rocm/grouped.py CHANGED
@@ -19,6 +19,14 @@ from .act_quant import fp8_act_quant
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
 
 
 
 
 
 
 
 
22
  @triton.jit
23
  def w8a8_block_fp8_matmul_grouped_kernel(
24
  A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
@@ -118,9 +126,7 @@ def w8a8_block_fp8_matmul_grouped_kernel(
118
  a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
119
  # ---- same as baseline from here ----
120
  b = tl.load(b_ptrs)
121
- k_start = k * block_k
122
- offs_ks = k_start // block_k
123
- b_s = tl.load(Bs_ptrs + offs_ks * stride_Bsk)
124
  accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
125
  a_ptrs += block_k * stride_ak
126
  b_ptrs += block_k * stride_bk
@@ -137,6 +143,14 @@ def w8a8_block_fp8_matmul_grouped_kernel(
137
  tl.store(c_ptrs, c, mask=c_mask)
138
 
139
 
 
 
 
 
 
 
 
 
140
  @triton.jit
141
  def w8a8_tensor_fp8_matmul_grouped_kernel(
142
  A, # (S, K) pre-quantized FP8 activations
@@ -162,6 +176,7 @@ def w8a8_tensor_fp8_matmul_grouped_kernel(
162
  block_k: tl.constexpr,
163
  BLOCK_SIZE_M: tl.constexpr,
164
  NUM_EXPERTS: tl.constexpr,
 
165
  ):
166
  """Tensor-scale grouped FP8 expert matmul kernel.
167
 
@@ -177,7 +192,7 @@ def w8a8_tensor_fp8_matmul_grouped_kernel(
177
 
178
  lo = 0
179
  hi = NUM_EXPERTS
180
- for _ in tl.static_range(NUM_EXPERTS.bit_length()):
181
  mid = (lo + hi) >> 1
182
  mid_val = tl.load(TileOffsets + mid)
183
  is_left = mid_val <= pid_m
@@ -241,6 +256,7 @@ def _w8a8_block_fp8_matmul_grouped(
241
  offsets: torch.Tensor,
242
  tokens_per_expert: torch.Tensor,
243
  block_size: list[int] | None,
 
244
  ) -> torch.Tensor:
245
  """Internal block-scale grouped FP8 matmul op.
246
 
@@ -299,11 +315,15 @@ def _w8a8_block_fp8_matmul_grouped(
299
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
300
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
301
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
302
- # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
303
- # Using a static upper bound keeps the grid size data-independent, which is
304
- # required for cuda-graph compatibility. Programs beyond the real tile count
305
- # exit immediately via the early-return guard inside the kernel.
306
- max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
 
 
 
 
307
 
308
  grid = (max_M_tiles, triton.cdiv(N, block_n))
309
  wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
@@ -345,6 +365,7 @@ def _w8a8_tensor_fp8_matmul_grouped(
345
  offsets: torch.Tensor,
346
  tokens_per_expert: torch.Tensor,
347
  block_size: list[int] | None,
 
348
  ) -> torch.Tensor:
349
  """Tensor-scale grouped FP8 matmul for sorted routed experts.
350
 
@@ -387,7 +408,15 @@ def _w8a8_tensor_fp8_matmul_grouped(
387
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
388
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
389
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
390
- max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
 
 
 
 
 
 
 
 
391
 
392
  qA, As = fp8_act_quant(A, K)
393
  grid = (max_M_tiles, triton.cdiv(N, block_n))
@@ -415,6 +444,7 @@ def _w8a8_tensor_fp8_matmul_grouped(
415
  block_k=block_k,
416
  BLOCK_SIZE_M=BLOCK_SIZE_M,
417
  NUM_EXPERTS=E,
 
418
  )
419
 
420
  return C
@@ -427,6 +457,7 @@ def w8a8_block_fp8_matmul_grouped(
427
  offsets: torch.Tensor,
428
  tokens_per_expert: torch.Tensor,
429
  block_size: list[int] | None,
 
430
  ) -> torch.Tensor:
431
  """Grouped W8A8 FP8 matmul for MoE expert dispatch with fused activation quantization.
432
 
@@ -443,12 +474,15 @@ def w8a8_block_fp8_matmul_grouped(
443
  offsets: Cumulative token counts per expert ``[E]`` (i.e. ``cumsum(tokens_per_expert)``).
444
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
445
  block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
 
 
 
446
 
447
  Returns:
448
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
449
  """
450
  return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
451
- A, B, Bs, offsets, tokens_per_expert, block_size
452
  )
453
 
454
 
@@ -459,6 +493,7 @@ def w8a8_tensor_fp8_matmul_grouped(
459
  offsets: torch.Tensor,
460
  tokens_per_expert: torch.Tensor,
461
  block_size: list[int] | None,
 
462
  ) -> torch.Tensor:
463
  """Tensor-scale grouped W8A8 FP8 matmul for MoE expert dispatch.
464
 
@@ -469,12 +504,13 @@ def w8a8_tensor_fp8_matmul_grouped(
469
  offsets: Cumulative token counts per expert ``[E]``.
470
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
471
  block_size: Kept for API consistency; tensor path derives tile sizes from ``N`` and ``K``.
 
472
 
473
  Returns:
474
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
475
  """
476
  return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
477
- A, B, Bs, offsets, tokens_per_expert, block_size
478
  )
479
 
480
 
@@ -485,6 +521,7 @@ def w8a8_fp8_matmul_grouped(
485
  offsets: torch.Tensor,
486
  tokens_per_expert: torch.Tensor,
487
  block_size: list[int] | None,
 
488
  ) -> torch.Tensor:
489
  """Unified grouped W8A8 FP8 matmul dispatcher.
490
 
@@ -500,9 +537,9 @@ def w8a8_fp8_matmul_grouped(
500
  block_size[0] == B.size(1) and block_size[1] == B.size(2)
501
  ):
502
  return w8a8_tensor_fp8_matmul_grouped(
503
- A, B, Bs, offsets, tokens_per_expert, block_size
504
  )
505
 
506
  return w8a8_block_fp8_matmul_grouped(
507
- A, B, Bs, offsets, tokens_per_expert, block_size
508
  )
 
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=w, num_stages=s)
25
+ for w in [2, 4, 8, 16]
26
+ for s in [2, 3, 4, 5]
27
+ ],
28
+ key=["N", "K"],
29
+ )
30
  @triton.jit
31
  def w8a8_block_fp8_matmul_grouped_kernel(
32
  A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
 
126
  a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
127
  # ---- same as baseline from here ----
128
  b = tl.load(b_ptrs)
129
+ b_s = tl.load(Bs_ptrs + k * stride_Bsk)
 
 
130
  accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
131
  a_ptrs += block_k * stride_ak
132
  b_ptrs += block_k * stride_bk
 
143
  tl.store(c_ptrs, c, mask=c_mask)
144
 
145
 
146
+ @triton.autotune(
147
+ configs=[
148
+ triton.Config({}, num_warps=w, num_stages=s)
149
+ for w in [2, 4, 8, 16]
150
+ for s in [2, 3, 4, 5]
151
+ ],
152
+ key=["N", "K"],
153
+ )
154
  @triton.jit
155
  def w8a8_tensor_fp8_matmul_grouped_kernel(
156
  A, # (S, K) pre-quantized FP8 activations
 
176
  block_k: tl.constexpr,
177
  BLOCK_SIZE_M: tl.constexpr,
178
  NUM_EXPERTS: tl.constexpr,
179
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
180
  ):
181
  """Tensor-scale grouped FP8 expert matmul kernel.
182
 
 
192
 
193
  lo = 0
194
  hi = NUM_EXPERTS
195
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
196
  mid = (lo + hi) >> 1
197
  mid_val = tl.load(TileOffsets + mid)
198
  is_left = mid_val <= pid_m
 
256
  offsets: torch.Tensor,
257
  tokens_per_expert: torch.Tensor,
258
  block_size: list[int] | None,
259
+ allow_sync: bool = False,
260
  ) -> torch.Tensor:
261
  """Internal block-scale grouped FP8 matmul op.
262
 
 
315
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
316
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
317
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
318
+ if allow_sync:
319
+ # Exact tile count via CPU/GPU sync no wasted programs.
320
+ max_M_tiles = int(tile_offsets[-1].item())
321
+ else:
322
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
323
+ # Using a static upper bound keeps the grid size data-independent, which is
324
+ # required for cuda-graph compatibility. Programs beyond the real tile count
325
+ # exit immediately via the early-return guard inside the kernel.
326
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
327
 
328
  grid = (max_M_tiles, triton.cdiv(N, block_n))
329
  wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
 
365
  offsets: torch.Tensor,
366
  tokens_per_expert: torch.Tensor,
367
  block_size: list[int] | None,
368
+ allow_sync: bool = False,
369
  ) -> torch.Tensor:
370
  """Tensor-scale grouped FP8 matmul for sorted routed experts.
371
 
 
408
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
409
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
410
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
411
+ if allow_sync:
412
+ # Exact tile count via CPU/GPU sync — no wasted programs.
413
+ max_M_tiles = int(tile_offsets[-1].item())
414
+ else:
415
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
416
+ # Using a static upper bound keeps the grid size data-independent, which is
417
+ # required for cuda-graph compatibility. Programs beyond the real tile count
418
+ # exit immediately via the early-return guard inside the kernel.
419
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
420
 
421
  qA, As = fp8_act_quant(A, K)
422
  grid = (max_M_tiles, triton.cdiv(N, block_n))
 
444
  block_k=block_k,
445
  BLOCK_SIZE_M=BLOCK_SIZE_M,
446
  NUM_EXPERTS=E,
447
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
448
  )
449
 
450
  return C
 
457
  offsets: torch.Tensor,
458
  tokens_per_expert: torch.Tensor,
459
  block_size: list[int] | None,
460
+ allow_sync: bool = False,
461
  ) -> torch.Tensor:
462
  """Grouped W8A8 FP8 matmul for MoE expert dispatch with fused activation quantization.
463
 
 
474
  offsets: Cumulative token counts per expert ``[E]`` (i.e. ``cumsum(tokens_per_expert)``).
475
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
476
  block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
477
+ allow_sync: If True (default), read back the exact tile count from the GPU
478
+ (avoids wasted programs). If False, use a data-independent upper bound
479
+ (required for CUDA-graph / torch.compile compatibility).
480
 
481
  Returns:
482
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
483
  """
484
  return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
485
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
486
  )
487
 
488
 
 
493
  offsets: torch.Tensor,
494
  tokens_per_expert: torch.Tensor,
495
  block_size: list[int] | None,
496
+ allow_sync: bool = False,
497
  ) -> torch.Tensor:
498
  """Tensor-scale grouped W8A8 FP8 matmul for MoE expert dispatch.
499
 
 
504
  offsets: Cumulative token counts per expert ``[E]``.
505
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
506
  block_size: Kept for API consistency; tensor path derives tile sizes from ``N`` and ``K``.
507
+ allow_sync: If True, sync for exact grid; if False, use upper bound.
508
 
509
  Returns:
510
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
511
  """
512
  return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
513
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
514
  )
515
 
516
 
 
521
  offsets: torch.Tensor,
522
  tokens_per_expert: torch.Tensor,
523
  block_size: list[int] | None,
524
+ allow_sync: bool = False,
525
  ) -> torch.Tensor:
526
  """Unified grouped W8A8 FP8 matmul dispatcher.
527
 
 
537
  block_size[0] == B.size(1) and block_size[1] == B.size(2)
538
  ):
539
  return w8a8_tensor_fp8_matmul_grouped(
540
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
541
  )
542
 
543
  return w8a8_block_fp8_matmul_grouped(
544
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
545
  )
build/torch-xpu/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._finegrained_fp8_2836236
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_finegrained_fp8_2836236::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._finegrained_fp8_e62632c
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_finegrained_fp8_e62632c::{op_name}"
build/torch-xpu/batched.py CHANGED
@@ -19,6 +19,14 @@ from .act_quant import fp8_act_quant
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
 
 
 
 
 
 
 
 
22
  @triton.jit
23
  def w8a8_block_fp8_matmul_batched_kernel(
24
  A, # (S, K) raw BF16/FP16 activations
@@ -103,6 +111,14 @@ def w8a8_block_fp8_matmul_batched_kernel(
103
  tl.store(c_ptrs, c)
104
 
105
 
 
 
 
 
 
 
 
 
106
  @triton.jit
107
  def w8a8_tensor_fp8_matmul_batched_kernel(
108
  A, # (S, K) pre-quantized FP8 activations
 
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=w, num_stages=s)
25
+ for w in [2, 4, 8, 16]
26
+ for s in [2, 3, 4, 5]
27
+ ],
28
+ key=["N", "K"],
29
+ )
30
  @triton.jit
31
  def w8a8_block_fp8_matmul_batched_kernel(
32
  A, # (S, K) raw BF16/FP16 activations
 
111
  tl.store(c_ptrs, c)
112
 
113
 
114
+ @triton.autotune(
115
+ configs=[
116
+ triton.Config({}, num_warps=w, num_stages=s)
117
+ for w in [2, 4, 8, 16]
118
+ for s in [2, 3, 4, 5]
119
+ ],
120
+ key=["N", "K"],
121
+ )
122
  @triton.jit
123
  def w8a8_tensor_fp8_matmul_batched_kernel(
124
  A, # (S, K) pre-quantized FP8 activations
build/torch-xpu/grouped.py CHANGED
@@ -19,6 +19,14 @@ from .act_quant import fp8_act_quant
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
 
 
 
 
 
 
 
 
22
  @triton.jit
23
  def w8a8_block_fp8_matmul_grouped_kernel(
24
  A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
@@ -118,9 +126,7 @@ def w8a8_block_fp8_matmul_grouped_kernel(
118
  a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
119
  # ---- same as baseline from here ----
120
  b = tl.load(b_ptrs)
121
- k_start = k * block_k
122
- offs_ks = k_start // block_k
123
- b_s = tl.load(Bs_ptrs + offs_ks * stride_Bsk)
124
  accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
125
  a_ptrs += block_k * stride_ak
126
  b_ptrs += block_k * stride_bk
@@ -137,6 +143,14 @@ def w8a8_block_fp8_matmul_grouped_kernel(
137
  tl.store(c_ptrs, c, mask=c_mask)
138
 
139
 
 
 
 
 
 
 
 
 
140
  @triton.jit
141
  def w8a8_tensor_fp8_matmul_grouped_kernel(
142
  A, # (S, K) pre-quantized FP8 activations
@@ -162,6 +176,7 @@ def w8a8_tensor_fp8_matmul_grouped_kernel(
162
  block_k: tl.constexpr,
163
  BLOCK_SIZE_M: tl.constexpr,
164
  NUM_EXPERTS: tl.constexpr,
 
165
  ):
166
  """Tensor-scale grouped FP8 expert matmul kernel.
167
 
@@ -177,7 +192,7 @@ def w8a8_tensor_fp8_matmul_grouped_kernel(
177
 
178
  lo = 0
179
  hi = NUM_EXPERTS
180
- for _ in tl.static_range(NUM_EXPERTS.bit_length()):
181
  mid = (lo + hi) >> 1
182
  mid_val = tl.load(TileOffsets + mid)
183
  is_left = mid_val <= pid_m
@@ -241,6 +256,7 @@ def _w8a8_block_fp8_matmul_grouped(
241
  offsets: torch.Tensor,
242
  tokens_per_expert: torch.Tensor,
243
  block_size: list[int] | None,
 
244
  ) -> torch.Tensor:
245
  """Internal block-scale grouped FP8 matmul op.
246
 
@@ -299,11 +315,15 @@ def _w8a8_block_fp8_matmul_grouped(
299
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
300
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
301
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
302
- # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
303
- # Using a static upper bound keeps the grid size data-independent, which is
304
- # required for cuda-graph compatibility. Programs beyond the real tile count
305
- # exit immediately via the early-return guard inside the kernel.
306
- max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
 
 
 
 
307
 
308
  grid = (max_M_tiles, triton.cdiv(N, block_n))
309
  wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
@@ -345,6 +365,7 @@ def _w8a8_tensor_fp8_matmul_grouped(
345
  offsets: torch.Tensor,
346
  tokens_per_expert: torch.Tensor,
347
  block_size: list[int] | None,
 
348
  ) -> torch.Tensor:
349
  """Tensor-scale grouped FP8 matmul for sorted routed experts.
350
 
@@ -387,7 +408,15 @@ def _w8a8_tensor_fp8_matmul_grouped(
387
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
388
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
389
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
390
- max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
 
 
 
 
 
 
 
 
391
 
392
  qA, As = fp8_act_quant(A, K)
393
  grid = (max_M_tiles, triton.cdiv(N, block_n))
@@ -415,6 +444,7 @@ def _w8a8_tensor_fp8_matmul_grouped(
415
  block_k=block_k,
416
  BLOCK_SIZE_M=BLOCK_SIZE_M,
417
  NUM_EXPERTS=E,
 
418
  )
419
 
420
  return C
@@ -427,6 +457,7 @@ def w8a8_block_fp8_matmul_grouped(
427
  offsets: torch.Tensor,
428
  tokens_per_expert: torch.Tensor,
429
  block_size: list[int] | None,
 
430
  ) -> torch.Tensor:
431
  """Grouped W8A8 FP8 matmul for MoE expert dispatch with fused activation quantization.
432
 
@@ -443,12 +474,15 @@ def w8a8_block_fp8_matmul_grouped(
443
  offsets: Cumulative token counts per expert ``[E]`` (i.e. ``cumsum(tokens_per_expert)``).
444
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
445
  block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
 
 
 
446
 
447
  Returns:
448
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
449
  """
450
  return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
451
- A, B, Bs, offsets, tokens_per_expert, block_size
452
  )
453
 
454
 
@@ -459,6 +493,7 @@ def w8a8_tensor_fp8_matmul_grouped(
459
  offsets: torch.Tensor,
460
  tokens_per_expert: torch.Tensor,
461
  block_size: list[int] | None,
 
462
  ) -> torch.Tensor:
463
  """Tensor-scale grouped W8A8 FP8 matmul for MoE expert dispatch.
464
 
@@ -469,12 +504,13 @@ def w8a8_tensor_fp8_matmul_grouped(
469
  offsets: Cumulative token counts per expert ``[E]``.
470
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
471
  block_size: Kept for API consistency; tensor path derives tile sizes from ``N`` and ``K``.
 
472
 
473
  Returns:
474
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
475
  """
476
  return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
477
- A, B, Bs, offsets, tokens_per_expert, block_size
478
  )
479
 
480
 
@@ -485,6 +521,7 @@ def w8a8_fp8_matmul_grouped(
485
  offsets: torch.Tensor,
486
  tokens_per_expert: torch.Tensor,
487
  block_size: list[int] | None,
 
488
  ) -> torch.Tensor:
489
  """Unified grouped W8A8 FP8 matmul dispatcher.
490
 
@@ -500,9 +537,9 @@ def w8a8_fp8_matmul_grouped(
500
  block_size[0] == B.size(1) and block_size[1] == B.size(2)
501
  ):
502
  return w8a8_tensor_fp8_matmul_grouped(
503
- A, B, Bs, offsets, tokens_per_expert, block_size
504
  )
505
 
506
  return w8a8_block_fp8_matmul_grouped(
507
- A, B, Bs, offsets, tokens_per_expert, block_size
508
  )
 
19
  from torch.library import triton_op, wrap_triton
20
 
21
 
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=w, num_stages=s)
25
+ for w in [2, 4, 8, 16]
26
+ for s in [2, 3, 4, 5]
27
+ ],
28
+ key=["N", "K"],
29
+ )
30
  @triton.jit
31
  def w8a8_block_fp8_matmul_grouped_kernel(
32
  A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
 
126
  a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
127
  # ---- same as baseline from here ----
128
  b = tl.load(b_ptrs)
129
+ b_s = tl.load(Bs_ptrs + k * stride_Bsk)
 
 
130
  accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
131
  a_ptrs += block_k * stride_ak
132
  b_ptrs += block_k * stride_bk
 
143
  tl.store(c_ptrs, c, mask=c_mask)
144
 
145
 
146
+ @triton.autotune(
147
+ configs=[
148
+ triton.Config({}, num_warps=w, num_stages=s)
149
+ for w in [2, 4, 8, 16]
150
+ for s in [2, 3, 4, 5]
151
+ ],
152
+ key=["N", "K"],
153
+ )
154
  @triton.jit
155
  def w8a8_tensor_fp8_matmul_grouped_kernel(
156
  A, # (S, K) pre-quantized FP8 activations
 
176
  block_k: tl.constexpr,
177
  BLOCK_SIZE_M: tl.constexpr,
178
  NUM_EXPERTS: tl.constexpr,
179
+ NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
180
  ):
181
  """Tensor-scale grouped FP8 expert matmul kernel.
182
 
 
192
 
193
  lo = 0
194
  hi = NUM_EXPERTS
195
+ for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
196
  mid = (lo + hi) >> 1
197
  mid_val = tl.load(TileOffsets + mid)
198
  is_left = mid_val <= pid_m
 
256
  offsets: torch.Tensor,
257
  tokens_per_expert: torch.Tensor,
258
  block_size: list[int] | None,
259
+ allow_sync: bool = False,
260
  ) -> torch.Tensor:
261
  """Internal block-scale grouped FP8 matmul op.
262
 
 
315
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
316
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
317
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
318
+ if allow_sync:
319
+ # Exact tile count via CPU/GPU sync no wasted programs.
320
+ max_M_tiles = int(tile_offsets[-1].item())
321
+ else:
322
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
323
+ # Using a static upper bound keeps the grid size data-independent, which is
324
+ # required for cuda-graph compatibility. Programs beyond the real tile count
325
+ # exit immediately via the early-return guard inside the kernel.
326
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
327
 
328
  grid = (max_M_tiles, triton.cdiv(N, block_n))
329
  wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
 
365
  offsets: torch.Tensor,
366
  tokens_per_expert: torch.Tensor,
367
  block_size: list[int] | None,
368
+ allow_sync: bool = False,
369
  ) -> torch.Tensor:
370
  """Tensor-scale grouped FP8 matmul for sorted routed experts.
371
 
 
408
  BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
409
  tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
410
  tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
411
+ if allow_sync:
412
+ # Exact tile count via CPU/GPU sync — no wasted programs.
413
+ max_M_tiles = int(tile_offsets[-1].item())
414
+ else:
415
+ # Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
416
+ # Using a static upper bound keeps the grid size data-independent, which is
417
+ # required for cuda-graph compatibility. Programs beyond the real tile count
418
+ # exit immediately via the early-return guard inside the kernel.
419
+ max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
420
 
421
  qA, As = fp8_act_quant(A, K)
422
  grid = (max_M_tiles, triton.cdiv(N, block_n))
 
444
  block_k=block_k,
445
  BLOCK_SIZE_M=BLOCK_SIZE_M,
446
  NUM_EXPERTS=E,
447
+ NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
448
  )
449
 
450
  return C
 
457
  offsets: torch.Tensor,
458
  tokens_per_expert: torch.Tensor,
459
  block_size: list[int] | None,
460
+ allow_sync: bool = False,
461
  ) -> torch.Tensor:
462
  """Grouped W8A8 FP8 matmul for MoE expert dispatch with fused activation quantization.
463
 
 
474
  offsets: Cumulative token counts per expert ``[E]`` (i.e. ``cumsum(tokens_per_expert)``).
475
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
476
  block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
477
+ allow_sync: If True (default), read back the exact tile count from the GPU
478
+ (avoids wasted programs). If False, use a data-independent upper bound
479
+ (required for CUDA-graph / torch.compile compatibility).
480
 
481
  Returns:
482
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
483
  """
484
  return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
485
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
486
  )
487
 
488
 
 
493
  offsets: torch.Tensor,
494
  tokens_per_expert: torch.Tensor,
495
  block_size: list[int] | None,
496
+ allow_sync: bool = False,
497
  ) -> torch.Tensor:
498
  """Tensor-scale grouped W8A8 FP8 matmul for MoE expert dispatch.
499
 
 
504
  offsets: Cumulative token counts per expert ``[E]``.
505
  tokens_per_expert: Number of tokens routed to each expert ``[E]``.
506
  block_size: Kept for API consistency; tensor path derives tile sizes from ``N`` and ``K``.
507
+ allow_sync: If True, sync for exact grid; if False, use upper bound.
508
 
509
  Returns:
510
  Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
511
  """
512
  return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
513
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
514
  )
515
 
516
 
 
521
  offsets: torch.Tensor,
522
  tokens_per_expert: torch.Tensor,
523
  block_size: list[int] | None,
524
+ allow_sync: bool = False,
525
  ) -> torch.Tensor:
526
  """Unified grouped W8A8 FP8 matmul dispatcher.
527
 
 
537
  block_size[0] == B.size(1) and block_size[1] == B.size(2)
538
  ):
539
  return w8a8_tensor_fp8_matmul_grouped(
540
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
541
  )
542
 
543
  return w8a8_block_fp8_matmul_grouped(
544
+ A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
545
  )