+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: utils | deps: torch, numpy | 34.17s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +
+
+
"""Simple utilities for running the models."""
+import torch
+
+def to_dtype(dtype_str: str):
+    """Convert string to torch dtype."""
+    if dtype_str == "float16":
+        return torch.float16
+    if dtype_str == "bfloat16":
+        return torch.bfloat16
+    return torch.float32
+
+def tensor_stats(t: torch.Tensor) -> str:
+    """Generate stats string for a tensor."""
+    return (f"shape={tuple(t.shape)}, "
+            f"dtype={t.dtype}, "
+            f"device={t.device}, "
+            f"mean={t.mean().item():.6f}, "
+            f"std={t.std().item():.6f}")
+
+def set_seed(seed: int):
+    """Set seeds for reproducibility."""
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
+
+ +
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: bench_utils | deps: torch, numpy | 34.13s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +
+
+
"""Reusable benchmarking utilities for performance testing."""
+import time
+import numpy as np
+from contextlib import contextmanager
+from typing import Callable, Dict, Tuple, Any, Optional
+import torch
+
+def to_dtype(dtype_str: str):
+    """Convert string to torch dtype."""
+    if dtype_str == "float16":
+        return torch.float16
+    if dtype_str == "bfloat16":
+        return torch.bfloat16
+    return torch.float32
+
+def _sync(device: str):
+    """Synchronize device if CUDA."""
+    if device == "cuda":
+        torch.cuda.synchronize()
+
+def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
+    """Compute comprehensive latency and throughput statistics."""
+    lat_ms = np.array([t * 1000.0 for t in times_s])
+    lat_ms_sorted = np.sort(lat_ms)
+    n = len(lat_ms)
+
+    stats = {
+        "avg_ms": np.mean(lat_ms),
+        "min_ms": np.min(lat_ms),
+        "max_ms": np.max(lat_ms),
+        "std_ms": np.std(lat_ms),
+        "p50_ms": np.percentile(lat_ms, 50),
+        "p95_ms": np.percentile(lat_ms, 95),
+        "p99_ms": np.percentile(lat_ms, 99),
+        "num_iters": n
+    }
+
+    if tokens is not None and n > 0:
+        avg_s = np.mean(times_s)
+        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
+        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
+
+    return stats
+
+def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
+    """Format timing statistics for display."""
+    lines = [
+        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
+        f"Iterations: {stats.get('num_iters', 0)}",
+        "\nLatency Statistics:",
+        f"  Average: {stats['avg_ms']:.3f} ms",
+        f"  Min:     {stats['min_ms']:.3f} ms",
+        f"  Max:     {stats['max_ms']:.3f} ms", 
+        f"  Std Dev: {stats['std_ms']:.3f} ms",
+        "\nPercentiles:",
+        f"  P50 (median): {stats['p50_ms']:.3f} ms",
+        f"  P95:          {stats['p95_ms']:.3f} ms",
+        f"  P99:          {stats['p99_ms']:.3f} ms",
+    ]
+
+    if tokens is not None and 'tokens_per_s' in stats:
+        lines.extend([
+            "\nThroughput:",
+            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
+            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
+        ])
+
+    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
+    return "\n".join(lines)
+
+def _bench_engine(
+    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
+) -> Tuple[Any, list]:
+    """Core benchmarking engine with warmup and timing."""
+    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
+
+    # Warmup phase
+    print(f"\nWarming up ({warmup} iterations)...")
+    with torch.inference_mode():
+        for _ in range(max(0, warmup)):
+            if use_autocast:
+                with torch.autocast(device_type="cuda", dtype=dtype):
+                    if input_gen is not None:
+                        _ = call(input_gen())
+                    else:
+                        _ = call()
+            else:
+                if input_gen is not None:
+                    _ = call(input_gen())
+                else:
+                    _ = call()
+        _sync(device)
+
+    # Measurement phase
+    print(f"Benchmarking ({iters} iterations)...")
+    times_s = []
+    last = None
+    with torch.inference_mode():
+        for i in range(max(1, iters)):
+            start = time.perf_counter()
+            if use_autocast:
+                with torch.autocast(device_type="cuda", dtype=dtype):
+                    if input_gen is not None:
+                        last = call(input_gen())
+                    else:
+                        last = call()
+            else:
+                if input_gen is not None:
+                    last = call(input_gen())
+                else:
+                    last = call()
+            _sync(device)
+            end = time.perf_counter()
+            times_s.append(end - start)
+
+            # Progress indicator every 20% of iterations
+            if i > 0 and i % max(1, iters // 5) == 0:
+                pct = (i / iters) * 100
+                avg_so_far = np.mean(times_s[:i]) * 1000
+                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
+
+    return last, times_s
+
+def tensor_stats(t: torch.Tensor) -> str:
+    """Generate comprehensive stats string for a tensor."""
+    return (f"shape={tuple(t.shape)}, "
+            f"dtype={t.dtype}, "
+            f"device={t.device}, "
+            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
+            f"mean={t.mean().item():.6f}, "
+            f"std={t.std().item():.6f}, "
+            f"norm={t.norm().item():.6f}")
+
+@contextmanager
+def bench_context(
+    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
+):
+    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
+
+    If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
+    by adding a small deterministic increment to prevent caching artifacts.
+    """
+
+    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
+        # Log configuration
+        if verbose:
+            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
+            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
+            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
+            if tokens:
+                print(f"│ Tokens: {tokens}                                        │")
+            if vary_inputs:
+                print(f"│ Input Variation: Enabled (prevents caching artifacts)  │")
+            print(f"└────────────────────────────────────────────────────────┘")
+
+        # Set up input generation
+        input_gen = None
+        if vary_inputs and args and isinstance(args[0], torch.Tensor):
+            base_input = args[0].clone()
+            iteration_counter = [0]  # Use list for mutable closure
+
+            def generate_varied_input():
+                """Generate input tensor varied by iteration to prevent caching."""
+                # Add small deterministic increment: 0.001 * iteration_number
+                varied_input = base_input + (iteration_counter[0] * 0.001)
+                iteration_counter[0] += 1
+                return varied_input
+
+            input_gen = generate_varied_input
+            call = lambda x: fn(x, *args[1:], **kwargs)
+
+            # Log base input stats
+            if verbose:
+                print(f"\nBase Input: {tensor_stats(base_input)}")
+                print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
+        else:
+            # Legacy mode - static inputs
+            call = lambda: fn(*args, **kwargs)
+            if verbose and args and isinstance(args[0], torch.Tensor):
+                print(f"\nInput: {tensor_stats(args[0])}")
+
+        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
+
+        # Log output if it's a tensor or tuple with tensors
+        if verbose:
+            print("\nOutput tensors:")
+            if isinstance(result, torch.Tensor):
+                print(f"  Primary: {tensor_stats(result)}")
+            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
+                print(f"  Primary: {tensor_stats(result[0])}")
+                if len(result) > 1:
+                    if isinstance(result[1], torch.Tensor):
+                        print(f"  Auxiliary: {tensor_stats(result[1])}")
+                    else:
+                        print(f"  Auxiliary: {type(result[1]).__name__}")
+
+        # Compute and display statistics
+        stats = _compute_stats(times_s, tokens=tokens)
+        if verbose:
+            print(_format_timing_stats(stats, tokens))
+
+        # Save to JSON if requested
+        if save_json:
+            import json
+            json_data = {
+                "implementation": save_json.replace(".json", ""),
+                "config": {
+                    "warmup": warmup,
+                    "iters": iters,
+                    "device": str(device),  # Convert device to string
+                    "dtype": str(dtype),
+                    "tokens": tokens,
+                    "vary_inputs": vary_inputs
+                },
+                "stats": stats,
+                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
+            }
+            with open(save_json, 'w') as f:
+                json.dump(json_data, f, indent=2)
+            if verbose:
+                print(f"\nSaved benchmark results to {save_json}")
+
+        return result, stats
+
+    yield runner
+
+def set_seed(seed: int):
+    """Set seeds for reproducibility."""
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
+
+ +

This notebook benchmarks multiple MoE implementations with varied inputs across iterations to prevent unrealistic caching artifacts and measure true performance characteristics.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: config | deps: torch, numpy | 35.83s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +
+
+
"""Shared configuration for both implementations."""
+import torch
+
+# Model configuration
+NUM_EXPERTS = 128
+HIDDEN_SIZE = 1152
+INTERMEDIATE_SIZE = 3072
+TOP_K = 4
+
+# Input configuration
+BATCH_SIZE = 1
+SEQ_LEN = 100
+DTYPE = "float32"
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+# Seeds for reproducibility
+WEIGHT_SEED = 999
+EXPERT_SEED = 777
+INPUT_SEED = 123
+GENERAL_SEED = 42
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
+
+ +
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: save_data | deps: torch, numpy | 39.38s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +
+
+
"""
+Generate deterministic shared weights once and save as artifacts so
+both implementations load identical parameters.
+"""
+import torch
+from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
+
+def save_shared_weights():
+    # Router: Kaiming uniform as used by both, bias zeros
+    torch.manual_seed(WEIGHT_SEED)
+    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
+    torch.nn.init.kaiming_uniform_(router_weight)
+    router_bias = torch.zeros(NUM_EXPERTS)
+
+    # Experts: normal(0, 0.02), biases zeros
+    torch.manual_seed(EXPERT_SEED)
+    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
+    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
+    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
+    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
+
+    # Save artifacts
+    torch.save(router_weight, 'router_weight.pt')
+    torch.save(router_bias, 'router_bias.pt')
+    torch.save(gate_up_proj, 'gate_up_proj.pt')
+    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
+    torch.save(down_proj, 'down_proj.pt')
+    torch.save(down_proj_bias, 'down_proj_bias.pt')
+
+    print("Saved shared weights to artifacts")
+    print(f"Router weight sum: {router_weight.sum().item():.6f}")
+    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+    print(f"Down sum: {down_proj.sum().item():.6f}")
+
+save_shared_weights()
+
+ +
+
+
+
+
+
Saved shared weights to artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 +
+
+
▶ UV Install Logs
+ +
+ +
+
+ +

Yamoe Implementation

+

This section runs the Yamoe MoE implementation with optimized Triton kernels.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: yamoe_run | deps: torch, kernels, numpy | 38.93s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class YamoeRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+class YamoeMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = YamoeRouter(router_weight, router_bias)
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.top_k = TOP_K
+
+        # Load Yamoe kernel
+        # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
+        self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
+
+        # Expert weights - use the loaded weights
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+    def forward(self, hidden_states):
+        batch_size, seq_len, hidden_dim = hidden_states.shape
+
+        # Get routing decisions
+        routing_weights, router_indices = self.router(hidden_states)
+
+        # Reshape for Yamoe kernel
+        hidden_states_flat = hidden_states.view(-1, hidden_dim)
+        routing_weights_flat = routing_weights.view(-1, self.num_experts)
+        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
+
+        # Call Yamoe optimized kernel
+        output = self.yamoe.experts(
+            hidden_states_flat,
+            router_indices,
+            routing_weights_flat,
+            self.gate_up_proj,
+            self.gate_up_proj_bias,
+            self.down_proj,
+            self.down_proj_bias,
+            expert_capacity,
+            self.num_experts,
+            self.top_k,
+        )
+
+        # Reshape output back
+        output = output.view(batch_size, seq_len, hidden_dim)
+
+        return output, routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
+dtype = to_dtype(DTYPE)
+
+print("\n=== Yamoe Implementation ===")
+# Initialize model with loaded weights
+model = YamoeMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate input
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/f8744f31d9cf720409852d42748815c6d61f005a2a9b297b7b9bf986ed98bb90 +Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== Yamoe Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 4.247 ms) + Progress: 40% complete (avg: 4.247 ms) + Progress: 60% complete (avg: 4.250 ms) + Progress: 80% complete (avg: 4.249 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 4.249 ms + Min: 4.131 ms + Max: 4.305 ms + Std Dev: 0.028 ms + +Percentiles: + P50 (median): 4.250 ms + P95: 4.289 ms + P99: 4.300 ms + +Throughput: + Tokens/sec: 23533.0 + Std Dev: 154.4 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to yamoe_results.json + +Output sum: 3.971905 +
+
+
▶ UV Install Logs
+ +
+
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] +Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:00, 5.17it/s] +Fetching 6 files: 33%|███▎ | 2/6 [00:00<00:01, 3.77it/s] +Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 4.24it/s] +Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 8.45it/s]
+
+

Artifacts:

+yamoe_results.json +
+
+
+ +

Binned Implementation

+

This section runs the binned implementation that manually handles token gathering/scattering.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: binned_run | deps: torch, numpy | 39.10s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def binned_gather(x, indices, bins, expert_capacity, top_k):
+    E, H = bins.shape[0], x.shape[1]
+    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
+    for e in range(E):
+        start = 0 if e == 0 else bins[e - 1]
+        end = bins[e]
+        n = min(end - start, expert_capacity)
+        for i in range(n):
+            flat_pos = indices[start + i]
+            tok = flat_pos // top_k
+            out[e, i] = x[tok]
+    return out
+
+def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
+    E, C, H = x.shape
+    N = indices.shape[0] // top_k
+    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
+    for e in range(E):
+        start = 0 if e == 0 else bins[e - 1]
+        end = bins[e]
+        n = end - start
+        if n == 0:
+            continue
+        take = min(n, expert_capacity)
+        for i in range(take):
+            flat_pos = indices[start + i]
+            tok = flat_pos // top_k
+            slot = flat_pos % top_k
+            scale = weights[flat_pos] if weights is not None else 1.0
+            out[tok, slot] = x[e, i] * scale
+    return out.sum(dim=1)
+
+def sort_tokens_by_expert(router_indices, num_experts):
+    flat_indices = router_indices.flatten()
+    sorted_values, sorted_indices = torch.sort(flat_indices)
+    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
+    bins = torch.cumsum(tokens_per_expert, dim=0)
+    return sorted_indices, sorted_values, bins, tokens_per_expert
+
+def binned_experts_ref(
+    hidden_states,
+    router_indices,
+    routing_weights,
+    gate_up_proj,
+    gate_up_proj_bias,
+    down_proj,
+    down_proj_bias,
+    expert_capacity,
+):
+    B, S, H = hidden_states.shape
+    E, K = routing_weights.shape[1], router_indices.shape[1]
+
+    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
+    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
+
+    gate_up = torch.bmm(x, gate_up_proj) 
+    gate_up += gate_up_proj_bias[..., None, :]
+
+    gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+
+    # clamp to limit
+    limit = 7.0
+    gate = gate.clamp(min=None, max=limit)
+    up = up.clamp(min=-limit, max=limit)
+
+    glu = gate * torch.sigmoid(gate * 1.702)
+    x = (up + 1) * glu
+    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
+
+    # build routing weights aligned to (token, slot)
+    flat_dense = routing_weights.view(-1, E)
+    flat_router = router_indices.view(-1, K)
+    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
+
+    # scatter back
+    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
+
+    return y.view(B, S, H)
+
+class BinnedRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+class BinnedMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = BinnedRouter(router_weight, router_bias)
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.top_k = TOP_K
+
+        # Expert weights - use the loaded weights
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        batch_size = hidden_states.shape[0]
+        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
+
+        output = binned_experts_ref(
+            hidden_states,
+            router_indices,
+            router_scores,
+            self.gate_up_proj,
+            self.gate_up_proj_bias,
+            self.down_proj,
+            self.down_proj_bias,
+            expert_capacity,
+        )
+
+        return output, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== Binned Implementation ===")
+# Initialize model with loaded weights
+model = BinnedMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate the same input as Yamoe
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== Binned Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 38.434 ms) + Progress: 40% complete (avg: 38.074 ms) + Progress: 60% complete (avg: 37.541 ms) + Progress: 80% complete (avg: 36.952 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 36.479 ms + Min: 33.550 ms + Max: 39.617 ms + Std Dev: 1.587 ms + +Percentiles: + P50 (median): 36.436 ms + P95: 39.168 ms + P99: 39.480 ms + +Throughput: + Tokens/sec: 2741.3 + Std Dev: 119.0 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to binned_results.json + +Output sum: 3.971905 +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+binned_results.json +
+
+
+ +

GPT-OSS Implementation

+

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: gptoss_run | deps: torch, numpy | 39.59s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class GptOssExperts(nn.Module):
+    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_dim = self.hidden_size
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+        self.alpha = 1.702
+        self.limit = 7.0
+
+    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states.reshape(-1, self.hidden_size)
+        num_experts = routing_weights.shape[1]
+
+        if hidden_states.device.type == "cpu" or self.training:
+            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+            with torch.no_grad():
+                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+                expert_mask = expert_mask.permute(2, 1, 0)
+                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+            for expert_idx in expert_hit[:]:
+                expert_idx = expert_idx[0]
+                with torch.no_grad():
+                    _, token_idx = torch.where(expert_mask[expert_idx])
+                current_state = hidden_states[token_idx]
+                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+                gate = gate.clamp(min=None, max=self.limit)
+                up = up.clamp(min=-self.limit, max=self.limit)
+                glu = gate * torch.sigmoid(gate * self.alpha)
+                gated_output = (up + 1) * glu
+                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+                weighted_output = out * routing_weights[token_idx, expert_idx, None]
+                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+            next_states = next_states.view(batch_size, -1, self.hidden_size)
+        else:
+            hidden_states = hidden_states.repeat(num_experts, 1)
+            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+            gate = gate.clamp(min=None, max=self.limit)
+            up = up.clamp(min=-self.limit, max=self.limit)
+            glu = gate * torch.sigmoid(gate * self.alpha)
+            next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+            next_states = next_states + self.down_proj_bias[..., None, :]
+            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+            next_states = next_states.sum(dim=0)
+        return next_states
+
+class GptOssMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = GptOssRouter(router_weight, router_bias)
+        self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+        return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation ===")
+# Initialize model with loaded weights
+model = GptOssMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== GPT-OSS Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 48.022 ms) + Progress: 40% complete (avg: 47.956 ms) + Progress: 60% complete (avg: 47.209 ms) + Progress: 80% complete (avg: 46.045 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 45.011 ms + Min: 39.029 ms + Max: 49.295 ms + Std Dev: 2.980 ms + +Percentiles: + P50 (median): 45.672 ms + P95: 48.489 ms + P99: 49.056 ms + +Throughput: + Tokens/sec: 2221.7 + Std Dev: 151.3 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to gptoss_results.json + +Output sum: 11.532237 +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+gptoss_results.json +
+
+
+ +

GPT-OSS Implementation (Training Mode)

+

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: gptoss_training_run | deps: torch, numpy | 39.07s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssTrainingRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class GptOssTrainingExperts(nn.Module):
+    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_dim = self.hidden_size
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+        self.alpha = 1.702
+        self.limit = 7.0
+
+    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states.reshape(-1, self.hidden_size)
+        num_experts = routing_weights.shape[1]
+
+        # Force training mode path (expert loop instead of batched)
+        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+        with torch.no_grad():
+            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+            expert_mask = expert_mask.permute(2, 1, 0)
+            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+        for expert_idx in expert_hit[:]:
+            expert_idx = expert_idx[0]
+            with torch.no_grad():
+                _, token_idx = torch.where(expert_mask[expert_idx])
+            current_state = hidden_states[token_idx]
+            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+            gate = gate.clamp(min=None, max=self.limit)
+            up = up.clamp(min=-self.limit, max=self.limit)
+            glu = gate * torch.sigmoid(gate * self.alpha)
+            gated_output = (up + 1) * glu
+            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+            weighted_output = out * routing_weights[token_idx, expert_idx, None]
+            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+        next_states = next_states.view(batch_size, -1, self.hidden_size)
+        return next_states
+
+class GptOssTrainingMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = GptOssTrainingRouter(router_weight, router_bias)
+        self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+        return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
+# Initialize model with loaded weights and force training mode
+model = GptOssTrainingMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+# Set to training mode to force expert loop path
+model.train()
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+print(f"Model training mode: {model.training}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== GPT-OSS Implementation (Training Mode - Expert Loop) === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 +Model training mode: True + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 48.048 ms) + Progress: 40% complete (avg: 47.576 ms) + Progress: 60% complete (avg: 46.769 ms) + Progress: 80% complete (avg: 45.726 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 44.679 ms + Min: 38.109 ms + Max: 49.008 ms + Std Dev: 2.899 ms + +Percentiles: + P50 (median): 45.400 ms + P95: 48.408 ms + P99: 48.790 ms + +Throughput: + Tokens/sec: 2238.2 + Std Dev: 150.3 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to gptoss_training_results.json + +Output sum: 11.532237 +
+
+
▶ UV Install Logs
+ +
+ +
+
+ +

MegaBlocks Implementation

+

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: megablocks_run | deps: torch, numpy, kernels | 40.94s | FAILED + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+from collections import namedtuple
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def build_megablocks_model(device: torch.device):
+    # Download optimized kernels from the Hugging Face hub
+    megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
+    model = megablocks.layers.MegaBlocksMoeMLP()
+
+    # Create attribute container for expert weights
+    model.experts = namedtuple(
+        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
+    )
+
+    # Use loaded router weights for consistency
+    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
+    with torch.no_grad():
+        model.router.weight.copy_(router_weight)
+        model.router.bias.copy_(router_bias)
+
+    # Attach loaded expert weights to the experts container
+    e = model.experts
+    e.alpha = 1.702
+    e.capacity_factor = 128
+    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
+    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
+    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
+    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
+    e.hidden_size = HIDDEN_SIZE
+
+    # Log weight statistics for comparison
+    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
+    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
+    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
+
+    return model
+
+# Create a wrapper to match the interface of other implementations
+class MegaBlocksMoEWrapper(nn.Module):
+    def __init__(self, megablocks_model):
+        super().__init__()
+        self.model = megablocks_model
+
+    def forward(self, hidden_states):
+        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
+        output, dummy_routing_weights = self.model(hidden_states)
+        return output, dummy_routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== MegaBlocks Implementation ===")
+# Build MegaBlocks model with loaded weights
+megablocks_model = build_megablocks_model(device)
+model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/f8744f31d9cf720409852d42748815c6d61f005a2a9b297b7b9bf986ed98bb90 +Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== MegaBlocks Implementation === +[MegaBlocks] Router weight sum: 12.588732 +[MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 +[MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +
+
+
▶ UV Install Logs
+ +
+
Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] +Fetching 66 files: 2%|▏ | 1/66 [00:00<00:25, 2.52it/s] +Fetching 66 files: 6%|▌ | 4/66 [00:00<00:06, 9.13it/s] +Fetching 66 files: 15%|█▌ | 10/66 [00:00<00:03, 16.75it/s] +Fetching 66 files: 21%|██ | 14/66 [00:00<00:02, 19.64it/s] +Fetching 66 files: 26%|██▌ | 17/66 [00:01<00:04, 12.25it/s] +Fetching 66 files: 45%|████▌ | 30/66 [00:01<00:01, 26.10it/s] +Fetching 66 files: 59%|█���███▉ | 39/66 [00:01<00:00, 27.77it/s] +Fetching 66 files: 76%|███████▌ | 50/66 [00:01<00:00, 38.40it/s] +Fetching 66 files: 85%|████████▍ | 56/66 [00:02<00:00, 40.98it/s] +Fetching 66 files: 94%|█████████▍| 62/66 [00:02<00:00, 36.13it/s] +Fetching 66 files: 100%|██████████| 66/66 [00:02<00:00, 27.20it/s] +/tmp/tmps8crtj9h/cuda_utils.c:5:10: fatal error: Python.h: No such file or directory + 5 | #include <Python.h> + | ^~~~~~~~~~ +compilation terminated. +Traceback (most recent call last): + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/megablocks_run.py", line 102, in <module> + output, stats = bench(model, x) + ^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 189, in runner + result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 96, in _bench_engine + _ = call(input_gen()) + ^^^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 177, in <lambda> + call = lambda x: fn(x, *args[1:], **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/megablocks_run.py", line 81, in forward + output, dummy_routing_weights = self.model(hidden_states) + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 896, in forward + output, expert_weights_out, *_ = moe_forward( + ^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 730, in moe_forward + x, tokens_per_expert = forward_fn(**forward_args) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 457, in forward_once + x = permute_and_compute( + ^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 401, in permute_and_compute + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/autograd/function.py", line 576, in apply + return super().apply(*args, **kwargs) # type: ignore[misc] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py", line 30, in decorate_fwd + return fwd(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py", line 26, in forward + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py", line 419, in binned_gather + _binned_copy[(num_experts, expert_capacity)]( + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/jit.py", line 390, in <lambda> + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 239, in run + benchmark() + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 228, in benchmark + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 228, in <dictcomp> + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 160, in _bench + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + ^^^^^^^^^^^^^ + File "/usr/lib/python3.11/functools.py", line 1001, in __get__ + val = self.func(instance) + ^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 121, in do_bench + return driver.active.get_benchmarker() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 30, in __getattr__ + return getattr(self._initialize_obj(), name) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 26, in _initialize_obj + self._obj = self._init_fn() + ^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 12, in _create_driver + return active_drivers[0]() + ^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 715, in __init__ + self.utils = CudaUtils() # TODO: make static + ^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 62, in __init__ + mod = compile_module_from_src( + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/build.py", line 88, in compile_module_from_src + so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or []) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/build.py", line 51, in _build + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + File "/usr/lib/python3.11/subprocess.py", line 413, in check_call + raise CalledProcessError(retcode, cmd) +subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmps8crtj9h/cuda_utils.c', '-O3', '-shared', '-fPIC', '-Wno-psabi', '-o', '/tmp/tmps8crtj9h/cuda_utils.cpython-311-x86_64-linux-gnu.so', '-lcuda', '-L/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/lib', '-L/usr/lib/x86_64-linux-gnu', '-I/tmp/uvnote-run-_d5r222t/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/include', '-I/tmp/tmps8crtj9h', '-I/usr/include/python3.11']' returned non-zero exit status 1.
+
+
+ +

Performance Visualization

+

This section reads all benchmark results and creates a comprehensive performance comparison chart.

+