File size: 4,078 Bytes
9b13459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch==2.8.0",
# "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
def binned_gather(x, indices, bins, expert_capacity, top_k):
E, H = bins.shape[0], x.shape[1]
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
for e in range(E):
start = 0 if e == 0 else bins[e - 1]
end = bins[e]
n = min(end - start, expert_capacity)
for i in range(n):
flat_pos = indices[start + i]
tok = flat_pos // top_k
out[e, i] = x[tok]
return out
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
E, C, H = x.shape
N = indices.shape[0] // top_k
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
for e in range(E):
start = 0 if e == 0 else bins[e - 1]
end = bins[e]
n = end - start
if n == 0:
continue
take = min(n, expert_capacity)
for i in range(take):
flat_pos = indices[start + i] # flattened (token, slot)
tok = flat_pos // top_k
slot = flat_pos % top_k
scale = weights[flat_pos] if weights is not None else 1.0
out[tok, slot] = x[e, i] * scale
return out.sum(dim=1)
def sort_tokens_by_expert(router_indices, num_experts):
flat_indices = router_indices.flatten()
sorted_values, sorted_indices = torch.sort(flat_indices)
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
bins = torch.cumsum(tokens_per_expert, dim=0)
return sorted_indices, sorted_values, bins, tokens_per_expert
def binned_experts_ref(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
expert_capacity,
):
B, S, H = hidden_states.shape
E, K = routing_weights.shape[2], router_indices.shape[1]
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# clamp to limit
limit = 7.0
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * 1.702)
x = (up + 1) * glu
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
# build routing weights aligned to (token, slot)
flat_dense = routing_weights.view(-1, E) # [B*S, E]
flat_router = router_indices.view(-1, K) # [B*S, K]
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) # [B*S*K]
# scatter back
y = binned_scatter(x, indices, selected, bins, expert_capacity, K) # [B*S, H]
return y.view(B, S, H)
def binned_torch_openai_moe(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
):
"""
Binned PyTorch implementation of OpenAI-style MoE.
Sorts tokens by expert assignment for more efficient batched processing.
"""
B, S = hidden_states.shape[0], hidden_states.shape[1]
K = router_indices.shape[1]
# Set expert_capacity to a reasonable value (max tokens per expert)
# Use 2x the average to handle imbalance
expert_capacity = (B * S * K * 2) // routing_weights.shape[2]
return binned_experts_ref(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
expert_capacity,
)
run_benchmark(
kernel_type=KernelTypeEnum.OPENAI_MOE,
impl_name="binned_torch",
impl_tags={"family": "pytorch", "backend": "eager"},
impl_func=binned_torch_openai_moe,
dtype="float32",
) |