danieldk HF Staff commited on
Commit
b6db418
·
1 Parent(s): e0e5abe

Revert "Build uploaded using `kernels`."

Browse files

This reverts commit e0e5abe72d54407c849a04974181545bf03eb02a.

build/torch-cuda/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
- from . import parallel_experts
3
- from . import kernels
4
- from . import layers
5
-
6
- __all__ = [
7
- "flatten_sort_count",
8
- "parallel_linear",
9
- "ParallelExperts",
10
- "parallel_experts",
11
- "kernels",
12
- "layers"
13
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._scattermoe_05b9d77
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_scattermoe_05b9d77::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-cuda/kernels/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from . import ops
2
-
3
- __all__ = ["ops"]
 
 
 
 
build/torch-cuda/kernels/ops.py DELETED
@@ -1,457 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
- from typing import Optional
5
-
6
- BLOCK_M = 128
7
- ALLOW_TF32 = True
8
-
9
-
10
-
11
- @triton.jit
12
- def _compute_expert_block(
13
- E_idx, E_mask,
14
- M_in_idx,
15
- N_block, N_mask,
16
- X_ptr, stride_xm, stride_xk,
17
- W_ptr, stride_we, stride_wk, stride_wn,
18
- K,
19
- acc,
20
- no_k_mask,
21
- BLOCK_K,
22
- allow_tf32=True,
23
- ):
24
-
25
- K_block = tl.arange(0, BLOCK_K)
26
- X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
27
- W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
28
- iters = tl.cdiv(K, BLOCK_K)
29
-
30
- for K_block_id in range(iters):
31
- if no_k_mask:
32
- x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
33
- w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
34
- else:
35
- K_mask = (K_block_id * BLOCK_K + K_block) < K
36
- x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
37
- w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
38
-
39
- X_blk_ptrs += BLOCK_K * stride_xk
40
- W_blk_ptrs += BLOCK_K * stride_wk
41
- acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
42
- return acc
43
-
44
-
45
- def _scatter2scatter_configs():
46
- return [
47
- triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
48
- ]
49
-
50
- @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
51
- @triton.heuristics({
52
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
53
- "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
54
- })
55
- @triton.jit
56
- def _scatter2scatter(
57
- X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr,
58
- W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr,
59
- Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr,
60
- B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr,
61
- grouped_idx_ptr, expert_idxs_ptr,
62
- # block_start_idx_ptr,
63
- FAN_OUT: tl.constexpr,
64
- M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
65
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
66
- ACC_TYPE: tl.constexpr,
67
- # OUT_M,
68
- allow_tf32: tl.constexpr,
69
- x_grouped: tl.constexpr, y_grouped: tl.constexpr,
70
- NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
71
- ):
72
- pid = tl.program_id(axis=0)
73
-
74
- N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
75
- M_block_id = pid // N_BLOCK_COUNT
76
- N_block_id = pid % N_BLOCK_COUNT
77
-
78
- M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
79
- N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
80
- N_mask = N_block < N
81
- M_boundary_mask = M_block < (FAN_OUT * M)
82
- E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
83
-
84
- no_k_mask = K % BLOCK_K == 0
85
-
86
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
87
- E_first_idx = tl.min(E_idxs)
88
- E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
89
- M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
90
- for E_idx in range(E_first_idx, E_last_idx + 1):
91
- E_mask = E_idxs == E_idx
92
- E_M_idx = M_idx
93
- if x_grouped:
94
- M_in_idx = M_block
95
- else:
96
- M_in_idx = E_M_idx // FAN_OUT
97
- acc = _compute_expert_block(
98
- E_idx, E_mask,
99
- M_in_idx, N_block, N_mask,
100
- X_ptr, stride_xm, stride_xk,
101
- W_ptr, stride_we, stride_wk, stride_wn,
102
- K,
103
- acc,
104
- no_k_mask,
105
- BLOCK_K,
106
- allow_tf32=allow_tf32,
107
- )
108
-
109
- if B_ptr is not None:
110
- B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
111
- acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
112
-
113
- if y_grouped:
114
- M_out_idx = M_block
115
- else:
116
- M_out_idx = M_idx
117
- Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
118
- tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
119
-
120
- def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
121
- b=None,
122
- x_grouped=False, y_grouped=False,
123
- out=None):
124
- assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
125
- assert sorted_scattered_idxs.size(0) == X.size(0) * k
126
- # Pre-kernel setup
127
- y_dim = W.size(-1)
128
- L_scattered = sorted_expert_idxs.size(0)
129
- if out is None:
130
- output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
131
- else:
132
- assert out.size(0) == L_scattered and out.size(1) == y_dim
133
- output = out
134
-
135
- scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs,
136
- b, x_grouped, y_grouped)
137
- return output
138
-
139
-
140
- @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
141
- def scatter2scatter_compileable(
142
- output: torch.Tensor,
143
- W: torch.Tensor,
144
- X: torch.Tensor,
145
- k: int,
146
- sorted_expert_idxs: torch.Tensor,
147
- sorted_scattered_idxs: torch.Tensor,
148
- b: Optional[torch.Tensor],
149
- x_grouped: bool, y_grouped: bool) -> None:
150
- def grid(META):
151
- grid_num = (
152
- triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) *
153
- triton.cdiv(META['N'], META['BLOCK_N']),
154
- )
155
- return grid_num
156
-
157
- if b is None:
158
- b = None
159
- stride_be = stride_bk = 0
160
- else:
161
- stride_be, stride_bk = b.stride()
162
-
163
- _scatter2scatter[grid](
164
- # X_ptr, stride_xm, stride_xk,
165
- X, X.stride(0), X.stride(1),
166
- # W_ptr, stride_we, stride_wk, stride_wn,
167
- W, W.stride(0), W.stride(1), W.stride(2),
168
- # Y_ptr, stride_ym, stride_yn,
169
- output, output.stride(0), output.stride(1),
170
- # B_ptr, stride_be, stride_bk
171
- b, stride_be, stride_bk,
172
- grouped_idx_ptr=sorted_scattered_idxs,
173
- expert_idxs_ptr=sorted_expert_idxs,
174
- # block_start_idx_ptr=padded_block_idxs,
175
- FAN_OUT=k,
176
- M=X.size(0),
177
- K=X.size(1),
178
- N=output.size(1), E=W.size(0),
179
- BLOCK_M=BLOCK_M,
180
- ACC_TYPE=tl.float32,
181
- allow_tf32=ALLOW_TF32,
182
- x_grouped=x_grouped, y_grouped=y_grouped,
183
- )
184
-
185
-
186
- def _config_XtY():
187
- return [
188
- triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
189
- ]
190
-
191
- def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
192
- DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
193
- DW = DWt.permute(0, 2, 1)
194
- if has_bias:
195
- Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
196
- else:
197
- Db = None
198
- groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
199
- return DW, Db
200
-
201
-
202
- @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"})
203
- def groupXtY_compileable(
204
- E: int,
205
- DW: torch.Tensor,
206
- Db: Optional[torch.Tensor],
207
- DY: torch.Tensor,
208
- X: torch.Tensor,
209
- expert_offsets: torch.Tensor) -> None:
210
- def grid(META):
211
- grid = (
212
- E * triton.cdiv(META['K'], META['BLOCK_K']),
213
- triton.cdiv(META['N'], META['BLOCK_N']),
214
- )
215
- return grid
216
-
217
- if Db is None:
218
- stride_dbe = 0
219
- stride_dbn = 0
220
- else:
221
- stride_dbe, stride_dbn = Db.stride()
222
-
223
- _groupXtY[grid](
224
- # DY_ptr, stride_dym, stride_dyk,
225
- DY, DY.stride(0), DY.stride(1),
226
- # X_ptr, stride_xm, stride_xn,
227
- X, X.stride(0), X.stride(1),
228
- # DW_ptr, stride_dwe, stride_dwk, stride_dwn,
229
- DW, DW.stride(0), DW.stride(1), DW.stride(2),
230
- # Db_ptr, stride_dwe, stride_dbn,
231
- Db, stride_dbe, stride_dbn,
232
- # expert_offsets_ptr,
233
- expert_offsets,
234
- # K: tl.constexpr, N: tl.constexpr,
235
- M=DY.size(0), N=DY.size(-1), K=X.size(-1),
236
- # ACC_TYPE: tl.constexpr,
237
- ACC_TYPE=tl.float32,
238
- allow_tf32=ALLOW_TF32
239
- )
240
-
241
-
242
- @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
243
- @triton.heuristics({
244
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
245
- "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
246
- })
247
- @triton.jit
248
- def _groupXtY(
249
- DY_ptr, stride_dym, stride_dyk,
250
- X_ptr, stride_xm, stride_xn,
251
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
252
- Db_ptr, stride_dbe, stride_dbn,
253
- expert_offsets_ptr,
254
- M, K: tl.constexpr, N: tl.constexpr,
255
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
256
- ACC_TYPE: tl.constexpr,
257
- allow_tf32: tl.constexpr,
258
- NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
259
- ):
260
- pid0 = tl.program_id(axis=0)
261
- pid1 = tl.program_id(axis=1)
262
- num0 = tl.num_programs(0)
263
- num1 = tl.num_programs(1)
264
- # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
265
- pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
266
-
267
- K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
268
- E_idx = pid0 // K_BLOCK_COUNT
269
- K_block_id = pid0 % K_BLOCK_COUNT
270
- N_block_id = pid1
271
-
272
- if E_idx == 0:
273
- start_idx = 0
274
- else:
275
- start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
276
- end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
277
-
278
-
279
- if end_idx > start_idx:
280
- M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
281
-
282
- K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
283
- K_mask = K_block < K
284
- K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
285
-
286
- N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
287
- N_mask = N_block < N
288
- N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
289
-
290
- M_idxs = M_block
291
- xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
292
- dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
293
- if (Db_ptr is not None) and (K_block_id == 0):
294
- _xty_and_bias(
295
- E_idx, start_idx, end_idx,
296
- M_block,
297
- K_block, K_mask, N_block, N_mask,
298
- dy_blk_ptrs, stride_dym,
299
- xt_blk_ptrs, stride_xm,
300
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
301
- Db_ptr, stride_dbe, stride_dbn,
302
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
303
- allow_tf32, NO_K_MASK, NO_N_MASK,
304
- compute_bias=True
305
- )
306
- else:
307
- _xty_and_bias(
308
- E_idx, start_idx, end_idx,
309
- M_block,
310
- K_block, K_mask, N_block, N_mask,
311
- dy_blk_ptrs, stride_dym,
312
- xt_blk_ptrs, stride_xm,
313
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
314
- Db_ptr, stride_dbe, stride_dbn,
315
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
316
- allow_tf32, NO_K_MASK, NO_N_MASK,
317
- compute_bias=False
318
- )
319
-
320
-
321
- @triton.jit
322
- def _xty_and_bias(
323
- E_idx, start_idx, end_idx,
324
- M_block,
325
- K_block, K_mask, N_block, N_mask,
326
- dy_blk_ptrs, stride_dym,
327
- xt_blk_ptrs, stride_xm,
328
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
329
- Db_ptr, stride_dbe, stride_dbn,
330
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
331
- allow_tf32, NO_K_MASK, NO_N_MASK,
332
- compute_bias: tl.constexpr
333
- ):
334
-
335
- if compute_bias:
336
- db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
337
- else:
338
- db_acc = None
339
-
340
- acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
341
- iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
342
- for i in range(0, iters):
343
- M_mask = (i * BLOCK_M + M_block) < end_idx
344
- if NO_K_MASK:
345
- xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
346
- else:
347
- xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
348
- if NO_N_MASK:
349
- dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
350
- else:
351
- dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
352
-
353
- acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
354
-
355
- xt_blk_ptrs += BLOCK_M * stride_xm
356
- dy_blk_ptrs += BLOCK_M * stride_dym
357
-
358
- if compute_bias:
359
- db_acc += tl.sum(dy, axis=0)
360
-
361
- DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
362
- acc = acc.to(DW_blk_ptrs.dtype.element_ty)
363
- tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
364
- if compute_bias:
365
- Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
366
- tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
367
-
368
-
369
-
370
- def _config_grouping():
371
- return [
372
- triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
373
- # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
374
- # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
375
- ]
376
-
377
- def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
378
- N = sorted_expert_idxs.size(0)
379
- K = A.size(1)
380
- assert A.size(0) * fan_out == N
381
- if out is not None:
382
- Y = out
383
- else:
384
- Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
385
- group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
386
- return Y
387
-
388
-
389
- @torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
390
- def group_compileable(
391
- A: torch.Tensor,
392
- K: int,
393
- N: int,
394
- Y: torch.Tensor,
395
- coeff: torch.Tensor, has_coeff: bool,
396
- fan_out: int,
397
- sorted_expert_idxs: torch.Tensor) -> None:
398
- def grid(META):
399
- grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
400
- return grid_num
401
- _group[grid](
402
- # A_ptr, stride_an, stride_ai,
403
- A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out,
404
- # Y_ptr, stride_yn, stride_yk,
405
- Y, Y.stride(0), Y.stride(1),
406
- # grouped_idx_ptr,
407
- sorted_expert_idxs,
408
- # N: tl.constexpr, K: tl.constexpr,
409
- N, K
410
- )
411
-
412
-
413
- @triton.autotune(configs=_config_grouping(), key=['K'])
414
- @triton.heuristics({
415
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
416
- })
417
- @triton.jit
418
- def _group(
419
- src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
420
- tgt_ptr, stride_tn, stride_ti,
421
- grouped_idx_ptr,
422
- N, K: tl.constexpr,
423
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
424
- NO_K_MASK: tl.constexpr
425
- ):
426
- pid = tl.program_id(axis=0)
427
-
428
- N_block_id = pid
429
- N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
430
- N_mask = N_blk < N
431
- N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
432
- N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
433
-
434
- K_blk = tl.arange(0, BLOCK_K)
435
- src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
436
- tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
437
-
438
- if has_coeff:
439
- c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
440
-
441
- iters = tl.cdiv(K, BLOCK_K)
442
- for i in range(0, iters):
443
- if NO_K_MASK or i < iters - 1:
444
- block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
445
- if has_coeff:
446
- block *= c
447
- tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
448
-
449
- else:
450
- K_mask = (i * BLOCK_K + K_blk) < K
451
- mask = N_mask[:, None] & K_mask[None, :]
452
- block = tl.load(src_blk_ptrs, mask=mask)
453
- if has_coeff:
454
- block *= c
455
- tl.store(tgt_blk_ptrs, block, mask=mask)
456
- src_blk_ptrs += BLOCK_K * stride_sk
457
- tgt_blk_ptrs += BLOCK_K * stride_ti
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/kernels/single.py DELETED
@@ -1,59 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
-
5
- @triton.jit
6
- def _single2scatter(
7
- X_ptr, stride_xm, stride_xk,
8
- W_ptr, stride_we, stride_wk, stride_wn,
9
- Y_ptr, stride_ym, stride_yn,
10
- expert_idxs_ptr,
11
- FAN_OUT: tl.constexpr,
12
- K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
13
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
14
- ACC_TYPE: tl.constexpr,
15
- ):
16
- pid0 = tl.program_id(axis=0)
17
- pid1 = tl.program_id(axis=1)
18
-
19
- N_block_id = pid0
20
- if FAN_OUT == 1:
21
- in_idx = pid1
22
- else:
23
- in_idx = 0
24
- out_idx = pid1
25
-
26
- K_block = tl.arange(0, BLOCK_K)
27
- N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
28
- E_idx = tl.load(expert_idxs_ptr + pid1)
29
- X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
30
- W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
31
- acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
32
- for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
33
- x = tl.load(X_blk_ptrs)
34
- w = tl.load(W_blk_ptrs)
35
- acc += tl.sum(x * w, axis=0)[None, :]
36
- X_blk_ptrs += BLOCK_K * stride_xk
37
- W_blk_ptrs += BLOCK_K * stride_wk
38
- Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
39
- tl.store(Y_blk_ptrs, acc)
40
-
41
- def single2scatter(X, W, expert_idxs):
42
- E, xdim, ydim = W.size()
43
- k = expert_idxs.size(1)
44
- assert X.size(0) == k or X.size(0) == 1
45
- Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
46
- BLOCK_N = 128
47
- BLOCK_K = 128
48
- grid = ydim // BLOCK_N, k
49
- _single2scatter[grid](
50
- X, X.stride(0), X.stride(1),
51
- W, W.stride(0), W.stride(1), W.stride(2),
52
- Y, Y.stride(0), Y.stride(1),
53
- expert_idxs,
54
- FAN_OUT=Y.size(0) // X.size(0),
55
- K=xdim, N=ydim, E=E,
56
- BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
57
- ACC_TYPE=tl.float32
58
- )
59
- return Y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/layers.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
- from torch import nn
4
-
5
- from . import parallel_linear, flatten_sort_count
6
-
7
- class ScatterMoEGatedMLP(nn.Module):
8
- def forward(self, layer_input):
9
- """
10
- Forward pass of the mixture of experts layer.
11
-
12
- Args:
13
- layer_input (Tensor):
14
- Input tensor.
15
-
16
- Returns:
17
- Tensor:
18
- Output tensor.
19
- Tensor:
20
- Router logits.
21
- """
22
- bsz, length, emb_size = layer_input.size()
23
- layer_input = layer_input.reshape(-1, emb_size)
24
- # compute the top_k routing decision
25
- router_logits = self.router.layer(layer_input)
26
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
27
- routing_weights, selected_experts = torch.topk(routing_weights, self.router.top_k, dim=-1)
28
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
29
- routing_weights = routing_weights.to(layer_input.dtype)
30
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
31
- flatten_sort_count(selected_experts, num_experts=self.router.num_experts)
32
-
33
- # compute experts
34
- gates, h = parallel_linear(
35
- layer_input, self.input_linear.weight.transpose(2, 1),
36
- self.router.top_k,
37
- sorted_expert_idxs, sorted_scattered_idxs,
38
- expert_offsets,
39
- grouped_in=False, grouped_out=True,
40
- ).chunk(2, dim=-1)
41
- h = self.activation(gates) * h
42
- layer_output = parallel_linear(
43
- h, self.output_linear.weight.transpose(2, 1),
44
- 1,
45
- sorted_expert_idxs, sorted_scattered_idxs,
46
- expert_offsets,
47
- grouped_in=True, grouped_out=False,
48
- gates=routing_weights
49
- )
50
- layer_output = layer_output.view(bsz, length, emb_size)
51
- return layer_output
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch-cuda/parallel_experts.py DELETED
@@ -1,182 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from . import kernels
4
- from typing import Optional
5
-
6
- @torch.library.custom_op("scattermoe::bincount", mutates_args={})
7
- def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
8
- return x.bincount(minlength=minlength)
9
-
10
- @compileable_bincount.register_fake
11
- def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
12
- return torch.empty(minlength, dtype=torch.long, device=x.device)
13
-
14
- @torch.compile
15
- def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
16
- with torch.no_grad():
17
- flattened_expert_idxs = expert_idxs.flatten()
18
- sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
19
- expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
20
- expert_offsets = expert_counts.cumsum(-1)
21
- return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
22
-
23
-
24
-
25
- class ParallelLinear(torch.autograd.Function):
26
- @staticmethod
27
- def forward(
28
- ctx,
29
- x: torch.Tensor, expert_weights: torch.Tensor, k: int,
30
- sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
31
- expert_offsets: torch.Tensor,
32
- expert_biases: Optional[torch.Tensor]=None,
33
- gates: Optional[torch.Tensor]=None,
34
- grouped_in: bool =False, grouped_out: bool=False,
35
- ):
36
- with torch.device(x.device):
37
- output = kernels.ops.scatter2scatter(
38
- X=x, W=expert_weights,
39
- b=expert_biases, k=k,
40
- sorted_expert_idxs=sorted_expert_idxs,
41
- sorted_scattered_idxs=sorted_scattered_idxs,
42
- x_grouped=grouped_in, y_grouped=grouped_out
43
- )
44
- if gates is not None:
45
- output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
46
- output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
47
- else:
48
- output_expanded = None
49
-
50
- ctx.save_for_backward(
51
- x, expert_weights,
52
- expert_biases,
53
- sorted_expert_idxs,
54
- sorted_scattered_idxs,
55
- expert_offsets,
56
- gates,
57
- output_expanded
58
- )
59
- ctx.grouped_in = grouped_in
60
- ctx.grouped_out = grouped_out
61
- ctx.k = k
62
- return output
63
- @staticmethod
64
- def backward(ctx, grad_out: torch.Tensor):
65
- with torch.device(grad_out.device):
66
- (x, expert_weights, expert_biases,
67
- sorted_expert_idxs,
68
- sorted_scattered_idxs,
69
- expert_offsets,
70
- gates, output_expanded) = ctx.saved_tensors
71
- k = ctx.k
72
- grouped_in = ctx.grouped_in
73
- grouped_out = ctx.grouped_out
74
- # print("backward")
75
-
76
- if gates is not None:
77
- # calculate gates gradient
78
- # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
79
- d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
80
- gates_flat = gates.flatten()
81
- gate_fan = gates.size(1)
82
- grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
83
- else:
84
- d_gates = None
85
- gates_flat = None
86
- gate_fan = 1
87
- grouped_grad_out = None
88
-
89
- if grouped_out:
90
- grouped_grad_out = grad_out
91
- else:
92
- grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
93
- fan_out=gate_fan, coeff=gates_flat,
94
- out=grouped_grad_out)
95
- if grouped_in:
96
- grouped_x = x
97
- d_expanded_input = None
98
- else:
99
- grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
100
- d_expanded_input = grouped_x
101
-
102
- d_weights, d_biases = kernels.ops.group_bwd_W(
103
- DY=grouped_grad_out, X=grouped_x,
104
- expert_offsets=expert_offsets,
105
- E=expert_weights.size(0),
106
- has_bias=expert_biases is not None
107
- )
108
-
109
-
110
- d_expanded_input = kernels.ops.scatter2scatter(
111
- X=grouped_grad_out, x_grouped=True,
112
- W=expert_weights.permute(0, 2, 1),
113
- sorted_expert_idxs=sorted_expert_idxs,
114
- sorted_scattered_idxs=sorted_scattered_idxs,
115
- k=1,
116
- y_grouped=grouped_in,
117
- out=d_expanded_input # Reuse grouped_x buffer
118
- )
119
-
120
- if k == 1:
121
- d_input = d_expanded_input
122
- else:
123
- d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
124
- # print("backward end.")
125
- return (
126
- # x, expert_weights,
127
- d_input, d_weights,
128
- # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
129
- None, None, None, None,
130
- # bias, gates
131
- d_biases, d_gates,
132
- # grouped_in, grouped_out,
133
- None, None
134
- )
135
-
136
- def parallel_linear(inputs, expert_weights, k,
137
- sorted_expert_idxs, sorted_scattered_idxs,
138
- expert_offsets,
139
- expert_biases=None,
140
- gates=None, grouped_in=False, grouped_out=False):
141
- results = ParallelLinear.apply(inputs, expert_weights, k,
142
- sorted_expert_idxs, sorted_scattered_idxs,
143
- expert_offsets,
144
- expert_biases,
145
- gates, grouped_in, grouped_out)
146
- return results
147
-
148
- class ParallelExperts(nn.Module):
149
- def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
150
- super().__init__()
151
- self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
152
-
153
- if bias:
154
- self.bias = nn.Parameter(torch.empty(num_experts, output_size))
155
- else:
156
- self.bias = None
157
-
158
- self.num_experts = num_experts
159
- self.input_size = input_size
160
- self.output_size = output_size
161
- self.reset_parameters()
162
-
163
- def extra_repr(self):
164
- return 'num_experts={}, input_size={}, output_size={}'.format(
165
- self.num_experts, self.input_size, self.output_size)
166
-
167
- def reset_parameters(self) -> None:
168
- nn.init.normal_(self.weight, std=0.02)
169
- if self.bias is not None:
170
- nn.init.zeros_(self.bias)
171
-
172
- def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
173
- expert_offsets,
174
- gates=None, grouped_in=False, grouped_out=False):
175
-
176
- results = parallel_linear(
177
- inputs, self.weight.permute(0, 2, 1), k,
178
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
179
- expert_biases=self.bias,
180
- gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
181
- )
182
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/scattermoe/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
- from . import parallel_experts
3
- from . import kernels
4
- from . import layers
5
-
6
- __all__ = [
7
- "flatten_sort_count",
8
- "parallel_linear",
9
- "ParallelExperts",
10
- "parallel_experts",
11
- "kernels",
12
- "layers"
13
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._scattermoe_05b9d77
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_scattermoe_05b9d77::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-rocm/kernels/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from . import ops
2
-
3
- __all__ = ["ops"]
 
 
 
 
build/torch-rocm/kernels/ops.py DELETED
@@ -1,457 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
- from typing import Optional
5
-
6
- BLOCK_M = 128
7
- ALLOW_TF32 = True
8
-
9
-
10
-
11
- @triton.jit
12
- def _compute_expert_block(
13
- E_idx, E_mask,
14
- M_in_idx,
15
- N_block, N_mask,
16
- X_ptr, stride_xm, stride_xk,
17
- W_ptr, stride_we, stride_wk, stride_wn,
18
- K,
19
- acc,
20
- no_k_mask,
21
- BLOCK_K,
22
- allow_tf32=True,
23
- ):
24
-
25
- K_block = tl.arange(0, BLOCK_K)
26
- X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
27
- W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
28
- iters = tl.cdiv(K, BLOCK_K)
29
-
30
- for K_block_id in range(iters):
31
- if no_k_mask:
32
- x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
33
- w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
34
- else:
35
- K_mask = (K_block_id * BLOCK_K + K_block) < K
36
- x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
37
- w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
38
-
39
- X_blk_ptrs += BLOCK_K * stride_xk
40
- W_blk_ptrs += BLOCK_K * stride_wk
41
- acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
42
- return acc
43
-
44
-
45
- def _scatter2scatter_configs():
46
- return [
47
- triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
48
- ]
49
-
50
- @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
51
- @triton.heuristics({
52
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
53
- "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
54
- })
55
- @triton.jit
56
- def _scatter2scatter(
57
- X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr,
58
- W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr,
59
- Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr,
60
- B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr,
61
- grouped_idx_ptr, expert_idxs_ptr,
62
- # block_start_idx_ptr,
63
- FAN_OUT: tl.constexpr,
64
- M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
65
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
66
- ACC_TYPE: tl.constexpr,
67
- # OUT_M,
68
- allow_tf32: tl.constexpr,
69
- x_grouped: tl.constexpr, y_grouped: tl.constexpr,
70
- NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
71
- ):
72
- pid = tl.program_id(axis=0)
73
-
74
- N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
75
- M_block_id = pid // N_BLOCK_COUNT
76
- N_block_id = pid % N_BLOCK_COUNT
77
-
78
- M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
79
- N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
80
- N_mask = N_block < N
81
- M_boundary_mask = M_block < (FAN_OUT * M)
82
- E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
83
-
84
- no_k_mask = K % BLOCK_K == 0
85
-
86
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
87
- E_first_idx = tl.min(E_idxs)
88
- E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
89
- M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
90
- for E_idx in range(E_first_idx, E_last_idx + 1):
91
- E_mask = E_idxs == E_idx
92
- E_M_idx = M_idx
93
- if x_grouped:
94
- M_in_idx = M_block
95
- else:
96
- M_in_idx = E_M_idx // FAN_OUT
97
- acc = _compute_expert_block(
98
- E_idx, E_mask,
99
- M_in_idx, N_block, N_mask,
100
- X_ptr, stride_xm, stride_xk,
101
- W_ptr, stride_we, stride_wk, stride_wn,
102
- K,
103
- acc,
104
- no_k_mask,
105
- BLOCK_K,
106
- allow_tf32=allow_tf32,
107
- )
108
-
109
- if B_ptr is not None:
110
- B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
111
- acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
112
-
113
- if y_grouped:
114
- M_out_idx = M_block
115
- else:
116
- M_out_idx = M_idx
117
- Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
118
- tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
119
-
120
- def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
121
- b=None,
122
- x_grouped=False, y_grouped=False,
123
- out=None):
124
- assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
125
- assert sorted_scattered_idxs.size(0) == X.size(0) * k
126
- # Pre-kernel setup
127
- y_dim = W.size(-1)
128
- L_scattered = sorted_expert_idxs.size(0)
129
- if out is None:
130
- output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
131
- else:
132
- assert out.size(0) == L_scattered and out.size(1) == y_dim
133
- output = out
134
-
135
- scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs,
136
- b, x_grouped, y_grouped)
137
- return output
138
-
139
-
140
- @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
141
- def scatter2scatter_compileable(
142
- output: torch.Tensor,
143
- W: torch.Tensor,
144
- X: torch.Tensor,
145
- k: int,
146
- sorted_expert_idxs: torch.Tensor,
147
- sorted_scattered_idxs: torch.Tensor,
148
- b: Optional[torch.Tensor],
149
- x_grouped: bool, y_grouped: bool) -> None:
150
- def grid(META):
151
- grid_num = (
152
- triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) *
153
- triton.cdiv(META['N'], META['BLOCK_N']),
154
- )
155
- return grid_num
156
-
157
- if b is None:
158
- b = None
159
- stride_be = stride_bk = 0
160
- else:
161
- stride_be, stride_bk = b.stride()
162
-
163
- _scatter2scatter[grid](
164
- # X_ptr, stride_xm, stride_xk,
165
- X, X.stride(0), X.stride(1),
166
- # W_ptr, stride_we, stride_wk, stride_wn,
167
- W, W.stride(0), W.stride(1), W.stride(2),
168
- # Y_ptr, stride_ym, stride_yn,
169
- output, output.stride(0), output.stride(1),
170
- # B_ptr, stride_be, stride_bk
171
- b, stride_be, stride_bk,
172
- grouped_idx_ptr=sorted_scattered_idxs,
173
- expert_idxs_ptr=sorted_expert_idxs,
174
- # block_start_idx_ptr=padded_block_idxs,
175
- FAN_OUT=k,
176
- M=X.size(0),
177
- K=X.size(1),
178
- N=output.size(1), E=W.size(0),
179
- BLOCK_M=BLOCK_M,
180
- ACC_TYPE=tl.float32,
181
- allow_tf32=ALLOW_TF32,
182
- x_grouped=x_grouped, y_grouped=y_grouped,
183
- )
184
-
185
-
186
- def _config_XtY():
187
- return [
188
- triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
189
- ]
190
-
191
- def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
192
- DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
193
- DW = DWt.permute(0, 2, 1)
194
- if has_bias:
195
- Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
196
- else:
197
- Db = None
198
- groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
199
- return DW, Db
200
-
201
-
202
- @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"})
203
- def groupXtY_compileable(
204
- E: int,
205
- DW: torch.Tensor,
206
- Db: Optional[torch.Tensor],
207
- DY: torch.Tensor,
208
- X: torch.Tensor,
209
- expert_offsets: torch.Tensor) -> None:
210
- def grid(META):
211
- grid = (
212
- E * triton.cdiv(META['K'], META['BLOCK_K']),
213
- triton.cdiv(META['N'], META['BLOCK_N']),
214
- )
215
- return grid
216
-
217
- if Db is None:
218
- stride_dbe = 0
219
- stride_dbn = 0
220
- else:
221
- stride_dbe, stride_dbn = Db.stride()
222
-
223
- _groupXtY[grid](
224
- # DY_ptr, stride_dym, stride_dyk,
225
- DY, DY.stride(0), DY.stride(1),
226
- # X_ptr, stride_xm, stride_xn,
227
- X, X.stride(0), X.stride(1),
228
- # DW_ptr, stride_dwe, stride_dwk, stride_dwn,
229
- DW, DW.stride(0), DW.stride(1), DW.stride(2),
230
- # Db_ptr, stride_dwe, stride_dbn,
231
- Db, stride_dbe, stride_dbn,
232
- # expert_offsets_ptr,
233
- expert_offsets,
234
- # K: tl.constexpr, N: tl.constexpr,
235
- M=DY.size(0), N=DY.size(-1), K=X.size(-1),
236
- # ACC_TYPE: tl.constexpr,
237
- ACC_TYPE=tl.float32,
238
- allow_tf32=ALLOW_TF32
239
- )
240
-
241
-
242
- @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
243
- @triton.heuristics({
244
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
245
- "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
246
- })
247
- @triton.jit
248
- def _groupXtY(
249
- DY_ptr, stride_dym, stride_dyk,
250
- X_ptr, stride_xm, stride_xn,
251
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
252
- Db_ptr, stride_dbe, stride_dbn,
253
- expert_offsets_ptr,
254
- M, K: tl.constexpr, N: tl.constexpr,
255
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
256
- ACC_TYPE: tl.constexpr,
257
- allow_tf32: tl.constexpr,
258
- NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
259
- ):
260
- pid0 = tl.program_id(axis=0)
261
- pid1 = tl.program_id(axis=1)
262
- num0 = tl.num_programs(0)
263
- num1 = tl.num_programs(1)
264
- # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
265
- pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
266
-
267
- K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
268
- E_idx = pid0 // K_BLOCK_COUNT
269
- K_block_id = pid0 % K_BLOCK_COUNT
270
- N_block_id = pid1
271
-
272
- if E_idx == 0:
273
- start_idx = 0
274
- else:
275
- start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
276
- end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
277
-
278
-
279
- if end_idx > start_idx:
280
- M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
281
-
282
- K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
283
- K_mask = K_block < K
284
- K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
285
-
286
- N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
287
- N_mask = N_block < N
288
- N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
289
-
290
- M_idxs = M_block
291
- xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
292
- dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
293
- if (Db_ptr is not None) and (K_block_id == 0):
294
- _xty_and_bias(
295
- E_idx, start_idx, end_idx,
296
- M_block,
297
- K_block, K_mask, N_block, N_mask,
298
- dy_blk_ptrs, stride_dym,
299
- xt_blk_ptrs, stride_xm,
300
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
301
- Db_ptr, stride_dbe, stride_dbn,
302
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
303
- allow_tf32, NO_K_MASK, NO_N_MASK,
304
- compute_bias=True
305
- )
306
- else:
307
- _xty_and_bias(
308
- E_idx, start_idx, end_idx,
309
- M_block,
310
- K_block, K_mask, N_block, N_mask,
311
- dy_blk_ptrs, stride_dym,
312
- xt_blk_ptrs, stride_xm,
313
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
314
- Db_ptr, stride_dbe, stride_dbn,
315
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
316
- allow_tf32, NO_K_MASK, NO_N_MASK,
317
- compute_bias=False
318
- )
319
-
320
-
321
- @triton.jit
322
- def _xty_and_bias(
323
- E_idx, start_idx, end_idx,
324
- M_block,
325
- K_block, K_mask, N_block, N_mask,
326
- dy_blk_ptrs, stride_dym,
327
- xt_blk_ptrs, stride_xm,
328
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
329
- Db_ptr, stride_dbe, stride_dbn,
330
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
331
- allow_tf32, NO_K_MASK, NO_N_MASK,
332
- compute_bias: tl.constexpr
333
- ):
334
-
335
- if compute_bias:
336
- db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
337
- else:
338
- db_acc = None
339
-
340
- acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
341
- iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
342
- for i in range(0, iters):
343
- M_mask = (i * BLOCK_M + M_block) < end_idx
344
- if NO_K_MASK:
345
- xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
346
- else:
347
- xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
348
- if NO_N_MASK:
349
- dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
350
- else:
351
- dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
352
-
353
- acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
354
-
355
- xt_blk_ptrs += BLOCK_M * stride_xm
356
- dy_blk_ptrs += BLOCK_M * stride_dym
357
-
358
- if compute_bias:
359
- db_acc += tl.sum(dy, axis=0)
360
-
361
- DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
362
- acc = acc.to(DW_blk_ptrs.dtype.element_ty)
363
- tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
364
- if compute_bias:
365
- Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
366
- tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
367
-
368
-
369
-
370
- def _config_grouping():
371
- return [
372
- triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
373
- # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
374
- # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
375
- ]
376
-
377
- def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
378
- N = sorted_expert_idxs.size(0)
379
- K = A.size(1)
380
- assert A.size(0) * fan_out == N
381
- if out is not None:
382
- Y = out
383
- else:
384
- Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
385
- group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
386
- return Y
387
-
388
-
389
- @torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
390
- def group_compileable(
391
- A: torch.Tensor,
392
- K: int,
393
- N: int,
394
- Y: torch.Tensor,
395
- coeff: torch.Tensor, has_coeff: bool,
396
- fan_out: int,
397
- sorted_expert_idxs: torch.Tensor) -> None:
398
- def grid(META):
399
- grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
400
- return grid_num
401
- _group[grid](
402
- # A_ptr, stride_an, stride_ai,
403
- A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out,
404
- # Y_ptr, stride_yn, stride_yk,
405
- Y, Y.stride(0), Y.stride(1),
406
- # grouped_idx_ptr,
407
- sorted_expert_idxs,
408
- # N: tl.constexpr, K: tl.constexpr,
409
- N, K
410
- )
411
-
412
-
413
- @triton.autotune(configs=_config_grouping(), key=['K'])
414
- @triton.heuristics({
415
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
416
- })
417
- @triton.jit
418
- def _group(
419
- src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
420
- tgt_ptr, stride_tn, stride_ti,
421
- grouped_idx_ptr,
422
- N, K: tl.constexpr,
423
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
424
- NO_K_MASK: tl.constexpr
425
- ):
426
- pid = tl.program_id(axis=0)
427
-
428
- N_block_id = pid
429
- N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
430
- N_mask = N_blk < N
431
- N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
432
- N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
433
-
434
- K_blk = tl.arange(0, BLOCK_K)
435
- src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
436
- tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
437
-
438
- if has_coeff:
439
- c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
440
-
441
- iters = tl.cdiv(K, BLOCK_K)
442
- for i in range(0, iters):
443
- if NO_K_MASK or i < iters - 1:
444
- block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
445
- if has_coeff:
446
- block *= c
447
- tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
448
-
449
- else:
450
- K_mask = (i * BLOCK_K + K_blk) < K
451
- mask = N_mask[:, None] & K_mask[None, :]
452
- block = tl.load(src_blk_ptrs, mask=mask)
453
- if has_coeff:
454
- block *= c
455
- tl.store(tgt_blk_ptrs, block, mask=mask)
456
- src_blk_ptrs += BLOCK_K * stride_sk
457
- tgt_blk_ptrs += BLOCK_K * stride_ti
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/kernels/single.py DELETED
@@ -1,59 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
-
5
- @triton.jit
6
- def _single2scatter(
7
- X_ptr, stride_xm, stride_xk,
8
- W_ptr, stride_we, stride_wk, stride_wn,
9
- Y_ptr, stride_ym, stride_yn,
10
- expert_idxs_ptr,
11
- FAN_OUT: tl.constexpr,
12
- K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
13
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
14
- ACC_TYPE: tl.constexpr,
15
- ):
16
- pid0 = tl.program_id(axis=0)
17
- pid1 = tl.program_id(axis=1)
18
-
19
- N_block_id = pid0
20
- if FAN_OUT == 1:
21
- in_idx = pid1
22
- else:
23
- in_idx = 0
24
- out_idx = pid1
25
-
26
- K_block = tl.arange(0, BLOCK_K)
27
- N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
28
- E_idx = tl.load(expert_idxs_ptr + pid1)
29
- X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
30
- W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
31
- acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
32
- for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
33
- x = tl.load(X_blk_ptrs)
34
- w = tl.load(W_blk_ptrs)
35
- acc += tl.sum(x * w, axis=0)[None, :]
36
- X_blk_ptrs += BLOCK_K * stride_xk
37
- W_blk_ptrs += BLOCK_K * stride_wk
38
- Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
39
- tl.store(Y_blk_ptrs, acc)
40
-
41
- def single2scatter(X, W, expert_idxs):
42
- E, xdim, ydim = W.size()
43
- k = expert_idxs.size(1)
44
- assert X.size(0) == k or X.size(0) == 1
45
- Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
46
- BLOCK_N = 128
47
- BLOCK_K = 128
48
- grid = ydim // BLOCK_N, k
49
- _single2scatter[grid](
50
- X, X.stride(0), X.stride(1),
51
- W, W.stride(0), W.stride(1), W.stride(2),
52
- Y, Y.stride(0), Y.stride(1),
53
- expert_idxs,
54
- FAN_OUT=Y.size(0) // X.size(0),
55
- K=xdim, N=ydim, E=E,
56
- BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
57
- ACC_TYPE=tl.float32
58
- )
59
- return Y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/layers.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
- from torch import nn
4
-
5
- from . import parallel_linear, flatten_sort_count
6
-
7
- class ScatterMoEGatedMLP(nn.Module):
8
- def forward(self, layer_input):
9
- """
10
- Forward pass of the mixture of experts layer.
11
-
12
- Args:
13
- layer_input (Tensor):
14
- Input tensor.
15
-
16
- Returns:
17
- Tensor:
18
- Output tensor.
19
- Tensor:
20
- Router logits.
21
- """
22
- bsz, length, emb_size = layer_input.size()
23
- layer_input = layer_input.reshape(-1, emb_size)
24
- # compute the top_k routing decision
25
- router_logits = self.router.layer(layer_input)
26
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
27
- routing_weights, selected_experts = torch.topk(routing_weights, self.router.top_k, dim=-1)
28
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
29
- routing_weights = routing_weights.to(layer_input.dtype)
30
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
31
- flatten_sort_count(selected_experts, num_experts=self.router.num_experts)
32
-
33
- # compute experts
34
- gates, h = parallel_linear(
35
- layer_input, self.input_linear.weight.transpose(2, 1),
36
- self.router.top_k,
37
- sorted_expert_idxs, sorted_scattered_idxs,
38
- expert_offsets,
39
- grouped_in=False, grouped_out=True,
40
- ).chunk(2, dim=-1)
41
- h = self.activation(gates) * h
42
- layer_output = parallel_linear(
43
- h, self.output_linear.weight.transpose(2, 1),
44
- 1,
45
- sorted_expert_idxs, sorted_scattered_idxs,
46
- expert_offsets,
47
- grouped_in=True, grouped_out=False,
48
- gates=routing_weights
49
- )
50
- layer_output = layer_output.view(bsz, length, emb_size)
51
- return layer_output
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch-rocm/parallel_experts.py DELETED
@@ -1,182 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from . import kernels
4
- from typing import Optional
5
-
6
- @torch.library.custom_op("scattermoe::bincount", mutates_args={})
7
- def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
8
- return x.bincount(minlength=minlength)
9
-
10
- @compileable_bincount.register_fake
11
- def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
12
- return torch.empty(minlength, dtype=torch.long, device=x.device)
13
-
14
- @torch.compile
15
- def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
16
- with torch.no_grad():
17
- flattened_expert_idxs = expert_idxs.flatten()
18
- sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
19
- expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
20
- expert_offsets = expert_counts.cumsum(-1)
21
- return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
22
-
23
-
24
-
25
- class ParallelLinear(torch.autograd.Function):
26
- @staticmethod
27
- def forward(
28
- ctx,
29
- x: torch.Tensor, expert_weights: torch.Tensor, k: int,
30
- sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
31
- expert_offsets: torch.Tensor,
32
- expert_biases: Optional[torch.Tensor]=None,
33
- gates: Optional[torch.Tensor]=None,
34
- grouped_in: bool =False, grouped_out: bool=False,
35
- ):
36
- with torch.device(x.device):
37
- output = kernels.ops.scatter2scatter(
38
- X=x, W=expert_weights,
39
- b=expert_biases, k=k,
40
- sorted_expert_idxs=sorted_expert_idxs,
41
- sorted_scattered_idxs=sorted_scattered_idxs,
42
- x_grouped=grouped_in, y_grouped=grouped_out
43
- )
44
- if gates is not None:
45
- output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
46
- output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
47
- else:
48
- output_expanded = None
49
-
50
- ctx.save_for_backward(
51
- x, expert_weights,
52
- expert_biases,
53
- sorted_expert_idxs,
54
- sorted_scattered_idxs,
55
- expert_offsets,
56
- gates,
57
- output_expanded
58
- )
59
- ctx.grouped_in = grouped_in
60
- ctx.grouped_out = grouped_out
61
- ctx.k = k
62
- return output
63
- @staticmethod
64
- def backward(ctx, grad_out: torch.Tensor):
65
- with torch.device(grad_out.device):
66
- (x, expert_weights, expert_biases,
67
- sorted_expert_idxs,
68
- sorted_scattered_idxs,
69
- expert_offsets,
70
- gates, output_expanded) = ctx.saved_tensors
71
- k = ctx.k
72
- grouped_in = ctx.grouped_in
73
- grouped_out = ctx.grouped_out
74
- # print("backward")
75
-
76
- if gates is not None:
77
- # calculate gates gradient
78
- # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
79
- d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
80
- gates_flat = gates.flatten()
81
- gate_fan = gates.size(1)
82
- grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
83
- else:
84
- d_gates = None
85
- gates_flat = None
86
- gate_fan = 1
87
- grouped_grad_out = None
88
-
89
- if grouped_out:
90
- grouped_grad_out = grad_out
91
- else:
92
- grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
93
- fan_out=gate_fan, coeff=gates_flat,
94
- out=grouped_grad_out)
95
- if grouped_in:
96
- grouped_x = x
97
- d_expanded_input = None
98
- else:
99
- grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
100
- d_expanded_input = grouped_x
101
-
102
- d_weights, d_biases = kernels.ops.group_bwd_W(
103
- DY=grouped_grad_out, X=grouped_x,
104
- expert_offsets=expert_offsets,
105
- E=expert_weights.size(0),
106
- has_bias=expert_biases is not None
107
- )
108
-
109
-
110
- d_expanded_input = kernels.ops.scatter2scatter(
111
- X=grouped_grad_out, x_grouped=True,
112
- W=expert_weights.permute(0, 2, 1),
113
- sorted_expert_idxs=sorted_expert_idxs,
114
- sorted_scattered_idxs=sorted_scattered_idxs,
115
- k=1,
116
- y_grouped=grouped_in,
117
- out=d_expanded_input # Reuse grouped_x buffer
118
- )
119
-
120
- if k == 1:
121
- d_input = d_expanded_input
122
- else:
123
- d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
124
- # print("backward end.")
125
- return (
126
- # x, expert_weights,
127
- d_input, d_weights,
128
- # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
129
- None, None, None, None,
130
- # bias, gates
131
- d_biases, d_gates,
132
- # grouped_in, grouped_out,
133
- None, None
134
- )
135
-
136
- def parallel_linear(inputs, expert_weights, k,
137
- sorted_expert_idxs, sorted_scattered_idxs,
138
- expert_offsets,
139
- expert_biases=None,
140
- gates=None, grouped_in=False, grouped_out=False):
141
- results = ParallelLinear.apply(inputs, expert_weights, k,
142
- sorted_expert_idxs, sorted_scattered_idxs,
143
- expert_offsets,
144
- expert_biases,
145
- gates, grouped_in, grouped_out)
146
- return results
147
-
148
- class ParallelExperts(nn.Module):
149
- def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
150
- super().__init__()
151
- self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
152
-
153
- if bias:
154
- self.bias = nn.Parameter(torch.empty(num_experts, output_size))
155
- else:
156
- self.bias = None
157
-
158
- self.num_experts = num_experts
159
- self.input_size = input_size
160
- self.output_size = output_size
161
- self.reset_parameters()
162
-
163
- def extra_repr(self):
164
- return 'num_experts={}, input_size={}, output_size={}'.format(
165
- self.num_experts, self.input_size, self.output_size)
166
-
167
- def reset_parameters(self) -> None:
168
- nn.init.normal_(self.weight, std=0.02)
169
- if self.bias is not None:
170
- nn.init.zeros_(self.bias)
171
-
172
- def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
173
- expert_offsets,
174
- gates=None, grouped_in=False, grouped_out=False):
175
-
176
- results = parallel_linear(
177
- inputs, self.weight.permute(0, 2, 1), k,
178
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
179
- expert_biases=self.bias,
180
- gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
181
- )
182
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/scattermoe/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
- from . import parallel_experts
3
- from . import kernels
4
- from . import layers
5
-
6
- __all__ = [
7
- "flatten_sort_count",
8
- "parallel_linear",
9
- "ParallelExperts",
10
- "parallel_experts",
11
- "kernels",
12
- "layers"
13
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._scattermoe_05b9d77
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_scattermoe_05b9d77::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-xpu/kernels/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from . import ops
2
-
3
- __all__ = ["ops"]
 
 
 
 
build/torch-xpu/kernels/ops.py DELETED
@@ -1,457 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
- from typing import Optional
5
-
6
- BLOCK_M = 128
7
- ALLOW_TF32 = True
8
-
9
-
10
-
11
- @triton.jit
12
- def _compute_expert_block(
13
- E_idx, E_mask,
14
- M_in_idx,
15
- N_block, N_mask,
16
- X_ptr, stride_xm, stride_xk,
17
- W_ptr, stride_we, stride_wk, stride_wn,
18
- K,
19
- acc,
20
- no_k_mask,
21
- BLOCK_K,
22
- allow_tf32=True,
23
- ):
24
-
25
- K_block = tl.arange(0, BLOCK_K)
26
- X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
27
- W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
28
- iters = tl.cdiv(K, BLOCK_K)
29
-
30
- for K_block_id in range(iters):
31
- if no_k_mask:
32
- x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
33
- w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
34
- else:
35
- K_mask = (K_block_id * BLOCK_K + K_block) < K
36
- x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
37
- w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
38
-
39
- X_blk_ptrs += BLOCK_K * stride_xk
40
- W_blk_ptrs += BLOCK_K * stride_wk
41
- acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
42
- return acc
43
-
44
-
45
- def _scatter2scatter_configs():
46
- return [
47
- triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
48
- ]
49
-
50
- @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
51
- @triton.heuristics({
52
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
53
- "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
54
- })
55
- @triton.jit
56
- def _scatter2scatter(
57
- X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr,
58
- W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr,
59
- Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr,
60
- B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr,
61
- grouped_idx_ptr, expert_idxs_ptr,
62
- # block_start_idx_ptr,
63
- FAN_OUT: tl.constexpr,
64
- M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
65
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
66
- ACC_TYPE: tl.constexpr,
67
- # OUT_M,
68
- allow_tf32: tl.constexpr,
69
- x_grouped: tl.constexpr, y_grouped: tl.constexpr,
70
- NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
71
- ):
72
- pid = tl.program_id(axis=0)
73
-
74
- N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
75
- M_block_id = pid // N_BLOCK_COUNT
76
- N_block_id = pid % N_BLOCK_COUNT
77
-
78
- M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
79
- N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
80
- N_mask = N_block < N
81
- M_boundary_mask = M_block < (FAN_OUT * M)
82
- E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
83
-
84
- no_k_mask = K % BLOCK_K == 0
85
-
86
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
87
- E_first_idx = tl.min(E_idxs)
88
- E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
89
- M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
90
- for E_idx in range(E_first_idx, E_last_idx + 1):
91
- E_mask = E_idxs == E_idx
92
- E_M_idx = M_idx
93
- if x_grouped:
94
- M_in_idx = M_block
95
- else:
96
- M_in_idx = E_M_idx // FAN_OUT
97
- acc = _compute_expert_block(
98
- E_idx, E_mask,
99
- M_in_idx, N_block, N_mask,
100
- X_ptr, stride_xm, stride_xk,
101
- W_ptr, stride_we, stride_wk, stride_wn,
102
- K,
103
- acc,
104
- no_k_mask,
105
- BLOCK_K,
106
- allow_tf32=allow_tf32,
107
- )
108
-
109
- if B_ptr is not None:
110
- B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
111
- acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
112
-
113
- if y_grouped:
114
- M_out_idx = M_block
115
- else:
116
- M_out_idx = M_idx
117
- Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
118
- tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
119
-
120
- def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
121
- b=None,
122
- x_grouped=False, y_grouped=False,
123
- out=None):
124
- assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
125
- assert sorted_scattered_idxs.size(0) == X.size(0) * k
126
- # Pre-kernel setup
127
- y_dim = W.size(-1)
128
- L_scattered = sorted_expert_idxs.size(0)
129
- if out is None:
130
- output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
131
- else:
132
- assert out.size(0) == L_scattered and out.size(1) == y_dim
133
- output = out
134
-
135
- scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs,
136
- b, x_grouped, y_grouped)
137
- return output
138
-
139
-
140
- @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
141
- def scatter2scatter_compileable(
142
- output: torch.Tensor,
143
- W: torch.Tensor,
144
- X: torch.Tensor,
145
- k: int,
146
- sorted_expert_idxs: torch.Tensor,
147
- sorted_scattered_idxs: torch.Tensor,
148
- b: Optional[torch.Tensor],
149
- x_grouped: bool, y_grouped: bool) -> None:
150
- def grid(META):
151
- grid_num = (
152
- triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) *
153
- triton.cdiv(META['N'], META['BLOCK_N']),
154
- )
155
- return grid_num
156
-
157
- if b is None:
158
- b = None
159
- stride_be = stride_bk = 0
160
- else:
161
- stride_be, stride_bk = b.stride()
162
-
163
- _scatter2scatter[grid](
164
- # X_ptr, stride_xm, stride_xk,
165
- X, X.stride(0), X.stride(1),
166
- # W_ptr, stride_we, stride_wk, stride_wn,
167
- W, W.stride(0), W.stride(1), W.stride(2),
168
- # Y_ptr, stride_ym, stride_yn,
169
- output, output.stride(0), output.stride(1),
170
- # B_ptr, stride_be, stride_bk
171
- b, stride_be, stride_bk,
172
- grouped_idx_ptr=sorted_scattered_idxs,
173
- expert_idxs_ptr=sorted_expert_idxs,
174
- # block_start_idx_ptr=padded_block_idxs,
175
- FAN_OUT=k,
176
- M=X.size(0),
177
- K=X.size(1),
178
- N=output.size(1), E=W.size(0),
179
- BLOCK_M=BLOCK_M,
180
- ACC_TYPE=tl.float32,
181
- allow_tf32=ALLOW_TF32,
182
- x_grouped=x_grouped, y_grouped=y_grouped,
183
- )
184
-
185
-
186
- def _config_XtY():
187
- return [
188
- triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
189
- ]
190
-
191
- def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
192
- DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
193
- DW = DWt.permute(0, 2, 1)
194
- if has_bias:
195
- Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
196
- else:
197
- Db = None
198
- groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
199
- return DW, Db
200
-
201
-
202
- @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"})
203
- def groupXtY_compileable(
204
- E: int,
205
- DW: torch.Tensor,
206
- Db: Optional[torch.Tensor],
207
- DY: torch.Tensor,
208
- X: torch.Tensor,
209
- expert_offsets: torch.Tensor) -> None:
210
- def grid(META):
211
- grid = (
212
- E * triton.cdiv(META['K'], META['BLOCK_K']),
213
- triton.cdiv(META['N'], META['BLOCK_N']),
214
- )
215
- return grid
216
-
217
- if Db is None:
218
- stride_dbe = 0
219
- stride_dbn = 0
220
- else:
221
- stride_dbe, stride_dbn = Db.stride()
222
-
223
- _groupXtY[grid](
224
- # DY_ptr, stride_dym, stride_dyk,
225
- DY, DY.stride(0), DY.stride(1),
226
- # X_ptr, stride_xm, stride_xn,
227
- X, X.stride(0), X.stride(1),
228
- # DW_ptr, stride_dwe, stride_dwk, stride_dwn,
229
- DW, DW.stride(0), DW.stride(1), DW.stride(2),
230
- # Db_ptr, stride_dwe, stride_dbn,
231
- Db, stride_dbe, stride_dbn,
232
- # expert_offsets_ptr,
233
- expert_offsets,
234
- # K: tl.constexpr, N: tl.constexpr,
235
- M=DY.size(0), N=DY.size(-1), K=X.size(-1),
236
- # ACC_TYPE: tl.constexpr,
237
- ACC_TYPE=tl.float32,
238
- allow_tf32=ALLOW_TF32
239
- )
240
-
241
-
242
- @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
243
- @triton.heuristics({
244
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
245
- "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
246
- })
247
- @triton.jit
248
- def _groupXtY(
249
- DY_ptr, stride_dym, stride_dyk,
250
- X_ptr, stride_xm, stride_xn,
251
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
252
- Db_ptr, stride_dbe, stride_dbn,
253
- expert_offsets_ptr,
254
- M, K: tl.constexpr, N: tl.constexpr,
255
- BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
256
- ACC_TYPE: tl.constexpr,
257
- allow_tf32: tl.constexpr,
258
- NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
259
- ):
260
- pid0 = tl.program_id(axis=0)
261
- pid1 = tl.program_id(axis=1)
262
- num0 = tl.num_programs(0)
263
- num1 = tl.num_programs(1)
264
- # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
265
- pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
266
-
267
- K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
268
- E_idx = pid0 // K_BLOCK_COUNT
269
- K_block_id = pid0 % K_BLOCK_COUNT
270
- N_block_id = pid1
271
-
272
- if E_idx == 0:
273
- start_idx = 0
274
- else:
275
- start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
276
- end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
277
-
278
-
279
- if end_idx > start_idx:
280
- M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
281
-
282
- K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
283
- K_mask = K_block < K
284
- K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
285
-
286
- N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
287
- N_mask = N_block < N
288
- N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
289
-
290
- M_idxs = M_block
291
- xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
292
- dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
293
- if (Db_ptr is not None) and (K_block_id == 0):
294
- _xty_and_bias(
295
- E_idx, start_idx, end_idx,
296
- M_block,
297
- K_block, K_mask, N_block, N_mask,
298
- dy_blk_ptrs, stride_dym,
299
- xt_blk_ptrs, stride_xm,
300
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
301
- Db_ptr, stride_dbe, stride_dbn,
302
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
303
- allow_tf32, NO_K_MASK, NO_N_MASK,
304
- compute_bias=True
305
- )
306
- else:
307
- _xty_and_bias(
308
- E_idx, start_idx, end_idx,
309
- M_block,
310
- K_block, K_mask, N_block, N_mask,
311
- dy_blk_ptrs, stride_dym,
312
- xt_blk_ptrs, stride_xm,
313
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
314
- Db_ptr, stride_dbe, stride_dbn,
315
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
316
- allow_tf32, NO_K_MASK, NO_N_MASK,
317
- compute_bias=False
318
- )
319
-
320
-
321
- @triton.jit
322
- def _xty_and_bias(
323
- E_idx, start_idx, end_idx,
324
- M_block,
325
- K_block, K_mask, N_block, N_mask,
326
- dy_blk_ptrs, stride_dym,
327
- xt_blk_ptrs, stride_xm,
328
- DW_ptr, stride_dwe, stride_dwk, stride_dwn,
329
- Db_ptr, stride_dbe, stride_dbn,
330
- BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
331
- allow_tf32, NO_K_MASK, NO_N_MASK,
332
- compute_bias: tl.constexpr
333
- ):
334
-
335
- if compute_bias:
336
- db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
337
- else:
338
- db_acc = None
339
-
340
- acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
341
- iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
342
- for i in range(0, iters):
343
- M_mask = (i * BLOCK_M + M_block) < end_idx
344
- if NO_K_MASK:
345
- xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
346
- else:
347
- xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
348
- if NO_N_MASK:
349
- dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
350
- else:
351
- dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
352
-
353
- acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
354
-
355
- xt_blk_ptrs += BLOCK_M * stride_xm
356
- dy_blk_ptrs += BLOCK_M * stride_dym
357
-
358
- if compute_bias:
359
- db_acc += tl.sum(dy, axis=0)
360
-
361
- DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
362
- acc = acc.to(DW_blk_ptrs.dtype.element_ty)
363
- tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
364
- if compute_bias:
365
- Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
366
- tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
367
-
368
-
369
-
370
- def _config_grouping():
371
- return [
372
- triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
373
- # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
374
- # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
375
- ]
376
-
377
- def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
378
- N = sorted_expert_idxs.size(0)
379
- K = A.size(1)
380
- assert A.size(0) * fan_out == N
381
- if out is not None:
382
- Y = out
383
- else:
384
- Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
385
- group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
386
- return Y
387
-
388
-
389
- @torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
390
- def group_compileable(
391
- A: torch.Tensor,
392
- K: int,
393
- N: int,
394
- Y: torch.Tensor,
395
- coeff: torch.Tensor, has_coeff: bool,
396
- fan_out: int,
397
- sorted_expert_idxs: torch.Tensor) -> None:
398
- def grid(META):
399
- grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
400
- return grid_num
401
- _group[grid](
402
- # A_ptr, stride_an, stride_ai,
403
- A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out,
404
- # Y_ptr, stride_yn, stride_yk,
405
- Y, Y.stride(0), Y.stride(1),
406
- # grouped_idx_ptr,
407
- sorted_expert_idxs,
408
- # N: tl.constexpr, K: tl.constexpr,
409
- N, K
410
- )
411
-
412
-
413
- @triton.autotune(configs=_config_grouping(), key=['K'])
414
- @triton.heuristics({
415
- "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
416
- })
417
- @triton.jit
418
- def _group(
419
- src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
420
- tgt_ptr, stride_tn, stride_ti,
421
- grouped_idx_ptr,
422
- N, K: tl.constexpr,
423
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
424
- NO_K_MASK: tl.constexpr
425
- ):
426
- pid = tl.program_id(axis=0)
427
-
428
- N_block_id = pid
429
- N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
430
- N_mask = N_blk < N
431
- N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
432
- N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
433
-
434
- K_blk = tl.arange(0, BLOCK_K)
435
- src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
436
- tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
437
-
438
- if has_coeff:
439
- c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
440
-
441
- iters = tl.cdiv(K, BLOCK_K)
442
- for i in range(0, iters):
443
- if NO_K_MASK or i < iters - 1:
444
- block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
445
- if has_coeff:
446
- block *= c
447
- tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
448
-
449
- else:
450
- K_mask = (i * BLOCK_K + K_blk) < K
451
- mask = N_mask[:, None] & K_mask[None, :]
452
- block = tl.load(src_blk_ptrs, mask=mask)
453
- if has_coeff:
454
- block *= c
455
- tl.store(tgt_blk_ptrs, block, mask=mask)
456
- src_blk_ptrs += BLOCK_K * stride_sk
457
- tgt_blk_ptrs += BLOCK_K * stride_ti
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/kernels/single.py DELETED
@@ -1,59 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
-
5
- @triton.jit
6
- def _single2scatter(
7
- X_ptr, stride_xm, stride_xk,
8
- W_ptr, stride_we, stride_wk, stride_wn,
9
- Y_ptr, stride_ym, stride_yn,
10
- expert_idxs_ptr,
11
- FAN_OUT: tl.constexpr,
12
- K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
13
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
14
- ACC_TYPE: tl.constexpr,
15
- ):
16
- pid0 = tl.program_id(axis=0)
17
- pid1 = tl.program_id(axis=1)
18
-
19
- N_block_id = pid0
20
- if FAN_OUT == 1:
21
- in_idx = pid1
22
- else:
23
- in_idx = 0
24
- out_idx = pid1
25
-
26
- K_block = tl.arange(0, BLOCK_K)
27
- N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
28
- E_idx = tl.load(expert_idxs_ptr + pid1)
29
- X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
30
- W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
31
- acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
32
- for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
33
- x = tl.load(X_blk_ptrs)
34
- w = tl.load(W_blk_ptrs)
35
- acc += tl.sum(x * w, axis=0)[None, :]
36
- X_blk_ptrs += BLOCK_K * stride_xk
37
- W_blk_ptrs += BLOCK_K * stride_wk
38
- Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
39
- tl.store(Y_blk_ptrs, acc)
40
-
41
- def single2scatter(X, W, expert_idxs):
42
- E, xdim, ydim = W.size()
43
- k = expert_idxs.size(1)
44
- assert X.size(0) == k or X.size(0) == 1
45
- Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
46
- BLOCK_N = 128
47
- BLOCK_K = 128
48
- grid = ydim // BLOCK_N, k
49
- _single2scatter[grid](
50
- X, X.stride(0), X.stride(1),
51
- W, W.stride(0), W.stride(1), W.stride(2),
52
- Y, Y.stride(0), Y.stride(1),
53
- expert_idxs,
54
- FAN_OUT=Y.size(0) // X.size(0),
55
- K=xdim, N=ydim, E=E,
56
- BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
57
- ACC_TYPE=tl.float32
58
- )
59
- return Y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/layers.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
- from torch import nn
4
-
5
- from . import parallel_linear, flatten_sort_count
6
-
7
- class ScatterMoEGatedMLP(nn.Module):
8
- def forward(self, layer_input):
9
- """
10
- Forward pass of the mixture of experts layer.
11
-
12
- Args:
13
- layer_input (Tensor):
14
- Input tensor.
15
-
16
- Returns:
17
- Tensor:
18
- Output tensor.
19
- Tensor:
20
- Router logits.
21
- """
22
- bsz, length, emb_size = layer_input.size()
23
- layer_input = layer_input.reshape(-1, emb_size)
24
- # compute the top_k routing decision
25
- router_logits = self.router.layer(layer_input)
26
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
27
- routing_weights, selected_experts = torch.topk(routing_weights, self.router.top_k, dim=-1)
28
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
29
- routing_weights = routing_weights.to(layer_input.dtype)
30
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
31
- flatten_sort_count(selected_experts, num_experts=self.router.num_experts)
32
-
33
- # compute experts
34
- gates, h = parallel_linear(
35
- layer_input, self.input_linear.weight.transpose(2, 1),
36
- self.router.top_k,
37
- sorted_expert_idxs, sorted_scattered_idxs,
38
- expert_offsets,
39
- grouped_in=False, grouped_out=True,
40
- ).chunk(2, dim=-1)
41
- h = self.activation(gates) * h
42
- layer_output = parallel_linear(
43
- h, self.output_linear.weight.transpose(2, 1),
44
- 1,
45
- sorted_expert_idxs, sorted_scattered_idxs,
46
- expert_offsets,
47
- grouped_in=True, grouped_out=False,
48
- gates=routing_weights
49
- )
50
- layer_output = layer_output.view(bsz, length, emb_size)
51
- return layer_output
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch-xpu/parallel_experts.py DELETED
@@ -1,182 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from . import kernels
4
- from typing import Optional
5
-
6
- @torch.library.custom_op("scattermoe::bincount", mutates_args={})
7
- def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
8
- return x.bincount(minlength=minlength)
9
-
10
- @compileable_bincount.register_fake
11
- def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
12
- return torch.empty(minlength, dtype=torch.long, device=x.device)
13
-
14
- @torch.compile
15
- def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
16
- with torch.no_grad():
17
- flattened_expert_idxs = expert_idxs.flatten()
18
- sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
19
- expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
20
- expert_offsets = expert_counts.cumsum(-1)
21
- return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
22
-
23
-
24
-
25
- class ParallelLinear(torch.autograd.Function):
26
- @staticmethod
27
- def forward(
28
- ctx,
29
- x: torch.Tensor, expert_weights: torch.Tensor, k: int,
30
- sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
31
- expert_offsets: torch.Tensor,
32
- expert_biases: Optional[torch.Tensor]=None,
33
- gates: Optional[torch.Tensor]=None,
34
- grouped_in: bool =False, grouped_out: bool=False,
35
- ):
36
- with torch.device(x.device):
37
- output = kernels.ops.scatter2scatter(
38
- X=x, W=expert_weights,
39
- b=expert_biases, k=k,
40
- sorted_expert_idxs=sorted_expert_idxs,
41
- sorted_scattered_idxs=sorted_scattered_idxs,
42
- x_grouped=grouped_in, y_grouped=grouped_out
43
- )
44
- if gates is not None:
45
- output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
46
- output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
47
- else:
48
- output_expanded = None
49
-
50
- ctx.save_for_backward(
51
- x, expert_weights,
52
- expert_biases,
53
- sorted_expert_idxs,
54
- sorted_scattered_idxs,
55
- expert_offsets,
56
- gates,
57
- output_expanded
58
- )
59
- ctx.grouped_in = grouped_in
60
- ctx.grouped_out = grouped_out
61
- ctx.k = k
62
- return output
63
- @staticmethod
64
- def backward(ctx, grad_out: torch.Tensor):
65
- with torch.device(grad_out.device):
66
- (x, expert_weights, expert_biases,
67
- sorted_expert_idxs,
68
- sorted_scattered_idxs,
69
- expert_offsets,
70
- gates, output_expanded) = ctx.saved_tensors
71
- k = ctx.k
72
- grouped_in = ctx.grouped_in
73
- grouped_out = ctx.grouped_out
74
- # print("backward")
75
-
76
- if gates is not None:
77
- # calculate gates gradient
78
- # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
79
- d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
80
- gates_flat = gates.flatten()
81
- gate_fan = gates.size(1)
82
- grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
83
- else:
84
- d_gates = None
85
- gates_flat = None
86
- gate_fan = 1
87
- grouped_grad_out = None
88
-
89
- if grouped_out:
90
- grouped_grad_out = grad_out
91
- else:
92
- grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
93
- fan_out=gate_fan, coeff=gates_flat,
94
- out=grouped_grad_out)
95
- if grouped_in:
96
- grouped_x = x
97
- d_expanded_input = None
98
- else:
99
- grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
100
- d_expanded_input = grouped_x
101
-
102
- d_weights, d_biases = kernels.ops.group_bwd_W(
103
- DY=grouped_grad_out, X=grouped_x,
104
- expert_offsets=expert_offsets,
105
- E=expert_weights.size(0),
106
- has_bias=expert_biases is not None
107
- )
108
-
109
-
110
- d_expanded_input = kernels.ops.scatter2scatter(
111
- X=grouped_grad_out, x_grouped=True,
112
- W=expert_weights.permute(0, 2, 1),
113
- sorted_expert_idxs=sorted_expert_idxs,
114
- sorted_scattered_idxs=sorted_scattered_idxs,
115
- k=1,
116
- y_grouped=grouped_in,
117
- out=d_expanded_input # Reuse grouped_x buffer
118
- )
119
-
120
- if k == 1:
121
- d_input = d_expanded_input
122
- else:
123
- d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
124
- # print("backward end.")
125
- return (
126
- # x, expert_weights,
127
- d_input, d_weights,
128
- # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
129
- None, None, None, None,
130
- # bias, gates
131
- d_biases, d_gates,
132
- # grouped_in, grouped_out,
133
- None, None
134
- )
135
-
136
- def parallel_linear(inputs, expert_weights, k,
137
- sorted_expert_idxs, sorted_scattered_idxs,
138
- expert_offsets,
139
- expert_biases=None,
140
- gates=None, grouped_in=False, grouped_out=False):
141
- results = ParallelLinear.apply(inputs, expert_weights, k,
142
- sorted_expert_idxs, sorted_scattered_idxs,
143
- expert_offsets,
144
- expert_biases,
145
- gates, grouped_in, grouped_out)
146
- return results
147
-
148
- class ParallelExperts(nn.Module):
149
- def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
150
- super().__init__()
151
- self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
152
-
153
- if bias:
154
- self.bias = nn.Parameter(torch.empty(num_experts, output_size))
155
- else:
156
- self.bias = None
157
-
158
- self.num_experts = num_experts
159
- self.input_size = input_size
160
- self.output_size = output_size
161
- self.reset_parameters()
162
-
163
- def extra_repr(self):
164
- return 'num_experts={}, input_size={}, output_size={}'.format(
165
- self.num_experts, self.input_size, self.output_size)
166
-
167
- def reset_parameters(self) -> None:
168
- nn.init.normal_(self.weight, std=0.02)
169
- if self.bias is not None:
170
- nn.init.zeros_(self.bias)
171
-
172
- def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
173
- expert_offsets,
174
- gates=None, grouped_in=False, grouped_out=False):
175
-
176
- results = parallel_linear(
177
- inputs, self.weight.permute(0, 2, 1), k,
178
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
179
- expert_biases=self.bias,
180
- gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
181
- )
182
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/scattermoe/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))