JSX_TTS / torch /utils /benchmark /examples /spectral_ops_fuzz_test.py
UMMJ's picture
Upload 5875 files
9dd3461
"""Microbenchmarks for the torch.fft module"""
from argparse import ArgumentParser
from collections import namedtuple
from collections.abc import Iterable
import torch
import torch.fft
from torch.utils import benchmark
from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer
def _dim_options(ndim):
if ndim == 1:
return [None]
elif ndim == 2:
return [0, 1, None]
elif ndim == 3:
return [0, 1, 2, (0, 1), (0, 2), None]
raise ValueError(f"Expected ndim in range 1-3, got {ndim}")
def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int,
probability_regular: float):
cuda = device == 'cuda'
spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda,
probability_regular=probability_regular)
results = []
for tensors, tensor_params, params in spectral_fuzzer.take(samples):
shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
str_shape = ' x '.join(["{:<4}".format(s) for s in shape])
sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
for dim in _dim_options(params['ndim']):
for nthreads in (1, 4, 16) if not cuda else (1,):
measurement = benchmark.Timer(
stmt='func(x, dim=dim)',
globals={'func': function, 'x': tensors['x'], 'dim': dim},
label=f"{name}_{device}",
sub_label=sub_label,
description=f"dim={dim}",
num_threads=nthreads,
).blocked_autorange(min_run_time=1)
measurement.metadata = {
'name': name,
'device': device,
'dim': dim,
'shape': shape,
}
measurement.metadata.update(tensor_params['x'])
results.append(measurement)
return results
Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype'])
BENCHMARKS = [
Benchmark('fft_real', torch.fft.fftn, torch.float32),
Benchmark('fft_complex', torch.fft.fftn, torch.complex64),
Benchmark('ifft', torch.fft.ifftn, torch.complex64),
Benchmark('rfft', torch.fft.rfftn, torch.float32),
Benchmark('irfft', torch.fft.irfftn, torch.complex64),
]
BENCHMARK_MAP = {b.name: b for b in BENCHMARKS}
BENCHMARK_NAMES = [b.name for b in BENCHMARKS]
DEVICE_NAMES = ['cpu', 'cuda']
def _output_csv(file, results):
file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n')
for measurement in results:
metadata = measurement.metadata
device, dim, shape, name, numel, contiguous = (
metadata['device'], metadata['dim'], metadata['shape'],
metadata['name'], metadata['numel'], metadata['is_contiguous'])
if isinstance(dim, Iterable):
dim_str = '-'.join(str(d) for d in dim)
else:
dim_str = str(dim)
shape_str = 'x'.join(str(s) for s in shape)
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,
measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
sep=',', file=file)
if __name__ == '__main__':
parser = ArgumentParser(description=__doc__)
parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES)
parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--samples', type=int, default=10)
parser.add_argument('--probability_regular', type=float, default=1.0)
parser.add_argument('-o', '--output', type=str)
args = parser.parse_args()
num_benchmarks = len(args.device) * len(args.bench)
i = 0
results = []
for device in args.device:
for bench in (BENCHMARK_MAP[b] for b in args.bench):
results += run_benchmark(
name=bench.name, function=bench.function, dtype=bench.dtype,
seed=args.seed, device=device, samples=args.samples,
probability_regular=args.probability_regular)
i += 1
print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})')
if args.output is not None:
with open(args.output, 'w') as f:
_output_csv(f, results)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()