+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: utils | deps: torch, numpy | 3.06s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +
+
+
"""Simple utilities for running the models."""
+import torch
+
+def to_dtype(dtype_str: str):
+    """Convert string to torch dtype."""
+    if dtype_str == "float16":
+        return torch.float16
+    if dtype_str == "bfloat16":
+        return torch.bfloat16
+    return torch.float32
+
+def tensor_stats(t: torch.Tensor) -> str:
+    """Generate stats string for a tensor."""
+    return (f"shape={tuple(t.shape)}, "
+            f"dtype={t.dtype}, "
+            f"device={t.device}, "
+            f"mean={t.mean().item():.6f}, "
+            f"std={t.std().item():.6f}")
+
+def set_seed(seed: int):
+    """Set seeds for reproducibility."""
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
+
+ +
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: bench_utils | deps: torch, numpy | 13.67s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +
+
+
"""Reusable benchmarking utilities for performance testing."""
+import time
+import numpy as np
+from contextlib import contextmanager
+from typing import Callable, Dict, Tuple, Any, Optional
+import torch
+
+def to_dtype(dtype_str: str):
+    """Convert string to torch dtype."""
+    if dtype_str == "float16":
+        return torch.float16
+    if dtype_str == "bfloat16":
+        return torch.bfloat16
+    return torch.float32
+
+def _sync(device: str):
+    """Synchronize device if CUDA."""
+    if device == "cuda":
+        torch.cuda.synchronize()
+
+def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
+    """Compute comprehensive latency and throughput statistics."""
+    lat_ms = np.array([t * 1000.0 for t in times_s])
+    lat_ms_sorted = np.sort(lat_ms)
+    n = len(lat_ms)
+
+    stats = {
+        "avg_ms": np.mean(lat_ms),
+        "min_ms": np.min(lat_ms),
+        "max_ms": np.max(lat_ms),
+        "std_ms": np.std(lat_ms),
+        "p50_ms": np.percentile(lat_ms, 50),
+        "p95_ms": np.percentile(lat_ms, 95),
+        "p99_ms": np.percentile(lat_ms, 99),
+        "num_iters": n
+    }
+
+    if tokens is not None and n > 0:
+        avg_s = np.mean(times_s)
+        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
+        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
+
+    return stats
+
+def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
+    """Format timing statistics for display."""
+    lines = [
+        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
+        f"Iterations: {stats.get('num_iters', 0)}",
+        "\nLatency Statistics:",
+        f"  Average: {stats['avg_ms']:.3f} ms",
+        f"  Min:     {stats['min_ms']:.3f} ms",
+        f"  Max:     {stats['max_ms']:.3f} ms", 
+        f"  Std Dev: {stats['std_ms']:.3f} ms",
+        "\nPercentiles:",
+        f"  P50 (median): {stats['p50_ms']:.3f} ms",
+        f"  P95:          {stats['p95_ms']:.3f} ms",
+        f"  P99:          {stats['p99_ms']:.3f} ms",
+    ]
+
+    if tokens is not None and 'tokens_per_s' in stats:
+        lines.extend([
+            "\nThroughput:",
+            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
+            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
+        ])
+
+    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
+    return "\n".join(lines)
+
+def _bench_engine(
+    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
+) -> Tuple[Any, list]:
+    """Core benchmarking engine with warmup and timing."""
+    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
+
+    # Warmup phase
+    print(f"\nWarming up ({warmup} iterations)...")
+    with torch.inference_mode():
+        for _ in range(max(0, warmup)):
+            if use_autocast:
+                with torch.autocast(device_type="cuda", dtype=dtype):
+                    if input_gen is not None:
+                        _ = call(input_gen())
+                    else:
+                        _ = call()
+            else:
+                if input_gen is not None:
+                    _ = call(input_gen())
+                else:
+                    _ = call()
+        _sync(device)
+
+    # Measurement phase
+    print(f"Benchmarking ({iters} iterations)...")
+    times_s = []
+    last = None
+    with torch.inference_mode():
+        for i in range(max(1, iters)):
+            start = time.perf_counter()
+            if use_autocast:
+                with torch.autocast(device_type="cuda", dtype=dtype):
+                    if input_gen is not None:
+                        last = call(input_gen())
+                    else:
+                        last = call()
+            else:
+                if input_gen is not None:
+                    last = call(input_gen())
+                else:
+                    last = call()
+            _sync(device)
+            end = time.perf_counter()
+            times_s.append(end - start)
+
+            # Progress indicator every 20% of iterations
+            if i > 0 and i % max(1, iters // 5) == 0:
+                pct = (i / iters) * 100
+                avg_so_far = np.mean(times_s[:i]) * 1000
+                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
+
+    return last, times_s
+
+def tensor_stats(t: torch.Tensor) -> str:
+    """Generate comprehensive stats string for a tensor."""
+    return (f"shape={tuple(t.shape)}, "
+            f"dtype={t.dtype}, "
+            f"device={t.device}, "
+            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
+            f"mean={t.mean().item():.6f}, "
+            f"std={t.std().item():.6f}, "
+            f"norm={t.norm().item():.6f}")
+
+@contextmanager
+def bench_context(
+    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
+):
+    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
+
+    If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
+    by adding a small deterministic increment to prevent caching artifacts.
+    """
+
+    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
+        # Log configuration
+        if verbose:
+            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
+            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
+            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
+            if tokens:
+                print(f"│ Tokens: {tokens}                                        │")
+            if vary_inputs:
+                print(f"│ Input Variation: Enabled (prevents caching artifacts)  │")
+            print(f"└────────────────────────────────────────────────────────┘")
+
+        # Set up input generation
+        input_gen = None
+        if vary_inputs and args and isinstance(args[0], torch.Tensor):
+            base_input = args[0].clone()
+            iteration_counter = [0]  # Use list for mutable closure
+
+            def generate_varied_input():
+                """Generate input tensor varied by iteration to prevent caching."""
+                # Add small deterministic increment: 0.001 * iteration_number
+                varied_input = base_input + (iteration_counter[0] * 0.001)
+                iteration_counter[0] += 1
+                return varied_input
+
+            input_gen = generate_varied_input
+            call = lambda x: fn(x, *args[1:], **kwargs)
+
+            # Log base input stats
+            if verbose:
+                print(f"\nBase Input: {tensor_stats(base_input)}")
+                print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
+        else:
+            # Legacy mode - static inputs
+            call = lambda: fn(*args, **kwargs)
+            if verbose and args and isinstance(args[0], torch.Tensor):
+                print(f"\nInput: {tensor_stats(args[0])}")
+
+        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
+
+        # Log output if it's a tensor or tuple with tensors
+        if verbose:
+            print("\nOutput tensors:")
+            if isinstance(result, torch.Tensor):
+                print(f"  Primary: {tensor_stats(result)}")
+            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
+                print(f"  Primary: {tensor_stats(result[0])}")
+                if len(result) > 1:
+                    if isinstance(result[1], torch.Tensor):
+                        print(f"  Auxiliary: {tensor_stats(result[1])}")
+                    else:
+                        print(f"  Auxiliary: {type(result[1]).__name__}")
+
+        # Compute and display statistics
+        stats = _compute_stats(times_s, tokens=tokens)
+        if verbose:
+            print(_format_timing_stats(stats, tokens))
+
+        # Save to JSON if requested
+        if save_json:
+            import json
+            json_data = {
+                "implementation": save_json.replace(".json", ""),
+                "config": {
+                    "warmup": warmup,
+                    "iters": iters,
+                    "device": str(device),  # Convert device to string
+                    "dtype": str(dtype),
+                    "tokens": tokens,
+                    "vary_inputs": vary_inputs
+                },
+                "stats": stats,
+                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
+            }
+            with open(save_json, 'w') as f:
+                json.dump(json_data, f, indent=2)
+            if verbose:
+                print(f"\nSaved benchmark results to {save_json}")
+
+        return result, stats
+
+    yield runner
+
+def set_seed(seed: int):
+    """Set seeds for reproducibility."""
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
+
+ +

This notebook benchmarks multiple MoE implementations with varied inputs across iterations to prevent unrealistic caching artifacts and measure true performance characteristics.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: config | deps: torch, numpy | 3.02s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +
+
+
"""Shared configuration for both implementations."""
+import torch
+
+# Model configuration
+NUM_EXPERTS = 128
+HIDDEN_SIZE = 1152
+INTERMEDIATE_SIZE = 3072
+TOP_K = 4
+
+# Input configuration
+BATCH_SIZE = 1
+SEQ_LEN = 100
+DTYPE = "float32"
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+# Seeds for reproducibility
+WEIGHT_SEED = 999
+EXPERT_SEED = 777
+INPUT_SEED = 123
+GENERAL_SEED = 42
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
+
+ +
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: save_data | deps: torch, numpy | 11.90s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +
+
+
"""
+Generate deterministic shared weights once and save as artifacts so
+both implementations load identical parameters.
+"""
+import torch
+from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
+
+def save_shared_weights():
+    # Router: Kaiming uniform as used by both, bias zeros
+    torch.manual_seed(WEIGHT_SEED)
+    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
+    torch.nn.init.kaiming_uniform_(router_weight)
+    router_bias = torch.zeros(NUM_EXPERTS)
+
+    # Experts: normal(0, 0.02), biases zeros
+    torch.manual_seed(EXPERT_SEED)
+    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
+    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
+    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
+    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
+
+    # Save artifacts
+    torch.save(router_weight, 'router_weight.pt')
+    torch.save(router_bias, 'router_bias.pt')
+    torch.save(gate_up_proj, 'gate_up_proj.pt')
+    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
+    torch.save(down_proj, 'down_proj.pt')
+    torch.save(down_proj_bias, 'down_proj_bias.pt')
+
+    print("Saved shared weights to artifacts")
+    print(f"Router weight sum: {router_weight.sum().item():.6f}")
+    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+    print(f"Down sum: {down_proj.sum().item():.6f}")
+
+save_shared_weights()
+
+ +
+
+
+
+
+
Saved shared weights to artifacts +Router weight sum: 12.588735 +Gate/up sum: 1026.601807 +Down sum: 206.729279 +
+
+
▶ UV Install Logs
+ +
+ +
+
+ +

Yamoe Implementation

+

This section runs the Yamoe MoE implementation with optimized Triton kernels.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: yamoe_run | deps: torch, kernels, numpy | 4.02s | FAILED + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class YamoeRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+class YamoeMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = YamoeRouter(router_weight, router_bias)
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.top_k = TOP_K
+
+        # Load Yamoe kernel
+        # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
+        self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
+
+        # Expert weights - use the loaded weights
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+    def forward(self, hidden_states):
+        batch_size, seq_len, hidden_dim = hidden_states.shape
+
+        # Get routing decisions
+        routing_weights, router_indices = self.router(hidden_states)
+
+        # Reshape for Yamoe kernel
+        hidden_states_flat = hidden_states.view(-1, hidden_dim)
+        routing_weights_flat = routing_weights.view(-1, self.num_experts)
+        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
+
+        # Call Yamoe optimized kernel
+        output = self.yamoe.experts(
+            hidden_states_flat,
+            router_indices,
+            routing_weights_flat,
+            self.gate_up_proj,
+            self.gate_up_proj_bias,
+            self.down_proj,
+            self.down_proj_bias,
+            expert_capacity,
+            self.num_experts,
+            self.top_k,
+        )
+
+        # Reshape output back
+        output = output.view(batch_size, seq_len, hidden_dim)
+
+        return output, routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
+dtype = to_dtype(DTYPE)
+
+print("\n=== Yamoe Implementation ===")
+# Initialize model with loaded weights
+model = YamoeMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate input
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loading weights from: /home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks_yamoe/.uvnote/cache/57bbe537b6c3412d45373a8967728666b60b8687c5d1f5d0decc3ba51923edde +Loaded shared weights from artifacts +Router weight sum: 12.588735 +Gate/up sum: 1026.601807 +Down sum: 206.729279 + +=== Yamoe Implementation === +
+
+
▶ UV Install Logs
+ +
+
Traceback (most recent call last): + File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks_yamoe/.uvnote/cells/yamoe_run.py", line 115, in <module> + router_weight.to(device), + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/yamoe-run-07f6c9b004377cec/lib/python3.11/site-packages/torch/cuda/__init__.py", line 412, in _lazy_init + torch._C._cuda_init() +RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
+
+
+ +

Binned Implementation

+

This section runs the binned implementation that manually handles token gathering/scattering.

+

GPT-OSS Implementation

+

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

+

GPT-OSS Implementation (Training Mode)

+

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

+

MegaBlocks Implementation

+

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

+

Performance Visualization

+

This section reads all benchmark results and creates a comprehensive performance comparison chart.

+