Kernels
danieldk HF Staff commited on
Commit
9396a12
·
verified ·
1 Parent(s): 8108956

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +233 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from collections import namedtuple
4
+
5
+ from kernels.benchmark import Benchmark
6
+
7
+
8
+ def moe_mlp_reference(
9
+ x: torch.Tensor,
10
+ router_weight: torch.Tensor,
11
+ router_bias: torch.Tensor,
12
+ gate_up_proj: torch.Tensor,
13
+ gate_up_proj_bias: torch.Tensor,
14
+ down_proj: torch.Tensor,
15
+ down_proj_bias: torch.Tensor,
16
+ top_k: int = 4,
17
+ alpha: float = 1.702,
18
+ limit: float = 7.0,
19
+ ) -> tuple[torch.Tensor, torch.Tensor]:
20
+ in_shape = x.shape
21
+ num_experts = router_weight.shape[0]
22
+ hidden_size = x.shape[-1]
23
+
24
+ # Flatten to (num_tokens, hidden_size)
25
+ hidden_states = x.view(-1, hidden_size)
26
+ num_tokens = hidden_states.shape[0]
27
+
28
+ # Router: compute logits and get top-k experts
29
+ logits = F.linear(hidden_states, router_weight, router_bias)
30
+ expert_weights, router_indices = torch.topk(logits, top_k, dim=-1)
31
+ routing_weights = F.softmax(expert_weights, dim=-1)
32
+
33
+ # Initialize output
34
+ next_states = torch.zeros_like(hidden_states)
35
+
36
+ # Create expert mask using one_hot
37
+ with torch.no_grad():
38
+ expert_mask = F.one_hot(router_indices, num_classes=num_experts)
39
+ expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, num_tokens)
40
+ # Find which experts are hit
41
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
42
+
43
+ # Process each expert that has tokens
44
+ for expert_idx in expert_hit:
45
+ expert_idx = expert_idx[0]
46
+ with torch.no_grad():
47
+ top_k_idx, token_idx = torch.where(expert_mask[expert_idx])
48
+
49
+ current_state = hidden_states[token_idx]
50
+
51
+ # Up projection
52
+ gate_up = (
53
+ current_state @ gate_up_proj[expert_idx] + gate_up_proj_bias[expert_idx]
54
+ )
55
+
56
+ # Split into gate and up
57
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
58
+
59
+ # Clamp
60
+ gate = gate.clamp(min=None, max=limit)
61
+ up = up.clamp(min=-limit, max=limit)
62
+
63
+ # SwiGLU-like activation
64
+ glu = gate * torch.sigmoid(gate * alpha)
65
+ gated_output = (up + 1) * glu
66
+
67
+ # Down projection
68
+ out = gated_output @ down_proj[expert_idx] + down_proj_bias[expert_idx]
69
+
70
+ # Get the routing weight for this expert at the correct top_k position
71
+ weights_for_expert = routing_weights[token_idx, top_k_idx]
72
+ weighted_output = out * weights_for_expert[:, None]
73
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
74
+
75
+ return next_states.view(in_shape), routing_weights
76
+
77
+
78
+ class MegaBlocksMoeBenchmark(Benchmark):
79
+ seed: int = 42
80
+
81
+ def setup(self):
82
+ # Config matching readme_example.py
83
+ ne, hs, isz = 128, 1152, 3072
84
+ batch, seq = 8, 1
85
+
86
+ # Router
87
+ self.router_weight = torch.randn(
88
+ ne, hs, device=self.device, dtype=torch.float32
89
+ )
90
+ torch.nn.init.kaiming_uniform_(self.router_weight)
91
+ self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32)
92
+
93
+ # Expert weights
94
+ self.gate_up_proj = (
95
+ torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02
96
+ )
97
+ self.gate_up_proj_bias = torch.zeros(
98
+ ne, isz, device=self.device, dtype=torch.float32
99
+ )
100
+ self.down_proj = (
101
+ torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32)
102
+ * 0.02
103
+ )
104
+ self.down_proj_bias = torch.zeros(
105
+ ne, hs, device=self.device, dtype=torch.float32
106
+ )
107
+
108
+ # Input
109
+ self.x = (
110
+ torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1
111
+ )
112
+
113
+ # Setup the model
114
+ self.model = self.kernel.layers.MegaBlocksMoeMLP()
115
+ self.model.router = torch.nn.Linear(hs, ne, device=self.device)
116
+ self.model.router.weight.data = self.router_weight.clone()
117
+ self.model.router.bias.data = self.router_bias.clone()
118
+
119
+ Experts = namedtuple(
120
+ "Experts",
121
+ [
122
+ "gate_up_proj",
123
+ "gate_up_proj_bias",
124
+ "down_proj",
125
+ "down_proj_bias",
126
+ "hidden_size",
127
+ "num_experts",
128
+ ],
129
+ )
130
+ self.model.experts = Experts(
131
+ gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()),
132
+ gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()),
133
+ down_proj=torch.nn.Parameter(self.down_proj.clone()),
134
+ down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()),
135
+ hidden_size=hs,
136
+ num_experts=ne,
137
+ )
138
+
139
+ self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32)
140
+
141
+ def benchmark_base(self):
142
+ self.out, self.expert_weights = self.model(self.x)
143
+
144
+ def verify_base(self) -> torch.Tensor:
145
+ ref_out, _ = moe_mlp_reference(
146
+ self.x,
147
+ self.router_weight,
148
+ self.router_bias,
149
+ self.gate_up_proj,
150
+ self.gate_up_proj_bias,
151
+ self.down_proj,
152
+ self.down_proj_bias,
153
+ top_k=4,
154
+ )
155
+ return ref_out
156
+
157
+ def setup_large(self):
158
+ # Larger config with more tokens
159
+ ne, hs, isz = 128, 1152, 3072
160
+ batch, seq = 32, 16
161
+
162
+ # Router
163
+ self.router_weight = torch.randn(
164
+ ne, hs, device=self.device, dtype=torch.float32
165
+ )
166
+ torch.nn.init.kaiming_uniform_(self.router_weight)
167
+ self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32)
168
+
169
+ # Expert weights
170
+ self.gate_up_proj = (
171
+ torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02
172
+ )
173
+ self.gate_up_proj_bias = torch.zeros(
174
+ ne, isz, device=self.device, dtype=torch.float32
175
+ )
176
+ self.down_proj = (
177
+ torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32)
178
+ * 0.02
179
+ )
180
+ self.down_proj_bias = torch.zeros(
181
+ ne, hs, device=self.device, dtype=torch.float32
182
+ )
183
+
184
+ # Input
185
+ self.x = (
186
+ torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1
187
+ )
188
+
189
+ # Setup the model
190
+ self.model = self.kernel.layers.MegaBlocksMoeMLP()
191
+ self.model.router = torch.nn.Linear(hs, ne, device=self.device)
192
+ self.model.router.weight.data = self.router_weight.clone()
193
+ self.model.router.bias.data = self.router_bias.clone()
194
+
195
+ Experts = namedtuple(
196
+ "Experts",
197
+ [
198
+ "gate_up_proj",
199
+ "gate_up_proj_bias",
200
+ "down_proj",
201
+ "down_proj_bias",
202
+ "hidden_size",
203
+ "num_experts",
204
+ "capacity_factor",
205
+ ],
206
+ )
207
+ self.model.experts = Experts(
208
+ gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()),
209
+ gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()),
210
+ down_proj=torch.nn.Parameter(self.down_proj.clone()),
211
+ down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()),
212
+ hidden_size=hs,
213
+ num_experts=ne,
214
+ capacity_factor=4.0, # Higher capacity to avoid token dropping
215
+ )
216
+
217
+ self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32)
218
+
219
+ def benchmark_large(self):
220
+ self.out, self.expert_weights = self.model(self.x)
221
+
222
+ def verify_large(self) -> torch.Tensor:
223
+ ref_out, _ = moe_mlp_reference(
224
+ self.x,
225
+ self.router_weight,
226
+ self.router_bias,
227
+ self.gate_up_proj,
228
+ self.gate_up_proj_bias,
229
+ self.down_proj,
230
+ self.down_proj_bias,
231
+ top_k=4,
232
+ )
233
+ return ref_out