Instructions to use kernels-community/triton-kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/triton-kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/triton-kernels") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Example: Mask compaction + Dual GEMM integration (MoE-style). | |
| Before dual GEMM: compact (Yv, Yi) per row based on BitMask. | |
| Then use compacted tensors for routing into expert weights. | |
| ROCm note: tl.store with dynamic write_indx may fail on ROCm Triton. | |
| If so, use the PyTorch fallback in mask_compaction.py. | |
| """ | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import torch | |
| from mask_compaction import masked_compaction, masked_compaction_torch_fallback | |
| from amd_dual_gemm_swiglu import dual_gemm_swiglu | |
| def example_integration(): | |
| """Sketch: compact routing outputs, then run dual GEMM on routed experts.""" | |
| device = "cuda" | |
| if not torch.cuda.is_available(): | |
| print("No GPU") | |
| return | |
| M, K, N = 256, 64, 128 # tokens, hidden, expert dim | |
| top_k = 8 | |
| num_experts = 4 | |
| # Simulate routing: Yv [M, K] values, Yi [M, K] expert indices (0..num_experts-1) | |
| torch.manual_seed(42) | |
| Yv = torch.randn(M, K, device=device, dtype=torch.float16) * 0.1 | |
| Yi = torch.randint(0, num_experts, (M, K), device=device, dtype=torch.int32) | |
| # BitMask [M, ceil(K/32)]: 1 = use, 0 = discard (e.g. from load balance) | |
| BitMask = torch.ones(M, (K + 31) // 32, device=device, dtype=torch.int32) | |
| BitMask[:, 0] = 0x55555555 # example: alternating bits | |
| # 1) Compact (Yv, Yi) per row based on BitMask | |
| try: | |
| RetYv, RetYi = masked_compaction(Yv, Yi, BitMask, sentinel=float("nan")) | |
| print("Compaction: Triton kernel OK") | |
| except Exception as e: | |
| print(f"Compaction: Triton failed ({e}), using PyTorch fallback") | |
| RetYv, RetYi = masked_compaction_torch_fallback(Yv, Yi, BitMask, sentinel=float("nan")) | |
| # 2) Use compacted indices for routing into expert weights | |
| # Expert weights: B1[E,K,N], B2[E,K,N] or similar. For simplicity, flat GEMM: | |
| # A = routed activations [M, K], B1/B2 = expert weights [K, N] | |
| # This is a simplified sketch; real MoE has per-expert B. | |
| B1 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1 | |
| B2 = torch.randn(K, N, device=device, dtype=torch.float16) * 0.1 | |
| # Use RetYv as activations (compacted); pad/truncate to [M, K] if needed | |
| A = RetYv[:, :K].contiguous() | |
| if A.shape[1] < K: | |
| A = torch.nn.functional.pad(A, (0, K - A.shape[1]), value=0) | |
| # 3) Dual GEMM | |
| out = dual_gemm_swiglu(A, B1, B2) | |
| print(f"Dual GEMM output: {out.shape}") | |
| print("Done.") | |
| if __name__ == "__main__": | |
| example_integration() | |