Kernels
danieldk HF Staff commited on
Commit
578605e
·
verified ·
1 Parent(s): 9d52a27

Build uploaded using `kernels`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-aarch64-linux/{_megablocks_cuda_6e04dec.abi3.so → _megablocks_cuda_a45325d.abi3.so} +1 -1
  2. build/torch210-cxx11-cu126-aarch64-linux/_ops.py +3 -3
  3. build/torch210-cxx11-cu126-aarch64-linux/megablocks/__init__.py +2 -2
  4. build/torch210-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py +26 -26
  5. build/torch210-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py +362 -362
  6. build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py +45 -45
  7. build/torch210-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py +118 -118
  8. build/torch210-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py +27 -27
  9. build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py +35 -35
  10. build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py +116 -116
  11. build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py +52 -52
  12. build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py +63 -63
  13. build/torch210-cxx11-cu128-aarch64-linux/{_megablocks_cuda_6e04dec.abi3.so → _megablocks_cuda_a45325d.abi3.so} +1 -1
  14. build/torch210-cxx11-cu128-aarch64-linux/_ops.py +3 -3
  15. build/torch210-cxx11-cu128-aarch64-linux/megablocks/__init__.py +2 -2
  16. build/torch210-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py +26 -26
  17. build/torch210-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py +362 -362
  18. build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py +45 -45
  19. build/torch210-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py +118 -118
  20. build/torch210-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py +27 -27
  21. build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py +35 -35
  22. build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py +116 -116
  23. build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py +52 -52
  24. build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py +63 -63
  25. build/torch210-cxx11-cu130-aarch64-linux/{_megablocks_cuda_6e04dec.abi3.so → _megablocks_cuda_a45325d.abi3.so} +1 -1
  26. build/torch210-cxx11-cu130-aarch64-linux/_ops.py +3 -3
  27. build/torch210-cxx11-cu130-aarch64-linux/megablocks/__init__.py +2 -2
  28. build/torch210-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py +26 -26
  29. build/torch210-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py +362 -362
  30. build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py +45 -45
  31. build/torch210-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py +118 -118
  32. build/torch210-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py +27 -27
  33. build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py +35 -35
  34. build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py +116 -116
  35. build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py +52 -52
  36. build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py +63 -63
  37. build/torch211-cxx11-cu126-aarch64-linux/__init__.py +202 -0
  38. build/torch211-cxx11-cu126-aarch64-linux/_layers/__init__.py +10 -0
  39. build/torch211-cxx11-cu126-aarch64-linux/_layers/activation_fn.py +33 -0
  40. build/torch211-cxx11-cu126-aarch64-linux/_layers/all_to_all.py +54 -0
  41. build/torch211-cxx11-cu126-aarch64-linux/_layers/arguments.py +101 -0
  42. build/torch211-cxx11-cu126-aarch64-linux/_layers/common.py +26 -0
  43. build/torch211-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py +42 -0
  44. build/torch211-cxx11-cu126-aarch64-linux/_layers/dmoe.py +337 -0
  45. build/torch211-cxx11-cu126-aarch64-linux/_layers/gelu.py +52 -0
  46. build/torch211-cxx11-cu126-aarch64-linux/_layers/glu.py +244 -0
  47. build/torch211-cxx11-cu126-aarch64-linux/_layers/memory_test.py +103 -0
  48. build/torch211-cxx11-cu126-aarch64-linux/_layers/mlp.py +587 -0
  49. build/torch211-cxx11-cu126-aarch64-linux/_layers/moe.py +507 -0
  50. build/torch211-cxx11-cu126-aarch64-linux/_layers/mpu.py +94 -0
build/torch210-cxx11-cu126-aarch64-linux/{_megablocks_cuda_6e04dec.abi3.so → _megablocks_cuda_a45325d.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d43ea617155587acccc47750e126596b0438c63c7ada6f3607a2ed4603337f72
3
  size 15124328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77772a8c5e37277e5f14ff74222bfd44b64d78b0ecd36a4c4ce21008acc71d50
3
  size 15124328
build/torch210-cxx11-cu126-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_6e04dec
3
- ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_6e04dec::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_a45325d
3
+ ops = torch.ops._megablocks_cuda_a45325d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_a45325d::{op_name}"
build/torch210-cxx11-cu126-aarch64-linux/megablocks/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
  import ctypes
 
2
  import sys
3
-
4
- import importlib
5
  from pathlib import Path
6
  from types import ModuleType
7
 
 
8
  def _import_from_path(file_path: Path) -> ModuleType:
9
  # We cannot use the module name as-is, after adding it to `sys.modules`,
10
  # it would also be used for other imports. So, we make a module name that
 
1
  import ctypes
2
+ import importlib.util
3
  import sys
 
 
4
  from pathlib import Path
5
  from types import ModuleType
6
 
7
+
8
  def _import_from_path(file_path: Path) -> ModuleType:
9
  # We cannot use the module name as-is, after adding it to `sys.modules`,
10
  # it would also be used for other imports. So, we make a module name that
build/torch210-cxx11-cu126-aarch64-linux/ops/histogram_benchmark.py CHANGED
@@ -5,7 +5,7 @@ import unittest
5
 
6
  import numpy as np
7
  import torch
8
- from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
@@ -47,31 +47,31 @@ def log_benchmark(arguments, mean_t, std_t):
47
  print('=' * 60)
48
 
49
 
50
- class HistogramBenchmark(parameterized.TestCase):
51
-
52
- @parameterized.parameters(*_HISTOGRAM_TESTS)
53
- def testHistogram(self, n, dtype, max_val):
54
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
-
56
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
- arguments = {
58
- 'n': n,
59
- 'dtype': dtype,
60
- 'max_val': max_val,
61
- }
62
- log_benchmark(arguments, mean_t, std_t)
63
-
64
- @parameterized.parameters(*_HISTOGRAM_TESTS)
65
- def testTorchHistogram(self, n, dtype, max_val):
66
- x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
-
68
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
- arguments = {
70
- 'n': n,
71
- 'dtype': dtype,
72
- 'max_val': max_val,
73
- }
74
- log_benchmark(arguments, mean_t, std_t)
75
 
76
 
77
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import torch
8
+ # from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
 
47
  print('=' * 60)
48
 
49
 
50
+ # class HistogramBenchmark(parameterized.TestCase):
51
+ #
52
+ # @parameterized.parameters(*_HISTOGRAM_TESTS)
53
+ # def testHistogram(self, n, dtype, max_val):
54
+ # x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
+ #
56
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
+ # arguments = {
58
+ # 'n': n,
59
+ # 'dtype': dtype,
60
+ # 'max_val': max_val,
61
+ # }
62
+ # log_benchmark(arguments, mean_t, std_t)
63
+ #
64
+ # @parameterized.parameters(*_HISTOGRAM_TESTS)
65
+ # def testTorchHistogram(self, n, dtype, max_val):
66
+ # x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
+ #
68
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
+ # arguments = {
70
+ # 'n': n,
71
+ # 'dtype': dtype,
72
+ # 'max_val': max_val,
73
+ # }
74
+ # log_benchmark(arguments, mean_t, std_t)
75
 
76
 
77
  if __name__ == '__main__':
build/torch210-cxx11-cu126-aarch64-linux/ops/matmul_benchmark.py CHANGED
@@ -17,7 +17,7 @@ import unittest
17
  from .. import stk
18
 
19
  import torch
20
- from absl.testing import parameterized
21
 
22
  from .. import benchmark_util, ops
23
 
@@ -48,367 +48,367 @@ def log_benchmark(name, arguments, time, std, flops):
48
  print('=' * 60)
49
 
50
 
51
- class MatmulBenchmark(parameterized.TestCase):
52
-
53
- def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
- blocking = 128
55
- padded_tokens, _ = x.size()
56
- assert padded_tokens % blocking == 0
57
- assert fhs % blocking == 0
58
-
59
- # Offsets for the sparse matrix. All rows have the
60
- # same number of nonzero blocks dictated by the
61
- # dimensionality of a single expert.
62
- block_rows = padded_tokens // blocking
63
- blocks_per_row = fhs // blocking
64
- offsets = torch.arange(
65
- 0,
66
- block_rows * blocks_per_row + 1,
67
- blocks_per_row,
68
- dtype=torch.int32,
69
- device=x.device,
70
- )
71
-
72
- # Indices for the sparse matrix. The indices for
73
- # the intermediate matrix are dynamic depending
74
- # on the mapping of tokens to experts.
75
- column_indices = ops.topology(
76
- padded_bins,
77
- blocking,
78
- block_rows,
79
- blocks_per_row,
80
- )
81
- data = torch.empty(
82
- column_indices.numel(),
83
- blocking,
84
- blocking,
85
- dtype=torch.float16,
86
- device=x.device,
87
- )
88
- shape = (padded_tokens, fhs * ne)
89
- row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
- return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
-
92
- def build_input_matrix(self, sl, hs, ne):
93
- x = torch.randn((sl, hs)).cuda().half()
94
-
95
- # Assign tokens to experts uniformly.
96
- top_expert = torch.arange(0, sl).cuda().int() % ne
97
-
98
- bin_ids, indices = ops.sort(top_expert)
99
- tokens_per_expert = ops.histogram(top_expert, ne)
100
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
- out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
- return out, padded_bins
105
-
106
- def build_weight_matrix(self, ne, hs, fhs):
107
- return torch.randn((hs, ne * fhs)).cuda().half()
108
-
109
- @parameterized.parameters(*_MATMUL_TESTS)
110
- def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
- w = transpose_view(w)
115
-
116
- def benchmark():
117
- return stk.ops.sdd(x, w, topo)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'ffn_hidden_size': fhs,
124
- 'num_experts': ne,
125
- }
126
- log_benchmark(
127
- '0::Fwd::SDD::NT',
128
- arguments,
129
- mean_t,
130
- std_t,
131
- x.numel() * fhs * 2,
132
- )
133
-
134
- @parameterized.parameters(*_MATMUL_TESTS)
135
- def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
-
140
- def benchmark():
141
- return stk.ops.dsd(topo, w)
142
-
143
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
- arguments = {
145
- 'sequence_length': sl,
146
- 'hidden_size': hs,
147
- 'ffn_hidden_size': fhs,
148
- 'num_experts': ne,
149
- }
150
- log_benchmark(
151
- '0::GradX::DSD::NN',
152
- arguments,
153
- mean_t,
154
- std_t,
155
- x.numel() * fhs * 2,
156
- )
157
-
158
- @parameterized.parameters(*_MATMUL_TESTS)
159
- def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
- topo = topo.t()
163
-
164
- def benchmark():
165
- return stk.ops.dsd(topo, x)
166
-
167
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
- arguments = {
169
- 'sequence_length': sl,
170
- 'hidden_size': hs,
171
- 'ffn_hidden_size': fhs,
172
- 'num_experts': ne,
173
- }
174
- log_benchmark(
175
- '0::GradW::DSD::TN',
176
- arguments,
177
- mean_t,
178
- std_t,
179
- x.numel() * fhs * 2,
180
- )
181
-
182
- @parameterized.parameters(*_MATMUL_TESTS)
183
- def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
-
188
- def benchmark():
189
- return stk.ops.dsd(x, w)
190
-
191
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
- arguments = {
193
- 'sequence_length': sl,
194
- 'hidden_size': hs,
195
- 'ffn_hidden_size': fhs,
196
- 'num_experts': ne,
197
- }
198
- log_benchmark(
199
- '1::Fwd::DSD::NN',
200
- arguments,
201
- mean_t,
202
- std_t,
203
- x.nnz * hs * 2,
204
- )
205
-
206
- @parameterized.parameters(*_MATMUL_TESTS)
207
- def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
- out = stk.ops.dsd(x, w)
212
- w = transpose_view(w)
213
-
214
- def benchmark():
215
- return stk.ops.sdd(out, w, x)
216
-
217
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
- arguments = {
219
- 'sequence_length': sl,
220
- 'hidden_size': hs,
221
- 'ffn_hidden_size': fhs,
222
- 'num_experts': ne,
223
- }
224
- log_benchmark(
225
- '1::GradX::SDD::NT',
226
- arguments,
227
- mean_t,
228
- std_t,
229
- x.nnz * hs * 2,
230
- )
231
-
232
- @parameterized.parameters(*_MATMUL_TESTS)
233
- def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
- out = stk.ops.dsd(x, w)
238
- x = x.t()
239
-
240
- def benchmark():
241
- return stk.ops.dsd(x, out)
242
-
243
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
- arguments = {
245
- 'sequence_length': sl,
246
- 'hidden_size': hs,
247
- 'ffn_hidden_size': fhs,
248
- 'num_experts': ne,
249
- }
250
- log_benchmark(
251
- '1::GradW::DSD::TN',
252
- arguments,
253
- mean_t,
254
- std_t,
255
- x.nnz * hs * 2,
256
- )
257
-
258
- @parameterized.parameters(*_MATMUL_TESTS)
259
- def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
- assert (sl % ne) == 0
261
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
- w = torch.randn((ne, hs, fhs)).cuda().half()
263
-
264
- w = w.transpose(1, 2).contiguous()
265
- w = w.transpose(1, 2)
266
-
267
- def benchmark():
268
- return torch.bmm(x, w)
269
-
270
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
- arguments = {
272
- 'sequence_length': sl,
273
- 'hidden_size': hs,
274
- 'ffn_hidden_size': fhs,
275
- 'num_experts': ne,
276
- }
277
- log_benchmark(
278
- '0::Fwd:DDD::NT',
279
- arguments,
280
- mean_t,
281
- std_t,
282
- x.numel() * fhs * 2,
283
- )
284
-
285
- @parameterized.parameters(*_MATMUL_TESTS)
286
- def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
- assert (sl % ne) == 0
288
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
- w = torch.randn((ne, hs, fhs)).cuda().half()
290
- out = torch.bmm(x, w)
291
- w = w.transpose(1, 2).contiguous()
292
-
293
- def benchmark():
294
- return torch.bmm(out, w)
295
-
296
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
- arguments = {
298
- 'sequence_length': sl,
299
- 'hidden_size': hs,
300
- 'ffn_hidden_size': fhs,
301
- 'num_experts': ne,
302
- }
303
- log_benchmark(
304
- '0:GradX:DDD::NN',
305
- arguments,
306
- mean_t,
307
- std_t,
308
- x.numel() * fhs * 2,
309
- )
310
-
311
- @parameterized.parameters(*_MATMUL_TESTS)
312
- def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
- assert (sl % ne) == 0
314
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
- w = torch.randn((ne, hs, fhs)).cuda().half()
316
- out = torch.bmm(x, w)
317
- out = out.transpose(1, 2)
318
-
319
- def benchmark():
320
- return torch.bmm(out, x)
321
-
322
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
- arguments = {
324
- 'sequence_length': sl,
325
- 'hidden_size': hs,
326
- 'ffn_hidden_size': fhs,
327
- 'num_experts': ne,
328
- }
329
- log_benchmark(
330
- '0:GradW:DDD::TN',
331
- arguments,
332
- mean_t,
333
- std_t,
334
- x.numel() * fhs * 2,
335
- )
336
-
337
- @parameterized.parameters(*_MATMUL_TESTS)
338
- def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
- assert (sl % ne) == 0
340
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
- w = torch.randn((ne, fhs, hs)).cuda().half()
342
-
343
- def benchmark():
344
- return torch.bmm(x, w)
345
-
346
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
- arguments = {
348
- 'sequence_length': sl,
349
- 'hidden_size': hs,
350
- 'ffn_hidden_size': fhs,
351
- 'num_experts': ne,
352
- }
353
- log_benchmark(
354
- '1::Fwd::DDD::NN',
355
- arguments,
356
- mean_t,
357
- std_t,
358
- x.numel() * hs * 2,
359
- )
360
-
361
- @parameterized.parameters(*_MATMUL_TESTS)
362
- def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
- assert (sl % ne) == 0
364
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
- w = torch.randn((ne, fhs, hs)).cuda().half()
366
- out = torch.bmm(x, w)
367
- w = torch.transpose(w, 1, 2)
368
-
369
- def benchmark():
370
- return torch.bmm(out, w)
371
-
372
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
- arguments = {
374
- 'sequence_length': sl,
375
- 'hidden_size': hs,
376
- 'ffn_hidden_size': fhs,
377
- 'num_experts': ne,
378
- }
379
- log_benchmark(
380
- '1::GradX::DDD::NT',
381
- arguments,
382
- mean_t,
383
- std_t,
384
- x.numel() * hs * 2,
385
- )
386
-
387
- @parameterized.parameters(*_MATMUL_TESTS)
388
- def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
- assert (sl % ne) == 0
390
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
- w = torch.randn((ne, fhs, hs)).cuda().half()
392
- out = torch.bmm(x, w)
393
- x = torch.transpose(x, 1, 2)
394
-
395
- def benchmark():
396
- return torch.bmm(x, out)
397
-
398
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
- arguments = {
400
- 'sequence_length': sl,
401
- 'hidden_size': hs,
402
- 'ffn_hidden_size': fhs,
403
- 'num_experts': ne,
404
- }
405
- log_benchmark(
406
- '1::GradW::DDD::TN',
407
- arguments,
408
- mean_t,
409
- std_t,
410
- x.numel() * hs * 2,
411
- )
412
 
413
 
414
  if __name__ == '__main__':
 
17
  from .. import stk
18
 
19
  import torch
20
+ # from absl.testing import parameterized
21
 
22
  from .. import benchmark_util, ops
23
 
 
48
  print('=' * 60)
49
 
50
 
51
+ # class MatmulBenchmark(parameterized.TestCase):
52
+ #
53
+ # def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
+ # blocking = 128
55
+ # padded_tokens, _ = x.size()
56
+ # assert padded_tokens % blocking == 0
57
+ # assert fhs % blocking == 0
58
+ #
59
+ # # Offsets for the sparse matrix. All rows have the
60
+ # # same number of nonzero blocks dictated by the
61
+ # # dimensionality of a single expert.
62
+ # block_rows = padded_tokens // blocking
63
+ # blocks_per_row = fhs // blocking
64
+ # offsets = torch.arange(
65
+ # 0,
66
+ # block_rows * blocks_per_row + 1,
67
+ # blocks_per_row,
68
+ # dtype=torch.int32,
69
+ # device=x.device,
70
+ # )
71
+ #
72
+ # # Indices for the sparse matrix. The indices for
73
+ # # the intermediate matrix are dynamic depending
74
+ # # on the mapping of tokens to experts.
75
+ # column_indices = ops.topology(
76
+ # padded_bins,
77
+ # blocking,
78
+ # block_rows,
79
+ # blocks_per_row,
80
+ # )
81
+ # data = torch.empty(
82
+ # column_indices.numel(),
83
+ # blocking,
84
+ # blocking,
85
+ # dtype=torch.float16,
86
+ # device=x.device,
87
+ # )
88
+ # shape = (padded_tokens, fhs * ne)
89
+ # row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
+ # return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
+ #
92
+ # def build_input_matrix(self, sl, hs, ne):
93
+ # x = torch.randn((sl, hs)).cuda().half()
94
+ #
95
+ # # Assign tokens to experts uniformly.
96
+ # top_expert = torch.arange(0, sl).cuda().int() % ne
97
+ #
98
+ # bin_ids, indices = ops.sort(top_expert)
99
+ # tokens_per_expert = ops.histogram(top_expert, ne)
100
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
+ # out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
+ # return out, padded_bins
105
+ #
106
+ # def build_weight_matrix(self, ne, hs, fhs):
107
+ # return torch.randn((hs, ne * fhs)).cuda().half()
108
+ #
109
+ # @parameterized.parameters(*_MATMUL_TESTS)
110
+ # def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
+ # w = transpose_view(w)
115
+ #
116
+ # def benchmark():
117
+ # return stk.ops.sdd(x, w, topo)
118
+ #
119
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ # arguments = {
121
+ # 'sequence_length': sl,
122
+ # 'hidden_size': hs,
123
+ # 'ffn_hidden_size': fhs,
124
+ # 'num_experts': ne,
125
+ # }
126
+ # log_benchmark(
127
+ # '0::Fwd::SDD::NT',
128
+ # arguments,
129
+ # mean_t,
130
+ # std_t,
131
+ # x.numel() * fhs * 2,
132
+ # )
133
+ #
134
+ # @parameterized.parameters(*_MATMUL_TESTS)
135
+ # def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
+ #
140
+ # def benchmark():
141
+ # return stk.ops.dsd(topo, w)
142
+ #
143
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
+ # arguments = {
145
+ # 'sequence_length': sl,
146
+ # 'hidden_size': hs,
147
+ # 'ffn_hidden_size': fhs,
148
+ # 'num_experts': ne,
149
+ # }
150
+ # log_benchmark(
151
+ # '0::GradX::DSD::NN',
152
+ # arguments,
153
+ # mean_t,
154
+ # std_t,
155
+ # x.numel() * fhs * 2,
156
+ # )
157
+ #
158
+ # @parameterized.parameters(*_MATMUL_TESTS)
159
+ # def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
+ # topo = topo.t()
163
+ #
164
+ # def benchmark():
165
+ # return stk.ops.dsd(topo, x)
166
+ #
167
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
+ # arguments = {
169
+ # 'sequence_length': sl,
170
+ # 'hidden_size': hs,
171
+ # 'ffn_hidden_size': fhs,
172
+ # 'num_experts': ne,
173
+ # }
174
+ # log_benchmark(
175
+ # '0::GradW::DSD::TN',
176
+ # arguments,
177
+ # mean_t,
178
+ # std_t,
179
+ # x.numel() * fhs * 2,
180
+ # )
181
+ #
182
+ # @parameterized.parameters(*_MATMUL_TESTS)
183
+ # def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
+ #
188
+ # def benchmark():
189
+ # return stk.ops.dsd(x, w)
190
+ #
191
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
+ # arguments = {
193
+ # 'sequence_length': sl,
194
+ # 'hidden_size': hs,
195
+ # 'ffn_hidden_size': fhs,
196
+ # 'num_experts': ne,
197
+ # }
198
+ # log_benchmark(
199
+ # '1::Fwd::DSD::NN',
200
+ # arguments,
201
+ # mean_t,
202
+ # std_t,
203
+ # x.nnz * hs * 2,
204
+ # )
205
+ #
206
+ # @parameterized.parameters(*_MATMUL_TESTS)
207
+ # def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
+ # out = stk.ops.dsd(x, w)
212
+ # w = transpose_view(w)
213
+ #
214
+ # def benchmark():
215
+ # return stk.ops.sdd(out, w, x)
216
+ #
217
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
+ # arguments = {
219
+ # 'sequence_length': sl,
220
+ # 'hidden_size': hs,
221
+ # 'ffn_hidden_size': fhs,
222
+ # 'num_experts': ne,
223
+ # }
224
+ # log_benchmark(
225
+ # '1::GradX::SDD::NT',
226
+ # arguments,
227
+ # mean_t,
228
+ # std_t,
229
+ # x.nnz * hs * 2,
230
+ # )
231
+ #
232
+ # @parameterized.parameters(*_MATMUL_TESTS)
233
+ # def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
+ # out = stk.ops.dsd(x, w)
238
+ # x = x.t()
239
+ #
240
+ # def benchmark():
241
+ # return stk.ops.dsd(x, out)
242
+ #
243
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
+ # arguments = {
245
+ # 'sequence_length': sl,
246
+ # 'hidden_size': hs,
247
+ # 'ffn_hidden_size': fhs,
248
+ # 'num_experts': ne,
249
+ # }
250
+ # log_benchmark(
251
+ # '1::GradW::DSD::TN',
252
+ # arguments,
253
+ # mean_t,
254
+ # std_t,
255
+ # x.nnz * hs * 2,
256
+ # )
257
+ #
258
+ # @parameterized.parameters(*_MATMUL_TESTS)
259
+ # def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
+ # assert (sl % ne) == 0
261
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
263
+ #
264
+ # w = w.transpose(1, 2).contiguous()
265
+ # w = w.transpose(1, 2)
266
+ #
267
+ # def benchmark():
268
+ # return torch.bmm(x, w)
269
+ #
270
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
+ # arguments = {
272
+ # 'sequence_length': sl,
273
+ # 'hidden_size': hs,
274
+ # 'ffn_hidden_size': fhs,
275
+ # 'num_experts': ne,
276
+ # }
277
+ # log_benchmark(
278
+ # '0::Fwd:DDD::NT',
279
+ # arguments,
280
+ # mean_t,
281
+ # std_t,
282
+ # x.numel() * fhs * 2,
283
+ # )
284
+ #
285
+ # @parameterized.parameters(*_MATMUL_TESTS)
286
+ # def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
+ # assert (sl % ne) == 0
288
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
290
+ # out = torch.bmm(x, w)
291
+ # w = w.transpose(1, 2).contiguous()
292
+ #
293
+ # def benchmark():
294
+ # return torch.bmm(out, w)
295
+ #
296
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
+ # arguments = {
298
+ # 'sequence_length': sl,
299
+ # 'hidden_size': hs,
300
+ # 'ffn_hidden_size': fhs,
301
+ # 'num_experts': ne,
302
+ # }
303
+ # log_benchmark(
304
+ # '0:GradX:DDD::NN',
305
+ # arguments,
306
+ # mean_t,
307
+ # std_t,
308
+ # x.numel() * fhs * 2,
309
+ # )
310
+ #
311
+ # @parameterized.parameters(*_MATMUL_TESTS)
312
+ # def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
+ # assert (sl % ne) == 0
314
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
316
+ # out = torch.bmm(x, w)
317
+ # out = out.transpose(1, 2)
318
+ #
319
+ # def benchmark():
320
+ # return torch.bmm(out, x)
321
+ #
322
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
+ # arguments = {
324
+ # 'sequence_length': sl,
325
+ # 'hidden_size': hs,
326
+ # 'ffn_hidden_size': fhs,
327
+ # 'num_experts': ne,
328
+ # }
329
+ # log_benchmark(
330
+ # '0:GradW:DDD::TN',
331
+ # arguments,
332
+ # mean_t,
333
+ # std_t,
334
+ # x.numel() * fhs * 2,
335
+ # )
336
+ #
337
+ # @parameterized.parameters(*_MATMUL_TESTS)
338
+ # def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
+ # assert (sl % ne) == 0
340
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
342
+ #
343
+ # def benchmark():
344
+ # return torch.bmm(x, w)
345
+ #
346
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
+ # arguments = {
348
+ # 'sequence_length': sl,
349
+ # 'hidden_size': hs,
350
+ # 'ffn_hidden_size': fhs,
351
+ # 'num_experts': ne,
352
+ # }
353
+ # log_benchmark(
354
+ # '1::Fwd::DDD::NN',
355
+ # arguments,
356
+ # mean_t,
357
+ # std_t,
358
+ # x.numel() * hs * 2,
359
+ # )
360
+ #
361
+ # @parameterized.parameters(*_MATMUL_TESTS)
362
+ # def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
+ # assert (sl % ne) == 0
364
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
366
+ # out = torch.bmm(x, w)
367
+ # w = torch.transpose(w, 1, 2)
368
+ #
369
+ # def benchmark():
370
+ # return torch.bmm(out, w)
371
+ #
372
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
+ # arguments = {
374
+ # 'sequence_length': sl,
375
+ # 'hidden_size': hs,
376
+ # 'ffn_hidden_size': fhs,
377
+ # 'num_experts': ne,
378
+ # }
379
+ # log_benchmark(
380
+ # '1::GradX::DDD::NT',
381
+ # arguments,
382
+ # mean_t,
383
+ # std_t,
384
+ # x.numel() * hs * 2,
385
+ # )
386
+ #
387
+ # @parameterized.parameters(*_MATMUL_TESTS)
388
+ # def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
+ # assert (sl % ne) == 0
390
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
392
+ # out = torch.bmm(x, w)
393
+ # x = torch.transpose(x, 1, 2)
394
+ #
395
+ # def benchmark():
396
+ # return torch.bmm(x, out)
397
+ #
398
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
+ # arguments = {
400
+ # 'sequence_length': sl,
401
+ # 'hidden_size': hs,
402
+ # 'ffn_hidden_size': fhs,
403
+ # 'num_experts': ne,
404
+ # }
405
+ # log_benchmark(
406
+ # '1::GradW::DDD::TN',
407
+ # arguments,
408
+ # mean_t,
409
+ # std_t,
410
+ # x.numel() * hs * 2,
411
+ # )
412
 
413
 
414
  if __name__ == '__main__':
build/torch210-cxx11-cu126-aarch64-linux/ops/padded_scatter_benchmark.py CHANGED
@@ -4,7 +4,7 @@
4
  import unittest
5
 
6
  import torch
7
- from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
@@ -16,50 +16,50 @@ _PADDED_SCATTER_BENCHMARK = (
16
  )
17
 
18
 
19
- class PaddedScatterTest(parameterized.TestCase):
20
-
21
- @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
- def testPaddedScatter(self, sl, hs, ne, top_k):
23
- # Create the data and indices.
24
- x = torch.randn((sl, hs)).cuda().half()
25
-
26
- # Randomly assign tokens to experts.
27
- top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
- bin_ids, indices = ops.sort(top_expert)
29
- tokens_per_expert = ops.histogram(top_expert, ne)
30
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
-
34
- # Sample weights for the scatter reduce.
35
- weights = torch.rand((sl * top_k,)).cuda().half()
36
-
37
- # Gather the data to prepare for backwards.
38
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
-
40
- def benchmark():
41
- return ops.padded_scatter(
42
- x,
43
- indices,
44
- bin_ids,
45
- weights,
46
- bins,
47
- padded_bins,
48
- top_k,
49
- )
50
-
51
- time, std = benchmark_util.benchmark_function(benchmark)
52
- benchmark_util.log_benchmark(
53
- 'Padded Scatter',
54
- {
55
- 'sequence_length': sl,
56
- 'hidden_size': hs,
57
- 'num_experts': ne,
58
- 'top_k': top_k,
59
- },
60
- time,
61
- std,
62
- )
63
 
64
 
65
  if __name__ == '__main__':
 
4
  import unittest
5
 
6
  import torch
7
+ # from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
 
16
  )
17
 
18
 
19
+ # class PaddedScatterTest(parameterized.TestCase):
20
+ #
21
+ # @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
+ # def testPaddedScatter(self, sl, hs, ne, top_k):
23
+ # # Create the data and indices.
24
+ # x = torch.randn((sl, hs)).cuda().half()
25
+ #
26
+ # # Randomly assign tokens to experts.
27
+ # top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
+ # bin_ids, indices = ops.sort(top_expert)
29
+ # tokens_per_expert = ops.histogram(top_expert, ne)
30
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
+ #
34
+ # # Sample weights for the scatter reduce.
35
+ # weights = torch.rand((sl * top_k,)).cuda().half()
36
+ #
37
+ # # Gather the data to prepare for backwards.
38
+ # x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
+ #
40
+ # def benchmark():
41
+ # return ops.padded_scatter(
42
+ # x,
43
+ # indices,
44
+ # bin_ids,
45
+ # weights,
46
+ # bins,
47
+ # padded_bins,
48
+ # top_k,
49
+ # )
50
+ #
51
+ # time, std = benchmark_util.benchmark_function(benchmark)
52
+ # benchmark_util.log_benchmark(
53
+ # 'Padded Scatter',
54
+ # {
55
+ # 'sequence_length': sl,
56
+ # 'hidden_size': hs,
57
+ # 'num_experts': ne,
58
+ # 'top_k': top_k,
59
+ # },
60
+ # time,
61
+ # std,
62
+ # )
63
 
64
 
65
  if __name__ == '__main__':
build/torch210-cxx11-cu126-aarch64-linux/ops/permute_benchmark.py CHANGED
@@ -4,7 +4,7 @@
4
  import unittest
5
 
6
  import torch
7
- from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
@@ -26,123 +26,123 @@ _PERMUTE_TESTS = (
26
  )
27
 
28
 
29
- class PermuteBenchmark(parameterized.TestCase):
30
-
31
- @parameterized.parameters(*_PERMUTE_TESTS)
32
- def testBinnedGather(self, sl, hs, ne):
33
- # NOTE: Capacity factor == 1.
34
- ec = sl // ne
35
-
36
- # Create the data and indices.
37
- x = torch.randn((sl, hs)).cuda().half()
38
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
- bin_ids, indices = ops.sort(top_expert)
40
- tokens_per_expert = ops.histogram(indices, ne)
41
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
-
43
- def benchmark():
44
- return ops.binned_gather(x, indices, bins, ec)
45
-
46
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
- arguments = {
48
- 'sequence_length': sl,
49
- 'hidden_size': hs,
50
- 'num_experts': ne,
51
- }
52
- benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
-
54
- @parameterized.parameters(*_PERMUTE_TESTS)
55
- def testBinnedScatter(self, sl, hs, ne):
56
- # NOTE: Capacity factor == 1.
57
- ec = sl // ne
58
-
59
- # Create the data and indices.
60
- x = torch.randn((sl, hs)).cuda().half()
61
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
- bin_ids, indices = ops.sort(top_expert)
63
- tokens_per_expert = ops.histogram(indices, ne)
64
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
- x = ops.binned_gather(x, indices, bins, ec)
66
-
67
- def benchmark():
68
- return ops.binned_scatter(x, indices, bins)
69
-
70
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
- arguments = {
72
- 'sequence_length': sl,
73
- 'hidden_size': hs,
74
- 'num_experts': ne,
75
- }
76
- benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
-
78
- @parameterized.parameters(*_PERMUTE_TESTS)
79
- def testPaddedGather(self, sl, hs, ne):
80
- # Create the data and indices.
81
- x = torch.randn((sl, hs)).cuda().half()
82
-
83
- # Randomly assign tokens to experts.
84
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
- bin_ids, indices = ops.sort(top_expert)
86
- tokens_per_expert = ops.histogram(top_expert, ne)
87
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
-
91
- def benchmark():
92
- return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
-
94
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
- arguments = {
96
- 'sequence_length': sl,
97
- 'hidden_size': hs,
98
- 'num_experts': ne,
99
- }
100
- benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
-
102
- @parameterized.parameters(*_PERMUTE_TESTS)
103
- def testPaddedScatter(self, sl, hs, ne):
104
- # Create the data and indices.
105
- x = torch.randn((sl, hs)).cuda().half()
106
-
107
- # Randomly assign tokens to experts.
108
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
- bin_ids, indices = ops.sort(top_expert)
110
- tokens_per_expert = ops.histogram(top_expert, ne)
111
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
-
116
- def benchmark():
117
- return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'num_experts': ne,
124
- }
125
- benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
-
127
- @parameterized.parameters(*_PERMUTE_TESTS)
128
- def testCopy(self, sl, hs, ne):
129
- # NOTE: Capacity factor == 1.
130
- # ec = sl // ne
131
-
132
- # Create the data and indices.
133
- x = torch.randn((sl, hs)).cuda().half()
134
- y = x.clone()
135
-
136
- def benchmark():
137
- return y.copy_(x)
138
-
139
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
- arguments = {
141
- 'sequence_length': sl,
142
- 'hidden_size': hs,
143
- 'num_experts': ne,
144
- }
145
- benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
 
147
 
148
  if __name__ == '__main__':
 
4
  import unittest
5
 
6
  import torch
7
+ # from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
 
26
  )
27
 
28
 
29
+ # class PermuteBenchmark(parameterized.TestCase):
30
+ #
31
+ # @parameterized.parameters(*_PERMUTE_TESTS)
32
+ # def testBinnedGather(self, sl, hs, ne):
33
+ # # NOTE: Capacity factor == 1.
34
+ # ec = sl // ne
35
+ #
36
+ # # Create the data and indices.
37
+ # x = torch.randn((sl, hs)).cuda().half()
38
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
+ # bin_ids, indices = ops.sort(top_expert)
40
+ # tokens_per_expert = ops.histogram(indices, ne)
41
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
+ #
43
+ # def benchmark():
44
+ # return ops.binned_gather(x, indices, bins, ec)
45
+ #
46
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
+ # arguments = {
48
+ # 'sequence_length': sl,
49
+ # 'hidden_size': hs,
50
+ # 'num_experts': ne,
51
+ # }
52
+ # benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
+ #
54
+ # @parameterized.parameters(*_PERMUTE_TESTS)
55
+ # def testBinnedScatter(self, sl, hs, ne):
56
+ # # NOTE: Capacity factor == 1.
57
+ # ec = sl // ne
58
+ #
59
+ # # Create the data and indices.
60
+ # x = torch.randn((sl, hs)).cuda().half()
61
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
+ # bin_ids, indices = ops.sort(top_expert)
63
+ # tokens_per_expert = ops.histogram(indices, ne)
64
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+ # x = ops.binned_gather(x, indices, bins, ec)
66
+ #
67
+ # def benchmark():
68
+ # return ops.binned_scatter(x, indices, bins)
69
+ #
70
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
+ # arguments = {
72
+ # 'sequence_length': sl,
73
+ # 'hidden_size': hs,
74
+ # 'num_experts': ne,
75
+ # }
76
+ # benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
+ #
78
+ # @parameterized.parameters(*_PERMUTE_TESTS)
79
+ # def testPaddedGather(self, sl, hs, ne):
80
+ # # Create the data and indices.
81
+ # x = torch.randn((sl, hs)).cuda().half()
82
+ #
83
+ # # Randomly assign tokens to experts.
84
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
+ # bin_ids, indices = ops.sort(top_expert)
86
+ # tokens_per_expert = ops.histogram(top_expert, ne)
87
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+ #
91
+ # def benchmark():
92
+ # return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
+ #
94
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
+ # arguments = {
96
+ # 'sequence_length': sl,
97
+ # 'hidden_size': hs,
98
+ # 'num_experts': ne,
99
+ # }
100
+ # benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
+ #
102
+ # @parameterized.parameters(*_PERMUTE_TESTS)
103
+ # def testPaddedScatter(self, sl, hs, ne):
104
+ # # Create the data and indices.
105
+ # x = torch.randn((sl, hs)).cuda().half()
106
+ #
107
+ # # Randomly assign tokens to experts.
108
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
+ # bin_ids, indices = ops.sort(top_expert)
110
+ # tokens_per_expert = ops.histogram(top_expert, ne)
111
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
+ # x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
+ #
116
+ # def benchmark():
117
+ # return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
+ #
119
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ # arguments = {
121
+ # 'sequence_length': sl,
122
+ # 'hidden_size': hs,
123
+ # 'num_experts': ne,
124
+ # }
125
+ # benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
+ #
127
+ # @parameterized.parameters(*_PERMUTE_TESTS)
128
+ # def testCopy(self, sl, hs, ne):
129
+ # # NOTE: Capacity factor == 1.
130
+ # # ec = sl // ne
131
+ #
132
+ # # Create the data and indices.
133
+ # x = torch.randn((sl, hs)).cuda().half()
134
+ # y = x.clone()
135
+ #
136
+ # def benchmark():
137
+ # return y.copy_(x)
138
+ #
139
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
+ # arguments = {
141
+ # 'sequence_length': sl,
142
+ # 'hidden_size': hs,
143
+ # 'num_experts': ne,
144
+ # }
145
+ # benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
 
147
 
148
  if __name__ == '__main__':
build/torch210-cxx11-cu126-aarch64-linux/ops/sort_benchmark.py CHANGED
@@ -5,7 +5,7 @@ import unittest
5
 
6
  import numpy as np
7
  import torch
8
- from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
@@ -53,32 +53,32 @@ def log_benchmark(arguments, mean_t, std_t):
53
  print('=' * 60)
54
 
55
 
56
- class SortBenchmark(parameterized.TestCase):
57
-
58
- @parameterized.parameters(*_SORT_TESTS)
59
- def testSort(self, n, dtype, max_val):
60
- if max_val is None:
61
- max_val = np.iinfo(numpy_dtype(dtype)).max
62
- end_bit = int(np.ceil(np.log2(max_val)))
63
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
-
65
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
- arguments = {
67
- 'n': n,
68
- 'dtype': dtype,
69
- 'max_val': max_val,
70
- }
71
- log_benchmark(arguments, mean_t, std_t)
72
-
73
- @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
- def testTorchSort(self, n):
75
- x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
-
77
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
- arguments = {
79
- 'n': n,
80
- }
81
- log_benchmark(arguments, mean_t, std_t)
82
 
83
 
84
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import torch
8
+ # from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
 
53
  print('=' * 60)
54
 
55
 
56
+ # class SortBenchmark(parameterized.TestCase):
57
+ #
58
+ # @parameterized.parameters(*_SORT_TESTS)
59
+ # def testSort(self, n, dtype, max_val):
60
+ # if max_val is None:
61
+ # max_val = np.iinfo(numpy_dtype(dtype)).max
62
+ # end_bit = int(np.ceil(np.log2(max_val)))
63
+ # x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
+ #
65
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
+ # arguments = {
67
+ # 'n': n,
68
+ # 'dtype': dtype,
69
+ # 'max_val': max_val,
70
+ # }
71
+ # log_benchmark(arguments, mean_t, std_t)
72
+ #
73
+ # @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
+ # def testTorchSort(self, n):
75
+ # x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
+ #
77
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
+ # arguments = {
79
+ # 'n': n,
80
+ # }
81
+ # log_benchmark(arguments, mean_t, std_t)
82
 
83
 
84
  if __name__ == '__main__':
build/torch210-cxx11-cu126-aarch64-linux/stk/ops/eltwise_ops_test.py CHANGED
@@ -1,7 +1,7 @@
1
  import unittest
2
  import itertools
3
  import torch
4
- from absl.testing import parameterized
5
 
6
  import stk
7
  from stk.ops.linear_ops_test import allclose, _dense_and_sparse
@@ -47,40 +47,40 @@ def _dense_and_sparse_like(x, std=0.1):
47
  return (dense.requires_grad_(True),
48
  sparse.requires_grad_(True))
49
 
50
- @parameterized.parameters(_ELTWISE_OP_TESTS)
51
- class EltwiseOpsTest(parameterized.TestCase):
52
-
53
- def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
-
55
- a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
- b_dense, b = _dense_and_sparse_like(a)
57
-
58
- out = stk.ops.mul(a, b)
59
- expected_out = torch.mul(a_dense, b_dense)
60
-
61
- # Compute the gradients w.r.t. the inputs.
62
- expected_out.sum().backward()
63
- stk.ops.sum(out).backward()
64
-
65
- # Validate the results.
66
- out = stk.ops.to_dense(out)
67
- self.assertEqual(out.dim(), 2)
68
- self.assertEqual(expected_out.size(), out.size())
69
- self.assertTrue(allclose(out, expected_out))
70
-
71
- # LHS gradient.
72
- grad = stk.ops.to_dense(a.grad)
73
- expected_grad = a_dense.grad
74
- self.assertEqual(grad.dim(), 2)
75
- self.assertEqual(expected_grad.size(), grad.size())
76
- self.assertTrue(allclose(grad, expected_grad))
77
-
78
- # RHS gradient.
79
- grad = stk.ops.to_dense(b.grad)
80
- expected_grad = b_dense.grad
81
- self.assertEqual(grad.dim(), 2)
82
- self.assertEqual(expected_grad.size(), grad.size())
83
- self.assertTrue(allclose(grad, expected_grad))
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
 
1
  import unittest
2
  import itertools
3
  import torch
4
+ # from absl.testing import parameterized
5
 
6
  import stk
7
  from stk.ops.linear_ops_test import allclose, _dense_and_sparse
 
47
  return (dense.requires_grad_(True),
48
  sparse.requires_grad_(True))
49
 
50
+ # @parameterized.parameters(_ELTWISE_OP_TESTS)
51
+ # class EltwiseOpsTest(parameterized.TestCase):
52
+ #
53
+ # def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
+ #
55
+ # a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
+ # b_dense, b = _dense_and_sparse_like(a)
57
+ #
58
+ # out = stk.ops.mul(a, b)
59
+ # expected_out = torch.mul(a_dense, b_dense)
60
+ #
61
+ # # Compute the gradients w.r.t. the inputs.
62
+ # expected_out.sum().backward()
63
+ # stk.ops.sum(out).backward()
64
+ #
65
+ # # Validate the results.
66
+ # out = stk.ops.to_dense(out)
67
+ # self.assertEqual(out.dim(), 2)
68
+ # self.assertEqual(expected_out.size(), out.size())
69
+ # self.assertTrue(allclose(out, expected_out))
70
+ #
71
+ # # LHS gradient.
72
+ # grad = stk.ops.to_dense(a.grad)
73
+ # expected_grad = a_dense.grad
74
+ # self.assertEqual(grad.dim(), 2)
75
+ # self.assertEqual(expected_grad.size(), grad.size())
76
+ # self.assertTrue(allclose(grad, expected_grad))
77
+ #
78
+ # # RHS gradient.
79
+ # grad = stk.ops.to_dense(b.grad)
80
+ # expected_grad = b_dense.grad
81
+ # self.assertEqual(grad.dim(), 2)
82
+ # self.assertEqual(expected_grad.size(), grad.size())
83
+ # self.assertTrue(allclose(grad, expected_grad))
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
build/torch210-cxx11-cu126-aarch64-linux/stk/ops/linear_ops_test.py CHANGED
@@ -2,7 +2,7 @@ import unittest
2
  import itertools
3
  import numpy as np
4
  import torch
5
- from absl.testing import parameterized
6
 
7
  import stk
8
 
@@ -96,121 +96,121 @@ def _mask(x, mask):
96
  return x * mask
97
 
98
 
99
- @parameterized.parameters(*_LINEAR_OP_TESTS)
100
- class LinearOpsTest(parameterized.TestCase):
101
-
102
- def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
- # Construct the operands.
104
- a_shape = (k, m) if trans_a else (m, k)
105
- a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
- b_shape = (n, k) if trans_b else (k, n)
107
- b, bcp = _dense_2x(*b_shape, dtype)
108
-
109
- # Execute the matmul.
110
- out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
- expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
-
113
- # Compute the gradients w.r.t. the inputs.
114
- expected_out.sum().backward()
115
- out.sum().backward()
116
-
117
- # Validate the results.
118
- self.assertEqual(out.dim(), 2)
119
- self.assertEqual(expected_out.size()[0], out.size()[0])
120
- self.assertEqual(expected_out.size()[1], out.size()[1])
121
- self.assertTrue(allclose(out, expected_out))
122
-
123
- # LHS gradient.
124
- grad = stk.ops.to_dense(a.grad)
125
- expected_grad = _mask(a_dense.grad, a.grad)
126
- self.assertEqual(grad.dim(), 2)
127
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
- self.assertTrue(allclose(grad, expected_grad))
130
-
131
- # RHS gradient.
132
- grad = b.grad
133
- expected_grad = bcp.grad
134
- self.assertEqual(grad.dim(), 2)
135
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
- self.assertTrue(allclose(grad, expected_grad))
138
-
139
- def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
- # Construct the operands.
141
- a_shape = (k, m) if trans_a else (m, k)
142
- a, acp = _dense_2x(*a_shape, dtype)
143
- b_shape = (n, k) if trans_b else (k, n)
144
- b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
-
146
- # Execute the matmul.
147
- out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
- expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
-
150
- # Compute the gradients w.r.t. the inputs.
151
- expected_out.sum().backward()
152
- out.sum().backward()
153
-
154
- # Validate the results.
155
- self.assertEqual(out.dim(), 2)
156
- self.assertEqual(expected_out.size()[0], out.size()[0])
157
- self.assertEqual(expected_out.size()[1], out.size()[1])
158
- self.assertTrue(allclose(out, expected_out))
159
-
160
- # LHS gradient.
161
- grad = a.grad
162
- expected_grad = acp.grad
163
- self.assertEqual(grad.dim(), 2)
164
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
- self.assertTrue(allclose(grad, expected_grad))
167
-
168
- # RHS gradient.
169
- grad = stk.ops.to_dense(b.grad)
170
- expected_grad = _mask(b_dense.grad, b.grad)
171
- self.assertEqual(grad.dim(), 2)
172
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
- self.assertTrue(allclose(grad, expected_grad))
175
-
176
- def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
- # Construct the operands.
178
- a_shape = (k, m) if trans_a else (m, k)
179
- a, acp = _dense_2x(*a_shape, dtype)
180
- b_shape = (n, k) if trans_b else (k, n)
181
- b, bcp = _dense_2x(*b_shape, dtype)
182
- _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
-
184
- # Execute the matmul.
185
- out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
- expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
-
188
- # Compute the gradients w.r.t. the inputs.
189
- expected_out.sum().backward()
190
- stk.ops.sum(out).backward()
191
-
192
- # Validate the results.
193
- out = stk.ops.to_dense(out)
194
- self.assertEqual(out.dim(), 2)
195
- self.assertEqual(expected_out.size()[0], out.size()[0])
196
- self.assertEqual(expected_out.size()[1], out.size()[1])
197
- self.assertTrue(allclose(out, expected_out))
198
-
199
- # LHS gradient.
200
- grad = a.grad
201
- expected_grad = acp.grad
202
- self.assertEqual(grad.dim(), 2)
203
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
- self.assertTrue(allclose(grad, expected_grad))
206
-
207
- # RHS gradient.
208
- grad = b.grad
209
- expected_grad = bcp.grad
210
- self.assertEqual(grad.dim(), 2)
211
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
- self.assertTrue(allclose(grad, expected_grad))
214
 
215
  if __name__ == '__main__':
216
  unittest.main()
 
2
  import itertools
3
  import numpy as np
4
  import torch
5
+ # from absl.testing import parameterized
6
 
7
  import stk
8
 
 
96
  return x * mask
97
 
98
 
99
+ # @parameterized.parameters(*_LINEAR_OP_TESTS)
100
+ # class LinearOpsTest(parameterized.TestCase):
101
+ #
102
+ # def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
+ # # Construct the operands.
104
+ # a_shape = (k, m) if trans_a else (m, k)
105
+ # a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
+ # b_shape = (n, k) if trans_b else (k, n)
107
+ # b, bcp = _dense_2x(*b_shape, dtype)
108
+ #
109
+ # # Execute the matmul.
110
+ # out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
+ # expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
+ #
113
+ # # Compute the gradients w.r.t. the inputs.
114
+ # expected_out.sum().backward()
115
+ # out.sum().backward()
116
+ #
117
+ # # Validate the results.
118
+ # self.assertEqual(out.dim(), 2)
119
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
120
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
121
+ # self.assertTrue(allclose(out, expected_out))
122
+ #
123
+ # # LHS gradient.
124
+ # grad = stk.ops.to_dense(a.grad)
125
+ # expected_grad = _mask(a_dense.grad, a.grad)
126
+ # self.assertEqual(grad.dim(), 2)
127
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
+ # self.assertTrue(allclose(grad, expected_grad))
130
+ #
131
+ # # RHS gradient.
132
+ # grad = b.grad
133
+ # expected_grad = bcp.grad
134
+ # self.assertEqual(grad.dim(), 2)
135
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
+ # self.assertTrue(allclose(grad, expected_grad))
138
+ #
139
+ # def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
+ # # Construct the operands.
141
+ # a_shape = (k, m) if trans_a else (m, k)
142
+ # a, acp = _dense_2x(*a_shape, dtype)
143
+ # b_shape = (n, k) if trans_b else (k, n)
144
+ # b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
+ #
146
+ # # Execute the matmul.
147
+ # out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
+ # expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
+ #
150
+ # # Compute the gradients w.r.t. the inputs.
151
+ # expected_out.sum().backward()
152
+ # out.sum().backward()
153
+ #
154
+ # # Validate the results.
155
+ # self.assertEqual(out.dim(), 2)
156
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
157
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
158
+ # self.assertTrue(allclose(out, expected_out))
159
+ #
160
+ # # LHS gradient.
161
+ # grad = a.grad
162
+ # expected_grad = acp.grad
163
+ # self.assertEqual(grad.dim(), 2)
164
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
+ # self.assertTrue(allclose(grad, expected_grad))
167
+ #
168
+ # # RHS gradient.
169
+ # grad = stk.ops.to_dense(b.grad)
170
+ # expected_grad = _mask(b_dense.grad, b.grad)
171
+ # self.assertEqual(grad.dim(), 2)
172
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
+ # self.assertTrue(allclose(grad, expected_grad))
175
+ #
176
+ # def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
+ # # Construct the operands.
178
+ # a_shape = (k, m) if trans_a else (m, k)
179
+ # a, acp = _dense_2x(*a_shape, dtype)
180
+ # b_shape = (n, k) if trans_b else (k, n)
181
+ # b, bcp = _dense_2x(*b_shape, dtype)
182
+ # _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
+ #
184
+ # # Execute the matmul.
185
+ # out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
+ # expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
+ #
188
+ # # Compute the gradients w.r.t. the inputs.
189
+ # expected_out.sum().backward()
190
+ # stk.ops.sum(out).backward()
191
+ #
192
+ # # Validate the results.
193
+ # out = stk.ops.to_dense(out)
194
+ # self.assertEqual(out.dim(), 2)
195
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
196
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
197
+ # self.assertTrue(allclose(out, expected_out))
198
+ #
199
+ # # LHS gradient.
200
+ # grad = a.grad
201
+ # expected_grad = acp.grad
202
+ # self.assertEqual(grad.dim(), 2)
203
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
+ # self.assertTrue(allclose(grad, expected_grad))
206
+ #
207
+ # # RHS gradient.
208
+ # grad = b.grad
209
+ # expected_grad = bcp.grad
210
+ # self.assertEqual(grad.dim(), 2)
211
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
+ # self.assertTrue(allclose(grad, expected_grad))
214
 
215
  if __name__ == '__main__':
216
  unittest.main()
build/torch210-cxx11-cu126-aarch64-linux/stk/ops/matrix_ops_test.py CHANGED
@@ -1,61 +1,61 @@
1
  import unittest
2
 
3
- from absl.testing import parameterized
4
  import stk
5
  import torch
6
 
7
 
8
- @parameterized.parameters(
9
- (8, 16, 0.0, 1),
10
- (8, 16, 0.5, 1),
11
- (8, 16, .95, 1),
12
- (16, 8, 0.0, 1),
13
- (16, 8, 0.5, 1),
14
- (16, 8, .95, 1),
15
- (8, 16, 0.0, 8),
16
- (8, 16, 0.5, 8),
17
- (8, 16, 1.0, 8),
18
- (16, 8, 0.0, 8),
19
- (16, 8, 0.5, 8),
20
- (16, 8, 1.0, 8),
21
- (128, 256, 0.5, 16),
22
- (256, 128, 0.75, 32),
23
- (512, 512, .875, 128))
24
- class MatrixOpsTest(parameterized.TestCase):
25
-
26
- def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
- mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
- x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
-
30
- # Convert the matrix to sparse format.
31
- sparse_x = stk.ops.to_sparse(x, blocking)
32
-
33
- # Validate the matrix.
34
- sparse_x.validate()
35
-
36
- # Validate the shape.
37
- self.assertEqual(sparse_x.dim(), 2)
38
- self.assertEqual(sparse_x.size()[0], rows)
39
- self.assertEqual(sparse_x.size()[1], cols)
40
-
41
- # Validate the sparsity.
42
- numblocks = rows // blocking * cols // blocking
43
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
- self.assertEqual(sparse_x.nnz, nnz)
45
-
46
- # Convert back to dense format.
47
- dense_x = stk.ops.to_dense(sparse_x)
48
-
49
- # Validate the shape.
50
- self.assertEqual(dense_x.dim(), 2)
51
- self.assertEqual(dense_x.size()[0], rows)
52
- self.assertEqual(dense_x.size()[1], cols)
53
-
54
- # Validate the sparsity
55
- self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
-
57
- # Validate the output.
58
- self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
 
60
 
61
  if __name__ == '__main__':
 
1
  import unittest
2
 
3
+ # from absl.testing import parameterized
4
  import stk
5
  import torch
6
 
7
 
8
+ # @parameterized.parameters(
9
+ # (8, 16, 0.0, 1),
10
+ # (8, 16, 0.5, 1),
11
+ # (8, 16, .95, 1),
12
+ # (16, 8, 0.0, 1),
13
+ # (16, 8, 0.5, 1),
14
+ # (16, 8, .95, 1),
15
+ # (8, 16, 0.0, 8),
16
+ # (8, 16, 0.5, 8),
17
+ # (8, 16, 1.0, 8),
18
+ # (16, 8, 0.0, 8),
19
+ # (16, 8, 0.5, 8),
20
+ # (16, 8, 1.0, 8),
21
+ # (128, 256, 0.5, 16),
22
+ # (256, 128, 0.75, 32),
23
+ # (512, 512, .875, 128))
24
+ # class MatrixOpsTest(parameterized.TestCase):
25
+ #
26
+ # def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
+ # mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
+ # x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
+ #
30
+ # # Convert the matrix to sparse format.
31
+ # sparse_x = stk.ops.to_sparse(x, blocking)
32
+ #
33
+ # # Validate the matrix.
34
+ # sparse_x.validate()
35
+ #
36
+ # # Validate the shape.
37
+ # self.assertEqual(sparse_x.dim(), 2)
38
+ # self.assertEqual(sparse_x.size()[0], rows)
39
+ # self.assertEqual(sparse_x.size()[1], cols)
40
+ #
41
+ # # Validate the sparsity.
42
+ # numblocks = rows // blocking * cols // blocking
43
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
+ # self.assertEqual(sparse_x.nnz, nnz)
45
+ #
46
+ # # Convert back to dense format.
47
+ # dense_x = stk.ops.to_dense(sparse_x)
48
+ #
49
+ # # Validate the shape.
50
+ # self.assertEqual(dense_x.dim(), 2)
51
+ # self.assertEqual(dense_x.size()[0], rows)
52
+ # self.assertEqual(dense_x.size()[1], cols)
53
+ #
54
+ # # Validate the sparsity
55
+ # self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
+ #
57
+ # # Validate the output.
58
+ # self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
 
60
 
61
  if __name__ == '__main__':
build/torch210-cxx11-cu126-aarch64-linux/stk/random/random_ops_test.py CHANGED
@@ -1,72 +1,72 @@
1
  import unittest
2
 
3
- from absl.testing import parameterized
4
  from . import random
5
  import torch
6
 
7
 
8
- @parameterized.parameters(
9
- (8, 16, 0.0, 1),
10
- (8, 16, 0.5, 1),
11
- (8, 16, .95, 1),
12
- (16, 8, 0.0, 1),
13
- (16, 8, 0.5, 1),
14
- (16, 8, .95, 1),
15
- (8, 16, 0.0, 8),
16
- (8, 16, 0.5, 8),
17
- (8, 16, 1.0, 8),
18
- (16, 8, 0.0, 8),
19
- (16, 8, 0.5, 8),
20
- (16, 8, 1.0, 8),
21
- (128, 256, 0.5, 16),
22
- (256, 128, 0.75, 32),
23
- (512, 512, .875, 128))
24
- class RandomOpsTest(parameterized.TestCase):
25
-
26
- def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
- mask = random.dense_mask(
28
- rows, cols, sparsity, blocking)
29
-
30
- # Validate the shape.
31
- self.assertEqual(mask.dim(), 2)
32
- self.assertEqual(mask.size()[0], rows)
33
- self.assertEqual(mask.size()[1], cols)
34
-
35
- # Validate the sparsity
36
- numblocks = rows // blocking * cols // blocking
37
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
- self.assertEqual(
39
- torch.count_nonzero(mask).item(),
40
- nnz)
41
-
42
- # Check values are zero or one.
43
- self.assertTrue(
44
- torch.all(torch.logical_or(
45
- torch.eq(mask, 0),
46
- torch.eq(mask, 1))))
47
-
48
- def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
- mask = random.mask(
50
- rows, cols, sparsity, blocking)
51
-
52
- # Validate the matrix.
53
- mask.validate()
54
-
55
- # Validate the shape.
56
- self.assertEqual(mask.dim(), 2)
57
- self.assertEqual(mask.size()[0], rows)
58
- self.assertEqual(mask.size()[1], cols)
59
-
60
- # Validate the sparsity.
61
- numblocks = rows // blocking * cols // blocking
62
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
- self.assertEqual(mask.nnz, nnz)
64
-
65
- # Check values are zero or one.
66
- self.assertTrue(
67
- torch.all(torch.logical_or(
68
- torch.eq(mask.data, 0),
69
- torch.eq(mask.data, 1))))
70
 
71
 
72
  if __name__ == '__main__':
 
1
  import unittest
2
 
3
+ # from absl.testing import parameterized
4
  from . import random
5
  import torch
6
 
7
 
8
+ # @parameterized.parameters(
9
+ # (8, 16, 0.0, 1),
10
+ # (8, 16, 0.5, 1),
11
+ # (8, 16, .95, 1),
12
+ # (16, 8, 0.0, 1),
13
+ # (16, 8, 0.5, 1),
14
+ # (16, 8, .95, 1),
15
+ # (8, 16, 0.0, 8),
16
+ # (8, 16, 0.5, 8),
17
+ # (8, 16, 1.0, 8),
18
+ # (16, 8, 0.0, 8),
19
+ # (16, 8, 0.5, 8),
20
+ # (16, 8, 1.0, 8),
21
+ # (128, 256, 0.5, 16),
22
+ # (256, 128, 0.75, 32),
23
+ # (512, 512, .875, 128))
24
+ # class RandomOpsTest(parameterized.TestCase):
25
+ #
26
+ # def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
+ # mask = random.dense_mask(
28
+ # rows, cols, sparsity, blocking)
29
+ #
30
+ # # Validate the shape.
31
+ # self.assertEqual(mask.dim(), 2)
32
+ # self.assertEqual(mask.size()[0], rows)
33
+ # self.assertEqual(mask.size()[1], cols)
34
+ #
35
+ # # Validate the sparsity
36
+ # numblocks = rows // blocking * cols // blocking
37
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
+ # self.assertEqual(
39
+ # torch.count_nonzero(mask).item(),
40
+ # nnz)
41
+ #
42
+ # # Check values are zero or one.
43
+ # self.assertTrue(
44
+ # torch.all(torch.logical_or(
45
+ # torch.eq(mask, 0),
46
+ # torch.eq(mask, 1))))
47
+ #
48
+ # def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
+ # mask = random.mask(
50
+ # rows, cols, sparsity, blocking)
51
+ #
52
+ # # Validate the matrix.
53
+ # mask.validate()
54
+ #
55
+ # # Validate the shape.
56
+ # self.assertEqual(mask.dim(), 2)
57
+ # self.assertEqual(mask.size()[0], rows)
58
+ # self.assertEqual(mask.size()[1], cols)
59
+ #
60
+ # # Validate the sparsity.
61
+ # numblocks = rows // blocking * cols // blocking
62
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
+ # self.assertEqual(mask.nnz, nnz)
64
+ #
65
+ # # Check values are zero or one.
66
+ # self.assertTrue(
67
+ # torch.all(torch.logical_or(
68
+ # torch.eq(mask.data, 0),
69
+ # torch.eq(mask.data, 1))))
70
 
71
 
72
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/{_megablocks_cuda_6e04dec.abi3.so → _megablocks_cuda_a45325d.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12705f4547b6a55442c52e081a303d4407202cdc26522f7269c983b627946ab9
3
  size 21088232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4352757f386e2d948559c737927fa969b2cc9674ee6255487ba8f67fb3470199
3
  size 21088232
build/torch210-cxx11-cu128-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_6e04dec
3
- ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_6e04dec::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_a45325d
3
+ ops = torch.ops._megablocks_cuda_a45325d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_a45325d::{op_name}"
build/torch210-cxx11-cu128-aarch64-linux/megablocks/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
  import ctypes
 
2
  import sys
3
-
4
- import importlib
5
  from pathlib import Path
6
  from types import ModuleType
7
 
 
8
  def _import_from_path(file_path: Path) -> ModuleType:
9
  # We cannot use the module name as-is, after adding it to `sys.modules`,
10
  # it would also be used for other imports. So, we make a module name that
 
1
  import ctypes
2
+ import importlib.util
3
  import sys
 
 
4
  from pathlib import Path
5
  from types import ModuleType
6
 
7
+
8
  def _import_from_path(file_path: Path) -> ModuleType:
9
  # We cannot use the module name as-is, after adding it to `sys.modules`,
10
  # it would also be used for other imports. So, we make a module name that
build/torch210-cxx11-cu128-aarch64-linux/ops/histogram_benchmark.py CHANGED
@@ -5,7 +5,7 @@ import unittest
5
 
6
  import numpy as np
7
  import torch
8
- from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
@@ -47,31 +47,31 @@ def log_benchmark(arguments, mean_t, std_t):
47
  print('=' * 60)
48
 
49
 
50
- class HistogramBenchmark(parameterized.TestCase):
51
-
52
- @parameterized.parameters(*_HISTOGRAM_TESTS)
53
- def testHistogram(self, n, dtype, max_val):
54
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
-
56
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
- arguments = {
58
- 'n': n,
59
- 'dtype': dtype,
60
- 'max_val': max_val,
61
- }
62
- log_benchmark(arguments, mean_t, std_t)
63
-
64
- @parameterized.parameters(*_HISTOGRAM_TESTS)
65
- def testTorchHistogram(self, n, dtype, max_val):
66
- x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
-
68
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
- arguments = {
70
- 'n': n,
71
- 'dtype': dtype,
72
- 'max_val': max_val,
73
- }
74
- log_benchmark(arguments, mean_t, std_t)
75
 
76
 
77
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import torch
8
+ # from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
 
47
  print('=' * 60)
48
 
49
 
50
+ # class HistogramBenchmark(parameterized.TestCase):
51
+ #
52
+ # @parameterized.parameters(*_HISTOGRAM_TESTS)
53
+ # def testHistogram(self, n, dtype, max_val):
54
+ # x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
+ #
56
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
+ # arguments = {
58
+ # 'n': n,
59
+ # 'dtype': dtype,
60
+ # 'max_val': max_val,
61
+ # }
62
+ # log_benchmark(arguments, mean_t, std_t)
63
+ #
64
+ # @parameterized.parameters(*_HISTOGRAM_TESTS)
65
+ # def testTorchHistogram(self, n, dtype, max_val):
66
+ # x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
+ #
68
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
+ # arguments = {
70
+ # 'n': n,
71
+ # 'dtype': dtype,
72
+ # 'max_val': max_val,
73
+ # }
74
+ # log_benchmark(arguments, mean_t, std_t)
75
 
76
 
77
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/ops/matmul_benchmark.py CHANGED
@@ -17,7 +17,7 @@ import unittest
17
  from .. import stk
18
 
19
  import torch
20
- from absl.testing import parameterized
21
 
22
  from .. import benchmark_util, ops
23
 
@@ -48,367 +48,367 @@ def log_benchmark(name, arguments, time, std, flops):
48
  print('=' * 60)
49
 
50
 
51
- class MatmulBenchmark(parameterized.TestCase):
52
-
53
- def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
- blocking = 128
55
- padded_tokens, _ = x.size()
56
- assert padded_tokens % blocking == 0
57
- assert fhs % blocking == 0
58
-
59
- # Offsets for the sparse matrix. All rows have the
60
- # same number of nonzero blocks dictated by the
61
- # dimensionality of a single expert.
62
- block_rows = padded_tokens // blocking
63
- blocks_per_row = fhs // blocking
64
- offsets = torch.arange(
65
- 0,
66
- block_rows * blocks_per_row + 1,
67
- blocks_per_row,
68
- dtype=torch.int32,
69
- device=x.device,
70
- )
71
-
72
- # Indices for the sparse matrix. The indices for
73
- # the intermediate matrix are dynamic depending
74
- # on the mapping of tokens to experts.
75
- column_indices = ops.topology(
76
- padded_bins,
77
- blocking,
78
- block_rows,
79
- blocks_per_row,
80
- )
81
- data = torch.empty(
82
- column_indices.numel(),
83
- blocking,
84
- blocking,
85
- dtype=torch.float16,
86
- device=x.device,
87
- )
88
- shape = (padded_tokens, fhs * ne)
89
- row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
- return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
-
92
- def build_input_matrix(self, sl, hs, ne):
93
- x = torch.randn((sl, hs)).cuda().half()
94
-
95
- # Assign tokens to experts uniformly.
96
- top_expert = torch.arange(0, sl).cuda().int() % ne
97
-
98
- bin_ids, indices = ops.sort(top_expert)
99
- tokens_per_expert = ops.histogram(top_expert, ne)
100
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
- out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
- return out, padded_bins
105
-
106
- def build_weight_matrix(self, ne, hs, fhs):
107
- return torch.randn((hs, ne * fhs)).cuda().half()
108
-
109
- @parameterized.parameters(*_MATMUL_TESTS)
110
- def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
- w = transpose_view(w)
115
-
116
- def benchmark():
117
- return stk.ops.sdd(x, w, topo)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'ffn_hidden_size': fhs,
124
- 'num_experts': ne,
125
- }
126
- log_benchmark(
127
- '0::Fwd::SDD::NT',
128
- arguments,
129
- mean_t,
130
- std_t,
131
- x.numel() * fhs * 2,
132
- )
133
-
134
- @parameterized.parameters(*_MATMUL_TESTS)
135
- def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
-
140
- def benchmark():
141
- return stk.ops.dsd(topo, w)
142
-
143
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
- arguments = {
145
- 'sequence_length': sl,
146
- 'hidden_size': hs,
147
- 'ffn_hidden_size': fhs,
148
- 'num_experts': ne,
149
- }
150
- log_benchmark(
151
- '0::GradX::DSD::NN',
152
- arguments,
153
- mean_t,
154
- std_t,
155
- x.numel() * fhs * 2,
156
- )
157
-
158
- @parameterized.parameters(*_MATMUL_TESTS)
159
- def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
- topo = topo.t()
163
-
164
- def benchmark():
165
- return stk.ops.dsd(topo, x)
166
-
167
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
- arguments = {
169
- 'sequence_length': sl,
170
- 'hidden_size': hs,
171
- 'ffn_hidden_size': fhs,
172
- 'num_experts': ne,
173
- }
174
- log_benchmark(
175
- '0::GradW::DSD::TN',
176
- arguments,
177
- mean_t,
178
- std_t,
179
- x.numel() * fhs * 2,
180
- )
181
-
182
- @parameterized.parameters(*_MATMUL_TESTS)
183
- def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
-
188
- def benchmark():
189
- return stk.ops.dsd(x, w)
190
-
191
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
- arguments = {
193
- 'sequence_length': sl,
194
- 'hidden_size': hs,
195
- 'ffn_hidden_size': fhs,
196
- 'num_experts': ne,
197
- }
198
- log_benchmark(
199
- '1::Fwd::DSD::NN',
200
- arguments,
201
- mean_t,
202
- std_t,
203
- x.nnz * hs * 2,
204
- )
205
-
206
- @parameterized.parameters(*_MATMUL_TESTS)
207
- def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
- out = stk.ops.dsd(x, w)
212
- w = transpose_view(w)
213
-
214
- def benchmark():
215
- return stk.ops.sdd(out, w, x)
216
-
217
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
- arguments = {
219
- 'sequence_length': sl,
220
- 'hidden_size': hs,
221
- 'ffn_hidden_size': fhs,
222
- 'num_experts': ne,
223
- }
224
- log_benchmark(
225
- '1::GradX::SDD::NT',
226
- arguments,
227
- mean_t,
228
- std_t,
229
- x.nnz * hs * 2,
230
- )
231
-
232
- @parameterized.parameters(*_MATMUL_TESTS)
233
- def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
- out = stk.ops.dsd(x, w)
238
- x = x.t()
239
-
240
- def benchmark():
241
- return stk.ops.dsd(x, out)
242
-
243
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
- arguments = {
245
- 'sequence_length': sl,
246
- 'hidden_size': hs,
247
- 'ffn_hidden_size': fhs,
248
- 'num_experts': ne,
249
- }
250
- log_benchmark(
251
- '1::GradW::DSD::TN',
252
- arguments,
253
- mean_t,
254
- std_t,
255
- x.nnz * hs * 2,
256
- )
257
-
258
- @parameterized.parameters(*_MATMUL_TESTS)
259
- def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
- assert (sl % ne) == 0
261
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
- w = torch.randn((ne, hs, fhs)).cuda().half()
263
-
264
- w = w.transpose(1, 2).contiguous()
265
- w = w.transpose(1, 2)
266
-
267
- def benchmark():
268
- return torch.bmm(x, w)
269
-
270
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
- arguments = {
272
- 'sequence_length': sl,
273
- 'hidden_size': hs,
274
- 'ffn_hidden_size': fhs,
275
- 'num_experts': ne,
276
- }
277
- log_benchmark(
278
- '0::Fwd:DDD::NT',
279
- arguments,
280
- mean_t,
281
- std_t,
282
- x.numel() * fhs * 2,
283
- )
284
-
285
- @parameterized.parameters(*_MATMUL_TESTS)
286
- def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
- assert (sl % ne) == 0
288
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
- w = torch.randn((ne, hs, fhs)).cuda().half()
290
- out = torch.bmm(x, w)
291
- w = w.transpose(1, 2).contiguous()
292
-
293
- def benchmark():
294
- return torch.bmm(out, w)
295
-
296
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
- arguments = {
298
- 'sequence_length': sl,
299
- 'hidden_size': hs,
300
- 'ffn_hidden_size': fhs,
301
- 'num_experts': ne,
302
- }
303
- log_benchmark(
304
- '0:GradX:DDD::NN',
305
- arguments,
306
- mean_t,
307
- std_t,
308
- x.numel() * fhs * 2,
309
- )
310
-
311
- @parameterized.parameters(*_MATMUL_TESTS)
312
- def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
- assert (sl % ne) == 0
314
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
- w = torch.randn((ne, hs, fhs)).cuda().half()
316
- out = torch.bmm(x, w)
317
- out = out.transpose(1, 2)
318
-
319
- def benchmark():
320
- return torch.bmm(out, x)
321
-
322
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
- arguments = {
324
- 'sequence_length': sl,
325
- 'hidden_size': hs,
326
- 'ffn_hidden_size': fhs,
327
- 'num_experts': ne,
328
- }
329
- log_benchmark(
330
- '0:GradW:DDD::TN',
331
- arguments,
332
- mean_t,
333
- std_t,
334
- x.numel() * fhs * 2,
335
- )
336
-
337
- @parameterized.parameters(*_MATMUL_TESTS)
338
- def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
- assert (sl % ne) == 0
340
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
- w = torch.randn((ne, fhs, hs)).cuda().half()
342
-
343
- def benchmark():
344
- return torch.bmm(x, w)
345
-
346
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
- arguments = {
348
- 'sequence_length': sl,
349
- 'hidden_size': hs,
350
- 'ffn_hidden_size': fhs,
351
- 'num_experts': ne,
352
- }
353
- log_benchmark(
354
- '1::Fwd::DDD::NN',
355
- arguments,
356
- mean_t,
357
- std_t,
358
- x.numel() * hs * 2,
359
- )
360
-
361
- @parameterized.parameters(*_MATMUL_TESTS)
362
- def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
- assert (sl % ne) == 0
364
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
- w = torch.randn((ne, fhs, hs)).cuda().half()
366
- out = torch.bmm(x, w)
367
- w = torch.transpose(w, 1, 2)
368
-
369
- def benchmark():
370
- return torch.bmm(out, w)
371
-
372
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
- arguments = {
374
- 'sequence_length': sl,
375
- 'hidden_size': hs,
376
- 'ffn_hidden_size': fhs,
377
- 'num_experts': ne,
378
- }
379
- log_benchmark(
380
- '1::GradX::DDD::NT',
381
- arguments,
382
- mean_t,
383
- std_t,
384
- x.numel() * hs * 2,
385
- )
386
-
387
- @parameterized.parameters(*_MATMUL_TESTS)
388
- def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
- assert (sl % ne) == 0
390
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
- w = torch.randn((ne, fhs, hs)).cuda().half()
392
- out = torch.bmm(x, w)
393
- x = torch.transpose(x, 1, 2)
394
-
395
- def benchmark():
396
- return torch.bmm(x, out)
397
-
398
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
- arguments = {
400
- 'sequence_length': sl,
401
- 'hidden_size': hs,
402
- 'ffn_hidden_size': fhs,
403
- 'num_experts': ne,
404
- }
405
- log_benchmark(
406
- '1::GradW::DDD::TN',
407
- arguments,
408
- mean_t,
409
- std_t,
410
- x.numel() * hs * 2,
411
- )
412
 
413
 
414
  if __name__ == '__main__':
 
17
  from .. import stk
18
 
19
  import torch
20
+ # from absl.testing import parameterized
21
 
22
  from .. import benchmark_util, ops
23
 
 
48
  print('=' * 60)
49
 
50
 
51
+ # class MatmulBenchmark(parameterized.TestCase):
52
+ #
53
+ # def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
+ # blocking = 128
55
+ # padded_tokens, _ = x.size()
56
+ # assert padded_tokens % blocking == 0
57
+ # assert fhs % blocking == 0
58
+ #
59
+ # # Offsets for the sparse matrix. All rows have the
60
+ # # same number of nonzero blocks dictated by the
61
+ # # dimensionality of a single expert.
62
+ # block_rows = padded_tokens // blocking
63
+ # blocks_per_row = fhs // blocking
64
+ # offsets = torch.arange(
65
+ # 0,
66
+ # block_rows * blocks_per_row + 1,
67
+ # blocks_per_row,
68
+ # dtype=torch.int32,
69
+ # device=x.device,
70
+ # )
71
+ #
72
+ # # Indices for the sparse matrix. The indices for
73
+ # # the intermediate matrix are dynamic depending
74
+ # # on the mapping of tokens to experts.
75
+ # column_indices = ops.topology(
76
+ # padded_bins,
77
+ # blocking,
78
+ # block_rows,
79
+ # blocks_per_row,
80
+ # )
81
+ # data = torch.empty(
82
+ # column_indices.numel(),
83
+ # blocking,
84
+ # blocking,
85
+ # dtype=torch.float16,
86
+ # device=x.device,
87
+ # )
88
+ # shape = (padded_tokens, fhs * ne)
89
+ # row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
+ # return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
+ #
92
+ # def build_input_matrix(self, sl, hs, ne):
93
+ # x = torch.randn((sl, hs)).cuda().half()
94
+ #
95
+ # # Assign tokens to experts uniformly.
96
+ # top_expert = torch.arange(0, sl).cuda().int() % ne
97
+ #
98
+ # bin_ids, indices = ops.sort(top_expert)
99
+ # tokens_per_expert = ops.histogram(top_expert, ne)
100
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
+ # out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
+ # return out, padded_bins
105
+ #
106
+ # def build_weight_matrix(self, ne, hs, fhs):
107
+ # return torch.randn((hs, ne * fhs)).cuda().half()
108
+ #
109
+ # @parameterized.parameters(*_MATMUL_TESTS)
110
+ # def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
+ # w = transpose_view(w)
115
+ #
116
+ # def benchmark():
117
+ # return stk.ops.sdd(x, w, topo)
118
+ #
119
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ # arguments = {
121
+ # 'sequence_length': sl,
122
+ # 'hidden_size': hs,
123
+ # 'ffn_hidden_size': fhs,
124
+ # 'num_experts': ne,
125
+ # }
126
+ # log_benchmark(
127
+ # '0::Fwd::SDD::NT',
128
+ # arguments,
129
+ # mean_t,
130
+ # std_t,
131
+ # x.numel() * fhs * 2,
132
+ # )
133
+ #
134
+ # @parameterized.parameters(*_MATMUL_TESTS)
135
+ # def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
+ #
140
+ # def benchmark():
141
+ # return stk.ops.dsd(topo, w)
142
+ #
143
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
+ # arguments = {
145
+ # 'sequence_length': sl,
146
+ # 'hidden_size': hs,
147
+ # 'ffn_hidden_size': fhs,
148
+ # 'num_experts': ne,
149
+ # }
150
+ # log_benchmark(
151
+ # '0::GradX::DSD::NN',
152
+ # arguments,
153
+ # mean_t,
154
+ # std_t,
155
+ # x.numel() * fhs * 2,
156
+ # )
157
+ #
158
+ # @parameterized.parameters(*_MATMUL_TESTS)
159
+ # def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
+ # topo = topo.t()
163
+ #
164
+ # def benchmark():
165
+ # return stk.ops.dsd(topo, x)
166
+ #
167
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
+ # arguments = {
169
+ # 'sequence_length': sl,
170
+ # 'hidden_size': hs,
171
+ # 'ffn_hidden_size': fhs,
172
+ # 'num_experts': ne,
173
+ # }
174
+ # log_benchmark(
175
+ # '0::GradW::DSD::TN',
176
+ # arguments,
177
+ # mean_t,
178
+ # std_t,
179
+ # x.numel() * fhs * 2,
180
+ # )
181
+ #
182
+ # @parameterized.parameters(*_MATMUL_TESTS)
183
+ # def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
+ #
188
+ # def benchmark():
189
+ # return stk.ops.dsd(x, w)
190
+ #
191
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
+ # arguments = {
193
+ # 'sequence_length': sl,
194
+ # 'hidden_size': hs,
195
+ # 'ffn_hidden_size': fhs,
196
+ # 'num_experts': ne,
197
+ # }
198
+ # log_benchmark(
199
+ # '1::Fwd::DSD::NN',
200
+ # arguments,
201
+ # mean_t,
202
+ # std_t,
203
+ # x.nnz * hs * 2,
204
+ # )
205
+ #
206
+ # @parameterized.parameters(*_MATMUL_TESTS)
207
+ # def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
+ # out = stk.ops.dsd(x, w)
212
+ # w = transpose_view(w)
213
+ #
214
+ # def benchmark():
215
+ # return stk.ops.sdd(out, w, x)
216
+ #
217
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
+ # arguments = {
219
+ # 'sequence_length': sl,
220
+ # 'hidden_size': hs,
221
+ # 'ffn_hidden_size': fhs,
222
+ # 'num_experts': ne,
223
+ # }
224
+ # log_benchmark(
225
+ # '1::GradX::SDD::NT',
226
+ # arguments,
227
+ # mean_t,
228
+ # std_t,
229
+ # x.nnz * hs * 2,
230
+ # )
231
+ #
232
+ # @parameterized.parameters(*_MATMUL_TESTS)
233
+ # def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
+ # out = stk.ops.dsd(x, w)
238
+ # x = x.t()
239
+ #
240
+ # def benchmark():
241
+ # return stk.ops.dsd(x, out)
242
+ #
243
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
+ # arguments = {
245
+ # 'sequence_length': sl,
246
+ # 'hidden_size': hs,
247
+ # 'ffn_hidden_size': fhs,
248
+ # 'num_experts': ne,
249
+ # }
250
+ # log_benchmark(
251
+ # '1::GradW::DSD::TN',
252
+ # arguments,
253
+ # mean_t,
254
+ # std_t,
255
+ # x.nnz * hs * 2,
256
+ # )
257
+ #
258
+ # @parameterized.parameters(*_MATMUL_TESTS)
259
+ # def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
+ # assert (sl % ne) == 0
261
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
263
+ #
264
+ # w = w.transpose(1, 2).contiguous()
265
+ # w = w.transpose(1, 2)
266
+ #
267
+ # def benchmark():
268
+ # return torch.bmm(x, w)
269
+ #
270
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
+ # arguments = {
272
+ # 'sequence_length': sl,
273
+ # 'hidden_size': hs,
274
+ # 'ffn_hidden_size': fhs,
275
+ # 'num_experts': ne,
276
+ # }
277
+ # log_benchmark(
278
+ # '0::Fwd:DDD::NT',
279
+ # arguments,
280
+ # mean_t,
281
+ # std_t,
282
+ # x.numel() * fhs * 2,
283
+ # )
284
+ #
285
+ # @parameterized.parameters(*_MATMUL_TESTS)
286
+ # def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
+ # assert (sl % ne) == 0
288
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
290
+ # out = torch.bmm(x, w)
291
+ # w = w.transpose(1, 2).contiguous()
292
+ #
293
+ # def benchmark():
294
+ # return torch.bmm(out, w)
295
+ #
296
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
+ # arguments = {
298
+ # 'sequence_length': sl,
299
+ # 'hidden_size': hs,
300
+ # 'ffn_hidden_size': fhs,
301
+ # 'num_experts': ne,
302
+ # }
303
+ # log_benchmark(
304
+ # '0:GradX:DDD::NN',
305
+ # arguments,
306
+ # mean_t,
307
+ # std_t,
308
+ # x.numel() * fhs * 2,
309
+ # )
310
+ #
311
+ # @parameterized.parameters(*_MATMUL_TESTS)
312
+ # def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
+ # assert (sl % ne) == 0
314
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
316
+ # out = torch.bmm(x, w)
317
+ # out = out.transpose(1, 2)
318
+ #
319
+ # def benchmark():
320
+ # return torch.bmm(out, x)
321
+ #
322
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
+ # arguments = {
324
+ # 'sequence_length': sl,
325
+ # 'hidden_size': hs,
326
+ # 'ffn_hidden_size': fhs,
327
+ # 'num_experts': ne,
328
+ # }
329
+ # log_benchmark(
330
+ # '0:GradW:DDD::TN',
331
+ # arguments,
332
+ # mean_t,
333
+ # std_t,
334
+ # x.numel() * fhs * 2,
335
+ # )
336
+ #
337
+ # @parameterized.parameters(*_MATMUL_TESTS)
338
+ # def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
+ # assert (sl % ne) == 0
340
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
342
+ #
343
+ # def benchmark():
344
+ # return torch.bmm(x, w)
345
+ #
346
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
+ # arguments = {
348
+ # 'sequence_length': sl,
349
+ # 'hidden_size': hs,
350
+ # 'ffn_hidden_size': fhs,
351
+ # 'num_experts': ne,
352
+ # }
353
+ # log_benchmark(
354
+ # '1::Fwd::DDD::NN',
355
+ # arguments,
356
+ # mean_t,
357
+ # std_t,
358
+ # x.numel() * hs * 2,
359
+ # )
360
+ #
361
+ # @parameterized.parameters(*_MATMUL_TESTS)
362
+ # def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
+ # assert (sl % ne) == 0
364
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
366
+ # out = torch.bmm(x, w)
367
+ # w = torch.transpose(w, 1, 2)
368
+ #
369
+ # def benchmark():
370
+ # return torch.bmm(out, w)
371
+ #
372
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
+ # arguments = {
374
+ # 'sequence_length': sl,
375
+ # 'hidden_size': hs,
376
+ # 'ffn_hidden_size': fhs,
377
+ # 'num_experts': ne,
378
+ # }
379
+ # log_benchmark(
380
+ # '1::GradX::DDD::NT',
381
+ # arguments,
382
+ # mean_t,
383
+ # std_t,
384
+ # x.numel() * hs * 2,
385
+ # )
386
+ #
387
+ # @parameterized.parameters(*_MATMUL_TESTS)
388
+ # def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
+ # assert (sl % ne) == 0
390
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
392
+ # out = torch.bmm(x, w)
393
+ # x = torch.transpose(x, 1, 2)
394
+ #
395
+ # def benchmark():
396
+ # return torch.bmm(x, out)
397
+ #
398
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
+ # arguments = {
400
+ # 'sequence_length': sl,
401
+ # 'hidden_size': hs,
402
+ # 'ffn_hidden_size': fhs,
403
+ # 'num_experts': ne,
404
+ # }
405
+ # log_benchmark(
406
+ # '1::GradW::DDD::TN',
407
+ # arguments,
408
+ # mean_t,
409
+ # std_t,
410
+ # x.numel() * hs * 2,
411
+ # )
412
 
413
 
414
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/ops/padded_scatter_benchmark.py CHANGED
@@ -4,7 +4,7 @@
4
  import unittest
5
 
6
  import torch
7
- from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
@@ -16,50 +16,50 @@ _PADDED_SCATTER_BENCHMARK = (
16
  )
17
 
18
 
19
- class PaddedScatterTest(parameterized.TestCase):
20
-
21
- @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
- def testPaddedScatter(self, sl, hs, ne, top_k):
23
- # Create the data and indices.
24
- x = torch.randn((sl, hs)).cuda().half()
25
-
26
- # Randomly assign tokens to experts.
27
- top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
- bin_ids, indices = ops.sort(top_expert)
29
- tokens_per_expert = ops.histogram(top_expert, ne)
30
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
-
34
- # Sample weights for the scatter reduce.
35
- weights = torch.rand((sl * top_k,)).cuda().half()
36
-
37
- # Gather the data to prepare for backwards.
38
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
-
40
- def benchmark():
41
- return ops.padded_scatter(
42
- x,
43
- indices,
44
- bin_ids,
45
- weights,
46
- bins,
47
- padded_bins,
48
- top_k,
49
- )
50
-
51
- time, std = benchmark_util.benchmark_function(benchmark)
52
- benchmark_util.log_benchmark(
53
- 'Padded Scatter',
54
- {
55
- 'sequence_length': sl,
56
- 'hidden_size': hs,
57
- 'num_experts': ne,
58
- 'top_k': top_k,
59
- },
60
- time,
61
- std,
62
- )
63
 
64
 
65
  if __name__ == '__main__':
 
4
  import unittest
5
 
6
  import torch
7
+ # from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
 
16
  )
17
 
18
 
19
+ # class PaddedScatterTest(parameterized.TestCase):
20
+ #
21
+ # @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
+ # def testPaddedScatter(self, sl, hs, ne, top_k):
23
+ # # Create the data and indices.
24
+ # x = torch.randn((sl, hs)).cuda().half()
25
+ #
26
+ # # Randomly assign tokens to experts.
27
+ # top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
+ # bin_ids, indices = ops.sort(top_expert)
29
+ # tokens_per_expert = ops.histogram(top_expert, ne)
30
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
+ #
34
+ # # Sample weights for the scatter reduce.
35
+ # weights = torch.rand((sl * top_k,)).cuda().half()
36
+ #
37
+ # # Gather the data to prepare for backwards.
38
+ # x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
+ #
40
+ # def benchmark():
41
+ # return ops.padded_scatter(
42
+ # x,
43
+ # indices,
44
+ # bin_ids,
45
+ # weights,
46
+ # bins,
47
+ # padded_bins,
48
+ # top_k,
49
+ # )
50
+ #
51
+ # time, std = benchmark_util.benchmark_function(benchmark)
52
+ # benchmark_util.log_benchmark(
53
+ # 'Padded Scatter',
54
+ # {
55
+ # 'sequence_length': sl,
56
+ # 'hidden_size': hs,
57
+ # 'num_experts': ne,
58
+ # 'top_k': top_k,
59
+ # },
60
+ # time,
61
+ # std,
62
+ # )
63
 
64
 
65
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/ops/permute_benchmark.py CHANGED
@@ -4,7 +4,7 @@
4
  import unittest
5
 
6
  import torch
7
- from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
@@ -26,123 +26,123 @@ _PERMUTE_TESTS = (
26
  )
27
 
28
 
29
- class PermuteBenchmark(parameterized.TestCase):
30
-
31
- @parameterized.parameters(*_PERMUTE_TESTS)
32
- def testBinnedGather(self, sl, hs, ne):
33
- # NOTE: Capacity factor == 1.
34
- ec = sl // ne
35
-
36
- # Create the data and indices.
37
- x = torch.randn((sl, hs)).cuda().half()
38
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
- bin_ids, indices = ops.sort(top_expert)
40
- tokens_per_expert = ops.histogram(indices, ne)
41
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
-
43
- def benchmark():
44
- return ops.binned_gather(x, indices, bins, ec)
45
-
46
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
- arguments = {
48
- 'sequence_length': sl,
49
- 'hidden_size': hs,
50
- 'num_experts': ne,
51
- }
52
- benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
-
54
- @parameterized.parameters(*_PERMUTE_TESTS)
55
- def testBinnedScatter(self, sl, hs, ne):
56
- # NOTE: Capacity factor == 1.
57
- ec = sl // ne
58
-
59
- # Create the data and indices.
60
- x = torch.randn((sl, hs)).cuda().half()
61
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
- bin_ids, indices = ops.sort(top_expert)
63
- tokens_per_expert = ops.histogram(indices, ne)
64
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
- x = ops.binned_gather(x, indices, bins, ec)
66
-
67
- def benchmark():
68
- return ops.binned_scatter(x, indices, bins)
69
-
70
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
- arguments = {
72
- 'sequence_length': sl,
73
- 'hidden_size': hs,
74
- 'num_experts': ne,
75
- }
76
- benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
-
78
- @parameterized.parameters(*_PERMUTE_TESTS)
79
- def testPaddedGather(self, sl, hs, ne):
80
- # Create the data and indices.
81
- x = torch.randn((sl, hs)).cuda().half()
82
-
83
- # Randomly assign tokens to experts.
84
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
- bin_ids, indices = ops.sort(top_expert)
86
- tokens_per_expert = ops.histogram(top_expert, ne)
87
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
-
91
- def benchmark():
92
- return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
-
94
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
- arguments = {
96
- 'sequence_length': sl,
97
- 'hidden_size': hs,
98
- 'num_experts': ne,
99
- }
100
- benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
-
102
- @parameterized.parameters(*_PERMUTE_TESTS)
103
- def testPaddedScatter(self, sl, hs, ne):
104
- # Create the data and indices.
105
- x = torch.randn((sl, hs)).cuda().half()
106
-
107
- # Randomly assign tokens to experts.
108
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
- bin_ids, indices = ops.sort(top_expert)
110
- tokens_per_expert = ops.histogram(top_expert, ne)
111
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
-
116
- def benchmark():
117
- return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'num_experts': ne,
124
- }
125
- benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
-
127
- @parameterized.parameters(*_PERMUTE_TESTS)
128
- def testCopy(self, sl, hs, ne):
129
- # NOTE: Capacity factor == 1.
130
- # ec = sl // ne
131
-
132
- # Create the data and indices.
133
- x = torch.randn((sl, hs)).cuda().half()
134
- y = x.clone()
135
-
136
- def benchmark():
137
- return y.copy_(x)
138
-
139
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
- arguments = {
141
- 'sequence_length': sl,
142
- 'hidden_size': hs,
143
- 'num_experts': ne,
144
- }
145
- benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
 
147
 
148
  if __name__ == '__main__':
 
4
  import unittest
5
 
6
  import torch
7
+ # from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
 
26
  )
27
 
28
 
29
+ # class PermuteBenchmark(parameterized.TestCase):
30
+ #
31
+ # @parameterized.parameters(*_PERMUTE_TESTS)
32
+ # def testBinnedGather(self, sl, hs, ne):
33
+ # # NOTE: Capacity factor == 1.
34
+ # ec = sl // ne
35
+ #
36
+ # # Create the data and indices.
37
+ # x = torch.randn((sl, hs)).cuda().half()
38
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
+ # bin_ids, indices = ops.sort(top_expert)
40
+ # tokens_per_expert = ops.histogram(indices, ne)
41
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
+ #
43
+ # def benchmark():
44
+ # return ops.binned_gather(x, indices, bins, ec)
45
+ #
46
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
+ # arguments = {
48
+ # 'sequence_length': sl,
49
+ # 'hidden_size': hs,
50
+ # 'num_experts': ne,
51
+ # }
52
+ # benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
+ #
54
+ # @parameterized.parameters(*_PERMUTE_TESTS)
55
+ # def testBinnedScatter(self, sl, hs, ne):
56
+ # # NOTE: Capacity factor == 1.
57
+ # ec = sl // ne
58
+ #
59
+ # # Create the data and indices.
60
+ # x = torch.randn((sl, hs)).cuda().half()
61
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
+ # bin_ids, indices = ops.sort(top_expert)
63
+ # tokens_per_expert = ops.histogram(indices, ne)
64
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+ # x = ops.binned_gather(x, indices, bins, ec)
66
+ #
67
+ # def benchmark():
68
+ # return ops.binned_scatter(x, indices, bins)
69
+ #
70
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
+ # arguments = {
72
+ # 'sequence_length': sl,
73
+ # 'hidden_size': hs,
74
+ # 'num_experts': ne,
75
+ # }
76
+ # benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
+ #
78
+ # @parameterized.parameters(*_PERMUTE_TESTS)
79
+ # def testPaddedGather(self, sl, hs, ne):
80
+ # # Create the data and indices.
81
+ # x = torch.randn((sl, hs)).cuda().half()
82
+ #
83
+ # # Randomly assign tokens to experts.
84
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
+ # bin_ids, indices = ops.sort(top_expert)
86
+ # tokens_per_expert = ops.histogram(top_expert, ne)
87
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+ #
91
+ # def benchmark():
92
+ # return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
+ #
94
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
+ # arguments = {
96
+ # 'sequence_length': sl,
97
+ # 'hidden_size': hs,
98
+ # 'num_experts': ne,
99
+ # }
100
+ # benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
+ #
102
+ # @parameterized.parameters(*_PERMUTE_TESTS)
103
+ # def testPaddedScatter(self, sl, hs, ne):
104
+ # # Create the data and indices.
105
+ # x = torch.randn((sl, hs)).cuda().half()
106
+ #
107
+ # # Randomly assign tokens to experts.
108
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
+ # bin_ids, indices = ops.sort(top_expert)
110
+ # tokens_per_expert = ops.histogram(top_expert, ne)
111
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
+ # x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
+ #
116
+ # def benchmark():
117
+ # return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
+ #
119
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ # arguments = {
121
+ # 'sequence_length': sl,
122
+ # 'hidden_size': hs,
123
+ # 'num_experts': ne,
124
+ # }
125
+ # benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
+ #
127
+ # @parameterized.parameters(*_PERMUTE_TESTS)
128
+ # def testCopy(self, sl, hs, ne):
129
+ # # NOTE: Capacity factor == 1.
130
+ # # ec = sl // ne
131
+ #
132
+ # # Create the data and indices.
133
+ # x = torch.randn((sl, hs)).cuda().half()
134
+ # y = x.clone()
135
+ #
136
+ # def benchmark():
137
+ # return y.copy_(x)
138
+ #
139
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
+ # arguments = {
141
+ # 'sequence_length': sl,
142
+ # 'hidden_size': hs,
143
+ # 'num_experts': ne,
144
+ # }
145
+ # benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
 
147
 
148
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/ops/sort_benchmark.py CHANGED
@@ -5,7 +5,7 @@ import unittest
5
 
6
  import numpy as np
7
  import torch
8
- from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
@@ -53,32 +53,32 @@ def log_benchmark(arguments, mean_t, std_t):
53
  print('=' * 60)
54
 
55
 
56
- class SortBenchmark(parameterized.TestCase):
57
-
58
- @parameterized.parameters(*_SORT_TESTS)
59
- def testSort(self, n, dtype, max_val):
60
- if max_val is None:
61
- max_val = np.iinfo(numpy_dtype(dtype)).max
62
- end_bit = int(np.ceil(np.log2(max_val)))
63
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
-
65
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
- arguments = {
67
- 'n': n,
68
- 'dtype': dtype,
69
- 'max_val': max_val,
70
- }
71
- log_benchmark(arguments, mean_t, std_t)
72
-
73
- @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
- def testTorchSort(self, n):
75
- x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
-
77
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
- arguments = {
79
- 'n': n,
80
- }
81
- log_benchmark(arguments, mean_t, std_t)
82
 
83
 
84
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import torch
8
+ # from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
 
53
  print('=' * 60)
54
 
55
 
56
+ # class SortBenchmark(parameterized.TestCase):
57
+ #
58
+ # @parameterized.parameters(*_SORT_TESTS)
59
+ # def testSort(self, n, dtype, max_val):
60
+ # if max_val is None:
61
+ # max_val = np.iinfo(numpy_dtype(dtype)).max
62
+ # end_bit = int(np.ceil(np.log2(max_val)))
63
+ # x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
+ #
65
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
+ # arguments = {
67
+ # 'n': n,
68
+ # 'dtype': dtype,
69
+ # 'max_val': max_val,
70
+ # }
71
+ # log_benchmark(arguments, mean_t, std_t)
72
+ #
73
+ # @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
+ # def testTorchSort(self, n):
75
+ # x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
+ #
77
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
+ # arguments = {
79
+ # 'n': n,
80
+ # }
81
+ # log_benchmark(arguments, mean_t, std_t)
82
 
83
 
84
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/stk/ops/eltwise_ops_test.py CHANGED
@@ -1,7 +1,7 @@
1
  import unittest
2
  import itertools
3
  import torch
4
- from absl.testing import parameterized
5
 
6
  import stk
7
  from stk.ops.linear_ops_test import allclose, _dense_and_sparse
@@ -47,40 +47,40 @@ def _dense_and_sparse_like(x, std=0.1):
47
  return (dense.requires_grad_(True),
48
  sparse.requires_grad_(True))
49
 
50
- @parameterized.parameters(_ELTWISE_OP_TESTS)
51
- class EltwiseOpsTest(parameterized.TestCase):
52
-
53
- def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
-
55
- a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
- b_dense, b = _dense_and_sparse_like(a)
57
-
58
- out = stk.ops.mul(a, b)
59
- expected_out = torch.mul(a_dense, b_dense)
60
-
61
- # Compute the gradients w.r.t. the inputs.
62
- expected_out.sum().backward()
63
- stk.ops.sum(out).backward()
64
-
65
- # Validate the results.
66
- out = stk.ops.to_dense(out)
67
- self.assertEqual(out.dim(), 2)
68
- self.assertEqual(expected_out.size(), out.size())
69
- self.assertTrue(allclose(out, expected_out))
70
-
71
- # LHS gradient.
72
- grad = stk.ops.to_dense(a.grad)
73
- expected_grad = a_dense.grad
74
- self.assertEqual(grad.dim(), 2)
75
- self.assertEqual(expected_grad.size(), grad.size())
76
- self.assertTrue(allclose(grad, expected_grad))
77
-
78
- # RHS gradient.
79
- grad = stk.ops.to_dense(b.grad)
80
- expected_grad = b_dense.grad
81
- self.assertEqual(grad.dim(), 2)
82
- self.assertEqual(expected_grad.size(), grad.size())
83
- self.assertTrue(allclose(grad, expected_grad))
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
 
1
  import unittest
2
  import itertools
3
  import torch
4
+ # from absl.testing import parameterized
5
 
6
  import stk
7
  from stk.ops.linear_ops_test import allclose, _dense_and_sparse
 
47
  return (dense.requires_grad_(True),
48
  sparse.requires_grad_(True))
49
 
50
+ # @parameterized.parameters(_ELTWISE_OP_TESTS)
51
+ # class EltwiseOpsTest(parameterized.TestCase):
52
+ #
53
+ # def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
+ #
55
+ # a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
+ # b_dense, b = _dense_and_sparse_like(a)
57
+ #
58
+ # out = stk.ops.mul(a, b)
59
+ # expected_out = torch.mul(a_dense, b_dense)
60
+ #
61
+ # # Compute the gradients w.r.t. the inputs.
62
+ # expected_out.sum().backward()
63
+ # stk.ops.sum(out).backward()
64
+ #
65
+ # # Validate the results.
66
+ # out = stk.ops.to_dense(out)
67
+ # self.assertEqual(out.dim(), 2)
68
+ # self.assertEqual(expected_out.size(), out.size())
69
+ # self.assertTrue(allclose(out, expected_out))
70
+ #
71
+ # # LHS gradient.
72
+ # grad = stk.ops.to_dense(a.grad)
73
+ # expected_grad = a_dense.grad
74
+ # self.assertEqual(grad.dim(), 2)
75
+ # self.assertEqual(expected_grad.size(), grad.size())
76
+ # self.assertTrue(allclose(grad, expected_grad))
77
+ #
78
+ # # RHS gradient.
79
+ # grad = stk.ops.to_dense(b.grad)
80
+ # expected_grad = b_dense.grad
81
+ # self.assertEqual(grad.dim(), 2)
82
+ # self.assertEqual(expected_grad.size(), grad.size())
83
+ # self.assertTrue(allclose(grad, expected_grad))
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
build/torch210-cxx11-cu128-aarch64-linux/stk/ops/linear_ops_test.py CHANGED
@@ -2,7 +2,7 @@ import unittest
2
  import itertools
3
  import numpy as np
4
  import torch
5
- from absl.testing import parameterized
6
 
7
  import stk
8
 
@@ -96,121 +96,121 @@ def _mask(x, mask):
96
  return x * mask
97
 
98
 
99
- @parameterized.parameters(*_LINEAR_OP_TESTS)
100
- class LinearOpsTest(parameterized.TestCase):
101
-
102
- def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
- # Construct the operands.
104
- a_shape = (k, m) if trans_a else (m, k)
105
- a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
- b_shape = (n, k) if trans_b else (k, n)
107
- b, bcp = _dense_2x(*b_shape, dtype)
108
-
109
- # Execute the matmul.
110
- out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
- expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
-
113
- # Compute the gradients w.r.t. the inputs.
114
- expected_out.sum().backward()
115
- out.sum().backward()
116
-
117
- # Validate the results.
118
- self.assertEqual(out.dim(), 2)
119
- self.assertEqual(expected_out.size()[0], out.size()[0])
120
- self.assertEqual(expected_out.size()[1], out.size()[1])
121
- self.assertTrue(allclose(out, expected_out))
122
-
123
- # LHS gradient.
124
- grad = stk.ops.to_dense(a.grad)
125
- expected_grad = _mask(a_dense.grad, a.grad)
126
- self.assertEqual(grad.dim(), 2)
127
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
- self.assertTrue(allclose(grad, expected_grad))
130
-
131
- # RHS gradient.
132
- grad = b.grad
133
- expected_grad = bcp.grad
134
- self.assertEqual(grad.dim(), 2)
135
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
- self.assertTrue(allclose(grad, expected_grad))
138
-
139
- def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
- # Construct the operands.
141
- a_shape = (k, m) if trans_a else (m, k)
142
- a, acp = _dense_2x(*a_shape, dtype)
143
- b_shape = (n, k) if trans_b else (k, n)
144
- b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
-
146
- # Execute the matmul.
147
- out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
- expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
-
150
- # Compute the gradients w.r.t. the inputs.
151
- expected_out.sum().backward()
152
- out.sum().backward()
153
-
154
- # Validate the results.
155
- self.assertEqual(out.dim(), 2)
156
- self.assertEqual(expected_out.size()[0], out.size()[0])
157
- self.assertEqual(expected_out.size()[1], out.size()[1])
158
- self.assertTrue(allclose(out, expected_out))
159
-
160
- # LHS gradient.
161
- grad = a.grad
162
- expected_grad = acp.grad
163
- self.assertEqual(grad.dim(), 2)
164
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
- self.assertTrue(allclose(grad, expected_grad))
167
-
168
- # RHS gradient.
169
- grad = stk.ops.to_dense(b.grad)
170
- expected_grad = _mask(b_dense.grad, b.grad)
171
- self.assertEqual(grad.dim(), 2)
172
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
- self.assertTrue(allclose(grad, expected_grad))
175
-
176
- def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
- # Construct the operands.
178
- a_shape = (k, m) if trans_a else (m, k)
179
- a, acp = _dense_2x(*a_shape, dtype)
180
- b_shape = (n, k) if trans_b else (k, n)
181
- b, bcp = _dense_2x(*b_shape, dtype)
182
- _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
-
184
- # Execute the matmul.
185
- out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
- expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
-
188
- # Compute the gradients w.r.t. the inputs.
189
- expected_out.sum().backward()
190
- stk.ops.sum(out).backward()
191
-
192
- # Validate the results.
193
- out = stk.ops.to_dense(out)
194
- self.assertEqual(out.dim(), 2)
195
- self.assertEqual(expected_out.size()[0], out.size()[0])
196
- self.assertEqual(expected_out.size()[1], out.size()[1])
197
- self.assertTrue(allclose(out, expected_out))
198
-
199
- # LHS gradient.
200
- grad = a.grad
201
- expected_grad = acp.grad
202
- self.assertEqual(grad.dim(), 2)
203
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
- self.assertTrue(allclose(grad, expected_grad))
206
-
207
- # RHS gradient.
208
- grad = b.grad
209
- expected_grad = bcp.grad
210
- self.assertEqual(grad.dim(), 2)
211
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
- self.assertTrue(allclose(grad, expected_grad))
214
 
215
  if __name__ == '__main__':
216
  unittest.main()
 
2
  import itertools
3
  import numpy as np
4
  import torch
5
+ # from absl.testing import parameterized
6
 
7
  import stk
8
 
 
96
  return x * mask
97
 
98
 
99
+ # @parameterized.parameters(*_LINEAR_OP_TESTS)
100
+ # class LinearOpsTest(parameterized.TestCase):
101
+ #
102
+ # def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
+ # # Construct the operands.
104
+ # a_shape = (k, m) if trans_a else (m, k)
105
+ # a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
+ # b_shape = (n, k) if trans_b else (k, n)
107
+ # b, bcp = _dense_2x(*b_shape, dtype)
108
+ #
109
+ # # Execute the matmul.
110
+ # out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
+ # expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
+ #
113
+ # # Compute the gradients w.r.t. the inputs.
114
+ # expected_out.sum().backward()
115
+ # out.sum().backward()
116
+ #
117
+ # # Validate the results.
118
+ # self.assertEqual(out.dim(), 2)
119
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
120
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
121
+ # self.assertTrue(allclose(out, expected_out))
122
+ #
123
+ # # LHS gradient.
124
+ # grad = stk.ops.to_dense(a.grad)
125
+ # expected_grad = _mask(a_dense.grad, a.grad)
126
+ # self.assertEqual(grad.dim(), 2)
127
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
+ # self.assertTrue(allclose(grad, expected_grad))
130
+ #
131
+ # # RHS gradient.
132
+ # grad = b.grad
133
+ # expected_grad = bcp.grad
134
+ # self.assertEqual(grad.dim(), 2)
135
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
+ # self.assertTrue(allclose(grad, expected_grad))
138
+ #
139
+ # def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
+ # # Construct the operands.
141
+ # a_shape = (k, m) if trans_a else (m, k)
142
+ # a, acp = _dense_2x(*a_shape, dtype)
143
+ # b_shape = (n, k) if trans_b else (k, n)
144
+ # b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
+ #
146
+ # # Execute the matmul.
147
+ # out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
+ # expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
+ #
150
+ # # Compute the gradients w.r.t. the inputs.
151
+ # expected_out.sum().backward()
152
+ # out.sum().backward()
153
+ #
154
+ # # Validate the results.
155
+ # self.assertEqual(out.dim(), 2)
156
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
157
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
158
+ # self.assertTrue(allclose(out, expected_out))
159
+ #
160
+ # # LHS gradient.
161
+ # grad = a.grad
162
+ # expected_grad = acp.grad
163
+ # self.assertEqual(grad.dim(), 2)
164
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
+ # self.assertTrue(allclose(grad, expected_grad))
167
+ #
168
+ # # RHS gradient.
169
+ # grad = stk.ops.to_dense(b.grad)
170
+ # expected_grad = _mask(b_dense.grad, b.grad)
171
+ # self.assertEqual(grad.dim(), 2)
172
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
+ # self.assertTrue(allclose(grad, expected_grad))
175
+ #
176
+ # def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
+ # # Construct the operands.
178
+ # a_shape = (k, m) if trans_a else (m, k)
179
+ # a, acp = _dense_2x(*a_shape, dtype)
180
+ # b_shape = (n, k) if trans_b else (k, n)
181
+ # b, bcp = _dense_2x(*b_shape, dtype)
182
+ # _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
+ #
184
+ # # Execute the matmul.
185
+ # out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
+ # expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
+ #
188
+ # # Compute the gradients w.r.t. the inputs.
189
+ # expected_out.sum().backward()
190
+ # stk.ops.sum(out).backward()
191
+ #
192
+ # # Validate the results.
193
+ # out = stk.ops.to_dense(out)
194
+ # self.assertEqual(out.dim(), 2)
195
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
196
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
197
+ # self.assertTrue(allclose(out, expected_out))
198
+ #
199
+ # # LHS gradient.
200
+ # grad = a.grad
201
+ # expected_grad = acp.grad
202
+ # self.assertEqual(grad.dim(), 2)
203
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
+ # self.assertTrue(allclose(grad, expected_grad))
206
+ #
207
+ # # RHS gradient.
208
+ # grad = b.grad
209
+ # expected_grad = bcp.grad
210
+ # self.assertEqual(grad.dim(), 2)
211
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
+ # self.assertTrue(allclose(grad, expected_grad))
214
 
215
  if __name__ == '__main__':
216
  unittest.main()
build/torch210-cxx11-cu128-aarch64-linux/stk/ops/matrix_ops_test.py CHANGED
@@ -1,61 +1,61 @@
1
  import unittest
2
 
3
- from absl.testing import parameterized
4
  import stk
5
  import torch
6
 
7
 
8
- @parameterized.parameters(
9
- (8, 16, 0.0, 1),
10
- (8, 16, 0.5, 1),
11
- (8, 16, .95, 1),
12
- (16, 8, 0.0, 1),
13
- (16, 8, 0.5, 1),
14
- (16, 8, .95, 1),
15
- (8, 16, 0.0, 8),
16
- (8, 16, 0.5, 8),
17
- (8, 16, 1.0, 8),
18
- (16, 8, 0.0, 8),
19
- (16, 8, 0.5, 8),
20
- (16, 8, 1.0, 8),
21
- (128, 256, 0.5, 16),
22
- (256, 128, 0.75, 32),
23
- (512, 512, .875, 128))
24
- class MatrixOpsTest(parameterized.TestCase):
25
-
26
- def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
- mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
- x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
-
30
- # Convert the matrix to sparse format.
31
- sparse_x = stk.ops.to_sparse(x, blocking)
32
-
33
- # Validate the matrix.
34
- sparse_x.validate()
35
-
36
- # Validate the shape.
37
- self.assertEqual(sparse_x.dim(), 2)
38
- self.assertEqual(sparse_x.size()[0], rows)
39
- self.assertEqual(sparse_x.size()[1], cols)
40
-
41
- # Validate the sparsity.
42
- numblocks = rows // blocking * cols // blocking
43
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
- self.assertEqual(sparse_x.nnz, nnz)
45
-
46
- # Convert back to dense format.
47
- dense_x = stk.ops.to_dense(sparse_x)
48
-
49
- # Validate the shape.
50
- self.assertEqual(dense_x.dim(), 2)
51
- self.assertEqual(dense_x.size()[0], rows)
52
- self.assertEqual(dense_x.size()[1], cols)
53
-
54
- # Validate the sparsity
55
- self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
-
57
- # Validate the output.
58
- self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
 
60
 
61
  if __name__ == '__main__':
 
1
  import unittest
2
 
3
+ # from absl.testing import parameterized
4
  import stk
5
  import torch
6
 
7
 
8
+ # @parameterized.parameters(
9
+ # (8, 16, 0.0, 1),
10
+ # (8, 16, 0.5, 1),
11
+ # (8, 16, .95, 1),
12
+ # (16, 8, 0.0, 1),
13
+ # (16, 8, 0.5, 1),
14
+ # (16, 8, .95, 1),
15
+ # (8, 16, 0.0, 8),
16
+ # (8, 16, 0.5, 8),
17
+ # (8, 16, 1.0, 8),
18
+ # (16, 8, 0.0, 8),
19
+ # (16, 8, 0.5, 8),
20
+ # (16, 8, 1.0, 8),
21
+ # (128, 256, 0.5, 16),
22
+ # (256, 128, 0.75, 32),
23
+ # (512, 512, .875, 128))
24
+ # class MatrixOpsTest(parameterized.TestCase):
25
+ #
26
+ # def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
+ # mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
+ # x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
+ #
30
+ # # Convert the matrix to sparse format.
31
+ # sparse_x = stk.ops.to_sparse(x, blocking)
32
+ #
33
+ # # Validate the matrix.
34
+ # sparse_x.validate()
35
+ #
36
+ # # Validate the shape.
37
+ # self.assertEqual(sparse_x.dim(), 2)
38
+ # self.assertEqual(sparse_x.size()[0], rows)
39
+ # self.assertEqual(sparse_x.size()[1], cols)
40
+ #
41
+ # # Validate the sparsity.
42
+ # numblocks = rows // blocking * cols // blocking
43
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
+ # self.assertEqual(sparse_x.nnz, nnz)
45
+ #
46
+ # # Convert back to dense format.
47
+ # dense_x = stk.ops.to_dense(sparse_x)
48
+ #
49
+ # # Validate the shape.
50
+ # self.assertEqual(dense_x.dim(), 2)
51
+ # self.assertEqual(dense_x.size()[0], rows)
52
+ # self.assertEqual(dense_x.size()[1], cols)
53
+ #
54
+ # # Validate the sparsity
55
+ # self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
+ #
57
+ # # Validate the output.
58
+ # self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
 
60
 
61
  if __name__ == '__main__':
build/torch210-cxx11-cu128-aarch64-linux/stk/random/random_ops_test.py CHANGED
@@ -1,72 +1,72 @@
1
  import unittest
2
 
3
- from absl.testing import parameterized
4
  from . import random
5
  import torch
6
 
7
 
8
- @parameterized.parameters(
9
- (8, 16, 0.0, 1),
10
- (8, 16, 0.5, 1),
11
- (8, 16, .95, 1),
12
- (16, 8, 0.0, 1),
13
- (16, 8, 0.5, 1),
14
- (16, 8, .95, 1),
15
- (8, 16, 0.0, 8),
16
- (8, 16, 0.5, 8),
17
- (8, 16, 1.0, 8),
18
- (16, 8, 0.0, 8),
19
- (16, 8, 0.5, 8),
20
- (16, 8, 1.0, 8),
21
- (128, 256, 0.5, 16),
22
- (256, 128, 0.75, 32),
23
- (512, 512, .875, 128))
24
- class RandomOpsTest(parameterized.TestCase):
25
-
26
- def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
- mask = random.dense_mask(
28
- rows, cols, sparsity, blocking)
29
-
30
- # Validate the shape.
31
- self.assertEqual(mask.dim(), 2)
32
- self.assertEqual(mask.size()[0], rows)
33
- self.assertEqual(mask.size()[1], cols)
34
-
35
- # Validate the sparsity
36
- numblocks = rows // blocking * cols // blocking
37
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
- self.assertEqual(
39
- torch.count_nonzero(mask).item(),
40
- nnz)
41
-
42
- # Check values are zero or one.
43
- self.assertTrue(
44
- torch.all(torch.logical_or(
45
- torch.eq(mask, 0),
46
- torch.eq(mask, 1))))
47
-
48
- def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
- mask = random.mask(
50
- rows, cols, sparsity, blocking)
51
-
52
- # Validate the matrix.
53
- mask.validate()
54
-
55
- # Validate the shape.
56
- self.assertEqual(mask.dim(), 2)
57
- self.assertEqual(mask.size()[0], rows)
58
- self.assertEqual(mask.size()[1], cols)
59
-
60
- # Validate the sparsity.
61
- numblocks = rows // blocking * cols // blocking
62
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
- self.assertEqual(mask.nnz, nnz)
64
-
65
- # Check values are zero or one.
66
- self.assertTrue(
67
- torch.all(torch.logical_or(
68
- torch.eq(mask.data, 0),
69
- torch.eq(mask.data, 1))))
70
 
71
 
72
  if __name__ == '__main__':
 
1
  import unittest
2
 
3
+ # from absl.testing import parameterized
4
  from . import random
5
  import torch
6
 
7
 
8
+ # @parameterized.parameters(
9
+ # (8, 16, 0.0, 1),
10
+ # (8, 16, 0.5, 1),
11
+ # (8, 16, .95, 1),
12
+ # (16, 8, 0.0, 1),
13
+ # (16, 8, 0.5, 1),
14
+ # (16, 8, .95, 1),
15
+ # (8, 16, 0.0, 8),
16
+ # (8, 16, 0.5, 8),
17
+ # (8, 16, 1.0, 8),
18
+ # (16, 8, 0.0, 8),
19
+ # (16, 8, 0.5, 8),
20
+ # (16, 8, 1.0, 8),
21
+ # (128, 256, 0.5, 16),
22
+ # (256, 128, 0.75, 32),
23
+ # (512, 512, .875, 128))
24
+ # class RandomOpsTest(parameterized.TestCase):
25
+ #
26
+ # def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
+ # mask = random.dense_mask(
28
+ # rows, cols, sparsity, blocking)
29
+ #
30
+ # # Validate the shape.
31
+ # self.assertEqual(mask.dim(), 2)
32
+ # self.assertEqual(mask.size()[0], rows)
33
+ # self.assertEqual(mask.size()[1], cols)
34
+ #
35
+ # # Validate the sparsity
36
+ # numblocks = rows // blocking * cols // blocking
37
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
+ # self.assertEqual(
39
+ # torch.count_nonzero(mask).item(),
40
+ # nnz)
41
+ #
42
+ # # Check values are zero or one.
43
+ # self.assertTrue(
44
+ # torch.all(torch.logical_or(
45
+ # torch.eq(mask, 0),
46
+ # torch.eq(mask, 1))))
47
+ #
48
+ # def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
+ # mask = random.mask(
50
+ # rows, cols, sparsity, blocking)
51
+ #
52
+ # # Validate the matrix.
53
+ # mask.validate()
54
+ #
55
+ # # Validate the shape.
56
+ # self.assertEqual(mask.dim(), 2)
57
+ # self.assertEqual(mask.size()[0], rows)
58
+ # self.assertEqual(mask.size()[1], cols)
59
+ #
60
+ # # Validate the sparsity.
61
+ # numblocks = rows // blocking * cols // blocking
62
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
+ # self.assertEqual(mask.nnz, nnz)
64
+ #
65
+ # # Check values are zero or one.
66
+ # self.assertTrue(
67
+ # torch.all(torch.logical_or(
68
+ # torch.eq(mask.data, 0),
69
+ # torch.eq(mask.data, 1))))
70
 
71
 
72
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/{_megablocks_cuda_6e04dec.abi3.so → _megablocks_cuda_a45325d.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ca7f2de93adbb930ffecaea6953cb94c870333295d05eade3c9c17296aa766a0
3
  size 12073200
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afae994771a9eb3f5ce01fbee1d46fc21b83dbcb3c556cd331cb3fbfde0ff604
3
  size 12073200
build/torch210-cxx11-cu130-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_6e04dec
3
- ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_6e04dec::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_a45325d
3
+ ops = torch.ops._megablocks_cuda_a45325d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_a45325d::{op_name}"
build/torch210-cxx11-cu130-aarch64-linux/megablocks/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
  import ctypes
 
2
  import sys
3
-
4
- import importlib
5
  from pathlib import Path
6
  from types import ModuleType
7
 
 
8
  def _import_from_path(file_path: Path) -> ModuleType:
9
  # We cannot use the module name as-is, after adding it to `sys.modules`,
10
  # it would also be used for other imports. So, we make a module name that
 
1
  import ctypes
2
+ import importlib.util
3
  import sys
 
 
4
  from pathlib import Path
5
  from types import ModuleType
6
 
7
+
8
  def _import_from_path(file_path: Path) -> ModuleType:
9
  # We cannot use the module name as-is, after adding it to `sys.modules`,
10
  # it would also be used for other imports. So, we make a module name that
build/torch210-cxx11-cu130-aarch64-linux/ops/histogram_benchmark.py CHANGED
@@ -5,7 +5,7 @@ import unittest
5
 
6
  import numpy as np
7
  import torch
8
- from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
@@ -47,31 +47,31 @@ def log_benchmark(arguments, mean_t, std_t):
47
  print('=' * 60)
48
 
49
 
50
- class HistogramBenchmark(parameterized.TestCase):
51
-
52
- @parameterized.parameters(*_HISTOGRAM_TESTS)
53
- def testHistogram(self, n, dtype, max_val):
54
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
-
56
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
- arguments = {
58
- 'n': n,
59
- 'dtype': dtype,
60
- 'max_val': max_val,
61
- }
62
- log_benchmark(arguments, mean_t, std_t)
63
-
64
- @parameterized.parameters(*_HISTOGRAM_TESTS)
65
- def testTorchHistogram(self, n, dtype, max_val):
66
- x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
-
68
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
- arguments = {
70
- 'n': n,
71
- 'dtype': dtype,
72
- 'max_val': max_val,
73
- }
74
- log_benchmark(arguments, mean_t, std_t)
75
 
76
 
77
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import torch
8
+ # from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
 
47
  print('=' * 60)
48
 
49
 
50
+ # class HistogramBenchmark(parameterized.TestCase):
51
+ #
52
+ # @parameterized.parameters(*_HISTOGRAM_TESTS)
53
+ # def testHistogram(self, n, dtype, max_val):
54
+ # x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
+ #
56
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
+ # arguments = {
58
+ # 'n': n,
59
+ # 'dtype': dtype,
60
+ # 'max_val': max_val,
61
+ # }
62
+ # log_benchmark(arguments, mean_t, std_t)
63
+ #
64
+ # @parameterized.parameters(*_HISTOGRAM_TESTS)
65
+ # def testTorchHistogram(self, n, dtype, max_val):
66
+ # x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
+ #
68
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
+ # arguments = {
70
+ # 'n': n,
71
+ # 'dtype': dtype,
72
+ # 'max_val': max_val,
73
+ # }
74
+ # log_benchmark(arguments, mean_t, std_t)
75
 
76
 
77
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/ops/matmul_benchmark.py CHANGED
@@ -17,7 +17,7 @@ import unittest
17
  from .. import stk
18
 
19
  import torch
20
- from absl.testing import parameterized
21
 
22
  from .. import benchmark_util, ops
23
 
@@ -48,367 +48,367 @@ def log_benchmark(name, arguments, time, std, flops):
48
  print('=' * 60)
49
 
50
 
51
- class MatmulBenchmark(parameterized.TestCase):
52
-
53
- def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
- blocking = 128
55
- padded_tokens, _ = x.size()
56
- assert padded_tokens % blocking == 0
57
- assert fhs % blocking == 0
58
-
59
- # Offsets for the sparse matrix. All rows have the
60
- # same number of nonzero blocks dictated by the
61
- # dimensionality of a single expert.
62
- block_rows = padded_tokens // blocking
63
- blocks_per_row = fhs // blocking
64
- offsets = torch.arange(
65
- 0,
66
- block_rows * blocks_per_row + 1,
67
- blocks_per_row,
68
- dtype=torch.int32,
69
- device=x.device,
70
- )
71
-
72
- # Indices for the sparse matrix. The indices for
73
- # the intermediate matrix are dynamic depending
74
- # on the mapping of tokens to experts.
75
- column_indices = ops.topology(
76
- padded_bins,
77
- blocking,
78
- block_rows,
79
- blocks_per_row,
80
- )
81
- data = torch.empty(
82
- column_indices.numel(),
83
- blocking,
84
- blocking,
85
- dtype=torch.float16,
86
- device=x.device,
87
- )
88
- shape = (padded_tokens, fhs * ne)
89
- row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
- return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
-
92
- def build_input_matrix(self, sl, hs, ne):
93
- x = torch.randn((sl, hs)).cuda().half()
94
-
95
- # Assign tokens to experts uniformly.
96
- top_expert = torch.arange(0, sl).cuda().int() % ne
97
-
98
- bin_ids, indices = ops.sort(top_expert)
99
- tokens_per_expert = ops.histogram(top_expert, ne)
100
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
- out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
- return out, padded_bins
105
-
106
- def build_weight_matrix(self, ne, hs, fhs):
107
- return torch.randn((hs, ne * fhs)).cuda().half()
108
-
109
- @parameterized.parameters(*_MATMUL_TESTS)
110
- def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
- w = transpose_view(w)
115
-
116
- def benchmark():
117
- return stk.ops.sdd(x, w, topo)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'ffn_hidden_size': fhs,
124
- 'num_experts': ne,
125
- }
126
- log_benchmark(
127
- '0::Fwd::SDD::NT',
128
- arguments,
129
- mean_t,
130
- std_t,
131
- x.numel() * fhs * 2,
132
- )
133
-
134
- @parameterized.parameters(*_MATMUL_TESTS)
135
- def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
-
140
- def benchmark():
141
- return stk.ops.dsd(topo, w)
142
-
143
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
- arguments = {
145
- 'sequence_length': sl,
146
- 'hidden_size': hs,
147
- 'ffn_hidden_size': fhs,
148
- 'num_experts': ne,
149
- }
150
- log_benchmark(
151
- '0::GradX::DSD::NN',
152
- arguments,
153
- mean_t,
154
- std_t,
155
- x.numel() * fhs * 2,
156
- )
157
-
158
- @parameterized.parameters(*_MATMUL_TESTS)
159
- def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
- topo = topo.t()
163
-
164
- def benchmark():
165
- return stk.ops.dsd(topo, x)
166
-
167
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
- arguments = {
169
- 'sequence_length': sl,
170
- 'hidden_size': hs,
171
- 'ffn_hidden_size': fhs,
172
- 'num_experts': ne,
173
- }
174
- log_benchmark(
175
- '0::GradW::DSD::TN',
176
- arguments,
177
- mean_t,
178
- std_t,
179
- x.numel() * fhs * 2,
180
- )
181
-
182
- @parameterized.parameters(*_MATMUL_TESTS)
183
- def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
-
188
- def benchmark():
189
- return stk.ops.dsd(x, w)
190
-
191
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
- arguments = {
193
- 'sequence_length': sl,
194
- 'hidden_size': hs,
195
- 'ffn_hidden_size': fhs,
196
- 'num_experts': ne,
197
- }
198
- log_benchmark(
199
- '1::Fwd::DSD::NN',
200
- arguments,
201
- mean_t,
202
- std_t,
203
- x.nnz * hs * 2,
204
- )
205
-
206
- @parameterized.parameters(*_MATMUL_TESTS)
207
- def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
- out = stk.ops.dsd(x, w)
212
- w = transpose_view(w)
213
-
214
- def benchmark():
215
- return stk.ops.sdd(out, w, x)
216
-
217
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
- arguments = {
219
- 'sequence_length': sl,
220
- 'hidden_size': hs,
221
- 'ffn_hidden_size': fhs,
222
- 'num_experts': ne,
223
- }
224
- log_benchmark(
225
- '1::GradX::SDD::NT',
226
- arguments,
227
- mean_t,
228
- std_t,
229
- x.nnz * hs * 2,
230
- )
231
-
232
- @parameterized.parameters(*_MATMUL_TESTS)
233
- def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
- out = stk.ops.dsd(x, w)
238
- x = x.t()
239
-
240
- def benchmark():
241
- return stk.ops.dsd(x, out)
242
-
243
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
- arguments = {
245
- 'sequence_length': sl,
246
- 'hidden_size': hs,
247
- 'ffn_hidden_size': fhs,
248
- 'num_experts': ne,
249
- }
250
- log_benchmark(
251
- '1::GradW::DSD::TN',
252
- arguments,
253
- mean_t,
254
- std_t,
255
- x.nnz * hs * 2,
256
- )
257
-
258
- @parameterized.parameters(*_MATMUL_TESTS)
259
- def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
- assert (sl % ne) == 0
261
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
- w = torch.randn((ne, hs, fhs)).cuda().half()
263
-
264
- w = w.transpose(1, 2).contiguous()
265
- w = w.transpose(1, 2)
266
-
267
- def benchmark():
268
- return torch.bmm(x, w)
269
-
270
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
- arguments = {
272
- 'sequence_length': sl,
273
- 'hidden_size': hs,
274
- 'ffn_hidden_size': fhs,
275
- 'num_experts': ne,
276
- }
277
- log_benchmark(
278
- '0::Fwd:DDD::NT',
279
- arguments,
280
- mean_t,
281
- std_t,
282
- x.numel() * fhs * 2,
283
- )
284
-
285
- @parameterized.parameters(*_MATMUL_TESTS)
286
- def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
- assert (sl % ne) == 0
288
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
- w = torch.randn((ne, hs, fhs)).cuda().half()
290
- out = torch.bmm(x, w)
291
- w = w.transpose(1, 2).contiguous()
292
-
293
- def benchmark():
294
- return torch.bmm(out, w)
295
-
296
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
- arguments = {
298
- 'sequence_length': sl,
299
- 'hidden_size': hs,
300
- 'ffn_hidden_size': fhs,
301
- 'num_experts': ne,
302
- }
303
- log_benchmark(
304
- '0:GradX:DDD::NN',
305
- arguments,
306
- mean_t,
307
- std_t,
308
- x.numel() * fhs * 2,
309
- )
310
-
311
- @parameterized.parameters(*_MATMUL_TESTS)
312
- def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
- assert (sl % ne) == 0
314
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
- w = torch.randn((ne, hs, fhs)).cuda().half()
316
- out = torch.bmm(x, w)
317
- out = out.transpose(1, 2)
318
-
319
- def benchmark():
320
- return torch.bmm(out, x)
321
-
322
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
- arguments = {
324
- 'sequence_length': sl,
325
- 'hidden_size': hs,
326
- 'ffn_hidden_size': fhs,
327
- 'num_experts': ne,
328
- }
329
- log_benchmark(
330
- '0:GradW:DDD::TN',
331
- arguments,
332
- mean_t,
333
- std_t,
334
- x.numel() * fhs * 2,
335
- )
336
-
337
- @parameterized.parameters(*_MATMUL_TESTS)
338
- def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
- assert (sl % ne) == 0
340
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
- w = torch.randn((ne, fhs, hs)).cuda().half()
342
-
343
- def benchmark():
344
- return torch.bmm(x, w)
345
-
346
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
- arguments = {
348
- 'sequence_length': sl,
349
- 'hidden_size': hs,
350
- 'ffn_hidden_size': fhs,
351
- 'num_experts': ne,
352
- }
353
- log_benchmark(
354
- '1::Fwd::DDD::NN',
355
- arguments,
356
- mean_t,
357
- std_t,
358
- x.numel() * hs * 2,
359
- )
360
-
361
- @parameterized.parameters(*_MATMUL_TESTS)
362
- def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
- assert (sl % ne) == 0
364
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
- w = torch.randn((ne, fhs, hs)).cuda().half()
366
- out = torch.bmm(x, w)
367
- w = torch.transpose(w, 1, 2)
368
-
369
- def benchmark():
370
- return torch.bmm(out, w)
371
-
372
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
- arguments = {
374
- 'sequence_length': sl,
375
- 'hidden_size': hs,
376
- 'ffn_hidden_size': fhs,
377
- 'num_experts': ne,
378
- }
379
- log_benchmark(
380
- '1::GradX::DDD::NT',
381
- arguments,
382
- mean_t,
383
- std_t,
384
- x.numel() * hs * 2,
385
- )
386
-
387
- @parameterized.parameters(*_MATMUL_TESTS)
388
- def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
- assert (sl % ne) == 0
390
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
- w = torch.randn((ne, fhs, hs)).cuda().half()
392
- out = torch.bmm(x, w)
393
- x = torch.transpose(x, 1, 2)
394
-
395
- def benchmark():
396
- return torch.bmm(x, out)
397
-
398
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
- arguments = {
400
- 'sequence_length': sl,
401
- 'hidden_size': hs,
402
- 'ffn_hidden_size': fhs,
403
- 'num_experts': ne,
404
- }
405
- log_benchmark(
406
- '1::GradW::DDD::TN',
407
- arguments,
408
- mean_t,
409
- std_t,
410
- x.numel() * hs * 2,
411
- )
412
 
413
 
414
  if __name__ == '__main__':
 
17
  from .. import stk
18
 
19
  import torch
20
+ # from absl.testing import parameterized
21
 
22
  from .. import benchmark_util, ops
23
 
 
48
  print('=' * 60)
49
 
50
 
51
+ # class MatmulBenchmark(parameterized.TestCase):
52
+ #
53
+ # def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
+ # blocking = 128
55
+ # padded_tokens, _ = x.size()
56
+ # assert padded_tokens % blocking == 0
57
+ # assert fhs % blocking == 0
58
+ #
59
+ # # Offsets for the sparse matrix. All rows have the
60
+ # # same number of nonzero blocks dictated by the
61
+ # # dimensionality of a single expert.
62
+ # block_rows = padded_tokens // blocking
63
+ # blocks_per_row = fhs // blocking
64
+ # offsets = torch.arange(
65
+ # 0,
66
+ # block_rows * blocks_per_row + 1,
67
+ # blocks_per_row,
68
+ # dtype=torch.int32,
69
+ # device=x.device,
70
+ # )
71
+ #
72
+ # # Indices for the sparse matrix. The indices for
73
+ # # the intermediate matrix are dynamic depending
74
+ # # on the mapping of tokens to experts.
75
+ # column_indices = ops.topology(
76
+ # padded_bins,
77
+ # blocking,
78
+ # block_rows,
79
+ # blocks_per_row,
80
+ # )
81
+ # data = torch.empty(
82
+ # column_indices.numel(),
83
+ # blocking,
84
+ # blocking,
85
+ # dtype=torch.float16,
86
+ # device=x.device,
87
+ # )
88
+ # shape = (padded_tokens, fhs * ne)
89
+ # row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
+ # return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
+ #
92
+ # def build_input_matrix(self, sl, hs, ne):
93
+ # x = torch.randn((sl, hs)).cuda().half()
94
+ #
95
+ # # Assign tokens to experts uniformly.
96
+ # top_expert = torch.arange(0, sl).cuda().int() % ne
97
+ #
98
+ # bin_ids, indices = ops.sort(top_expert)
99
+ # tokens_per_expert = ops.histogram(top_expert, ne)
100
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
+ # out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
+ # return out, padded_bins
105
+ #
106
+ # def build_weight_matrix(self, ne, hs, fhs):
107
+ # return torch.randn((hs, ne * fhs)).cuda().half()
108
+ #
109
+ # @parameterized.parameters(*_MATMUL_TESTS)
110
+ # def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
+ # w = transpose_view(w)
115
+ #
116
+ # def benchmark():
117
+ # return stk.ops.sdd(x, w, topo)
118
+ #
119
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ # arguments = {
121
+ # 'sequence_length': sl,
122
+ # 'hidden_size': hs,
123
+ # 'ffn_hidden_size': fhs,
124
+ # 'num_experts': ne,
125
+ # }
126
+ # log_benchmark(
127
+ # '0::Fwd::SDD::NT',
128
+ # arguments,
129
+ # mean_t,
130
+ # std_t,
131
+ # x.numel() * fhs * 2,
132
+ # )
133
+ #
134
+ # @parameterized.parameters(*_MATMUL_TESTS)
135
+ # def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
+ #
140
+ # def benchmark():
141
+ # return stk.ops.dsd(topo, w)
142
+ #
143
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
+ # arguments = {
145
+ # 'sequence_length': sl,
146
+ # 'hidden_size': hs,
147
+ # 'ffn_hidden_size': fhs,
148
+ # 'num_experts': ne,
149
+ # }
150
+ # log_benchmark(
151
+ # '0::GradX::DSD::NN',
152
+ # arguments,
153
+ # mean_t,
154
+ # std_t,
155
+ # x.numel() * fhs * 2,
156
+ # )
157
+ #
158
+ # @parameterized.parameters(*_MATMUL_TESTS)
159
+ # def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
+ # topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
+ # topo = topo.t()
163
+ #
164
+ # def benchmark():
165
+ # return stk.ops.dsd(topo, x)
166
+ #
167
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
+ # arguments = {
169
+ # 'sequence_length': sl,
170
+ # 'hidden_size': hs,
171
+ # 'ffn_hidden_size': fhs,
172
+ # 'num_experts': ne,
173
+ # }
174
+ # log_benchmark(
175
+ # '0::GradW::DSD::TN',
176
+ # arguments,
177
+ # mean_t,
178
+ # std_t,
179
+ # x.numel() * fhs * 2,
180
+ # )
181
+ #
182
+ # @parameterized.parameters(*_MATMUL_TESTS)
183
+ # def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
+ #
188
+ # def benchmark():
189
+ # return stk.ops.dsd(x, w)
190
+ #
191
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
+ # arguments = {
193
+ # 'sequence_length': sl,
194
+ # 'hidden_size': hs,
195
+ # 'ffn_hidden_size': fhs,
196
+ # 'num_experts': ne,
197
+ # }
198
+ # log_benchmark(
199
+ # '1::Fwd::DSD::NN',
200
+ # arguments,
201
+ # mean_t,
202
+ # std_t,
203
+ # x.nnz * hs * 2,
204
+ # )
205
+ #
206
+ # @parameterized.parameters(*_MATMUL_TESTS)
207
+ # def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
+ # out = stk.ops.dsd(x, w)
212
+ # w = transpose_view(w)
213
+ #
214
+ # def benchmark():
215
+ # return stk.ops.sdd(out, w, x)
216
+ #
217
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
+ # arguments = {
219
+ # 'sequence_length': sl,
220
+ # 'hidden_size': hs,
221
+ # 'ffn_hidden_size': fhs,
222
+ # 'num_experts': ne,
223
+ # }
224
+ # log_benchmark(
225
+ # '1::GradX::SDD::NT',
226
+ # arguments,
227
+ # mean_t,
228
+ # std_t,
229
+ # x.nnz * hs * 2,
230
+ # )
231
+ #
232
+ # @parameterized.parameters(*_MATMUL_TESTS)
233
+ # def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
+ # x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
+ # w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
+ # x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
+ # out = stk.ops.dsd(x, w)
238
+ # x = x.t()
239
+ #
240
+ # def benchmark():
241
+ # return stk.ops.dsd(x, out)
242
+ #
243
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
+ # arguments = {
245
+ # 'sequence_length': sl,
246
+ # 'hidden_size': hs,
247
+ # 'ffn_hidden_size': fhs,
248
+ # 'num_experts': ne,
249
+ # }
250
+ # log_benchmark(
251
+ # '1::GradW::DSD::TN',
252
+ # arguments,
253
+ # mean_t,
254
+ # std_t,
255
+ # x.nnz * hs * 2,
256
+ # )
257
+ #
258
+ # @parameterized.parameters(*_MATMUL_TESTS)
259
+ # def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
+ # assert (sl % ne) == 0
261
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
263
+ #
264
+ # w = w.transpose(1, 2).contiguous()
265
+ # w = w.transpose(1, 2)
266
+ #
267
+ # def benchmark():
268
+ # return torch.bmm(x, w)
269
+ #
270
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
+ # arguments = {
272
+ # 'sequence_length': sl,
273
+ # 'hidden_size': hs,
274
+ # 'ffn_hidden_size': fhs,
275
+ # 'num_experts': ne,
276
+ # }
277
+ # log_benchmark(
278
+ # '0::Fwd:DDD::NT',
279
+ # arguments,
280
+ # mean_t,
281
+ # std_t,
282
+ # x.numel() * fhs * 2,
283
+ # )
284
+ #
285
+ # @parameterized.parameters(*_MATMUL_TESTS)
286
+ # def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
+ # assert (sl % ne) == 0
288
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
290
+ # out = torch.bmm(x, w)
291
+ # w = w.transpose(1, 2).contiguous()
292
+ #
293
+ # def benchmark():
294
+ # return torch.bmm(out, w)
295
+ #
296
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
+ # arguments = {
298
+ # 'sequence_length': sl,
299
+ # 'hidden_size': hs,
300
+ # 'ffn_hidden_size': fhs,
301
+ # 'num_experts': ne,
302
+ # }
303
+ # log_benchmark(
304
+ # '0:GradX:DDD::NN',
305
+ # arguments,
306
+ # mean_t,
307
+ # std_t,
308
+ # x.numel() * fhs * 2,
309
+ # )
310
+ #
311
+ # @parameterized.parameters(*_MATMUL_TESTS)
312
+ # def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
+ # assert (sl % ne) == 0
314
+ # x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
+ # w = torch.randn((ne, hs, fhs)).cuda().half()
316
+ # out = torch.bmm(x, w)
317
+ # out = out.transpose(1, 2)
318
+ #
319
+ # def benchmark():
320
+ # return torch.bmm(out, x)
321
+ #
322
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
+ # arguments = {
324
+ # 'sequence_length': sl,
325
+ # 'hidden_size': hs,
326
+ # 'ffn_hidden_size': fhs,
327
+ # 'num_experts': ne,
328
+ # }
329
+ # log_benchmark(
330
+ # '0:GradW:DDD::TN',
331
+ # arguments,
332
+ # mean_t,
333
+ # std_t,
334
+ # x.numel() * fhs * 2,
335
+ # )
336
+ #
337
+ # @parameterized.parameters(*_MATMUL_TESTS)
338
+ # def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
+ # assert (sl % ne) == 0
340
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
342
+ #
343
+ # def benchmark():
344
+ # return torch.bmm(x, w)
345
+ #
346
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
+ # arguments = {
348
+ # 'sequence_length': sl,
349
+ # 'hidden_size': hs,
350
+ # 'ffn_hidden_size': fhs,
351
+ # 'num_experts': ne,
352
+ # }
353
+ # log_benchmark(
354
+ # '1::Fwd::DDD::NN',
355
+ # arguments,
356
+ # mean_t,
357
+ # std_t,
358
+ # x.numel() * hs * 2,
359
+ # )
360
+ #
361
+ # @parameterized.parameters(*_MATMUL_TESTS)
362
+ # def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
+ # assert (sl % ne) == 0
364
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
366
+ # out = torch.bmm(x, w)
367
+ # w = torch.transpose(w, 1, 2)
368
+ #
369
+ # def benchmark():
370
+ # return torch.bmm(out, w)
371
+ #
372
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
+ # arguments = {
374
+ # 'sequence_length': sl,
375
+ # 'hidden_size': hs,
376
+ # 'ffn_hidden_size': fhs,
377
+ # 'num_experts': ne,
378
+ # }
379
+ # log_benchmark(
380
+ # '1::GradX::DDD::NT',
381
+ # arguments,
382
+ # mean_t,
383
+ # std_t,
384
+ # x.numel() * hs * 2,
385
+ # )
386
+ #
387
+ # @parameterized.parameters(*_MATMUL_TESTS)
388
+ # def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
+ # assert (sl % ne) == 0
390
+ # x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
+ # w = torch.randn((ne, fhs, hs)).cuda().half()
392
+ # out = torch.bmm(x, w)
393
+ # x = torch.transpose(x, 1, 2)
394
+ #
395
+ # def benchmark():
396
+ # return torch.bmm(x, out)
397
+ #
398
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
+ # arguments = {
400
+ # 'sequence_length': sl,
401
+ # 'hidden_size': hs,
402
+ # 'ffn_hidden_size': fhs,
403
+ # 'num_experts': ne,
404
+ # }
405
+ # log_benchmark(
406
+ # '1::GradW::DDD::TN',
407
+ # arguments,
408
+ # mean_t,
409
+ # std_t,
410
+ # x.numel() * hs * 2,
411
+ # )
412
 
413
 
414
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/ops/padded_scatter_benchmark.py CHANGED
@@ -4,7 +4,7 @@
4
  import unittest
5
 
6
  import torch
7
- from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
@@ -16,50 +16,50 @@ _PADDED_SCATTER_BENCHMARK = (
16
  )
17
 
18
 
19
- class PaddedScatterTest(parameterized.TestCase):
20
-
21
- @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
- def testPaddedScatter(self, sl, hs, ne, top_k):
23
- # Create the data and indices.
24
- x = torch.randn((sl, hs)).cuda().half()
25
-
26
- # Randomly assign tokens to experts.
27
- top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
- bin_ids, indices = ops.sort(top_expert)
29
- tokens_per_expert = ops.histogram(top_expert, ne)
30
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
-
34
- # Sample weights for the scatter reduce.
35
- weights = torch.rand((sl * top_k,)).cuda().half()
36
-
37
- # Gather the data to prepare for backwards.
38
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
-
40
- def benchmark():
41
- return ops.padded_scatter(
42
- x,
43
- indices,
44
- bin_ids,
45
- weights,
46
- bins,
47
- padded_bins,
48
- top_k,
49
- )
50
-
51
- time, std = benchmark_util.benchmark_function(benchmark)
52
- benchmark_util.log_benchmark(
53
- 'Padded Scatter',
54
- {
55
- 'sequence_length': sl,
56
- 'hidden_size': hs,
57
- 'num_experts': ne,
58
- 'top_k': top_k,
59
- },
60
- time,
61
- std,
62
- )
63
 
64
 
65
  if __name__ == '__main__':
 
4
  import unittest
5
 
6
  import torch
7
+ # from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
 
16
  )
17
 
18
 
19
+ # class PaddedScatterTest(parameterized.TestCase):
20
+ #
21
+ # @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
+ # def testPaddedScatter(self, sl, hs, ne, top_k):
23
+ # # Create the data and indices.
24
+ # x = torch.randn((sl, hs)).cuda().half()
25
+ #
26
+ # # Randomly assign tokens to experts.
27
+ # top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
+ # bin_ids, indices = ops.sort(top_expert)
29
+ # tokens_per_expert = ops.histogram(top_expert, ne)
30
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
+ #
34
+ # # Sample weights for the scatter reduce.
35
+ # weights = torch.rand((sl * top_k,)).cuda().half()
36
+ #
37
+ # # Gather the data to prepare for backwards.
38
+ # x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
+ #
40
+ # def benchmark():
41
+ # return ops.padded_scatter(
42
+ # x,
43
+ # indices,
44
+ # bin_ids,
45
+ # weights,
46
+ # bins,
47
+ # padded_bins,
48
+ # top_k,
49
+ # )
50
+ #
51
+ # time, std = benchmark_util.benchmark_function(benchmark)
52
+ # benchmark_util.log_benchmark(
53
+ # 'Padded Scatter',
54
+ # {
55
+ # 'sequence_length': sl,
56
+ # 'hidden_size': hs,
57
+ # 'num_experts': ne,
58
+ # 'top_k': top_k,
59
+ # },
60
+ # time,
61
+ # std,
62
+ # )
63
 
64
 
65
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/ops/permute_benchmark.py CHANGED
@@ -4,7 +4,7 @@
4
  import unittest
5
 
6
  import torch
7
- from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
@@ -26,123 +26,123 @@ _PERMUTE_TESTS = (
26
  )
27
 
28
 
29
- class PermuteBenchmark(parameterized.TestCase):
30
-
31
- @parameterized.parameters(*_PERMUTE_TESTS)
32
- def testBinnedGather(self, sl, hs, ne):
33
- # NOTE: Capacity factor == 1.
34
- ec = sl // ne
35
-
36
- # Create the data and indices.
37
- x = torch.randn((sl, hs)).cuda().half()
38
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
- bin_ids, indices = ops.sort(top_expert)
40
- tokens_per_expert = ops.histogram(indices, ne)
41
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
-
43
- def benchmark():
44
- return ops.binned_gather(x, indices, bins, ec)
45
-
46
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
- arguments = {
48
- 'sequence_length': sl,
49
- 'hidden_size': hs,
50
- 'num_experts': ne,
51
- }
52
- benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
-
54
- @parameterized.parameters(*_PERMUTE_TESTS)
55
- def testBinnedScatter(self, sl, hs, ne):
56
- # NOTE: Capacity factor == 1.
57
- ec = sl // ne
58
-
59
- # Create the data and indices.
60
- x = torch.randn((sl, hs)).cuda().half()
61
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
- bin_ids, indices = ops.sort(top_expert)
63
- tokens_per_expert = ops.histogram(indices, ne)
64
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
- x = ops.binned_gather(x, indices, bins, ec)
66
-
67
- def benchmark():
68
- return ops.binned_scatter(x, indices, bins)
69
-
70
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
- arguments = {
72
- 'sequence_length': sl,
73
- 'hidden_size': hs,
74
- 'num_experts': ne,
75
- }
76
- benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
-
78
- @parameterized.parameters(*_PERMUTE_TESTS)
79
- def testPaddedGather(self, sl, hs, ne):
80
- # Create the data and indices.
81
- x = torch.randn((sl, hs)).cuda().half()
82
-
83
- # Randomly assign tokens to experts.
84
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
- bin_ids, indices = ops.sort(top_expert)
86
- tokens_per_expert = ops.histogram(top_expert, ne)
87
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
-
91
- def benchmark():
92
- return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
-
94
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
- arguments = {
96
- 'sequence_length': sl,
97
- 'hidden_size': hs,
98
- 'num_experts': ne,
99
- }
100
- benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
-
102
- @parameterized.parameters(*_PERMUTE_TESTS)
103
- def testPaddedScatter(self, sl, hs, ne):
104
- # Create the data and indices.
105
- x = torch.randn((sl, hs)).cuda().half()
106
-
107
- # Randomly assign tokens to experts.
108
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
- bin_ids, indices = ops.sort(top_expert)
110
- tokens_per_expert = ops.histogram(top_expert, ne)
111
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
-
116
- def benchmark():
117
- return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'num_experts': ne,
124
- }
125
- benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
-
127
- @parameterized.parameters(*_PERMUTE_TESTS)
128
- def testCopy(self, sl, hs, ne):
129
- # NOTE: Capacity factor == 1.
130
- # ec = sl // ne
131
-
132
- # Create the data and indices.
133
- x = torch.randn((sl, hs)).cuda().half()
134
- y = x.clone()
135
-
136
- def benchmark():
137
- return y.copy_(x)
138
-
139
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
- arguments = {
141
- 'sequence_length': sl,
142
- 'hidden_size': hs,
143
- 'num_experts': ne,
144
- }
145
- benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
 
147
 
148
  if __name__ == '__main__':
 
4
  import unittest
5
 
6
  import torch
7
+ # from absl.testing import parameterized
8
 
9
  from .. import benchmark_util, ops
10
 
 
26
  )
27
 
28
 
29
+ # class PermuteBenchmark(parameterized.TestCase):
30
+ #
31
+ # @parameterized.parameters(*_PERMUTE_TESTS)
32
+ # def testBinnedGather(self, sl, hs, ne):
33
+ # # NOTE: Capacity factor == 1.
34
+ # ec = sl // ne
35
+ #
36
+ # # Create the data and indices.
37
+ # x = torch.randn((sl, hs)).cuda().half()
38
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
+ # bin_ids, indices = ops.sort(top_expert)
40
+ # tokens_per_expert = ops.histogram(indices, ne)
41
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
+ #
43
+ # def benchmark():
44
+ # return ops.binned_gather(x, indices, bins, ec)
45
+ #
46
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
+ # arguments = {
48
+ # 'sequence_length': sl,
49
+ # 'hidden_size': hs,
50
+ # 'num_experts': ne,
51
+ # }
52
+ # benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
+ #
54
+ # @parameterized.parameters(*_PERMUTE_TESTS)
55
+ # def testBinnedScatter(self, sl, hs, ne):
56
+ # # NOTE: Capacity factor == 1.
57
+ # ec = sl // ne
58
+ #
59
+ # # Create the data and indices.
60
+ # x = torch.randn((sl, hs)).cuda().half()
61
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
+ # bin_ids, indices = ops.sort(top_expert)
63
+ # tokens_per_expert = ops.histogram(indices, ne)
64
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+ # x = ops.binned_gather(x, indices, bins, ec)
66
+ #
67
+ # def benchmark():
68
+ # return ops.binned_scatter(x, indices, bins)
69
+ #
70
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
+ # arguments = {
72
+ # 'sequence_length': sl,
73
+ # 'hidden_size': hs,
74
+ # 'num_experts': ne,
75
+ # }
76
+ # benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
+ #
78
+ # @parameterized.parameters(*_PERMUTE_TESTS)
79
+ # def testPaddedGather(self, sl, hs, ne):
80
+ # # Create the data and indices.
81
+ # x = torch.randn((sl, hs)).cuda().half()
82
+ #
83
+ # # Randomly assign tokens to experts.
84
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
+ # bin_ids, indices = ops.sort(top_expert)
86
+ # tokens_per_expert = ops.histogram(top_expert, ne)
87
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+ #
91
+ # def benchmark():
92
+ # return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
+ #
94
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
+ # arguments = {
96
+ # 'sequence_length': sl,
97
+ # 'hidden_size': hs,
98
+ # 'num_experts': ne,
99
+ # }
100
+ # benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
+ #
102
+ # @parameterized.parameters(*_PERMUTE_TESTS)
103
+ # def testPaddedScatter(self, sl, hs, ne):
104
+ # # Create the data and indices.
105
+ # x = torch.randn((sl, hs)).cuda().half()
106
+ #
107
+ # # Randomly assign tokens to experts.
108
+ # top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
+ # bin_ids, indices = ops.sort(top_expert)
110
+ # tokens_per_expert = ops.histogram(top_expert, ne)
111
+ # padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
+ # padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
+ # x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
+ #
116
+ # def benchmark():
117
+ # return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
+ #
119
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ # arguments = {
121
+ # 'sequence_length': sl,
122
+ # 'hidden_size': hs,
123
+ # 'num_experts': ne,
124
+ # }
125
+ # benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
+ #
127
+ # @parameterized.parameters(*_PERMUTE_TESTS)
128
+ # def testCopy(self, sl, hs, ne):
129
+ # # NOTE: Capacity factor == 1.
130
+ # # ec = sl // ne
131
+ #
132
+ # # Create the data and indices.
133
+ # x = torch.randn((sl, hs)).cuda().half()
134
+ # y = x.clone()
135
+ #
136
+ # def benchmark():
137
+ # return y.copy_(x)
138
+ #
139
+ # mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
+ # arguments = {
141
+ # 'sequence_length': sl,
142
+ # 'hidden_size': hs,
143
+ # 'num_experts': ne,
144
+ # }
145
+ # benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
 
147
 
148
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/ops/sort_benchmark.py CHANGED
@@ -5,7 +5,7 @@ import unittest
5
 
6
  import numpy as np
7
  import torch
8
- from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
@@ -53,32 +53,32 @@ def log_benchmark(arguments, mean_t, std_t):
53
  print('=' * 60)
54
 
55
 
56
- class SortBenchmark(parameterized.TestCase):
57
-
58
- @parameterized.parameters(*_SORT_TESTS)
59
- def testSort(self, n, dtype, max_val):
60
- if max_val is None:
61
- max_val = np.iinfo(numpy_dtype(dtype)).max
62
- end_bit = int(np.ceil(np.log2(max_val)))
63
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
-
65
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
- arguments = {
67
- 'n': n,
68
- 'dtype': dtype,
69
- 'max_val': max_val,
70
- }
71
- log_benchmark(arguments, mean_t, std_t)
72
-
73
- @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
- def testTorchSort(self, n):
75
- x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
-
77
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
- arguments = {
79
- 'n': n,
80
- }
81
- log_benchmark(arguments, mean_t, std_t)
82
 
83
 
84
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import torch
8
+ # from absl.testing import parameterized
9
 
10
  from .. import ops
11
 
 
53
  print('=' * 60)
54
 
55
 
56
+ # class SortBenchmark(parameterized.TestCase):
57
+ #
58
+ # @parameterized.parameters(*_SORT_TESTS)
59
+ # def testSort(self, n, dtype, max_val):
60
+ # if max_val is None:
61
+ # max_val = np.iinfo(numpy_dtype(dtype)).max
62
+ # end_bit = int(np.ceil(np.log2(max_val)))
63
+ # x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
+ #
65
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
+ # arguments = {
67
+ # 'n': n,
68
+ # 'dtype': dtype,
69
+ # 'max_val': max_val,
70
+ # }
71
+ # log_benchmark(arguments, mean_t, std_t)
72
+ #
73
+ # @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
+ # def testTorchSort(self, n):
75
+ # x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
+ #
77
+ # mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
+ # arguments = {
79
+ # 'n': n,
80
+ # }
81
+ # log_benchmark(arguments, mean_t, std_t)
82
 
83
 
84
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/stk/ops/eltwise_ops_test.py CHANGED
@@ -1,7 +1,7 @@
1
  import unittest
2
  import itertools
3
  import torch
4
- from absl.testing import parameterized
5
 
6
  import stk
7
  from stk.ops.linear_ops_test import allclose, _dense_and_sparse
@@ -47,40 +47,40 @@ def _dense_and_sparse_like(x, std=0.1):
47
  return (dense.requires_grad_(True),
48
  sparse.requires_grad_(True))
49
 
50
- @parameterized.parameters(_ELTWISE_OP_TESTS)
51
- class EltwiseOpsTest(parameterized.TestCase):
52
-
53
- def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
-
55
- a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
- b_dense, b = _dense_and_sparse_like(a)
57
-
58
- out = stk.ops.mul(a, b)
59
- expected_out = torch.mul(a_dense, b_dense)
60
-
61
- # Compute the gradients w.r.t. the inputs.
62
- expected_out.sum().backward()
63
- stk.ops.sum(out).backward()
64
-
65
- # Validate the results.
66
- out = stk.ops.to_dense(out)
67
- self.assertEqual(out.dim(), 2)
68
- self.assertEqual(expected_out.size(), out.size())
69
- self.assertTrue(allclose(out, expected_out))
70
-
71
- # LHS gradient.
72
- grad = stk.ops.to_dense(a.grad)
73
- expected_grad = a_dense.grad
74
- self.assertEqual(grad.dim(), 2)
75
- self.assertEqual(expected_grad.size(), grad.size())
76
- self.assertTrue(allclose(grad, expected_grad))
77
-
78
- # RHS gradient.
79
- grad = stk.ops.to_dense(b.grad)
80
- expected_grad = b_dense.grad
81
- self.assertEqual(grad.dim(), 2)
82
- self.assertEqual(expected_grad.size(), grad.size())
83
- self.assertTrue(allclose(grad, expected_grad))
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
 
1
  import unittest
2
  import itertools
3
  import torch
4
+ # from absl.testing import parameterized
5
 
6
  import stk
7
  from stk.ops.linear_ops_test import allclose, _dense_and_sparse
 
47
  return (dense.requires_grad_(True),
48
  sparse.requires_grad_(True))
49
 
50
+ # @parameterized.parameters(_ELTWISE_OP_TESTS)
51
+ # class EltwiseOpsTest(parameterized.TestCase):
52
+ #
53
+ # def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
54
+ #
55
+ # a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
56
+ # b_dense, b = _dense_and_sparse_like(a)
57
+ #
58
+ # out = stk.ops.mul(a, b)
59
+ # expected_out = torch.mul(a_dense, b_dense)
60
+ #
61
+ # # Compute the gradients w.r.t. the inputs.
62
+ # expected_out.sum().backward()
63
+ # stk.ops.sum(out).backward()
64
+ #
65
+ # # Validate the results.
66
+ # out = stk.ops.to_dense(out)
67
+ # self.assertEqual(out.dim(), 2)
68
+ # self.assertEqual(expected_out.size(), out.size())
69
+ # self.assertTrue(allclose(out, expected_out))
70
+ #
71
+ # # LHS gradient.
72
+ # grad = stk.ops.to_dense(a.grad)
73
+ # expected_grad = a_dense.grad
74
+ # self.assertEqual(grad.dim(), 2)
75
+ # self.assertEqual(expected_grad.size(), grad.size())
76
+ # self.assertTrue(allclose(grad, expected_grad))
77
+ #
78
+ # # RHS gradient.
79
+ # grad = stk.ops.to_dense(b.grad)
80
+ # expected_grad = b_dense.grad
81
+ # self.assertEqual(grad.dim(), 2)
82
+ # self.assertEqual(expected_grad.size(), grad.size())
83
+ # self.assertTrue(allclose(grad, expected_grad))
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/stk/ops/linear_ops_test.py CHANGED
@@ -2,7 +2,7 @@ import unittest
2
  import itertools
3
  import numpy as np
4
  import torch
5
- from absl.testing import parameterized
6
 
7
  import stk
8
 
@@ -96,121 +96,121 @@ def _mask(x, mask):
96
  return x * mask
97
 
98
 
99
- @parameterized.parameters(*_LINEAR_OP_TESTS)
100
- class LinearOpsTest(parameterized.TestCase):
101
-
102
- def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
- # Construct the operands.
104
- a_shape = (k, m) if trans_a else (m, k)
105
- a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
- b_shape = (n, k) if trans_b else (k, n)
107
- b, bcp = _dense_2x(*b_shape, dtype)
108
-
109
- # Execute the matmul.
110
- out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
- expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
-
113
- # Compute the gradients w.r.t. the inputs.
114
- expected_out.sum().backward()
115
- out.sum().backward()
116
-
117
- # Validate the results.
118
- self.assertEqual(out.dim(), 2)
119
- self.assertEqual(expected_out.size()[0], out.size()[0])
120
- self.assertEqual(expected_out.size()[1], out.size()[1])
121
- self.assertTrue(allclose(out, expected_out))
122
-
123
- # LHS gradient.
124
- grad = stk.ops.to_dense(a.grad)
125
- expected_grad = _mask(a_dense.grad, a.grad)
126
- self.assertEqual(grad.dim(), 2)
127
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
- self.assertTrue(allclose(grad, expected_grad))
130
-
131
- # RHS gradient.
132
- grad = b.grad
133
- expected_grad = bcp.grad
134
- self.assertEqual(grad.dim(), 2)
135
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
- self.assertTrue(allclose(grad, expected_grad))
138
-
139
- def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
- # Construct the operands.
141
- a_shape = (k, m) if trans_a else (m, k)
142
- a, acp = _dense_2x(*a_shape, dtype)
143
- b_shape = (n, k) if trans_b else (k, n)
144
- b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
-
146
- # Execute the matmul.
147
- out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
- expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
-
150
- # Compute the gradients w.r.t. the inputs.
151
- expected_out.sum().backward()
152
- out.sum().backward()
153
-
154
- # Validate the results.
155
- self.assertEqual(out.dim(), 2)
156
- self.assertEqual(expected_out.size()[0], out.size()[0])
157
- self.assertEqual(expected_out.size()[1], out.size()[1])
158
- self.assertTrue(allclose(out, expected_out))
159
-
160
- # LHS gradient.
161
- grad = a.grad
162
- expected_grad = acp.grad
163
- self.assertEqual(grad.dim(), 2)
164
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
- self.assertTrue(allclose(grad, expected_grad))
167
-
168
- # RHS gradient.
169
- grad = stk.ops.to_dense(b.grad)
170
- expected_grad = _mask(b_dense.grad, b.grad)
171
- self.assertEqual(grad.dim(), 2)
172
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
- self.assertTrue(allclose(grad, expected_grad))
175
-
176
- def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
- # Construct the operands.
178
- a_shape = (k, m) if trans_a else (m, k)
179
- a, acp = _dense_2x(*a_shape, dtype)
180
- b_shape = (n, k) if trans_b else (k, n)
181
- b, bcp = _dense_2x(*b_shape, dtype)
182
- _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
-
184
- # Execute the matmul.
185
- out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
- expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
-
188
- # Compute the gradients w.r.t. the inputs.
189
- expected_out.sum().backward()
190
- stk.ops.sum(out).backward()
191
-
192
- # Validate the results.
193
- out = stk.ops.to_dense(out)
194
- self.assertEqual(out.dim(), 2)
195
- self.assertEqual(expected_out.size()[0], out.size()[0])
196
- self.assertEqual(expected_out.size()[1], out.size()[1])
197
- self.assertTrue(allclose(out, expected_out))
198
-
199
- # LHS gradient.
200
- grad = a.grad
201
- expected_grad = acp.grad
202
- self.assertEqual(grad.dim(), 2)
203
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
- self.assertTrue(allclose(grad, expected_grad))
206
-
207
- # RHS gradient.
208
- grad = b.grad
209
- expected_grad = bcp.grad
210
- self.assertEqual(grad.dim(), 2)
211
- self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
- self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
- self.assertTrue(allclose(grad, expected_grad))
214
 
215
  if __name__ == '__main__':
216
  unittest.main()
 
2
  import itertools
3
  import numpy as np
4
  import torch
5
+ # from absl.testing import parameterized
6
 
7
  import stk
8
 
 
96
  return x * mask
97
 
98
 
99
+ # @parameterized.parameters(*_LINEAR_OP_TESTS)
100
+ # class LinearOpsTest(parameterized.TestCase):
101
+ #
102
+ # def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
103
+ # # Construct the operands.
104
+ # a_shape = (k, m) if trans_a else (m, k)
105
+ # a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
106
+ # b_shape = (n, k) if trans_b else (k, n)
107
+ # b, bcp = _dense_2x(*b_shape, dtype)
108
+ #
109
+ # # Execute the matmul.
110
+ # out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
111
+ # expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
112
+ #
113
+ # # Compute the gradients w.r.t. the inputs.
114
+ # expected_out.sum().backward()
115
+ # out.sum().backward()
116
+ #
117
+ # # Validate the results.
118
+ # self.assertEqual(out.dim(), 2)
119
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
120
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
121
+ # self.assertTrue(allclose(out, expected_out))
122
+ #
123
+ # # LHS gradient.
124
+ # grad = stk.ops.to_dense(a.grad)
125
+ # expected_grad = _mask(a_dense.grad, a.grad)
126
+ # self.assertEqual(grad.dim(), 2)
127
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
128
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
129
+ # self.assertTrue(allclose(grad, expected_grad))
130
+ #
131
+ # # RHS gradient.
132
+ # grad = b.grad
133
+ # expected_grad = bcp.grad
134
+ # self.assertEqual(grad.dim(), 2)
135
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
136
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
137
+ # self.assertTrue(allclose(grad, expected_grad))
138
+ #
139
+ # def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
140
+ # # Construct the operands.
141
+ # a_shape = (k, m) if trans_a else (m, k)
142
+ # a, acp = _dense_2x(*a_shape, dtype)
143
+ # b_shape = (n, k) if trans_b else (k, n)
144
+ # b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
145
+ #
146
+ # # Execute the matmul.
147
+ # out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
148
+ # expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
149
+ #
150
+ # # Compute the gradients w.r.t. the inputs.
151
+ # expected_out.sum().backward()
152
+ # out.sum().backward()
153
+ #
154
+ # # Validate the results.
155
+ # self.assertEqual(out.dim(), 2)
156
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
157
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
158
+ # self.assertTrue(allclose(out, expected_out))
159
+ #
160
+ # # LHS gradient.
161
+ # grad = a.grad
162
+ # expected_grad = acp.grad
163
+ # self.assertEqual(grad.dim(), 2)
164
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
165
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
166
+ # self.assertTrue(allclose(grad, expected_grad))
167
+ #
168
+ # # RHS gradient.
169
+ # grad = stk.ops.to_dense(b.grad)
170
+ # expected_grad = _mask(b_dense.grad, b.grad)
171
+ # self.assertEqual(grad.dim(), 2)
172
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
173
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
174
+ # self.assertTrue(allclose(grad, expected_grad))
175
+ #
176
+ # def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
177
+ # # Construct the operands.
178
+ # a_shape = (k, m) if trans_a else (m, k)
179
+ # a, acp = _dense_2x(*a_shape, dtype)
180
+ # b_shape = (n, k) if trans_b else (k, n)
181
+ # b, bcp = _dense_2x(*b_shape, dtype)
182
+ # _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
183
+ #
184
+ # # Execute the matmul.
185
+ # out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
186
+ # expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
187
+ #
188
+ # # Compute the gradients w.r.t. the inputs.
189
+ # expected_out.sum().backward()
190
+ # stk.ops.sum(out).backward()
191
+ #
192
+ # # Validate the results.
193
+ # out = stk.ops.to_dense(out)
194
+ # self.assertEqual(out.dim(), 2)
195
+ # self.assertEqual(expected_out.size()[0], out.size()[0])
196
+ # self.assertEqual(expected_out.size()[1], out.size()[1])
197
+ # self.assertTrue(allclose(out, expected_out))
198
+ #
199
+ # # LHS gradient.
200
+ # grad = a.grad
201
+ # expected_grad = acp.grad
202
+ # self.assertEqual(grad.dim(), 2)
203
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
204
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
205
+ # self.assertTrue(allclose(grad, expected_grad))
206
+ #
207
+ # # RHS gradient.
208
+ # grad = b.grad
209
+ # expected_grad = bcp.grad
210
+ # self.assertEqual(grad.dim(), 2)
211
+ # self.assertEqual(expected_grad.size()[0], grad.size()[0])
212
+ # self.assertEqual(expected_grad.size()[1], grad.size()[1])
213
+ # self.assertTrue(allclose(grad, expected_grad))
214
 
215
  if __name__ == '__main__':
216
  unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/stk/ops/matrix_ops_test.py CHANGED
@@ -1,61 +1,61 @@
1
  import unittest
2
 
3
- from absl.testing import parameterized
4
  import stk
5
  import torch
6
 
7
 
8
- @parameterized.parameters(
9
- (8, 16, 0.0, 1),
10
- (8, 16, 0.5, 1),
11
- (8, 16, .95, 1),
12
- (16, 8, 0.0, 1),
13
- (16, 8, 0.5, 1),
14
- (16, 8, .95, 1),
15
- (8, 16, 0.0, 8),
16
- (8, 16, 0.5, 8),
17
- (8, 16, 1.0, 8),
18
- (16, 8, 0.0, 8),
19
- (16, 8, 0.5, 8),
20
- (16, 8, 1.0, 8),
21
- (128, 256, 0.5, 16),
22
- (256, 128, 0.75, 32),
23
- (512, 512, .875, 128))
24
- class MatrixOpsTest(parameterized.TestCase):
25
-
26
- def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
- mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
- x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
-
30
- # Convert the matrix to sparse format.
31
- sparse_x = stk.ops.to_sparse(x, blocking)
32
-
33
- # Validate the matrix.
34
- sparse_x.validate()
35
-
36
- # Validate the shape.
37
- self.assertEqual(sparse_x.dim(), 2)
38
- self.assertEqual(sparse_x.size()[0], rows)
39
- self.assertEqual(sparse_x.size()[1], cols)
40
-
41
- # Validate the sparsity.
42
- numblocks = rows // blocking * cols // blocking
43
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
- self.assertEqual(sparse_x.nnz, nnz)
45
-
46
- # Convert back to dense format.
47
- dense_x = stk.ops.to_dense(sparse_x)
48
-
49
- # Validate the shape.
50
- self.assertEqual(dense_x.dim(), 2)
51
- self.assertEqual(dense_x.size()[0], rows)
52
- self.assertEqual(dense_x.size()[1], cols)
53
-
54
- # Validate the sparsity
55
- self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
-
57
- # Validate the output.
58
- self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
 
60
 
61
  if __name__ == '__main__':
 
1
  import unittest
2
 
3
+ # from absl.testing import parameterized
4
  import stk
5
  import torch
6
 
7
 
8
+ # @parameterized.parameters(
9
+ # (8, 16, 0.0, 1),
10
+ # (8, 16, 0.5, 1),
11
+ # (8, 16, .95, 1),
12
+ # (16, 8, 0.0, 1),
13
+ # (16, 8, 0.5, 1),
14
+ # (16, 8, .95, 1),
15
+ # (8, 16, 0.0, 8),
16
+ # (8, 16, 0.5, 8),
17
+ # (8, 16, 1.0, 8),
18
+ # (16, 8, 0.0, 8),
19
+ # (16, 8, 0.5, 8),
20
+ # (16, 8, 1.0, 8),
21
+ # (128, 256, 0.5, 16),
22
+ # (256, 128, 0.75, 32),
23
+ # (512, 512, .875, 128))
24
+ # class MatrixOpsTest(parameterized.TestCase):
25
+ #
26
+ # def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
27
+ # mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
28
+ # x = (torch.randn(rows, cols) * mask).type(torch.float16)
29
+ #
30
+ # # Convert the matrix to sparse format.
31
+ # sparse_x = stk.ops.to_sparse(x, blocking)
32
+ #
33
+ # # Validate the matrix.
34
+ # sparse_x.validate()
35
+ #
36
+ # # Validate the shape.
37
+ # self.assertEqual(sparse_x.dim(), 2)
38
+ # self.assertEqual(sparse_x.size()[0], rows)
39
+ # self.assertEqual(sparse_x.size()[1], cols)
40
+ #
41
+ # # Validate the sparsity.
42
+ # numblocks = rows // blocking * cols // blocking
43
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
44
+ # self.assertEqual(sparse_x.nnz, nnz)
45
+ #
46
+ # # Convert back to dense format.
47
+ # dense_x = stk.ops.to_dense(sparse_x)
48
+ #
49
+ # # Validate the shape.
50
+ # self.assertEqual(dense_x.dim(), 2)
51
+ # self.assertEqual(dense_x.size()[0], rows)
52
+ # self.assertEqual(dense_x.size()[1], cols)
53
+ #
54
+ # # Validate the sparsity
55
+ # self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
56
+ #
57
+ # # Validate the output.
58
+ # self.assertTrue(torch.all(torch.eq(x, dense_x)))
59
 
60
 
61
  if __name__ == '__main__':
build/torch210-cxx11-cu130-aarch64-linux/stk/random/random_ops_test.py CHANGED
@@ -1,72 +1,72 @@
1
  import unittest
2
 
3
- from absl.testing import parameterized
4
  from . import random
5
  import torch
6
 
7
 
8
- @parameterized.parameters(
9
- (8, 16, 0.0, 1),
10
- (8, 16, 0.5, 1),
11
- (8, 16, .95, 1),
12
- (16, 8, 0.0, 1),
13
- (16, 8, 0.5, 1),
14
- (16, 8, .95, 1),
15
- (8, 16, 0.0, 8),
16
- (8, 16, 0.5, 8),
17
- (8, 16, 1.0, 8),
18
- (16, 8, 0.0, 8),
19
- (16, 8, 0.5, 8),
20
- (16, 8, 1.0, 8),
21
- (128, 256, 0.5, 16),
22
- (256, 128, 0.75, 32),
23
- (512, 512, .875, 128))
24
- class RandomOpsTest(parameterized.TestCase):
25
-
26
- def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
- mask = random.dense_mask(
28
- rows, cols, sparsity, blocking)
29
-
30
- # Validate the shape.
31
- self.assertEqual(mask.dim(), 2)
32
- self.assertEqual(mask.size()[0], rows)
33
- self.assertEqual(mask.size()[1], cols)
34
-
35
- # Validate the sparsity
36
- numblocks = rows // blocking * cols // blocking
37
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
- self.assertEqual(
39
- torch.count_nonzero(mask).item(),
40
- nnz)
41
-
42
- # Check values are zero or one.
43
- self.assertTrue(
44
- torch.all(torch.logical_or(
45
- torch.eq(mask, 0),
46
- torch.eq(mask, 1))))
47
-
48
- def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
- mask = random.mask(
50
- rows, cols, sparsity, blocking)
51
-
52
- # Validate the matrix.
53
- mask.validate()
54
-
55
- # Validate the shape.
56
- self.assertEqual(mask.dim(), 2)
57
- self.assertEqual(mask.size()[0], rows)
58
- self.assertEqual(mask.size()[1], cols)
59
-
60
- # Validate the sparsity.
61
- numblocks = rows // blocking * cols // blocking
62
- nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
- self.assertEqual(mask.nnz, nnz)
64
-
65
- # Check values are zero or one.
66
- self.assertTrue(
67
- torch.all(torch.logical_or(
68
- torch.eq(mask.data, 0),
69
- torch.eq(mask.data, 1))))
70
 
71
 
72
  if __name__ == '__main__':
 
1
  import unittest
2
 
3
+ # from absl.testing import parameterized
4
  from . import random
5
  import torch
6
 
7
 
8
+ # @parameterized.parameters(
9
+ # (8, 16, 0.0, 1),
10
+ # (8, 16, 0.5, 1),
11
+ # (8, 16, .95, 1),
12
+ # (16, 8, 0.0, 1),
13
+ # (16, 8, 0.5, 1),
14
+ # (16, 8, .95, 1),
15
+ # (8, 16, 0.0, 8),
16
+ # (8, 16, 0.5, 8),
17
+ # (8, 16, 1.0, 8),
18
+ # (16, 8, 0.0, 8),
19
+ # (16, 8, 0.5, 8),
20
+ # (16, 8, 1.0, 8),
21
+ # (128, 256, 0.5, 16),
22
+ # (256, 128, 0.75, 32),
23
+ # (512, 512, .875, 128))
24
+ # class RandomOpsTest(parameterized.TestCase):
25
+ #
26
+ # def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
27
+ # mask = random.dense_mask(
28
+ # rows, cols, sparsity, blocking)
29
+ #
30
+ # # Validate the shape.
31
+ # self.assertEqual(mask.dim(), 2)
32
+ # self.assertEqual(mask.size()[0], rows)
33
+ # self.assertEqual(mask.size()[1], cols)
34
+ #
35
+ # # Validate the sparsity
36
+ # numblocks = rows // blocking * cols // blocking
37
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
38
+ # self.assertEqual(
39
+ # torch.count_nonzero(mask).item(),
40
+ # nnz)
41
+ #
42
+ # # Check values are zero or one.
43
+ # self.assertTrue(
44
+ # torch.all(torch.logical_or(
45
+ # torch.eq(mask, 0),
46
+ # torch.eq(mask, 1))))
47
+ #
48
+ # def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
49
+ # mask = random.mask(
50
+ # rows, cols, sparsity, blocking)
51
+ #
52
+ # # Validate the matrix.
53
+ # mask.validate()
54
+ #
55
+ # # Validate the shape.
56
+ # self.assertEqual(mask.dim(), 2)
57
+ # self.assertEqual(mask.size()[0], rows)
58
+ # self.assertEqual(mask.size()[1], cols)
59
+ #
60
+ # # Validate the sparsity.
61
+ # numblocks = rows // blocking * cols // blocking
62
+ # nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
63
+ # self.assertEqual(mask.nnz, nnz)
64
+ #
65
+ # # Check values are zero or one.
66
+ # self.assertTrue(
67
+ # torch.all(torch.logical_or(
68
+ # torch.eq(mask.data, 0),
69
+ # torch.eq(mask.data, 1))))
70
 
71
 
72
  if __name__ == '__main__':
build/torch211-cxx11-cu126-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
+
20
+ # This section contains the direct kernel exports (not inlcuded in the original code)
21
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Compute exclusive cumulative sum along the specified dimension.
24
+
25
+ Args:
26
+ x: Input tensor
27
+ dim: Dimension along which to compute cumsum
28
+ out: Output tensor (modified in-place)
29
+
30
+ Returns:
31
+ The output tensor
32
+ """
33
+ result = ops.exclusive_cumsum(x, dim)
34
+ out.copy_(result)
35
+ return out
36
+
37
+
38
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Compute inclusive cumulative sum along the specified dimension.
41
+
42
+ Args:
43
+ x: Input tensor
44
+ dim: Dimension along which to compute cumsum
45
+ out: Output tensor (modified in-place)
46
+
47
+ Returns:
48
+ The output tensor
49
+ """
50
+ result = ops.inclusive_cumsum(x, dim)
51
+ out.copy_(result)
52
+ return out
53
+
54
+
55
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
56
+ """
57
+ Compute histogram of input tensor values.
58
+
59
+ Args:
60
+ x: Input tensor
61
+ num_bins: Number of histogram bins
62
+
63
+ Returns:
64
+ Histogram tensor with counts for each bin
65
+ """
66
+ return ops.histogram(x, num_bins)
67
+
68
+
69
+ def indices(
70
+ padded_bins: torch.Tensor,
71
+ block_size: int,
72
+ output_block_rows: int,
73
+ output_block_columns: int,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Construct indices from padded bins for sparse operations.
77
+
78
+ Args:
79
+ padded_bins: Tensor containing bin boundaries
80
+ block_size: Size of each block
81
+ output_block_rows: Number of rows in output blocks
82
+ output_block_columns: Number of columns in output blocks
83
+
84
+ Returns:
85
+ Tensor containing constructed indices
86
+ """
87
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
88
+
89
+
90
+ def replicate_forward(
91
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ """
94
+ Forward pass of replicate operation - replicate values according to bin sizes.
95
+
96
+ Args:
97
+ x: Input tensor with values to replicate
98
+ bins: Tensor containing bin sizes
99
+ out: Output tensor (modified in-place)
100
+
101
+ Returns:
102
+ The output tensor
103
+ """
104
+ return ops.replicate_forward(x, bins, out)
105
+
106
+
107
+ def replicate_backward(
108
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
109
+ ) -> torch.Tensor:
110
+ """
111
+ Backward pass of replicate operation - reduce gradients back to bins.
112
+
113
+ Args:
114
+ grad: Gradient tensor to reduce
115
+ bins: Tensor containing bin sizes
116
+ out: Output tensor (modified in-place)
117
+
118
+ Returns:
119
+ The output tensor
120
+ """
121
+ return ops.replicate_backward(grad, bins, out)
122
+
123
+
124
+ def sort(
125
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
126
+ ) -> torch.Tensor:
127
+ """
128
+ Radix sort with index tracking.
129
+
130
+ Args:
131
+ x: Input tensor to sort
132
+ end_bit: Number of bits to consider in sorting
133
+ x_out: Output tensor for sorted values
134
+ iota_out: Output tensor for sorted indices
135
+
136
+ Returns:
137
+ The sorted values tensor
138
+ """
139
+ return ops.sort(x, end_bit, x_out, iota_out)
140
+
141
+
142
+ # Convenience functions for common use cases
143
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
144
+ """
145
+ Compute cumulative sum with automatic output allocation.
146
+
147
+ Args:
148
+ x: Input tensor
149
+ dim: Dimension along which to compute cumsum (default: last dimension)
150
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
151
+
152
+ Returns:
153
+ New tensor containing the cumulative sum
154
+ """
155
+ out = torch.empty_like(x)
156
+ if exclusive:
157
+ return exclusive_cumsum(x, dim, out)
158
+ else:
159
+ return inclusive_cumsum(x, dim, out)
160
+
161
+
162
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ Sort tensor and return both sorted values and indices.
165
+
166
+ Args:
167
+ x: Input tensor to sort
168
+ end_bit: Number of bits to consider in sorting
169
+
170
+ Returns:
171
+ Tuple of (sorted_values, sorted_indices)
172
+ """
173
+ x_out = torch.empty_like(x)
174
+ iota_out = torch.empty_like(x)
175
+ sort(x, end_bit, x_out, iota_out)
176
+ return x_out, iota_out
177
+
178
+
179
+ # Export public API
180
+ __all__ = [
181
+ "MyReplacementLayer",
182
+ # Direct kernel exports
183
+ "exclusive_cumsum",
184
+ "inclusive_cumsum",
185
+ "histogram",
186
+ "indices",
187
+ "replicate_forward",
188
+ "replicate_backward",
189
+ "sort",
190
+ "cumsum",
191
+ "argsort",
192
+ # Original exports
193
+ "Arguments",
194
+ "ParallelDroplessMLP",
195
+ "dMoE",
196
+ "SparseGLU",
197
+ "MLP",
198
+ "SparseMLP",
199
+ "MoE",
200
+ "ParallelMLP",
201
+ "get_load_balancing_loss",
202
+ ]
build/torch211-cxx11-cu126-aarch64-linux/_layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
+
7
+ __all__ = [
8
+ 'MoE',
9
+ # 'dMoE',
10
+ ]
build/torch211-cxx11-cu126-aarch64-linux/_layers/activation_fn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Callable, Union
5
+
6
+ import torch
7
+ from ..stk import Matrix
8
+
9
+
10
+ def act_fn(
11
+ x: Matrix,
12
+ function: Callable,
13
+ return_grad_fn: bool = False,
14
+ **kwargs,
15
+ ) -> Union[tuple[Matrix, Any] | Matrix]:
16
+ assert isinstance(x, Matrix)
17
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
+ if return_grad_fn:
19
+ x.data.requires_grad = True
20
+ out = function(x.data, **kwargs)
21
+ y = Matrix(
22
+ x.size(),
23
+ out,
24
+ x.row_indices,
25
+ x.column_indices,
26
+ x.offsets,
27
+ x.column_indices_t,
28
+ x.offsets_t,
29
+ x.block_offsets_t,
30
+ )
31
+ if return_grad_fn:
32
+ return y, out.backward
33
+ return y
build/torch211-cxx11-cu126-aarch64-linux/_layers/all_to_all.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class AllToAllOp(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
+
14
+ ctx.input_shape = x.shape
15
+ ctx.output_split_sizes = output_split_sizes
16
+ ctx.input_split_sizes = input_split_sizes
17
+ ctx.group = group
18
+ handle = dist.all_to_all_single(
19
+ out,
20
+ x,
21
+ output_split_sizes=output_split_sizes,
22
+ input_split_sizes=input_split_sizes,
23
+ group=group,
24
+ async_op=async_op,
25
+ )
26
+ return out, handle
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad, _):
30
+ if ctx.needs_input_grad[0]:
31
+ out = torch.empty(
32
+ ctx.input_shape,
33
+ device=grad.device,
34
+ dtype=grad.dtype,
35
+ )
36
+ dist.all_to_all_single(
37
+ out,
38
+ grad,
39
+ output_split_sizes=ctx.input_split_sizes,
40
+ input_split_sizes=ctx.output_split_sizes,
41
+ group=ctx.group,
42
+ )
43
+ return out, None, None, None, None
44
+ return None, None, None, None, None
45
+
46
+
47
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
+ return AllToAllOp.apply(
49
+ x,
50
+ output_split_sizes,
51
+ input_split_sizes,
52
+ group,
53
+ async_op,
54
+ )
build/torch211-cxx11-cu126-aarch64-linux/_layers/arguments.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import dataclasses
5
+ from functools import partial
6
+ from typing import Any, Callable, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+
12
+ # import megablocks.grouped_gemm_util as grouped_gemm
13
+ from .. import grouped_gemm_util as grouped_gemm
14
+
15
+ # Type annotation for in-place Tensor initialization function.
16
+ InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
17
+
18
+ _ALLOWED_BITWIDTHS = (-1, 4, 8)
19
+
20
+ DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Arguments:
25
+ # Model arguments.
26
+ hidden_size: int = 1024
27
+ ffn_hidden_size: int = 4096
28
+ num_layers: int = 1
29
+ bias: bool = True
30
+ return_bias: bool = True
31
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
32
+
33
+ # MoE arguments.
34
+ moe_num_experts: int = 1
35
+ moe_top_k: int = 1
36
+ moe_capacity_factor: int = 1
37
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
38
+ moe_loss_weight: float = 0.1
39
+ moe_jitter_eps: Optional[float] = None
40
+ moe_lbl_in_fp32: bool = False
41
+
42
+ # Parallelism arguments.
43
+ moe_expert_model_parallelism: bool = False
44
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
45
+ pipeline_model_parallel_size: int = 1
46
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
47
+
48
+ # Compute arguments.
49
+ memory_optimized_mlp: bool = False
50
+ mlp_type: str = 'mlp'
51
+ mlp_impl: str = 'sparse'
52
+
53
+ # Initialization arguments.
54
+ fp16: bool = True
55
+ bf16: bool = False
56
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
57
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
58
+ output_layer_init_method: InitFn = init_method
59
+
60
+ # Benchmarking arguments.
61
+ uniform_expert_assignment: bool = False
62
+
63
+ # shared expert arguments
64
+ shared_expert: bool = False # enable using shared expert
65
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
66
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
67
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
68
+ shared_expert_hidden_size: Optional[
69
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
70
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
71
+
72
+ # Router Z-loss arguments
73
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
74
+ moe_zloss_in_fp32: bool = False
75
+
76
+ def __post_init__(self):
77
+ # Sparse MLP is not supported with triton >=3.2.0
78
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
79
+ if self.__getattribute__('mlp_impl') == 'sparse':
80
+ try:
81
+ import triton
82
+ if triton.__version__ >= '3.2.0':
83
+ raise ValueError(
84
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
85
+ )
86
+ except ImportError:
87
+ raise ImportError('Triton is required for sparse MLP implementation')
88
+
89
+ if self.__getattribute__('mlp_impl') == 'grouped':
90
+ grouped_gemm.assert_grouped_gemm_is_available()
91
+
92
+ if self.shared_expert_hidden_size is None:
93
+ self.shared_expert_hidden_size = self.ffn_hidden_size
94
+
95
+
96
+ def from_megatron(megatron_args: Any):
97
+ args = Arguments()
98
+ for field in dataclasses.fields(args):
99
+ if hasattr(megatron_args, field.name):
100
+ setattr(args, field.name, getattr(megatron_args, field.name))
101
+ return args
build/torch211-cxx11-cu126-aarch64-linux/_layers/common.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from .arguments import Arguments
7
+
8
+
9
+ def dtype(args: Arguments):
10
+ if args.fp16:
11
+ return torch.float16
12
+ elif args.bf16:
13
+ return torch.bfloat16
14
+ return None
15
+
16
+
17
+ def cast_if_autocast_enabled(tensor):
18
+ if torch.is_autocast_enabled():
19
+ if tensor.device.type == 'cuda':
20
+ dtype = torch.get_autocast_gpu_dtype()
21
+ elif tensor.device.type == 'cpu':
22
+ dtype = torch.get_autocast_cpu_dtype()
23
+ else:
24
+ raise NotImplementedError()
25
+ return tensor.to(dtype=dtype)
26
+ return tensor
build/torch211-cxx11-cu126-aarch64-linux/_layers/dmlp_registry.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from . import glu, mlp
7
+ from .arguments import Arguments
8
+
9
+ MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
+
11
+ _REGISTRY = {
12
+ 'mlp': {
13
+ 'grouped': mlp.GroupedMLP,
14
+ 'sparse': mlp.SparseMLP,
15
+ },
16
+ 'glu': {
17
+ 'grouped': glu.GroupedGLU,
18
+ 'sparse': glu.SparseGLU,
19
+ },
20
+ }
21
+
22
+
23
+ def get(args: Arguments) -> MlpType:
24
+ """Returns an MLP for use in a dMoE instance.
25
+
26
+ Uses the provided arguments to instantiate the appropriate
27
+ MLP instance. This only contains MLPs for use in dMoEs
28
+ (ie. only for the dropless versions of MoEs).
29
+
30
+ Args:
31
+ args: propagated Arguments dataclass.
32
+
33
+ Returns:
34
+ An instantiated MLP constructed using the input args.
35
+ """
36
+ if args.mlp_type not in _REGISTRY:
37
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
38
+
39
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
40
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
41
+
42
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
build/torch211-cxx11-cu126-aarch64-linux/_layers/dmoe.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
+
15
+ # import megablocks.ops as ops
16
+ # # from megablocks.ops import ops
17
+ # from megablocks.layers import common, dmlp_registry, moe, mpu
18
+ # from megablocks.layers.arguments import Arguments
19
+
20
+ from .. import stk
21
+ from .. import ops
22
+ from . import common, dmlp_registry, moe, mpu
23
+ from .arguments import Arguments
24
+
25
+ def promote_scalar(x):
26
+ return x.view(1) if not len(x.size()) else x
27
+
28
+
29
+ class ParallelDroplessMLP(moe.ParallelMLP):
30
+
31
+ def __init__(self, args: Arguments):
32
+ super(ParallelDroplessMLP, self).__init__(args)
33
+ self.hidden_size = args.hidden_size
34
+ self.ffn_hidden_size = mpu.features_per_rank(args)
35
+ self.blocking = 128
36
+ self.mlp = dmlp_registry.get(args)
37
+
38
+ # Calculate the number of bits needed to represent the column indices
39
+ # in the intermediate sparse matrix.
40
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
41
+ self.transpose_sort_end_bit = max(
42
+ int(np.ceil(np.log2(max_column_index))),
43
+ 1,
44
+ )
45
+
46
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
47
+ block_columns = size[1] // self.blocking
48
+
49
+ # Sort row indices by column indices to get the transposed matrix's
50
+ # column indices.
51
+ #
52
+ # NOTE: Our sort operation uses the same width indices as the input values.
53
+ # To avoid overflow when we have large activation matrices we cast to
54
+ # 32-bit before sorting.
55
+ _, gather_indices = ops.sort(
56
+ column_indices.int(),
57
+ self.transpose_sort_end_bit,
58
+ )
59
+
60
+ # There are a constant number of blocks in every row of the sparse matrix.
61
+ # A blocks offset is:
62
+ #
63
+ # row_index * blocks_per_row + column_index % blocks_per_row
64
+ #
65
+ # Once we have the block offsets ordered for transposition we can divide
66
+ # by blocks_per_row to get the transposed column indices.
67
+ column_indices_t = row_indices.gather(0, gather_indices.long())
68
+ block_offsets_t = gather_indices.int()
69
+
70
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
71
+ nnz_per_column = ops.histogram(column_indices, block_columns)
72
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
73
+ if nnz_per_column.dim() == 0:
74
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
75
+ nnz_per_column = nnz_per_column.unsqueeze(0)
76
+ offsets_t = torch.cat([zero, nnz_per_column])
77
+ return column_indices_t, offsets_t, block_offsets_t
78
+
79
+ def topology(self, x, padded_bins):
80
+ padded_tokens, _ = x.size()
81
+ assert padded_tokens % self.blocking == 0
82
+ if self.ffn_hidden_size % self.blocking != 0:
83
+ raise ValueError(
84
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
85
+ f'the block size {self.blocking}. Please update your configuration.',
86
+ )
87
+
88
+ # Offsets for the sparse matrix. All rows have the
89
+ # same number of nonzero blocks dictated by the
90
+ # dimensionality of a single expert.
91
+ block_rows = padded_tokens // self.blocking
92
+ blocks_per_row = self.ffn_hidden_size // self.blocking
93
+ offsets = torch.arange(
94
+ 0,
95
+ block_rows * blocks_per_row + 1,
96
+ blocks_per_row,
97
+ dtype=torch.int32,
98
+ device=x.device,
99
+ )
100
+
101
+ # Indices for the sparse matrix. The indices for
102
+ # the intermediate matrix are dynamic depending
103
+ # on the mapping of tokens to experts.
104
+ column_indices = ops.topology(
105
+ padded_bins,
106
+ self.blocking,
107
+ block_rows,
108
+ blocks_per_row,
109
+ )
110
+
111
+ # TODO(tgale): This is unused. Remove the need for this in stk.
112
+ # For now, use meta init to save the device memory.
113
+ data = torch.empty(
114
+ column_indices.numel(),
115
+ self.blocking,
116
+ self.blocking,
117
+ dtype=common.dtype(self.args),
118
+ device='meta',
119
+ )
120
+ shape = (
121
+ padded_tokens,
122
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
123
+ )
124
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
125
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
126
+ shape,
127
+ row_indices,
128
+ column_indices,
129
+ offsets,
130
+ )
131
+ return stk.Matrix(
132
+ shape,
133
+ data,
134
+ row_indices,
135
+ column_indices,
136
+ offsets,
137
+ column_indices_t,
138
+ offsets_t,
139
+ block_offsets_t,
140
+ )
141
+
142
+ def indices_and_padded_bins(self, top_experts):
143
+ # Sort the expert ids to produce the scatter/gather
144
+ # indices for the permutation.
145
+ top_experts = top_experts.int()
146
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
147
+
148
+ # Histogram the expert ids to identify the number of
149
+ # tokens routed to each expert.
150
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
151
+
152
+ # Round the token counts up to the block size used in
153
+ # the matrix muliplications. Caculate the starting
154
+ # position of each bin.
155
+ padded_tokens_per_expert = ops.round_up(
156
+ tokens_per_expert,
157
+ self.blocking,
158
+ )
159
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
160
+ padded_bins = promote_scalar(padded_bins)
161
+
162
+ # Calculate the bin bounds for the sorted tokens.
163
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
164
+ bins = promote_scalar(bins)
165
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
166
+
167
+ def sparse_forward_once(self, x, expert_weights, top_experts):
168
+ # x: [sl, bs, hs]
169
+ # expert_weights: [sl * bs, top-k]
170
+ # top_experts: [sl * bs, top-k]
171
+ expert_weights = expert_weights.flatten()
172
+ top_experts = top_experts.flatten()
173
+ with torch.no_grad():
174
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
175
+
176
+ # Route the tokens for MoE computation.
177
+ x = x.view(-1, x.shape[-1])
178
+ x = ops.padded_gather(
179
+ x,
180
+ indices,
181
+ bin_ids,
182
+ bins,
183
+ padded_bins,
184
+ self.top_k,
185
+ )
186
+
187
+ # Create the sparse matrix topology.
188
+ with torch.no_grad():
189
+ topo = self.topology(x, padded_bins)
190
+
191
+ # Perform the expert computation.
192
+ x = self.mlp(x, topo)
193
+
194
+ # Un-route the data for the MoE output.
195
+ x = ops.padded_scatter(
196
+ x,
197
+ indices,
198
+ bin_ids,
199
+ expert_weights,
200
+ bins,
201
+ padded_bins,
202
+ self.top_k,
203
+ )
204
+ return x, tokens_per_expert
205
+
206
+ # For use in the base-class parallel_forward_once.
207
+ def sparse_permute_and_compute(
208
+ self,
209
+ x,
210
+ tokens_per_expert,
211
+ indices,
212
+ bin_ids,
213
+ expert_weights,
214
+ bins,
215
+ expert_capactiy, # unused
216
+ top_k,
217
+ ):
218
+
219
+ # Round the token counts up to the block size used in the matrix
220
+ # multiplication. Calculate the starting position of each bin.
221
+ padded_tokens_per_expert = ops.round_up(
222
+ tokens_per_expert,
223
+ self.blocking,
224
+ )
225
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
226
+ padded_bins = promote_scalar(padded_bins)
227
+
228
+ # Route the tokens for MoE computation.
229
+ x = x.view(-1, x.shape[-1])
230
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
231
+
232
+ # Create the sparse matrix topology.
233
+ with torch.no_grad():
234
+ topo = self.topology(x, padded_bins)
235
+
236
+ # Perform the expert computation.
237
+ x = self.mlp(x, topo)
238
+
239
+ # Un-route the data for the MoE output.
240
+ return ops.padded_scatter(
241
+ x,
242
+ indices,
243
+ bin_ids,
244
+ expert_weights,
245
+ bins,
246
+ padded_bins,
247
+ top_k,
248
+ )
249
+
250
+ def grouped_forward_once(self, x, expert_weights, top_experts):
251
+ # x: [sl, bs, hs]
252
+ # expert_weights: [sl * bs, top-k]
253
+ # top_experts: [sl * bs, top-k]
254
+ expert_weights = expert_weights.flatten()
255
+ top_experts = top_experts.flatten()
256
+ with torch.no_grad():
257
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
258
+
259
+ out = self.grouped_permute_and_compute(
260
+ x,
261
+ tokens_per_expert,
262
+ indices,
263
+ bin_ids,
264
+ expert_weights,
265
+ bins,
266
+ -1, # unused
267
+ self.args.moe_top_k,
268
+ )
269
+ return out, tokens_per_expert
270
+
271
+ def grouped_permute_and_compute(
272
+ self,
273
+ x,
274
+ tokens_per_expert,
275
+ indices,
276
+ bin_ids,
277
+ expert_weights,
278
+ bins,
279
+ expert_capactiy, # unused
280
+ top_k,
281
+ ):
282
+
283
+ # Route the tokens for MoE computation.
284
+ x = x.view(-1, x.shape[-1])
285
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
286
+
287
+ # Perform the expert computation.
288
+ x = self.mlp(x, tokens_per_expert)
289
+
290
+ # Un-route the data for the MoE output.
291
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
292
+
293
+ def forward_once(self, x, expert_weights, top_experts):
294
+ if self.args.mlp_impl == 'sparse':
295
+ return self.sparse_forward_once(x, expert_weights, top_experts)
296
+ else:
297
+ return self.grouped_forward_once(x, expert_weights, top_experts)
298
+
299
+ def permute_and_compute(
300
+ self,
301
+ x,
302
+ tokens_per_expert,
303
+ indices,
304
+ bin_ids,
305
+ expert_weights,
306
+ bins,
307
+ expert_capactiy,
308
+ top_k,
309
+ ):
310
+ if self.args.mlp_impl == 'sparse':
311
+ return self.sparse_permute_and_compute(
312
+ x,
313
+ tokens_per_expert,
314
+ indices,
315
+ bin_ids,
316
+ expert_weights,
317
+ bins,
318
+ expert_capactiy,
319
+ top_k,
320
+ )
321
+ else:
322
+ return self.grouped_permute_and_compute(
323
+ x,
324
+ tokens_per_expert,
325
+ indices,
326
+ bin_ids,
327
+ expert_weights,
328
+ bins,
329
+ expert_capactiy,
330
+ top_k,
331
+ )
332
+
333
+
334
+ class dMoE(moe.MoE):
335
+
336
+ def _init_experts_mlp(self, args: Arguments):
337
+ return ParallelDroplessMLP(args)
build/torch211-cxx11-cu126-aarch64-linux/_layers/gelu.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # try:
5
+ # import stk
6
+ # except ImportError:
7
+ # import warnings
8
+ # warnings.warn(
9
+ # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
+ # )
11
+
12
+ from .. import stk
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ @torch.jit.script
19
+ def _gelu_backward_inplace(g, x):
20
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
21
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
22
+ return g.mul_(ff)
23
+
24
+
25
+ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
26
+ # NOTE: The two sparse matrices must have the same topology.
27
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
28
+ return stk.Matrix(
29
+ x.size(),
30
+ _gelu_backward_inplace(grad.data, x.data),
31
+ x.row_indices,
32
+ x.column_indices,
33
+ x.offsets,
34
+ x.column_indices_t,
35
+ x.offsets_t,
36
+ x.block_offsets_t,
37
+ )
38
+ return _gelu_backward_inplace(grad, x)
39
+
40
+
41
+ def gelu(x: stk.Matrix):
42
+ assert isinstance(x, stk.Matrix)
43
+ return stk.Matrix(
44
+ x.size(),
45
+ F.gelu(x.data, approximate='tanh'),
46
+ x.row_indices,
47
+ x.column_indices,
48
+ x.offsets,
49
+ x.column_indices_t,
50
+ x.offsets_t,
51
+ x.block_offsets_t,
52
+ )
build/torch211-cxx11-cu126-aarch64-linux/_layers/glu.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # import stk.ops
5
+ # try:
6
+ # import stk.ops
7
+ # except ImportError:
8
+ # import warnings
9
+ # warnings.warn(
10
+ # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
+ # )
12
+
13
+ from .. import stk
14
+
15
+ import torch
16
+
17
+ # from megablocks import grouped_gemm_util as gg
18
+ # from megablocks.layers import common, mpu
19
+ # from megablocks.layers.activation_fn import act_fn
20
+ # from megablocks.layers.arguments import Arguments
21
+ # from megablocks.layers.mlp import (
22
+ # SharedMLP,
23
+ # SparseMLP,
24
+ # create_dmoe_expert_weights,
25
+ # resolve_dtensor,
26
+ # )
27
+
28
+ from .. import grouped_gemm_util as gg
29
+ from . import common, mpu
30
+ from .activation_fn import act_fn
31
+ from .arguments import Arguments
32
+ from .mlp import (
33
+ SharedMLP,
34
+ SparseMLP,
35
+ create_dmoe_expert_weights,
36
+ resolve_dtensor,
37
+ )
38
+
39
+
40
+ class SparseGLU(SparseMLP):
41
+
42
+ def __init__(self, args: Arguments):
43
+ super().__init__(args)
44
+ self.v1 = torch.nn.Parameter(
45
+ torch.empty(
46
+ self._num_rows_per_rank,
47
+ args.hidden_size,
48
+ device=args.device,
49
+ dtype=common.dtype(args),
50
+ ),
51
+ )
52
+ with torch.no_grad():
53
+ self.v1.copy_(
54
+ create_dmoe_expert_weights(
55
+ args,
56
+ args.moe_num_experts,
57
+ args.ffn_hidden_size,
58
+ args.hidden_size,
59
+ args.init_method,
60
+ ),
61
+ )
62
+
63
+ mpu.set_expert_model_parallel_attributes(
64
+ self.v1,
65
+ self._should_set_parallelism_attribute,
66
+ )
67
+
68
+ def forward(self, x, topo):
69
+ if self.args.memory_optimized_mlp:
70
+ raise NotImplementedError(
71
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
72
+ )
73
+
74
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
75
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
76
+
77
+ # Compute the GLU.
78
+ x1 = stk.ops.sdd(x, w1.t(), topo)
79
+ x2 = stk.ops.sdd(x, v1.t(), topo)
80
+
81
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
82
+ x1 = stk.ops.mul(activation_fn_out, x2)
83
+
84
+ return stk.ops.dsd(x1, w2)
85
+
86
+
87
+ class MemoryOptimizedGroupedGLU(torch.autograd.Function):
88
+ """GroupedMLP with manually scheduled memory reuse."""
89
+
90
+ @staticmethod
91
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
92
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
93
+ # Cast inputs using ctx dtype from AMP
94
+ if ctx._fwd_used_autocast:
95
+ x = x.to(ctx._dtype)
96
+ w1 = w1.to(ctx._dtype)
97
+ v1 = v1.to(ctx._dtype)
98
+ w2 = w2.to(ctx._dtype)
99
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
100
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
101
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
102
+
103
+ # Layer 0: x @ w1.t().
104
+ assert gg.backend is not None
105
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
106
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
107
+
108
+ # GeLU.
109
+ activation_fn_out = activation_fn(sdd_out) * v1_out
110
+
111
+ # Layer 1: x @ w2.
112
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
113
+
114
+ # NOTE: Save the input to the layer and the activation_fn input for
115
+ # gradient computation. We'll re-compute the activation_fn forward
116
+ # pass in the backward pass to avoid materializing another
117
+ # intermediate.
118
+ ctx.x_shape = x.shape
119
+ ctx.sdd_out_shape = sdd_out.shape
120
+ ctx.dtype = x.dtype
121
+ ctx.activation_fn = activation_fn
122
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
123
+ return dsd_out
124
+
125
+ @staticmethod
126
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
127
+ def backward(ctx, ddsd_out):
128
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
129
+ raise ValueError('Expected all MLP inputs to need grad.')
130
+
131
+ # Unpack saved tensors
132
+ # dtype = ctx.dtype
133
+ saved_tensors = ctx.saved_tensors
134
+ w1, v1, w2 = saved_tensors[:3]
135
+ batch_sizes = saved_tensors[3]
136
+ x = saved_tensors[4]
137
+ sdd_out, v1_out = saved_tensors[5:7]
138
+
139
+ # Rematerialize activation_fn output.
140
+ activation_fn = ctx.activation_fn
141
+ with torch.set_grad_enabled(True):
142
+ sdd_out.requires_grad = True
143
+ v1_out.requires_grad = True
144
+ activation_fn_out = activation_fn(sdd_out) * v1_out
145
+ activation_grad_fn = activation_fn_out.backward
146
+
147
+ # Compute dw2 with recomputed activation_fn output.
148
+ assert gg.backend is not None
149
+ dw2 = gg.backend.gmm(
150
+ activation_fn_out,
151
+ ddsd_out,
152
+ batch_sizes,
153
+ trans_a=True,
154
+ )
155
+
156
+ # Compute dactivation_fn_out.
157
+ #
158
+ # NOTE: We reuse the activation_fn_out allocation.
159
+ dactivation_fn_out = activation_fn_out
160
+ gg.backend.gmm(
161
+ ddsd_out,
162
+ w2,
163
+ batch_sizes,
164
+ trans_b=True,
165
+ c=dactivation_fn_out,
166
+ )
167
+
168
+ # Compute dsdd_out.
169
+ #
170
+ # NOTE: This reuses the dactivation_fn_out allocation.
171
+ assert activation_grad_fn is not None
172
+ activation_grad_fn(dactivation_fn_out)
173
+ dsdd_out = sdd_out.grad
174
+ dv1_out = v1_out.grad
175
+
176
+ # Compute dw1.
177
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
178
+
179
+ # Compute dv1.
180
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
181
+
182
+ # Compute dx.
183
+ #
184
+ # NOTE: This reuses the ddsd_out allocation.
185
+ dx = ddsd_out
186
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
187
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
188
+ return dx, dw1, dv1, dw2, None, None
189
+
190
+
191
+ memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
192
+
193
+
194
+ class GroupedGLU(SparseGLU):
195
+
196
+ def forward(self, x, tokens_per_expert):
197
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
198
+ w1, v1, w2 = (
199
+ self.scale_grad(self.w1),
200
+ self.scale_grad(self.v1),
201
+ self.scale_grad(self.w2),
202
+ )
203
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
204
+
205
+ # Re-shape the weights for the grouped GEMMs.
206
+ ne = mpu.experts_per_rank(self.args)
207
+ w1 = w1.view(ne, -1, self.args.hidden_size)
208
+ v1 = v1.view(ne, -1, self.args.hidden_size)
209
+ w2 = w2.view(ne, -1, self.args.hidden_size)
210
+
211
+ if self.args.memory_optimized_mlp:
212
+ return memory_optimized_grouped_glu(
213
+ x,
214
+ w1,
215
+ v1,
216
+ w2,
217
+ batch_sizes,
218
+ self.args.activation_fn,
219
+ )
220
+
221
+ # Compute the MLP.
222
+ assert gg.ops is not None
223
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
224
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
225
+ x1 = self.args.activation_fn(x1) * x2
226
+ return gg.ops.gmm(x1, w2, batch_sizes)
227
+
228
+
229
+ class SharedGLU(SharedMLP):
230
+ """GPU for shared expert.
231
+
232
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
233
+ """
234
+
235
+ def __init__(self, args: Arguments):
236
+ super().__init__(args)
237
+ self.gate_proj = args.fc_cls(
238
+ args.hidden_size,
239
+ self.args.shared_expert_hidden_size,
240
+ **self.fc_kwargs,
241
+ )
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
build/torch211-cxx11-cu126-aarch64-linux/_layers/memory_test.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # from megablocks.layers import arguments, dmoe
10
+ from . import arguments, dmoe
11
+
12
+ _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
13
+
14
+
15
+ def get_tensors():
16
+ ptrs = set()
17
+ out = []
18
+ for obj in gc.get_objects():
19
+ if torch.is_tensor(obj):
20
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
21
+ continue
22
+ out.append(obj)
23
+ ptrs.add(obj.data_ptr())
24
+ return out
25
+
26
+
27
+ def test_memory(
28
+ group,
29
+ batch_size,
30
+ sequence_length,
31
+ hidden_size,
32
+ ffn_hidden_size,
33
+ num_experts,
34
+ top_k,
35
+ ):
36
+ args = arguments.Arguments(
37
+ hidden_size=hidden_size,
38
+ ffn_hidden_size=ffn_hidden_size,
39
+ moe_num_experts=num_experts,
40
+ moe_top_k=top_k,
41
+ moe_expert_model_parallelism=True,
42
+ expert_parallel_group=group,
43
+ fp16=False,
44
+ bf16=True,
45
+ device=torch.cuda.current_device(),
46
+ )
47
+ layer = dmoe.dMoE(args).cuda()
48
+
49
+ x = torch.randn((batch_size, sequence_length, hidden_size),
50
+ device=torch.cuda.current_device(),
51
+ dtype=torch.bfloat16).requires_grad_(True)
52
+ torch.cuda.empty_cache()
53
+
54
+ # Run forward + backward.
55
+ # with torch.autograd.detect_anomaly():
56
+ out, _ = layer(x)
57
+ out.mean().backward()
58
+
59
+ # Report peak memory.
60
+ mem = torch.cuda.max_memory_allocated()
61
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
62
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
63
+
64
+ # Calculate weight and gradient memory usage.
65
+ weight_memory = 2 * (
66
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
67
+ )
68
+
69
+ def grad_numel(x):
70
+ if x.grad is not None:
71
+ return x.grad.numel()
72
+ return 0
73
+
74
+ grad_memory = 2 * (
75
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
76
+ )
77
+ weight_memory += grad_memory
78
+
79
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
80
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
81
+
82
+ # Manually calculate GPU memory usage from the garbage
83
+ # collector.
84
+ gc.collect()
85
+ total = 0
86
+ tensors = get_tensors()
87
+ tensors = sorted(tensors, key=lambda x: -x.numel())
88
+ for i, t in enumerate(tensors):
89
+ total += t.numel()
90
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
91
+ del tensors
92
+
93
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
94
+
95
+
96
+ if __name__ == '__main__':
97
+ assert dist.is_available()
98
+ group = dist.init_process_group(backend='nccl')
99
+ local_rank = dist.get_rank(group)
100
+ torch.cuda.set_device(local_rank)
101
+
102
+ for args in _TESTS:
103
+ test_memory(group, *args)
build/torch211-cxx11-cu126-aarch64-linux/_layers/mlp.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # try:
7
+ # import stk
8
+ # import stk.backend.triton_kernels
9
+ # import stk.ops
10
+ # except ImportError:
11
+ # import warnings
12
+ # warnings.warn(
13
+ # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
+ # )
15
+
16
+ from .. import stk
17
+
18
+ import torch
19
+ from packaging import version
20
+
21
+ # from megablocks import grouped_gemm_util as gg
22
+ # from megablocks.layers import common, gelu, mpu
23
+ # from megablocks.layers.activation_fn import act_fn
24
+ # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
25
+
26
+ from .. import grouped_gemm_util as gg
27
+ from . import common, gelu, mpu
28
+ from .activation_fn import act_fn
29
+ from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
30
+
31
+ class ScaleGradient(torch.autograd.Function):
32
+
33
+ @staticmethod
34
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
35
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
36
+ ctx.scale = scale
37
+ return x
38
+
39
+ @staticmethod
40
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
41
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
42
+ return grad * ctx.scale, None
43
+
44
+
45
+ scale_gradient = ScaleGradient.apply
46
+
47
+
48
+ def resolve_dtensor(weight: torch.Tensor):
49
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
50
+ from torch.distributed._tensor import DTensor
51
+ if isinstance(weight, DTensor):
52
+ return weight.to_local()
53
+ return weight
54
+
55
+
56
+ def create_moe_expert_weights(
57
+ args: Arguments,
58
+ num_experts: int,
59
+ ffn_hidden_size: int,
60
+ hidden_size: int,
61
+ init_method: InitFn,
62
+ ):
63
+ # Create the entire weight matrix such that the sampled weights will
64
+ # not vary between data parallelism and expert model parallelism for
65
+ # the same random seed.
66
+ master_weights = torch.empty(
67
+ num_experts,
68
+ ffn_hidden_size,
69
+ hidden_size,
70
+ device=args.device,
71
+ dtype=common.dtype(args),
72
+ )
73
+ init_method(master_weights)
74
+
75
+ if not args.moe_expert_model_parallelism:
76
+ return master_weights
77
+
78
+ # Calculate the amount of sharding in each dimension.
79
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
80
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
81
+
82
+ # Calculate the experts per rank.
83
+ #
84
+ # NOTE: We assign ranks to be expert parallel before going
85
+ # tensor parallel.
86
+ rank = mpu.get_expert_parallel_rank(args)
87
+ expert_rank = rank % expert_sharding_degree
88
+ num_experts_per_rank = num_experts // expert_sharding_degree
89
+ start_expert = expert_rank * num_experts_per_rank
90
+ end_expert = (expert_rank + 1) * num_experts_per_rank
91
+
92
+ # Calculate the rows per rank.
93
+ row_rank = rank // expert_sharding_degree
94
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
95
+ start_row = row_rank * num_rows_per_rank
96
+ end_row = (row_rank + 1) * num_rows_per_rank
97
+
98
+ # Slice the weight matrix to get the chunk for this rank.
99
+ with torch.no_grad():
100
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
101
+ return weights
102
+
103
+
104
+ class MLP(torch.nn.Module):
105
+
106
+ def __init__(self, args: Arguments):
107
+ super().__init__()
108
+ self.args = args
109
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
110
+ experts_per_rank = mpu.experts_per_rank(args)
111
+
112
+ self.w1 = torch.nn.Parameter(
113
+ torch.empty(
114
+ experts_per_rank,
115
+ args.hidden_size,
116
+ mpu.features_per_rank(args),
117
+ device=args.device,
118
+ dtype=common.dtype(args),
119
+ ),
120
+ )
121
+ self.w2 = torch.nn.Parameter(
122
+ torch.empty(
123
+ experts_per_rank,
124
+ mpu.features_per_rank(args),
125
+ args.hidden_size,
126
+ device=args.device,
127
+ dtype=common.dtype(args),
128
+ ),
129
+ )
130
+ mpu.set_expert_model_parallel_attributes(
131
+ self.w1,
132
+ args.moe_expert_model_parallelism,
133
+ )
134
+ mpu.set_expert_model_parallel_attributes(
135
+ self.w2,
136
+ args.moe_expert_model_parallelism,
137
+ )
138
+
139
+ # Initialize the parameters for the MLP.
140
+ #
141
+ # NOTE: It is important that we create the weight tensors prior
142
+ # to creating the master weights and slicing our the piece for
143
+ # this rank. If the master weights are created first the PyTorch
144
+ # caching allocator appears to use the same memory block for these
145
+ # and the slice which causes large increases in our peak memory
146
+ # usage.
147
+ with torch.no_grad():
148
+ w1 = create_moe_expert_weights(
149
+ args,
150
+ args.moe_num_experts,
151
+ args.ffn_hidden_size,
152
+ args.hidden_size,
153
+ args.init_method,
154
+ )
155
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
156
+ self.w2.copy_(
157
+ create_moe_expert_weights(
158
+ args,
159
+ args.moe_num_experts,
160
+ args.ffn_hidden_size,
161
+ args.hidden_size,
162
+ args.output_layer_init_method,
163
+ ),
164
+ )
165
+
166
+ self.gradient_scale = None
167
+ if self.args.moe_expert_model_parallelism:
168
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
169
+
170
+ def scale_grad(self, w):
171
+ if self.gradient_scale is None:
172
+ return w
173
+ return scale_gradient(w, self.gradient_scale)
174
+
175
+ def forward(self, x):
176
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
177
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
178
+ x = torch.bmm(x, w1)
179
+ x = self.args.activation_fn(x)
180
+ return torch.bmm(x, w2)
181
+
182
+
183
+ def create_dmoe_expert_weights(
184
+ args: Arguments,
185
+ num_experts: int,
186
+ rows: int,
187
+ columns: int,
188
+ init_method: InitFn,
189
+ ):
190
+ weights = create_moe_expert_weights(
191
+ args,
192
+ num_experts,
193
+ rows,
194
+ columns,
195
+ init_method,
196
+ )
197
+ return weights.view([-1, columns])
198
+
199
+
200
+ class MemoryOptimizedMLP(torch.autograd.Function):
201
+ """Sparse MLP with manually scheduled memory reuse."""
202
+
203
+ @staticmethod
204
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
205
+ def forward(ctx, x, w1, w2, topo, activation_fn):
206
+ # Cast inputs using ctx dtype from AMP
207
+ if ctx._fwd_used_autocast:
208
+ x = x.to(ctx._dtype)
209
+ w1 = w1.to(ctx._dtype)
210
+ w2 = w2.to(ctx._dtype)
211
+ # x: [m, k], w1: [n, k], w2: [n, k]
212
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
213
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
214
+
215
+ topo_tensors = (
216
+ topo.row_indices,
217
+ topo.column_indices,
218
+ topo.offsets,
219
+ topo.column_indices_t,
220
+ topo.offsets_t,
221
+ topo.block_offsets_t,
222
+ )
223
+
224
+ # Layer 0: x @ w1.t().
225
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
226
+
227
+ # GeLU.
228
+ activation_fn_out = act_fn(sdd_out, activation_fn)
229
+
230
+ # Layer 1: x @ w2.
231
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
232
+
233
+ # NOTE: Save the input to the layer and the activation_fn input for
234
+ # gradient computation. We'll re-compute the activation_fn forward
235
+ # pass in the backward pass to avoid materializing another
236
+ # intermediate.
237
+ ctx.shape = topo.shape
238
+ ctx.x_shape = x.shape
239
+ ctx.sdd_out_shape = sdd_out.data.shape
240
+ ctx.dtype = x.dtype
241
+ ctx.activation_fn = activation_fn
242
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
243
+ return dsd_out
244
+
245
+ @staticmethod
246
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
247
+ def backward(ctx, ddsd_out):
248
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
249
+ raise ValueError('Expected all MLP inputs to need grad.')
250
+
251
+ # unpack saved tensors
252
+ # dtype = ctx.dtype
253
+ saved_tensors = ctx.saved_tensors
254
+ w1, w2 = saved_tensors[:2]
255
+ topo_tensors = saved_tensors[2:8]
256
+ x = saved_tensors[8]
257
+ sdd_out_data = saved_tensors[9]
258
+
259
+ # rematerialize activation function output
260
+ activation_fn = ctx.activation_fn
261
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
262
+ activation_fn_out, activation_grad_fn = act_fn(
263
+ sdd_out,
264
+ activation_fn,
265
+ return_grad_fn=True,
266
+ )
267
+
268
+ # Compute dw2 with recomputed activation_fn output.
269
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
270
+
271
+ # Compute dactivation_fn_out.
272
+ #
273
+ # NOTE: We reuse the activation_fn_out allocation.
274
+ dactivation_fn_out = activation_fn_out
275
+ stk.backend.triton_kernels.sdd(
276
+ ddsd_out,
277
+ w2.t(),
278
+ dactivation_fn_out.shape,
279
+ dactivation_fn_out.data,
280
+ dactivation_fn_out.offsets,
281
+ dactivation_fn_out.row_indices,
282
+ dactivation_fn_out.column_indices,
283
+ )
284
+
285
+ # Compute dsdd_out.
286
+ #
287
+ # NOTE: This reuses the dactivation_fn_out allocation.
288
+ if activation_fn is DEFAULT_ACTIVATION_FN:
289
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
290
+ else:
291
+ assert activation_grad_fn is not None
292
+ activation_grad_fn(dactivation_fn_out.data)
293
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
294
+
295
+ # Compute dw1.
296
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
297
+
298
+ # Compute dx.
299
+ #
300
+ # NOTE: This reuses the ddsd_out allocation.
301
+ stk.backend.triton_kernels.dsd(
302
+ dsdd_out.shape,
303
+ dsdd_out.data,
304
+ dsdd_out.offsets,
305
+ dsdd_out.row_indices,
306
+ dsdd_out.column_indices,
307
+ dsdd_out.offsets_t,
308
+ dsdd_out.column_indices_t,
309
+ dsdd_out.block_offsets_t,
310
+ False,
311
+ w1,
312
+ ddsd_out,
313
+ )
314
+ dx = ddsd_out
315
+ return dx, dw1, dw2, None, None
316
+
317
+
318
+ memory_optimized_mlp = MemoryOptimizedMLP.apply
319
+
320
+
321
+ class SparseMLP(torch.nn.Module):
322
+
323
+ def __init__(self, args: Arguments):
324
+ super().__init__()
325
+ self.args = args
326
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
327
+
328
+ self.w1 = torch.nn.Parameter(
329
+ torch.empty(
330
+ self._num_rows_per_rank,
331
+ args.hidden_size,
332
+ device=args.device,
333
+ dtype=common.dtype(args),
334
+ ),
335
+ )
336
+ self.w2 = torch.nn.Parameter(
337
+ torch.empty(
338
+ self._num_rows_per_rank,
339
+ args.hidden_size,
340
+ device=args.device,
341
+ dtype=common.dtype(args),
342
+ ),
343
+ )
344
+
345
+ # Initialize the parameters for the MLP.
346
+ #
347
+ # NOTE: It is important that we create the weight tensors prior
348
+ # to creating the master weights and slicing our the piece for
349
+ # this rank. If the master weights are created first the PyTorch
350
+ # caching allocator appears to use the same memory block for these
351
+ # and the slice which causes large increases in our peak memory
352
+ # usage.
353
+ with torch.no_grad():
354
+ self.w1.copy_(
355
+ create_dmoe_expert_weights(
356
+ args,
357
+ args.moe_num_experts,
358
+ args.ffn_hidden_size,
359
+ args.hidden_size,
360
+ args.init_method,
361
+ ),
362
+ )
363
+ self.w2.copy_(
364
+ create_dmoe_expert_weights(
365
+ args,
366
+ args.moe_num_experts,
367
+ args.ffn_hidden_size,
368
+ args.hidden_size,
369
+ args.output_layer_init_method,
370
+ ),
371
+ )
372
+
373
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
374
+ mpu.set_expert_model_parallel_attributes(
375
+ self.w1,
376
+ self._should_set_parallelism_attribute,
377
+ )
378
+ mpu.set_expert_model_parallel_attributes(
379
+ self.w2,
380
+ self._should_set_parallelism_attribute,
381
+ )
382
+
383
+ self.gradient_scale = None
384
+ if self.args.moe_expert_model_parallelism:
385
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
386
+
387
+ def scale_grad(self, w):
388
+ if self.gradient_scale is None:
389
+ return w
390
+ return scale_gradient(w, self.gradient_scale)
391
+
392
+ def forward(self, x, topo):
393
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
394
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
395
+ if self.args.memory_optimized_mlp:
396
+ return memory_optimized_mlp(
397
+ x,
398
+ w1,
399
+ w2,
400
+ topo,
401
+ self.args.activation_fn,
402
+ )
403
+
404
+ # Compute the MLP.
405
+ x = stk.ops.sdd(x, w1.t(), topo)
406
+ activation_fn_out = act_fn(x, self.args.activation_fn)
407
+ return stk.ops.dsd(activation_fn_out, w2)
408
+
409
+
410
+ class MemoryOptimizedGroupedMLP(torch.autograd.Function):
411
+ """GroupedMLP with manually scheduled memory reuse."""
412
+
413
+ @staticmethod
414
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
415
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
416
+ # Cast inputs using ctx dtype from AMP
417
+ if ctx._fwd_used_autocast:
418
+ x = x.to(ctx._dtype)
419
+ w1 = w1.to(ctx._dtype)
420
+ w2 = w2.to(ctx._dtype)
421
+ # x: [m, k], w1: [n, k], w2: [n, k]
422
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
423
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
424
+
425
+ # Layer 0: x @ w1.t().
426
+ assert gg.backend is not None
427
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
428
+
429
+ # activation_fn
430
+ activation_fn_out = activation_fn(sdd_out)
431
+
432
+ # Layer 1: x @ w2.
433
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
434
+
435
+ # NOTE: Save the input to the layer and the activation_fn input for
436
+ # gradient computation. We'll re-compute the activation_fn forward
437
+ # pass in the backward pass to avoid materializing another
438
+ # intermediate.
439
+ ctx.x_shape = x.shape
440
+ ctx.sdd_out_shape = sdd_out.shape
441
+ ctx.dtype = x.dtype
442
+ ctx.activation_fn = activation_fn
443
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
444
+ return dsd_out
445
+
446
+ @staticmethod
447
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
448
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
449
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
450
+ raise ValueError('Expected all MLP inputs to need grad.')
451
+
452
+ # Unpack saved tensors
453
+ # dtype = ctx.dtype
454
+ saved_tensors = ctx.saved_tensors
455
+ w1, w2 = saved_tensors[:2]
456
+ batch_sizes = saved_tensors[2]
457
+ x = saved_tensors[3]
458
+ sdd_out = saved_tensors[4]
459
+
460
+ # Rematerialize activation_fn output.
461
+ activation_fn = ctx.activation_fn
462
+ with torch.set_grad_enabled(True):
463
+ sdd_out.requires_grad = True
464
+ activation_fn_out = activation_fn(sdd_out)
465
+ activation_grad_fn = activation_fn_out.backward
466
+
467
+ # Compute dw2 with recomputed activation_fn output.
468
+ assert gg.backend is not None
469
+ dw2 = gg.backend.gmm(
470
+ activation_fn_out,
471
+ ddsd_out,
472
+ batch_sizes,
473
+ trans_a=True,
474
+ )
475
+
476
+ # Compute dactivation_fn_out.
477
+ #
478
+ # NOTE: We reuse the activation_fn_out allocation.
479
+ dactivation_fn_out = activation_fn_out
480
+ gg.backend.gmm(
481
+ ddsd_out,
482
+ w2,
483
+ batch_sizes,
484
+ trans_b=True,
485
+ c=dactivation_fn_out,
486
+ )
487
+
488
+ # Compute dsdd_out.
489
+ #
490
+ # NOTE: This reuses the dactivation_fn_out allocation.
491
+ if activation_fn is DEFAULT_ACTIVATION_FN:
492
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
493
+ else:
494
+ assert activation_grad_fn is not None
495
+ activation_grad_fn(dactivation_fn_out)
496
+ dsdd_out = sdd_out.grad
497
+
498
+ # Compute dw1.
499
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
500
+
501
+ # Compute dx.
502
+ #
503
+ # NOTE: This reuses the ddsd_out allocation.
504
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
505
+ dx = ddsd_out
506
+ return dx, dw1, dw2, None, None
507
+
508
+
509
+ memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
510
+
511
+
512
+ class GroupedMLP(SparseMLP):
513
+
514
+ def forward(self, x, tokens_per_expert):
515
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
516
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
517
+
518
+ # Re-shape the weights for the grouped GEMMs.
519
+ ne = mpu.experts_per_rank(self.args)
520
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
521
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
522
+
523
+ if self.args.memory_optimized_mlp:
524
+ return memory_optimized_grouped_mlp(
525
+ x,
526
+ w1,
527
+ w2,
528
+ batch_sizes,
529
+ self.args.activation_fn,
530
+ )
531
+
532
+ # Compute the MLP.
533
+ assert gg.ops is not None
534
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
535
+ x = self.args.activation_fn(x)
536
+ return gg.ops.gmm(x, w2, batch_sizes)
537
+
538
+
539
+ class SharedMLP(torch.nn.Module):
540
+ """MLP for shared expert.
541
+
542
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
543
+ """
544
+
545
+ def __init__(self, args: Arguments):
546
+ super().__init__()
547
+ self.args = args
548
+ self.fc_kwargs: dict[str, Any] = {
549
+ 'bias': args.bias,
550
+ 'device': args.device,
551
+ }
552
+ self.fc_kwargs.update(args.fc_kwargs)
553
+
554
+ self.up_proj = args.fc_cls(
555
+ args.hidden_size,
556
+ args.shared_expert_hidden_size,
557
+ **self.fc_kwargs,
558
+ )
559
+ self.act = args.activation_fn
560
+ self.down_proj = args.fc_cls(
561
+ args.shared_expert_hidden_size,
562
+ args.hidden_size,
563
+ **self.fc_kwargs,
564
+ )
565
+ self.down_proj._is_residual = True # a flag for llm-foundry init
566
+
567
+ def add_experts_sharedexpert(
568
+ self,
569
+ shared_expert_out: torch.Tensor,
570
+ expert_out: torch.Tensor,
571
+ ) -> torch.Tensor:
572
+ # Helper function to add expert output to shared expert output
573
+ # with optional weighted sum.
574
+ if self.args.shared_expert_weighted_sum:
575
+ # enable using weighted sum for shared expert output
576
+ # wieghted by number of experts used
577
+ t_experts = self.args.moe_top_k + 1
578
+ sh_mlp_out = shared_expert_out / t_experts
579
+ return sh_mlp_out.add(
580
+ expert_out,
581
+ alpha=(self.args.moe_top_k / t_experts),
582
+ )
583
+
584
+ return shared_expert_out + expert_out
585
+
586
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
587
+ return self.down_proj(self.act(self.up_proj(x)))
build/torch211-cxx11-cu126-aarch64-linux/_layers/moe.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # import megablocks.ops as ops
10
+ # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ # from megablocks.layers.all_to_all import all_to_all
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from ..ops import (
15
+ sort,
16
+ histogram,
17
+ inclusive_cumsum,
18
+ exclusive_cumsum,
19
+ binned_gather,
20
+ binned_scatter,
21
+ gather,
22
+ scatter,
23
+ repeat,
24
+ replicate,
25
+ )
26
+
27
+ from . import common, mlp, mpu, router, sharedexpert_registry
28
+ from .arguments import Arguments
29
+ from .all_to_all import all_to_all
30
+
31
+ _LOAD_BALANCING_LOSS = []
32
+
33
+
34
+ def save_load_balancing_loss(loss):
35
+ global _LOAD_BALANCING_LOSS
36
+ _LOAD_BALANCING_LOSS.append(loss)
37
+
38
+
39
+ def get_load_balancing_loss():
40
+ global _LOAD_BALANCING_LOSS
41
+ return _LOAD_BALANCING_LOSS
42
+
43
+
44
+ def clear_load_balancing_loss():
45
+ global _LOAD_BALANCING_LOSS
46
+ _LOAD_BALANCING_LOSS.clear()
47
+
48
+
49
+ def batched_load_balancing_loss(args: Arguments):
50
+ if args.moe_loss_weight == 0:
51
+ return 0.0
52
+
53
+ # tokens_per_expert[i].shape = (num_experts)
54
+ # expert_scores[i].shape = (tokens, num_experts)
55
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
56
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
57
+ if args.num_layers_per_virtual_pipeline_stage is not None:
58
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
59
+
60
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
61
+ raise ValueError(
62
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
63
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
64
+ f'{args.num_layers}\npipeline_model_parallel_size = '
65
+ f'{args.pipeline_model_parallel_size}\n'
66
+ 'num_layers_per_virtual_pipeline_stage'
67
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
68
+ )
69
+ if len(expert_scores) != num_layers_per_pipeline_stage:
70
+ raise ValueError(
71
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
72
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
73
+ f'{args.num_layers}\npipeline_model_parallel_size = '
74
+ f'{args.pipeline_model_parallel_size}\n'
75
+ 'num_layers_per_virtual_pipeline_stage'
76
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
77
+ )
78
+
79
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
80
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
81
+
82
+ tokens = expert_scores[0].shape[0]
83
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
84
+
85
+ # Concatenate the contributions of each layer and convert to
86
+ # the correct types and formats for the dot product.
87
+ expert_scores = torch.cat(expert_scores, dim=1)
88
+ if args.moe_lbl_in_fp32:
89
+ expert_scores = expert_scores.float()
90
+ if tokens != 0:
91
+ expert_scores = expert_scores.mean(dim=0)
92
+ else:
93
+ expert_scores = expert_scores.sum(dim=0)
94
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
95
+
96
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
97
+ assert tokens_per_expert.numel() == expected_values
98
+ assert expert_scores.numel() == expected_values
99
+
100
+ # Calculate the total scale across all factors.
101
+ #
102
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
103
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
104
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
105
+ scale = scale_numerator / scale_denominator
106
+ return scale * torch.dot(tokens_per_expert, expert_scores)
107
+
108
+
109
+ # NOTE: This class defines MoE expert computation, including expert model parallel
110
+ # communication. When using FSDP on top of MegaBlocks this is the module that should
111
+ # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
112
+ # parallel all2all.
113
+ class ParallelMLP(torch.nn.Module):
114
+
115
+ def __init__(self, args: Arguments):
116
+ super(ParallelMLP, self).__init__()
117
+ self.args = args
118
+
119
+ # Calculate the number of experts in total and the number of experts
120
+ # owned by this rank.
121
+ # world_size = mpu.get_expert_parallel_world_size(args)
122
+ self.num_experts = args.moe_num_experts
123
+ self.top_k = self.args.moe_top_k
124
+
125
+ # Calculate the number of bits needed to represent the expert indices
126
+ # so that we can pass it to radix sort.
127
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
128
+
129
+ # Expert MLP.
130
+ self.mlp = mlp.MLP(args)
131
+
132
+ self.bias: Optional[torch.Tensor]
133
+ if self.args.bias:
134
+ # Note that the output bias is not parallelized with expert
135
+ # model parallelism.
136
+ self.bias = torch.nn.Parameter(
137
+ torch.empty(
138
+ args.hidden_size,
139
+ device=args.device,
140
+ dtype=common.dtype(args),
141
+ ),
142
+ )
143
+ torch.nn.init.zeros_(self.bias)
144
+ else:
145
+ self.register_parameter('bias', None)
146
+
147
+ # Select the forward function for the operating mode.
148
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
149
+
150
+ def expert_capacity(self, tokens: int) -> int:
151
+ world_size = mpu.get_expert_parallel_world_size(self.args)
152
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
153
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
154
+
155
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
156
+ """Calculate the load balancing loss contribution."""
157
+ assert len(expert_scores.size()) == 2
158
+ tokens, num_experts = expert_scores.size()
159
+ assert num_experts == self.num_experts
160
+ assert len(tokens_per_expert.size()) == 1
161
+ num_experts, = tokens_per_expert.size()
162
+ assert num_experts == self.num_experts
163
+ scale = self.num_experts / (tokens * self.top_k)
164
+ return scale * torch.dot(
165
+ tokens_per_expert.to(expert_scores.dtype),
166
+ expert_scores.mean(dim=0),
167
+ )
168
+
169
+ def indices_and_bins(self,
170
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ # Sort the expert ids to produce the scatter/gather
172
+ # indices for the permutation.
173
+ #
174
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
175
+ # prior? Could we place the `torch.max` operation to return
176
+ # 32-bit expert indices?
177
+ top_expert = top_expert.int()
178
+ # output = ops.sort(top_expert, self.sort_end_bit)
179
+ output = sort(top_expert, self.sort_end_bit)
180
+ assert output is not None
181
+ bin_ids, indices = output
182
+
183
+ # Histogram the expert ids to identify the number of
184
+ # tokens routed to each expert.
185
+ #
186
+ # TODO(tgale): Does the sorted data produce a more favorable
187
+ # data distribution for histogram? Or is the op parallelism
188
+ # worth more?
189
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
190
+ tokens_per_expert = histogram(top_expert, self.num_experts)
191
+
192
+ # Calculate the bin bounds for the sorted tokens.
193
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
194
+ bins = inclusive_cumsum(tokens_per_expert, 0)
195
+ assert bins is not None
196
+ bins = bins.view(1) if not len(bins.size()) else bins
197
+
198
+ assert isinstance(indices, torch.Tensor)
199
+ assert isinstance(bin_ids, torch.Tensor)
200
+ assert isinstance(bins, torch.Tensor)
201
+ assert isinstance(tokens_per_expert, torch.Tensor)
202
+
203
+ return indices, bin_ids, bins, tokens_per_expert
204
+
205
+ def permute_and_compute(
206
+ self,
207
+ x: torch.Tensor,
208
+ tokens_per_expert: int, # unused
209
+ indices: torch.Tensor,
210
+ bin_ids: torch.Tensor, # unused
211
+ expert_weights: torch.Tensor,
212
+ bins: torch.Tensor,
213
+ expert_capacity: int,
214
+ top_k: int,
215
+ ):
216
+ # Route the tokens for MoE computation.
217
+ x = x.view(-1, x.shape[-1])
218
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
219
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
220
+ assert output is not None
221
+ x = output
222
+
223
+ # Perform the expert computation. Note that we don't
224
+ # use biases for these linear operations.
225
+ x = self.mlp(x)
226
+
227
+ # Un-route the data for the MoE output.
228
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
229
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
230
+
231
+
232
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
233
+ # x: [sl, bs, hs]
234
+ # expert_weights: [sl * bs, top-k]
235
+ # top_experts: [sl * bs, top-k]
236
+ expert_weights = expert_weights.flatten()
237
+ top_experts = top_experts.flatten()
238
+ with torch.no_grad():
239
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
240
+
241
+ # If expert_capacity is set to zero, set the number of tokens
242
+ # per expert to the maximum we need to avoid dropping tokens.
243
+ sl, bs, _ = x.size()
244
+ expert_capacity = self.expert_capacity(sl * bs)
245
+ if expert_capacity == 0:
246
+ expert_capacity = torch.max(tokens_per_expert).item()
247
+
248
+ x = self.permute_and_compute(
249
+ x,
250
+ tokens_per_expert,
251
+ indices,
252
+ bin_ids,
253
+ expert_weights,
254
+ bins,
255
+ expert_capacity,
256
+ self.top_k,
257
+ )
258
+ return x, tokens_per_expert
259
+
260
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
261
+ # NOTE: This function implements the same computation as forward_once
262
+ # but with expert model parallelism.
263
+ #
264
+ # 1. Permute the tokens locally so that they are grouped by their
265
+ # expert assignments. This allows us to transfer all of the tokens
266
+ # for a remote device in one communication primitive.
267
+ #
268
+ # 2. Permute the tokens across the expert parallel devices. After
269
+ # this is completed each device has all of the tokens assigned to
270
+ # its set of experts in its local HBM.
271
+ #
272
+ # 3. Permute the tokens locally so that they are grouped by their
273
+ # expert assignement. After the distributed permutation the tokens
274
+ # are grouped by which device they came from. We re-order them
275
+ # locally to allow for efficient computation.
276
+ #
277
+ # After this series of permutations we compute the linear layers
278
+ # and then repeat these three steps in reverse to produce the final
279
+ # output.
280
+ #
281
+ # Compute the mapping of local tokens to experts.
282
+ expert_weights = expert_weights.flatten()
283
+ top_experts = top_experts.flatten()
284
+ with torch.no_grad():
285
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
286
+
287
+ # If we're sharding the experts along the hidden dimension
288
+ # multiple devices own parts of the same sets of experts.
289
+ # Replicate the token counts so every device gets the counts.
290
+ # repeated_tokens_per_expert = ops.repeat(
291
+ repeated_tokens_per_expert = repeat(
292
+ tokens_per_expert,
293
+ (mpu.hidden_sharding_degree(self.args),),
294
+ )
295
+
296
+ # Pass token count information to the device on which the
297
+ # target expert resides.
298
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
299
+ tpe_handle = dist.all_to_all_single(
300
+ parallel_tokens_per_expert,
301
+ repeated_tokens_per_expert,
302
+ group=self.args.expert_parallel_group,
303
+ async_op=True,
304
+ )
305
+
306
+ # Permute locally and without any padding so that tokens for each
307
+ # parallel device are stored contiguously.
308
+ #
309
+ # This view updates the shape of the tensor from [sl, bs, hs] to
310
+ # [sl * bs, hs] prior to the permutation.
311
+ x = x.view(-1, x.shape[-1])
312
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
313
+ output = gather(x, indices, bin_ids, bins, self.top_k)
314
+ assert output is not None
315
+ x = output
316
+
317
+ # Compute the number of tokens that will be received from each
318
+ # device and permute the input data across the devices.
319
+ with torch.no_grad():
320
+ tpe_handle.wait()
321
+ experts_per_rank = mpu.experts_per_rank(self.args)
322
+
323
+ # Reshape to [world_size, num_experts_per_rank].
324
+ world_size = mpu.get_expert_parallel_world_size(self.args)
325
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
326
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
327
+
328
+ # TODO(tgale): It might be faster to do this on the GPU and
329
+ # then communicate the results back to the host.
330
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
331
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
332
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
333
+
334
+ # Convert the send/recv counts to lists.
335
+ send_counts = send_counts.tolist()
336
+ recv_counts = recv_counts.tolist()
337
+ tokens_received = sum(recv_counts)
338
+
339
+ # If we're sharding the experts along the hidden dimension
340
+ # multiple devices own parts of the same sets of experts.
341
+ # Replicate the token counts so devices that share experts
342
+ # get all of the tokens assigned to them.
343
+ #
344
+ # TODO(tgale): Fuse this into the prior, local permutation.
345
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
346
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
347
+
348
+ # Start the cross-device permutation asynchronously so we can
349
+ # overlap communication with computation.
350
+ parallel_x, parallel_x_handle = all_to_all(
351
+ x,
352
+ recv_counts,
353
+ send_counts,
354
+ self.args.expert_parallel_group,
355
+ async_op=True,
356
+ )
357
+
358
+ with torch.no_grad():
359
+ # After we do the cross-device permutation we have the tokens on the
360
+ # correct device but not yet grouped by expert because we received
361
+ # tokens from each device as contiguous chunks. To group the tokens
362
+ # for expert computation we'll do one more local permutation. The
363
+ # rest of this torch.no_grad() scope sets up the indices and bins
364
+ # for this permutation.
365
+ # replicate_bins = ops.inclusive_cumsum(
366
+ replicate_bins = inclusive_cumsum(
367
+ parallel_tokens_per_expert.flatten(),
368
+ 0,
369
+ )
370
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
371
+
372
+ # Construct the expert indices for the permuted tokens.
373
+ parallel_top_expert = torch.remainder(
374
+ torch.arange(
375
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
376
+ dtype=torch.int32,
377
+ device=indices.device,
378
+ ),
379
+ mpu.experts_per_rank(self.args),
380
+ )
381
+ # parallel_top_expert = ops.replicate(
382
+ parallel_top_expert = replicate(
383
+ parallel_top_expert.unsqueeze(dim=0),
384
+ replicate_bins,
385
+ tokens_received,
386
+ ).flatten()
387
+
388
+ # TODO(tgale): The sort_end_bit here can be reduced.
389
+ # parallel_bin_ids, parallel_indices = ops.sort(
390
+ parallel_bin_ids, parallel_indices = sort(
391
+ parallel_top_expert,
392
+ self.sort_end_bit,
393
+ )
394
+
395
+ # Calculate the bins boundaries from the token counts.
396
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
397
+ dim=0,
398
+ dtype=torch.int,
399
+ )
400
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
401
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
402
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
403
+
404
+ # If expert_capacity is set to zero, set the number of tokens
405
+ # per expert to the maximum we need to avoid dropping tokens.
406
+ tokens, _ = x.size()
407
+ expert_capacity = self.expert_capacity(tokens)
408
+ if expert_capacity == 0:
409
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
410
+
411
+ # Locally permute the tokens and perform the expert computation.
412
+ # Block to make sure that the cross-device permutation is complete.
413
+ if self.args.mlp_impl == 'grouped':
414
+ # GroupedMLP requires counts on CPU. We can use the tensor already
415
+ # moved to CPU for the prior all_to_all, which avoids an extra
416
+ # device synchronization.
417
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
418
+ dim=0,
419
+ dtype=torch.int,
420
+ )
421
+ parallel_x_handle.wait()
422
+ parallel_x = self.permute_and_compute(
423
+ parallel_x,
424
+ parallel_tokens_per_expert,
425
+ parallel_indices,
426
+ parallel_bin_ids,
427
+ None, # expert_weights
428
+ parallel_bins,
429
+ expert_capacity,
430
+ top_k=1,
431
+ )
432
+
433
+ # Un-permute the tokens across the devices.
434
+ x, _ = all_to_all(
435
+ parallel_x,
436
+ send_counts,
437
+ recv_counts,
438
+ self.args.expert_parallel_group,
439
+ )
440
+
441
+ # Reduce along the hidden sharding to get the final outputs.
442
+ #
443
+ # TODO(tgale): Fuse this into the following local permutation.
444
+ shape = (
445
+ mpu.hidden_sharding_degree(self.args),
446
+ -1,
447
+ self.args.hidden_size,
448
+ )
449
+ # x = ops.sum(x.view(shape), dim=0)
450
+ x = x.view(shape).sum(dim=0)
451
+
452
+ # Un-permute locally to setup for the next series of operations.
453
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
454
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
455
+ return x, tokens_per_expert.flatten()
456
+
457
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
458
+ in_shape = x.size()
459
+
460
+ # Compute the experts.
461
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
462
+ if self.training and self.args.moe_loss_weight > 0:
463
+ save_load_balancing_loss((tokens_per_expert, scores))
464
+ x = x.view(in_shape)
465
+ if self.bias is not None:
466
+ if self.args.return_bias:
467
+ return x, self.bias
468
+ return x + self.bias
469
+ return x
470
+
471
+
472
+ class MoE(torch.nn.Module):
473
+
474
+ def __init__(self, args: Arguments):
475
+ super(MoE, self).__init__()
476
+
477
+ # Token router.
478
+ self.router = router.LearnedRouter(args)
479
+
480
+ # Expert computation helper.
481
+ self.experts = self._init_experts_mlp(args)
482
+
483
+ self.shared_expert = None
484
+ if args.shared_expert:
485
+ # SharedExpert computation helper.
486
+ self.shared_expert = sharedexpert_registry.get(args)
487
+
488
+ def _init_experts_mlp(self, args: Arguments):
489
+ return ParallelMLP(args)
490
+
491
+ def forward(self, x: torch.Tensor):
492
+ # NOTE: If we're going to cast the activations to lower precision
493
+ # do it before we permute the tokens to save bandwidth.
494
+ x = common.cast_if_autocast_enabled(x)
495
+
496
+ # Compute the expert scores and assignments.
497
+ scores, expert_weights, top_experts = self.router(x)
498
+
499
+ # Compute the experts.
500
+ out = self.experts(x, scores, expert_weights, top_experts)
501
+ if self.shared_expert is not None:
502
+ shared_expert_out = self.shared_expert(x)
503
+ out = self.shared_expert.add_experts_sharedexpert(
504
+ shared_expert_out,
505
+ out,
506
+ )
507
+ return out
build/torch211-cxx11-cu126-aarch64-linux/_layers/mpu.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # from megablocks.layers.arguments import Arguments
10
+ from .arguments import Arguments
11
+
12
+
13
+ class MoeParam(torch.Tensor):
14
+
15
+ def __init__(self):
16
+ super().__init__(self)
17
+ self.expert_model_parallel: bool
18
+
19
+
20
+ def is_moe_param(tensor: torch.Tensor) -> bool:
21
+ return hasattr(tensor, 'expert_model_parallel')
22
+
23
+
24
+ def get_expert_parallel_world_size(args: Arguments) -> int:
25
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
26
+
27
+
28
+ def get_expert_parallel_rank(args: Arguments) -> int:
29
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
30
+
31
+
32
+ def set_expert_model_parallel_attributes(
33
+ tensor: torch.Tensor,
34
+ is_parallel: bool,
35
+ ):
36
+ assert not hasattr(tensor, 'expert_model_parallel')
37
+ setattr(tensor, 'expert_model_parallel', is_parallel)
38
+
39
+
40
+ def param_is_expert_model_parallel(param: MoeParam) -> bool:
41
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
42
+
43
+
44
+ def copy_expert_model_parallel_attributes(
45
+ destination_tensor: torch.Tensor,
46
+ source_tensor: torch.Tensor,
47
+ ):
48
+ if hasattr(source_tensor, 'expert_model_parallel'):
49
+ setattr(
50
+ destination_tensor,
51
+ 'expert_model_parallel',
52
+ getattr(source_tensor, 'expert_model_parallel'),
53
+ )
54
+
55
+
56
+ def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
57
+ world_size = dist.get_world_size(group)
58
+ rank = dist.get_rank(group)
59
+ for i in range(world_size):
60
+ dist.barrier(group)
61
+ if i == rank:
62
+ print(f'rank = {rank}', *x)
63
+
64
+
65
+ # Helpers for expert/tensor sharding.
66
+ def expert_sharding_degree(args: Arguments) -> int:
67
+ world_size = get_expert_parallel_world_size(args)
68
+ esd = min(world_size, args.moe_num_experts)
69
+
70
+ if (args.moe_num_experts % esd) != 0:
71
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
72
+ return esd
73
+
74
+
75
+ def hidden_sharding_degree(args: Arguments) -> int:
76
+ world_size = get_expert_parallel_world_size(args)
77
+ esd = expert_sharding_degree(args)
78
+ hsd = world_size // esd
79
+
80
+ if (args.ffn_hidden_size % hsd) != 0:
81
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
82
+ if (esd * hsd) != world_size:
83
+ raise ValueError(
84
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
85
+ )
86
+ return hsd
87
+
88
+
89
+ def experts_per_rank(args: Arguments) -> int:
90
+ return args.moe_num_experts // expert_sharding_degree(args)
91
+
92
+
93
+ def features_per_rank(args: Arguments) -> int:
94
+ return args.ffn_hidden_size // hidden_sharding_degree(args)