+
+
+ +▼ code +▼ output + | +Cell: utils | deps: torch, numpy | 30.61s + | + +
+
+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
"""Simple utilities for running the models."""
+import torch
+
+def to_dtype(dtype_str: str):
+    """Convert string to torch dtype."""
+    if dtype_str == "float16":
+        return torch.float16
+    if dtype_str == "bfloat16":
+        return torch.bfloat16
+    return torch.float32
+
+def tensor_stats(t: torch.Tensor) -> str:
+    """Generate stats string for a tensor."""
+    return (f"shape={tuple(t.shape)}, "
+            f"dtype={t.dtype}, "
+            f"device={t.device}, "
+            f"mean={t.mean().item():.6f}, "
+            f"std={t.std().item():.6f}")
+
+def set_seed(seed: int):
+    """Set seeds for reproducibility."""
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+ +
+
+
Downloading setuptools (1.1MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading sympy (6.0MiB) +Downloading torch (846.8MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading numpy (15.9MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading triton (148.4MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading sympy + Downloading nvidia-nvjitlink-cu12 + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 234ms +
+
+
+ +
+
+ +▼ code +▼ output + | +Cell: bench_utils | deps: torch, numpy | 31.57s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
"""Reusable benchmarking utilities for performance testing."""
+import time
+import numpy as np
+from contextlib import contextmanager
+from typing import Callable, Dict, Tuple, Any, Optional
+import torch
+
+def to_dtype(dtype_str: str):
+    """Convert string to torch dtype."""
+    if dtype_str == "float16":
+        return torch.float16
+    if dtype_str == "bfloat16":
+        return torch.bfloat16
+    return torch.float32
+
+def _sync(device: str):
+    """Synchronize device if CUDA."""
+    if device == "cuda":
+        torch.cuda.synchronize()
+
+def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
+    """Compute comprehensive latency and throughput statistics."""
+    lat_ms = np.array([t * 1000.0 for t in times_s])
+    lat_ms_sorted = np.sort(lat_ms)
+    n = len(lat_ms)
+
+    stats = {
+        "avg_ms": np.mean(lat_ms),
+        "min_ms": np.min(lat_ms),
+        "max_ms": np.max(lat_ms),
+        "std_ms": np.std(lat_ms),
+        "p50_ms": np.percentile(lat_ms, 50),
+        "p95_ms": np.percentile(lat_ms, 95),
+        "p99_ms": np.percentile(lat_ms, 99),
+        "num_iters": n
+    }
+
+    if tokens is not None and n > 0:
+        avg_s = np.mean(times_s)
+        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
+        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
+
+    return stats
+
+def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
+    """Format timing statistics for display."""
+    lines = [
+        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
+        f"Iterations: {stats.get('num_iters', 0)}",
+        "\nLatency Statistics:",
+        f"  Average: {stats['avg_ms']:.3f} ms",
+        f"  Min:     {stats['min_ms']:.3f} ms",
+        f"  Max:     {stats['max_ms']:.3f} ms", 
+        f"  Std Dev: {stats['std_ms']:.3f} ms",
+        "\nPercentiles:",
+        f"  P50 (median): {stats['p50_ms']:.3f} ms",
+        f"  P95:          {stats['p95_ms']:.3f} ms",
+        f"  P99:          {stats['p99_ms']:.3f} ms",
+    ]
+
+    if tokens is not None and 'tokens_per_s' in stats:
+        lines.extend([
+            "\nThroughput:",
+            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
+            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
+        ])
+
+    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
+    return "\n".join(lines)
+
+def _bench_engine(
+    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype
+) -> Tuple[Any, list]:
+    """Core benchmarking engine with warmup and timing."""
+    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
+
+    # Warmup phase
+    print(f"\nWarming up ({warmup} iterations)...")
+    with torch.inference_mode():
+        for _ in range(max(0, warmup)):
+            if use_autocast:
+                with torch.autocast(device_type="cuda", dtype=dtype):
+                    _ = call()
+            else:
+                _ = call()
+        _sync(device)
+
+    # Measurement phase
+    print(f"Benchmarking ({iters} iterations)...")
+    times_s = []
+    last = None
+    with torch.inference_mode():
+        for i in range(max(1, iters)):
+            start = time.perf_counter()
+            if use_autocast:
+                with torch.autocast(device_type="cuda", dtype=dtype):
+                    last = call()
+            else:
+                last = call()
+            _sync(device)
+            end = time.perf_counter()
+            times_s.append(end - start)
+
+            # Progress indicator every 20% of iterations
+            if i > 0 and i % max(1, iters // 5) == 0:
+                pct = (i / iters) * 100
+                avg_so_far = np.mean(times_s[:i]) * 1000
+                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
+
+    return last, times_s
+
+def tensor_stats(t: torch.Tensor) -> str:
+    """Generate comprehensive stats string for a tensor."""
+    return (f"shape={tuple(t.shape)}, "
+            f"dtype={t.dtype}, "
+            f"device={t.device}, "
+            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
+            f"mean={t.mean().item():.6f}, "
+            f"std={t.std().item():.6f}, "
+            f"norm={t.norm().item():.6f}")
+
+@contextmanager
+def bench_context(
+    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None
+):
+    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats)."""
+
+    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
+        # Log configuration
+        if verbose:
+            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
+            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
+            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
+            if tokens:
+                print(f"│ Tokens: {tokens}                                        │")
+            print(f"└────────────────────────────────────────────────────────┘")
+
+        # Log input if it's a tensor
+        if verbose and args and isinstance(args[0], torch.Tensor):
+            print(f"\nInput: {tensor_stats(args[0])}")
+
+        call = lambda: fn(*args, **kwargs)
+        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype)
+
+        # Log output if it's a tensor or tuple with tensors
+        if verbose:
+            print("\nOutput tensors:")
+            if isinstance(result, torch.Tensor):
+                print(f"  Primary: {tensor_stats(result)}")
+            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
+                print(f"  Primary: {tensor_stats(result[0])}")
+                if len(result) > 1:
+                    if isinstance(result[1], torch.Tensor):
+                        print(f"  Auxiliary: {tensor_stats(result[1])}")
+                    else:
+                        print(f"  Auxiliary: {type(result[1]).__name__}")
+
+        # Compute and display statistics
+        stats = _compute_stats(times_s, tokens=tokens)
+        if verbose:
+            print(_format_timing_stats(stats, tokens))
+
+        # Save to JSON if requested
+        if save_json:
+            import json
+            json_data = {
+                "implementation": save_json.replace(".json", ""),
+                "config": {
+                    "warmup": warmup,
+                    "iters": iters,
+                    "device": str(device),  # Convert device to string
+                    "dtype": str(dtype),
+                    "tokens": tokens
+                },
+                "stats": stats,
+                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
+            }
+            with open(save_json, 'w') as f:
+                json.dump(json_data, f, indent=2)
+            if verbose:
+                print(f"\nSaved benchmark results to {save_json}")
+
+        return result, stats
+
+    yield runner
+
+def set_seed(seed: int):
+    """Set seeds for reproducibility."""
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+ +
+
+
Downloading networkx (1.9MiB) +Downloading numpy (15.9MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading setuptools (1.1MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading sympy (6.0MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading torch (846.8MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading triton (148.4MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading nvidia-nvjitlink-cu12 + Downloading sympy + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 234ms +
+
+
+ +

This notebook runs the Yamoe and Binned MoE implementations once each with identical inputs to verify they produce consistent outputs.

+
+
+ +▼ code +▼ output + | +Cell: config | deps: torch, numpy | 37.88s + | + +
+
+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
"""Shared configuration for both implementations."""
+import torch
+
+# Model configuration
+NUM_EXPERTS = 128
+HIDDEN_SIZE = 1152
+INTERMEDIATE_SIZE = 3072
+TOP_K = 4
+
+# Input configuration
+BATCH_SIZE = 1
+SEQ_LEN = 100
+DTYPE = "float32"
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+# Seeds for reproducibility
+WEIGHT_SEED = 999
+EXPERT_SEED = 777
+INPUT_SEED = 123
+GENERAL_SEED = 42
+
+ +
+
+
Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading triton (148.4MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading setuptools (1.1MiB) +Downloading sympy (6.0MiB) +Downloading torch (846.8MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading numpy (15.9MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading sympy + Downloading nvidia-nvjitlink-cu12 + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 225ms +
+
+
+ +
+
+ +▼ code +▼ output + | +Cell: save_data | deps: torch, numpy | 38.59s + | + +
+
+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
"""
+Generate deterministic shared weights once and save as artifacts so
+both implementations load identical parameters.
+"""
+import torch
+from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
+
+def save_shared_weights():
+    # Router: Kaiming uniform as used by both, bias zeros
+    torch.manual_seed(WEIGHT_SEED)
+    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
+    torch.nn.init.kaiming_uniform_(router_weight)
+    router_bias = torch.zeros(NUM_EXPERTS)
+
+    # Experts: normal(0, 0.02), biases zeros
+    torch.manual_seed(EXPERT_SEED)
+    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
+    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
+    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
+    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
+
+    # Save artifacts
+    torch.save(router_weight, 'router_weight.pt')
+    torch.save(router_bias, 'router_bias.pt')
+    torch.save(gate_up_proj, 'gate_up_proj.pt')
+    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
+    torch.save(down_proj, 'down_proj.pt')
+    torch.save(down_proj_bias, 'down_proj_bias.pt')
+
+    print("Saved shared weights to artifacts")
+    print(f"Router weight sum: {router_weight.sum().item():.6f}")
+    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+    print(f"Down sum: {down_proj.sum().item():.6f}")
+
+save_shared_weights()
+
+ +
+
+
Saved shared weights to artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 +
+
Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading triton (148.4MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading sympy (6.0MiB) +Downloading numpy (15.9MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading setuptools (1.1MiB) +Downloading torch (846.8MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading nvidia-nvjitlink-cu12 + Downloading sympy + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 239ms +
+ +
+
+ +

Yamoe Implementation

+

This section runs the Yamoe MoE implementation with optimized Triton kernels.

+
+
+ +▼ code +▼ output + | +Cell: yamoe_run | deps: torch, kernels, numpy | 35.75s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class YamoeRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+
+class YamoeMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = YamoeRouter(router_weight, router_bias)
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.top_k = TOP_K
+
+        # Load Yamoe kernel
+        # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
+        self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
+
+        # Expert capacity - generous to avoid dropping tokens
+        # self.expert_capacity = 256
+        self.expert_capacity = 12
+
+        # Expert weights - use the loaded weights
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+    def forward(self, hidden_states):
+        batch_size, seq_len, hidden_dim = hidden_states.shape
+
+        # Get routing decisions
+        routing_weights, router_indices = self.router(hidden_states)
+
+        # Reshape for Yamoe kernel
+        hidden_states_flat = hidden_states.view(-1, hidden_dim)
+        routing_weights_flat = routing_weights.view(-1, self.num_experts)
+
+        # Call Yamoe optimized kernel
+        output = self.yamoe.experts(
+            hidden_states_flat,
+            router_indices,
+            routing_weights_flat,
+            self.gate_up_proj,
+            self.gate_up_proj_bias,
+            self.down_proj,
+            self.down_proj_bias,
+            self.expert_capacity,
+            self.num_experts,
+            self.top_k,
+        )
+
+        # Reshape output back
+        output = output.view(batch_size, seq_len, hidden_dim)
+
+        return output, routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
+dtype = to_dtype(DTYPE)
+
+print("\n=== Yamoe Implementation ===")
+# Initialize model with loaded weights
+model = YamoeMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate input
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json") as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73 +Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== Yamoe Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +└────────────────────────────────────────────────────────┘ + +Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 8.633 ms) + Progress: 40% complete (avg: 8.627 ms) + Progress: 60% complete (avg: 8.629 ms) + Progress: 80% complete (avg: 8.630 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 8.631 ms + Min: 8.526 ms + Max: 8.661 ms + Std Dev: 0.022 ms + +Percentiles: + P50 (median): 8.636 ms + P95: 8.653 ms + P99: 8.658 ms + +Throughput: + Tokens/sec: 11586.6 + Std Dev: 29.1 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to yamoe_results.json + +Output sum: -0.597250 +
+
Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading sympy (6.0MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading numpy (15.9MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading torch (846.8MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading hf-xet (3.0MiB) +Downloading setuptools (1.1MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading triton (148.4MiB) + Downloading nvidia-cufile-cu12 + Downloading hf-xet + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading sympy + Downloading nvidia-nvjitlink-cu12 + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 37 packages in 287ms + +Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] +Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:01, 3.90it/s] +Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.70it/s] +Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 7.44it/s] +
+
+

Artifacts:

+yamoe_results.json +
+
+
+ +

Binned Implementation

+

This section runs the binned implementation that manually handles token gathering/scattering.

+
+
+ +▼ code +▼ output + | +Cell: binned_run | deps: torch, numpy | 42.05s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def binned_gather(x, indices, bins, expert_capacity, top_k):
+    E, H = bins.shape[0], x.shape[1]
+    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
+    for e in range(E):
+        start = 0 if e == 0 else bins[e - 1]
+        end = bins[e]
+        n = min(end - start, expert_capacity)
+        for i in range(n):
+            flat_pos = indices[start + i]
+            tok = flat_pos // top_k
+            out[e, i] = x[tok]
+    return out
+
+def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
+    E, C, H = x.shape
+    N = indices.shape[0] // top_k
+    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
+    for e in range(E):
+        start = 0 if e == 0 else bins[e - 1]
+        end = bins[e]
+        n = end - start
+        if n == 0:
+            continue
+        take = min(n, expert_capacity)
+        for i in range(take):
+            flat_pos = indices[start + i]
+            tok = flat_pos // top_k
+            slot = flat_pos % top_k
+            scale = weights[flat_pos] if weights is not None else 1.0
+            out[tok, slot] = x[e, i] * scale
+    return out.sum(dim=1)
+
+def sort_tokens_by_expert(router_indices, num_experts):
+    flat_indices = router_indices.flatten()
+    sorted_values, sorted_indices = torch.sort(flat_indices)
+    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
+    bins = torch.cumsum(tokens_per_expert, dim=0)
+    return sorted_indices, sorted_values, bins, tokens_per_expert
+
+def binned_experts_ref(
+    hidden_states,
+    router_indices,
+    routing_weights,
+    gate_up_proj,
+    gate_up_proj_bias,
+    down_proj,
+    down_proj_bias,
+    expert_capacity,
+):
+    B, S, H = hidden_states.shape
+    E, K = routing_weights.shape[1], router_indices.shape[1]
+
+    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
+    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
+
+    gate_up = torch.bmm(x, gate_up_proj) 
+    gate_up += gate_up_proj_bias[..., None, :]
+
+    gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+
+    # clamp to limit
+    limit = 7.0
+    gate = gate.clamp(min=None, max=limit)
+    up = up.clamp(min=-limit, max=limit)
+
+    glu = gate * torch.sigmoid(gate * 1.702)
+    x = (up + 1) * glu
+    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
+
+    # build routing weights aligned to (token, slot)
+    flat_dense = routing_weights.view(-1, E)
+    flat_router = router_indices.view(-1, K)
+    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
+
+    # scatter back
+    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
+
+    return y.view(B, S, H)
+
+class BinnedRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class BinnedMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = BinnedRouter(router_weight, router_bias)
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_capacity = 256
+
+        # Expert weights - use the loaded weights
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+
+        output = binned_experts_ref(
+            hidden_states,
+            router_indices,
+            router_scores,
+            self.gate_up_proj,
+            self.gate_up_proj_bias,
+            self.down_proj,
+            self.down_proj_bias,
+            self.expert_capacity,
+        )
+
+        return output, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== Binned Implementation ===")
+# Initialize model with loaded weights
+model = BinnedMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate the same input as Yamoe
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json") as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== Binned Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +└────────────────────────────────────────────────────────┘ + +Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 104.222 ms) + Progress: 40% complete (avg: 104.671 ms) + Progress: 60% complete (avg: 105.372 ms) + Progress: 80% complete (avg: 105.570 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 105.618 ms + Min: 103.417 ms + Max: 107.809 ms + Std Dev: 1.458 ms + +Percentiles: + P50 (median): 105.048 ms + P95: 107.729 ms + P99: 107.790 ms + +Throughput: + Tokens/sec: 946.8 + Std Dev: 13.0 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to binned_results.json + +Output sum: -0.597248 +
+
Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading triton (148.4MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading torch (846.8MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading setuptools (1.1MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading sympy (6.0MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading numpy (15.9MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading nvidia-nvjitlink-cu12 + Downloading sympy + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 233ms +
+
+

Artifacts:

+binned_results.json +
+
+
+ +

GPT-OSS Implementation

+

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

+
+
+ +▼ code +▼ output + | +Cell: gptoss_run | deps: torch, numpy | 37.86s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class GptOssExperts(nn.Module):
+    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_dim = self.hidden_size
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+        self.alpha = 1.702
+        self.limit = 7.0
+
+    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states.reshape(-1, self.hidden_size)
+        num_experts = routing_weights.shape[1]
+
+        if hidden_states.device.type == "cpu" or self.training:
+            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+            with torch.no_grad():
+                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+                expert_mask = expert_mask.permute(2, 1, 0)
+                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+            for expert_idx in expert_hit[:]:
+                expert_idx = expert_idx[0]
+                with torch.no_grad():
+                    _, token_idx = torch.where(expert_mask[expert_idx])
+                current_state = hidden_states[token_idx]
+                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+                gate = gate.clamp(min=None, max=self.limit)
+                up = up.clamp(min=-self.limit, max=self.limit)
+                glu = gate * torch.sigmoid(gate * self.alpha)
+                gated_output = (up + 1) * glu
+                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+                weighted_output = out * routing_weights[token_idx, expert_idx, None]
+                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+            next_states = next_states.view(batch_size, -1, self.hidden_size)
+        else:
+            hidden_states = hidden_states.repeat(num_experts, 1)
+            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+            gate = gate.clamp(min=None, max=self.limit)
+            up = up.clamp(min=-self.limit, max=self.limit)
+            glu = gate * torch.sigmoid(gate * self.alpha)
+            next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+            next_states = next_states + self.down_proj_bias[..., None, :]
+            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+            next_states = next_states.sum(dim=0)
+        return next_states
+
+class GptOssMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = GptOssRouter(router_weight, router_bias)
+        self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+        return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation ===")
+# Initialize model with loaded weights
+model = GptOssMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json") as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== GPT-OSS Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +└────────────────────────────────────────────────────────┘ + +Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 46.973 ms) + Progress: 40% complete (avg: 47.262 ms) + Progress: 60% complete (avg: 47.067 ms) + Progress: 80% complete (avg: 46.985 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 47.135 ms + Min: 46.582 ms + Max: 47.895 ms + Std Dev: 0.503 ms + +Percentiles: + P50 (median): 46.789 ms + P95: 47.801 ms + P99: 47.856 ms + +Throughput: + Tokens/sec: 2121.6 + Std Dev: 22.5 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to gptoss_results.json + +Output sum: -0.597250 +
+
Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading setuptools (1.1MiB) +Downloading sympy (6.0MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading numpy (15.9MiB) +Downloading networkx (1.9MiB) +Downloading torch (846.8MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading triton (148.4MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading sympy + Downloading nvidia-nvjitlink-cu12 + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 241ms +
+
+

Artifacts:

+gptoss_results.json +
+
+
+ +

GPT-OSS Implementation (Training Mode)

+

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

+
+
+ +▼ code +▼ output + | +Cell: gptoss_training_run | deps: torch, numpy | 36.75s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssTrainingRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class GptOssTrainingExperts(nn.Module):
+    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_dim = self.hidden_size
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+        self.alpha = 1.702
+        self.limit = 7.0
+
+    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states.reshape(-1, self.hidden_size)
+        num_experts = routing_weights.shape[1]
+
+        # Force training mode path (expert loop instead of batched)
+        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+        with torch.no_grad():
+            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+            expert_mask = expert_mask.permute(2, 1, 0)
+            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+        for expert_idx in expert_hit[:]:
+            expert_idx = expert_idx[0]
+            with torch.no_grad():
+                _, token_idx = torch.where(expert_mask[expert_idx])
+            current_state = hidden_states[token_idx]
+            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+            gate = gate.clamp(min=None, max=self.limit)
+            up = up.clamp(min=-self.limit, max=self.limit)
+            glu = gate * torch.sigmoid(gate * self.alpha)
+            gated_output = (up + 1) * glu
+            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+            weighted_output = out * routing_weights[token_idx, expert_idx, None]
+            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+        next_states = next_states.view(batch_size, -1, self.hidden_size)
+        return next_states
+
+class GptOssTrainingMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = GptOssTrainingRouter(router_weight, router_bias)
+        self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+        return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
+# Initialize model with loaded weights and force training mode
+model = GptOssTrainingMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+# Set to training mode to force expert loop path
+model.train()
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+print(f"Model training mode: {model.training}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json") as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== GPT-OSS Implementation (Training Mode - Expert Loop) === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 +Model training mode: True + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +└────────────────────────────────────────────────────────┘ + +Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 48.328 ms) + Progress: 40% complete (avg: 48.764 ms) + Progress: 60% complete (avg: 48.825 ms) + Progress: 80% complete (avg: 48.769 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 48.630 ms + Min: 47.535 ms + Max: 49.414 ms + Std Dev: 0.559 ms + +Percentiles: + P50 (median): 48.395 ms + P95: 49.346 ms + P99: 49.390 ms + +Throughput: + Tokens/sec: 2056.3 + Std Dev: 23.6 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to gptoss_training_results.json + +Output sum: -0.597250 +
+
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading numpy (15.9MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading sympy (6.0MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading setuptools (1.1MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading triton (148.4MiB) +Downloading torch (846.8MiB) + Downloading nvidia-cufile-cu12 + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading nvidia-nvjitlink-cu12 + Downloading sympy + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 26 packages in 234ms +
+ +
+
+ +

MegaBlocks Implementation

+

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

+
+
+ +▼ code +▼ output + | +Cell: megablocks_run | deps: torch, numpy, kernels | 43.51s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+from collections import namedtuple
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def build_megablocks_model(device: torch.device):
+    # Download optimized kernels from the Hugging Face hub
+    megablocks = get_kernel("kernels-community/megablocks")
+
+    # megablocks = get_local_kernel(
+    #     Path("/home/ubuntu/Projects/megablocks-moe/build"), "megablocks")
+
+    model = megablocks.layers.MegaBlocksMoeMLP()
+
+    # Create attribute container for expert weights
+    model.experts = namedtuple(
+        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
+    )
+
+    # Use loaded router weights for consistency
+    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
+    with torch.no_grad():
+        model.router.weight.copy_(router_weight)
+        model.router.bias.copy_(router_bias)
+
+    # Attach loaded expert weights to the experts container
+    e = model.experts
+    e.alpha = 1.702
+    e.capacity_factor = 4
+    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
+    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
+    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
+    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
+    e.hidden_size = HIDDEN_SIZE
+
+    # Log weight statistics for comparison
+    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
+    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
+    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
+
+    return model
+
+# Create a wrapper to match the interface of other implementations
+class MegaBlocksMoEWrapper(nn.Module):
+    def __init__(self, megablocks_model):
+        super().__init__()
+        self.model = megablocks_model
+
+    def forward(self, hidden_states):
+        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
+        output, dummy_routing_weights = self.model(hidden_states)
+        # Return output and dummy routing weights for consistency with other implementations
+        # dummy_routing_weights = torch.zeros(
+        #     hidden_states.shape[0] * hidden_states.shape[1], 
+        #     NUM_EXPERTS, 
+        #     device=hidden_states.device,
+        #     dtype=hidden_states.dtype
+        # )
+        return output, dummy_routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== MegaBlocks Implementation ===")
+# Build MegaBlocks model with loaded weights
+megablocks_model = build_megablocks_model(device)
+model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json") as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73 +Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== MegaBlocks Implementation === +[MegaBlocks] Router weight sum: 12.588732 +[MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 +[MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +└────────────────────────────────────────────────────────┘ + +Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 0.867 ms) + Progress: 40% complete (avg: 0.853 ms) + Progress: 60% complete (avg: 1.181 ms) + Progress: 80% complete (avg: 3.026 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 + Auxiliary: shape=(100, 4), dtype=torch.float32, device=cuda:0, range=[0.220910, 0.294473], mean=0.250000, std=0.010777, norm=5.004632 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 4.133 ms + Min: 0.823 ms + Max: 8.589 ms + Std Dev: 3.781 ms + +Percentiles: + P50 (median): 0.864 ms + P95: 8.579 ms + P99: 8.589 ms + +Throughput: + Tokens/sec: 24194.9 + Std Dev: 52511.7 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to megablocks_results.json + +Output sum: -0.597249 +
+
Downloading setuptools (1.1MiB) +Downloading nvidia-cudnn-cu12 (674.0MiB) +Downloading numpy (15.9MiB) +Downloading nvidia-cusparse-cu12 (274.9MiB) +Downloading nvidia-nvjitlink-cu12 (37.4MiB) +Downloading hf-xet (3.0MiB) +Downloading nvidia-cusolver-cu12 (255.1MiB) +Downloading networkx (1.9MiB) +Downloading nvidia-cufft-cu12 (184.2MiB) +Downloading nvidia-cufile-cu12 (1.1MiB) +Downloading triton (148.4MiB) +Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) +Downloading nvidia-curand-cu12 (60.7MiB) +Downloading sympy (6.0MiB) +Downloading nvidia-cuda-cupti-cu12 (9.8MiB) +Downloading nvidia-nccl-cu12 (307.4MiB) +Downloading nvidia-cusparselt-cu12 (273.9MiB) +Downloading nvidia-cublas-cu12 (566.8MiB) +Downloading torch (846.8MiB) + Downloading nvidia-cufile-cu12 + Downloading hf-xet + Downloading setuptools + Downloading networkx + Downloading nvidia-cuda-cupti-cu12 + Downloading numpy + Downloading sympy + Downloading nvidia-nvjitlink-cu12 + Downloading nvidia-curand-cu12 + Downloading nvidia-cuda-nvrtc-cu12 + Downloading triton + Downloading nvidia-cufft-cu12 + Downloading nvidia-cusolver-cu12 + Downloading nvidia-cusparse-cu12 + Downloading nvidia-cusparselt-cu12 + Downloading nvidia-nccl-cu12 + Downloading nvidia-cublas-cu12 + Downloading nvidia-cudnn-cu12 + Downloading torch +Installed 37 packages in 216ms + +Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] +Fetching 66 files: 2%|▏ | 1/66 [00:00<00:22, 2.87it/s] +Fetching 66 files: 26%|██▌ | 17/66 [00:01<00:04, 11.84it/s] +Fetching 66 files: 100%|██████████| 66/66 [00:01<00:00, 43.56it/s] +
+
+

Artifacts:

+megablocks_results.json +
+
+
+ +

Performance Visualization

+

This section reads all benchmark results and creates a comprehensive performance comparison chart.

+
+
+ +▼ code +▼ output + | +Cell: visualization | deps: matplotlib | 3.96s + | + +
+
+
  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
import json
+import matplotlib.pyplot as plt
+import numpy as np
+from pathlib import Path
+import os
+
+# List of expected result files
+yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.')
+binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.')
+gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.')
+gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.')
+megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.')
+
+result_files = [
+    Path(yamoe_dir) / "yamoe_results.json",
+    Path(binned_dir) / "binned_results.json", 
+    Path(gptoss_dir) / "gptoss_results.json",
+    Path(gptoss_training_dir) / "gptoss_training_results.json",
+    Path(megablocks_dir) / "megablocks_results.json"
+]
+
+# Load all benchmark results
+results = {}
+for file in result_files:
+    if Path(file).exists():
+        with open(file, 'r') as f:
+            data = json.load(f)
+            results[data['implementation']] = data
+        print(f"Loaded {file}")
+    else:
+        print(f"Missing {file}")
+
+if not results:
+    print("No benchmark results found. Run the benchmark cells first.")
+else:
+    # Extract data for plotting
+    implementations = list(results.keys())
+    avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations]
+    p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations]
+    throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations]
+
+    # Create figure with subplots
+    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
+    fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold')
+
+    # Colors for each implementation
+    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)]
+
+    # 1. Average Latency Chart
+    bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
+    ax1.set_title('Average Latency', fontweight='bold', fontsize=14)
+    ax1.set_ylabel('Latency (ms)', fontweight='bold')
+    ax1.tick_params(axis='x', rotation=45)
+    ax1.grid(axis='y', alpha=0.3)
+
+    # Add value labels on bars
+    for bar, val in zip(bars1, avg_latencies):
+        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01,
+                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
+
+    # 2. P95 Latency Chart
+    bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
+    ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14)
+    ax2.set_ylabel('Latency (ms)', fontweight='bold')
+    ax2.tick_params(axis='x', rotation=45)
+    ax2.grid(axis='y', alpha=0.3)
+
+    # Add value labels on bars
+    for bar, val in zip(bars2, p95_latencies):
+        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01,
+                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
+
+    # 3. Throughput Chart
+    bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
+    ax3.set_title('Throughput', fontweight='bold', fontsize=14)
+    ax3.set_ylabel('Tokens/sec', fontweight='bold')
+    ax3.tick_params(axis='x', rotation=45)
+    ax3.grid(axis='y', alpha=0.3)
+
+    # Add value labels on bars
+    for bar, val in zip(bars3, throughputs):
+        if val > 0:  # Only show label if throughput was calculated
+            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01,
+                    f'{val:.0f}', ha='center', va='bottom', fontweight='bold')
+
+    plt.tight_layout()
+    plt.savefig("moe_performance_comparison.png", dpi=300)
+
+    # Print summary table
+    print("\nPerformance Summary:")
+    print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}")
+    print("-"*80)
+
+    # Sort by average latency for relative speed calculation
+    sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms'])
+    fastest_latency = sorted_results[0][1]['stats']['avg_ms']
+
+    for impl, data in sorted_results:
+        avg_ms = data['stats']['avg_ms']
+        p95_ms = data['stats']['p95_ms']
+        tokens_s = data['stats'].get('tokens_per_s', 0)
+        relative_speed = fastest_latency / avg_ms
+
+        print(f"{impl:<30} {avg_ms:>8.2f}    {p95_ms:>8.2f}    {tokens_s:>8.0f}      {relative_speed:>6.2f}x")
+
+    print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)")
+    if len(sorted_results) > 1:
+        print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)")
+        speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms']
+        print(f"Max Speedup: {speedup:.1f}x")
+
+ +
+
+
Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/c5c8a351e1080ea89737c25df783e5c81cd76df0f2b017cedfd813e3bdf2f9f9/yamoe_results.json +Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/af01d090b967f1cb05cacea7795553418933b27fc2f188da52f7c4642e456c24/binned_results.json +Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/cf359ebbdbfd10241ce11898ee298eefd5da768c42d502b034caf3ba5b16aed6/gptoss_results.json +Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/28eb2a85c2dc94e627a0c6373b55120bd67c549ef80cd5b5e94ae756ecd11aff/gptoss_training_results.json +Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/a712c225c474c8776a91d23a96a2d4dd5dde0716ed16f6eb0dce9d92b65e06b8/megablocks_results.json + +Performance Summary: +Implementation Avg (ms) P95 (ms) Tokens/sec Relative Speed +-------------------------------------------------------------------------------- +megablocks_results 4.13 8.58 24195 1.00x +yamoe_results 8.63 8.65 11587 0.48x +gptoss_results 47.14 47.80 2122 0.09x +gptoss_training_results 48.63 49.35 2056 0.08x +binned_results 105.62 107.73 947 0.04x + +Fastest: megablocks_results (4.13ms avg) +Slowest: binned_results (105.62ms avg) +Max Speedup: 25.6x +
+
Downloading numpy (15.9MiB) +Downloading fonttools (4.7MiB) +Downloading pillow (6.3MiB) +Downloading matplotlib (8.3MiB) +Downloading kiwisolver (1.4MiB) + Downloading kiwisolver + Downloading pillow + Downloading fonttools + Downloading matplotlib + Downloading numpy +Installed 11 packages in 24ms +
+
+

Artifacts:

+moe_performance_comparison.png +
+moe_performance_comparison.png +
+
+
+
+