Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
60615a0
·
1 Parent(s): 7d51e61

style: apply yapf + isort formatting

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

benchmarks/common/bench_framework.py CHANGED
@@ -4,8 +4,8 @@ import re
4
  from typing import Any, Dict, Sequence
5
 
6
  import torch
7
- from torch.profiler import ProfilerActivity, profile
8
  import triton
 
9
 
10
  from .diff_engine import DiffCase
11
 
@@ -42,8 +42,8 @@ def _compute_bytes(inputs, forward_fn, obj):
42
  if isinstance(output, torch.Tensor):
43
  output_bytes = output.nbytes
44
  elif isinstance(output, (tuple, list)):
45
- output_bytes = sum(
46
- o.nbytes for o in output if isinstance(o, torch.Tensor))
47
  else:
48
  output_bytes = 0
49
  return input_bytes + output_bytes
@@ -158,7 +158,9 @@ def make_fwd_benchmark_for_case(
158
  key = make_fwd_key(dim, batch_size, seq_len)
159
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
160
  if provider == "speedup":
161
- return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
 
 
162
  if provider.endswith("_bw"):
163
  base = provider[:-3]
164
  ms = timings_ms[base][key]
@@ -227,7 +229,8 @@ def make_fwd_benchmark_plot_for_case(
227
  ms = profile_bench(run, total_bytes=nbytes)
228
  timings_ms[provider][config] = ms
229
  if provider == "cuda":
230
- ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
 
231
  spdup_ratio.append(ratio)
232
  return round(ratio, 2)
233
  else:
@@ -267,7 +270,9 @@ def make_bwd_benchmark_for_case(
267
  key = make_bwd_key(dim, batch_size, seq_len)
268
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
269
  if provider == "speedup":
270
- return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
 
 
271
  if provider.endswith("_bw"):
272
  base = provider[:-3]
273
  ms = timings_ms[base][key]
@@ -360,7 +365,8 @@ def make_bwd_benchmark_plot_for_case(
360
  ms = profile_bench(run, total_bytes=nbytes)
361
  timings_ms[provider][config] = ms
362
  if provider == "cuda":
363
- ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
 
364
  spdup_ratio.append(ratio)
365
  return round(ratio, 2)
366
  else:
 
4
  from typing import Any, Dict, Sequence
5
 
6
  import torch
 
7
  import triton
8
+ from torch.profiler import ProfilerActivity, profile
9
 
10
  from .diff_engine import DiffCase
11
 
 
42
  if isinstance(output, torch.Tensor):
43
  output_bytes = output.nbytes
44
  elif isinstance(output, (tuple, list)):
45
+ output_bytes = sum(o.nbytes for o in output
46
+ if isinstance(o, torch.Tensor))
47
  else:
48
  output_bytes = 0
49
  return input_bytes + output_bytes
 
158
  key = make_fwd_key(dim, batch_size, seq_len)
159
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
160
  if provider == "speedup":
161
+ return round(
162
+ timings_ms["naive"][key] /
163
+ _get_best_cuda_timing(timings_ms, key), 2)
164
  if provider.endswith("_bw"):
165
  base = provider[:-3]
166
  ms = timings_ms[base][key]
 
229
  ms = profile_bench(run, total_bytes=nbytes)
230
  timings_ms[provider][config] = ms
231
  if provider == "cuda":
232
+ ratio = timings_ms["naive"][config] / _get_best_cuda_timing(
233
+ timings_ms, config)
234
  spdup_ratio.append(ratio)
235
  return round(ratio, 2)
236
  else:
 
270
  key = make_bwd_key(dim, batch_size, seq_len)
271
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
272
  if provider == "speedup":
273
+ return round(
274
+ timings_ms["naive"][key] /
275
+ _get_best_cuda_timing(timings_ms, key), 2)
276
  if provider.endswith("_bw"):
277
  base = provider[:-3]
278
  ms = timings_ms[base][key]
 
365
  ms = profile_bench(run, total_bytes=nbytes)
366
  timings_ms[provider][config] = ms
367
  if provider == "cuda":
368
+ ratio = timings_ms["naive"][config] / _get_best_cuda_timing(
369
+ timings_ms, config)
370
  spdup_ratio.append(ratio)
371
  return round(ratio, 2)
372
  else: