Kernels
triton-kernels / example_moe_compaction_gemm.py
fmgreco's picture
Add ROCm dual GEMM, MXFP4, mask compaction, group GEMM
199170e
raw
history blame
2.53 kB
#!/usr/bin/env python3
"""
Example: Mask compaction + Dual GEMM integration (MoE-style).
Before dual GEMM: compact (Yv, Yi) per row based on BitMask.
Then use compacted tensors for routing into expert weights.
ROCm note: tl.store with dynamic write_indx may fail on ROCm Triton.
If so, use the PyTorch fallback in mask_compaction.py.
"""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from mask_compaction import masked_compaction, masked_compaction_torch_fallback
from amd_dual_gemm_swiglu import dual_gemm_swiglu
def example_integration():
"""Sketch: compact routing outputs, then run dual GEMM on routed experts."""
device = "cuda"
if not torch.cuda.is_available():
print("No GPU")
return
M, K, N = 256, 64, 128 # tokens, hidden, expert dim
top_k = 8
num_experts = 4
# Simulate routing: Yv [M, K] values, Yi [M, K] expert indices (0..num_experts-1)
torch.manual_seed(42)
Yv = torch.randn(M, K, device=device, dtype=torch.float16) * 0.1
Yi = torch.randint(0, num_experts, (M, K), device=device, dtype=torch.int32)
# BitMask [M, ceil(K/32)]: 1 = use, 0 = discard (e.g. from load balance)
BitMask = torch.ones(M, (K + 31) // 32, device=device, dtype=torch.int32)
BitMask[:, 0] = 0x55555555 # example: alternating bits
# 1) Compact (Yv, Yi) per row based on BitMask
try:
RetYv, RetYi = masked_compaction(Yv, Yi, BitMask, sentinel=float("nan"))
print("Compaction: Triton kernel OK")
except Exception as e:
print(f"Compaction: Triton failed ({e}), using PyTorch fallback")
RetYv, RetYi = masked_compaction_torch_fallback(Yv, Yi, BitMask, sentinel=float("nan"))
# 2) Use compacted indices for routing into expert weights
# Expert weights: B1[E,K,N], B2[E,K,N] or similar. For simplicity, flat GEMM:
# A = routed activations [M, K], B1/B2 = expert weights [K, N]
# This is a simplified sketch; real MoE has per-expert B.
B1 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1
B2 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1
# Use RetYv as activations (compacted); pad/truncate to [M, K] if needed
A = RetYv[:, :K].contiguous()
if A.shape[1] < K:
A = torch.nn.functional.pad(A, (0, K - A.shape[1]), value=0)
# 3) Dual GEMM
out = dual_gemm_swiglu(A, B1, B2)
print(f"Dual GEMM output: {out.shape}")
print("Done.")
if __name__ == "__main__":
example_integration()