diff --git "a/moe_benchmarks/megablocks_yamoe/torch_profile.html" "b/moe_benchmarks/megablocks_yamoe/torch_profile.html" --- "a/moe_benchmarks/megablocks_yamoe/torch_profile.html" +++ "b/moe_benchmarks/megablocks_yamoe/torch_profile.html" @@ -3708,7 +3708,7 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
This section runs the Yamoe MoE implementation with optimized Triton kernels.
-This section runs the binned implementation that manually handles token gathering/scattering.
+import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def binned_gather(x, indices, bins, expert_capacity, top_k):
+ E, H = bins.shape[0], x.shape[1]
+ out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
+ for e in range(E):
+ start = 0 if e == 0 else bins[e - 1]
+ end = bins[e]
+ n = min(end - start, expert_capacity)
+ for i in range(n):
+ flat_pos = indices[start + i]
+ tok = flat_pos // top_k
+ out[e, i] = x[tok]
+ return out
+
+def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
+ E, C, H = x.shape
+ N = indices.shape[0] // top_k
+ out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
+ for e in range(E):
+ start = 0 if e == 0 else bins[e - 1]
+ end = bins[e]
+ n = end - start
+ if n == 0:
+ continue
+ take = min(n, expert_capacity)
+ for i in range(take):
+ flat_pos = indices[start + i]
+ tok = flat_pos // top_k
+ slot = flat_pos % top_k
+ scale = weights[flat_pos] if weights is not None else 1.0
+ out[tok, slot] = x[e, i] * scale
+ return out.sum(dim=1)
+
+def sort_tokens_by_expert(router_indices, num_experts):
+ flat_indices = router_indices.flatten()
+ sorted_values, sorted_indices = torch.sort(flat_indices)
+ tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
+ bins = torch.cumsum(tokens_per_expert, dim=0)
+ return sorted_indices, sorted_values, bins, tokens_per_expert
+
+def binned_experts_ref(
+ hidden_states,
+ router_indices,
+ routing_weights,
+ gate_up_proj,
+ gate_up_proj_bias,
+ down_proj,
+ down_proj_bias,
+ expert_capacity,
+):
+ B, S, H = hidden_states.shape
+ E, K = routing_weights.shape[1], router_indices.shape[1]
+
+ indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
+ x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
+
+ gate_up = torch.bmm(x, gate_up_proj)
+ gate_up += gate_up_proj_bias[..., None, :]
+
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+
+ # clamp to limit
+ limit = 7.0
+ gate = gate.clamp(min=None, max=limit)
+ up = up.clamp(min=-limit, max=limit)
+
+ glu = gate * torch.sigmoid(gate * 1.702)
+ x = (up + 1) * glu
+ x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
+
+ # build routing weights aligned to (token, slot)
+ flat_dense = routing_weights.view(-1, E)
+ flat_router = router_indices.view(-1, K)
+ selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
+
+ # scatter back
+ y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
+
+ return y.view(B, S, H)
+
+class BinnedRouter(nn.Module):
+ def __init__(self, router_weight, router_bias):
+ super().__init__()
+ self.top_k = TOP_K
+ self.num_experts = NUM_EXPERTS
+ self.hidden_dim = HIDDEN_SIZE
+ self.weight = nn.Parameter(router_weight.clone())
+ self.bias = nn.Parameter(router_bias.clone())
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+def ceil_div(a, b):
+ return (a + b - 1) // b
+
+class BinnedMoEMLP(nn.Module):
+ def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+ super().__init__()
+ self.router = BinnedRouter(router_weight, router_bias)
+ self.num_experts = NUM_EXPERTS
+ self.hidden_size = HIDDEN_SIZE
+ self.top_k = TOP_K
+
+ # Expert weights - use the loaded weights
+ self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+ self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+ self.down_proj = nn.Parameter(down_proj.clone())
+ self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states)
+ batch_size = hidden_states.shape[0]
+ expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
+
+ output = binned_experts_ref(
+ hidden_states,
+ router_indices,
+ router_scores,
+ self.gate_up_proj,
+ self.gate_up_proj_bias,
+ self.down_proj,
+ self.down_proj_bias,
+ expert_capacity,
+ )
+
+ return output, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== Binned Implementation ===")
+# Initialize model with loaded weights
+model = BinnedMoEMLP(
+ router_weight.to(device),
+ router_bias.to(device),
+ gate_up_proj.to(device),
+ gate_up_proj_bias.to(device),
+ down_proj.to(device),
+ down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate the same input as Yamoe
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
+ output, stats = bench(model, x)
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
+This section runs the GPT-OSS MoE implementation with manual expert loop handling.
+import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssRouter(nn.Module):
+ def __init__(self, router_weight, router_bias):
+ super().__init__()
+ self.top_k = TOP_K
+ self.num_experts = NUM_EXPERTS
+ self.hidden_dim = HIDDEN_SIZE
+ self.weight = nn.Parameter(router_weight.clone())
+ self.bias = nn.Parameter(router_bias.clone())
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+class GptOssExperts(nn.Module):
+ def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+ super().__init__()
+ self.num_experts = NUM_EXPERTS
+ self.hidden_size = HIDDEN_SIZE
+ self.expert_dim = self.hidden_size
+ self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+ self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+ self.down_proj = nn.Parameter(down_proj.clone())
+ self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+ self.alpha = 1.702
+ self.limit = 7.0
+
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
+ num_experts = routing_weights.shape[1]
+
+ if hidden_states.device.type == "cpu" or self.training:
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ for expert_idx in expert_hit[:]:
+ expert_idx = expert_idx[0]
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ gated_output = (up + 1) * glu
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ else:
+ hidden_states = hidden_states.repeat(num_experts, 1)
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+ next_states = next_states + self.down_proj_bias[..., None, :]
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+ next_states = next_states.sum(dim=0)
+ return next_states
+
+class GptOssMoEMLP(nn.Module):
+ def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+ super().__init__()
+ self.router = GptOssRouter(router_weight, router_bias)
+ self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states)
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+ return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation ===")
+# Initialize model with loaded weights
+model = GptOssMoEMLP(
+ router_weight.to(device),
+ router_bias.to(device),
+ gate_up_proj.to(device),
+ gate_up_proj_bias.to(device),
+ down_proj.to(device),
+ down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
+ output, stats = bench(model, x)
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
+This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.
+import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssTrainingRouter(nn.Module):
+ def __init__(self, router_weight, router_bias):
+ super().__init__()
+ self.top_k = TOP_K
+ self.num_experts = NUM_EXPERTS
+ self.hidden_dim = HIDDEN_SIZE
+ self.weight = nn.Parameter(router_weight.clone())
+ self.bias = nn.Parameter(router_bias.clone())
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+ router_logits = F.linear(hidden_states, self.weight, self.bias)
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+ return router_scores, router_indices
+
+class GptOssTrainingExperts(nn.Module):
+ def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+ super().__init__()
+ self.num_experts = NUM_EXPERTS
+ self.hidden_size = HIDDEN_SIZE
+ self.expert_dim = self.hidden_size
+ self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+ self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+ self.down_proj = nn.Parameter(down_proj.clone())
+ self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+ self.alpha = 1.702
+ self.limit = 7.0
+
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
+ num_experts = routing_weights.shape[1]
+
+ # Force training mode path (expert loop instead of batched)
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+ with torch.no_grad():
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+ expert_mask = expert_mask.permute(2, 1, 0)
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ for expert_idx in expert_hit[:]:
+ expert_idx = expert_idx[0]
+ with torch.no_grad():
+ _, token_idx = torch.where(expert_mask[expert_idx])
+ current_state = hidden_states[token_idx]
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+ gate = gate.clamp(min=None, max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ gated_output = (up + 1) * glu
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
+ return next_states
+
+class GptOssTrainingMoEMLP(nn.Module):
+ def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+ super().__init__()
+ self.router = GptOssTrainingRouter(router_weight, router_bias)
+ self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+ def forward(self, hidden_states):
+ router_scores, router_indices = self.router(hidden_states)
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+ return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
+# Initialize model with loaded weights and force training mode
+model = GptOssTrainingMoEMLP(
+ router_weight.to(device),
+ router_bias.to(device),
+ gate_up_proj.to(device),
+ gate_up_proj_bias.to(device),
+ down_proj.to(device),
+ down_proj_bias.to(device)
+).to(device=device)
+
+# Set to training mode to force expert loop path
+model.train()
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+print(f"Model training mode: {model.training}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
+ output, stats = bench(model, x)
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
+This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.
+import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+from collections import namedtuple
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def build_megablocks_model(device: torch.device):
+ # Download optimized kernels from the Hugging Face hub
+ megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
+ model = megablocks.layers.MegaBlocksMoeMLP()
+
+ # Create attribute container for expert weights
+ model.experts = namedtuple(
+ "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
+ )
+
+ # Use loaded router weights for consistency
+ model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
+ with torch.no_grad():
+ model.router.weight.copy_(router_weight)
+ model.router.bias.copy_(router_bias)
+
+ # Attach loaded expert weights to the experts container
+ e = model.experts
+ e.alpha = 1.702
+ e.capacity_factor = 32
+ e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
+ e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
+ e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
+ e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
+ e.hidden_size = HIDDEN_SIZE
+
+ # Log weight statistics for comparison
+ print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
+ print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
+ print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
+
+ return model
+
+# Create a wrapper to match the interface of other implementations
+class MegaBlocksMoEWrapper(nn.Module):
+ def __init__(self, megablocks_model):
+ super().__init__()
+ self.model = megablocks_model
+
+ def forward(self, hidden_states):
+ # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
+ output, dummy_routing_weights = self.model(hidden_states)
+ return output, dummy_routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== MegaBlocks Implementation ===")
+# Build MegaBlocks model with loaded weights
+megablocks_model = build_megablocks_model(device)
+model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
+ output, stats = bench(model, x)
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
+This section reads all benchmark results and creates a comprehensive performance comparison chart.