File size: 7,833 Bytes
9396a12 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | import torch
import torch.nn.functional as F
from collections import namedtuple
from kernels.benchmark import Benchmark
def moe_mlp_reference(
x: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
gate_up_proj: torch.Tensor,
gate_up_proj_bias: torch.Tensor,
down_proj: torch.Tensor,
down_proj_bias: torch.Tensor,
top_k: int = 4,
alpha: float = 1.702,
limit: float = 7.0,
) -> tuple[torch.Tensor, torch.Tensor]:
in_shape = x.shape
num_experts = router_weight.shape[0]
hidden_size = x.shape[-1]
# Flatten to (num_tokens, hidden_size)
hidden_states = x.view(-1, hidden_size)
num_tokens = hidden_states.shape[0]
# Router: compute logits and get top-k experts
logits = F.linear(hidden_states, router_weight, router_bias)
expert_weights, router_indices = torch.topk(logits, top_k, dim=-1)
routing_weights = F.softmax(expert_weights, dim=-1)
# Initialize output
next_states = torch.zeros_like(hidden_states)
# Create expert mask using one_hot
with torch.no_grad():
expert_mask = F.one_hot(router_indices, num_classes=num_experts)
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, num_tokens)
# Find which experts are hit
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
# Process each expert that has tokens
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
with torch.no_grad():
top_k_idx, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
# Up projection
gate_up = (
current_state @ gate_up_proj[expert_idx] + gate_up_proj_bias[expert_idx]
)
# Split into gate and up
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# Clamp
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
# SwiGLU-like activation
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
# Down projection
out = gated_output @ down_proj[expert_idx] + down_proj_bias[expert_idx]
# Get the routing weight for this expert at the correct top_k position
weights_for_expert = routing_weights[token_idx, top_k_idx]
weighted_output = out * weights_for_expert[:, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
return next_states.view(in_shape), routing_weights
class MegaBlocksMoeBenchmark(Benchmark):
seed: int = 42
def setup(self):
# Config matching readme_example.py
ne, hs, isz = 128, 1152, 3072
batch, seq = 8, 1
# Router
self.router_weight = torch.randn(
ne, hs, device=self.device, dtype=torch.float32
)
torch.nn.init.kaiming_uniform_(self.router_weight)
self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32)
# Expert weights
self.gate_up_proj = (
torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02
)
self.gate_up_proj_bias = torch.zeros(
ne, isz, device=self.device, dtype=torch.float32
)
self.down_proj = (
torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32)
* 0.02
)
self.down_proj_bias = torch.zeros(
ne, hs, device=self.device, dtype=torch.float32
)
# Input
self.x = (
torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1
)
# Setup the model
self.model = self.kernel.layers.MegaBlocksMoeMLP()
self.model.router = torch.nn.Linear(hs, ne, device=self.device)
self.model.router.weight.data = self.router_weight.clone()
self.model.router.bias.data = self.router_bias.clone()
Experts = namedtuple(
"Experts",
[
"gate_up_proj",
"gate_up_proj_bias",
"down_proj",
"down_proj_bias",
"hidden_size",
"num_experts",
],
)
self.model.experts = Experts(
gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()),
gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()),
down_proj=torch.nn.Parameter(self.down_proj.clone()),
down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()),
hidden_size=hs,
num_experts=ne,
)
self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32)
def benchmark_base(self):
self.out, self.expert_weights = self.model(self.x)
def verify_base(self) -> torch.Tensor:
ref_out, _ = moe_mlp_reference(
self.x,
self.router_weight,
self.router_bias,
self.gate_up_proj,
self.gate_up_proj_bias,
self.down_proj,
self.down_proj_bias,
top_k=4,
)
return ref_out
def setup_large(self):
# Larger config with more tokens
ne, hs, isz = 128, 1152, 3072
batch, seq = 32, 16
# Router
self.router_weight = torch.randn(
ne, hs, device=self.device, dtype=torch.float32
)
torch.nn.init.kaiming_uniform_(self.router_weight)
self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32)
# Expert weights
self.gate_up_proj = (
torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02
)
self.gate_up_proj_bias = torch.zeros(
ne, isz, device=self.device, dtype=torch.float32
)
self.down_proj = (
torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32)
* 0.02
)
self.down_proj_bias = torch.zeros(
ne, hs, device=self.device, dtype=torch.float32
)
# Input
self.x = (
torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1
)
# Setup the model
self.model = self.kernel.layers.MegaBlocksMoeMLP()
self.model.router = torch.nn.Linear(hs, ne, device=self.device)
self.model.router.weight.data = self.router_weight.clone()
self.model.router.bias.data = self.router_bias.clone()
Experts = namedtuple(
"Experts",
[
"gate_up_proj",
"gate_up_proj_bias",
"down_proj",
"down_proj_bias",
"hidden_size",
"num_experts",
"capacity_factor",
],
)
self.model.experts = Experts(
gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()),
gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()),
down_proj=torch.nn.Parameter(self.down_proj.clone()),
down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()),
hidden_size=hs,
num_experts=ne,
capacity_factor=4.0, # Higher capacity to avoid token dropping
)
self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32)
def benchmark_large(self):
self.out, self.expert_weights = self.model(self.x)
def verify_large(self) -> torch.Tensor:
ref_out, _ = moe_mlp_reference(
self.x,
self.router_weight,
self.router_bias,
self.gate_up_proj,
self.gate_up_proj_bias,
self.down_proj,
self.down_proj_bias,
top_k=4,
)
return ref_out
|