File size: 11,508 Bytes
67e9774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import inspect
import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.config import use_experimental_benchmarker
logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
use_experimental_benchmarker = (
use_experimental_benchmarker and torch.cuda.is_available()
)
MILLISECONDS_PER_SECOND = 1000
P = ParamSpec("P")
T = TypeVar("T")
def time_and_count(
fn: Callable[Concatenate[Any, P], T],
) -> Callable[Concatenate[Any, P], T]:
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
subclasses; typing limitations prevent us from declaring this directly.
"""
@wraps(fn)
def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}"
counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1
with dynamo_timed(fn_qual_name, log_pt2_compile_event=False):
return fn(self, *args, **kwargs)
return wrapper
class Benchmarker:
def __init__(self: Self) -> None:
pass
@time_and_count
def benchmark(
self: Self,
fn: Callable[..., Any],
fn_args: tuple[Any, ...],
fn_kwargs: dict[str, Any],
**kwargs: Any,
) -> float:
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
actual runtime calculation is dictated by the benchmarking implementation, but may be
one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
types are found.
Arguments:
- fn: The function to benchmark.
- fn_args: The function's arguments.
- fn_kwargs: The function's kwargs.
Keyword Arguments:
- **kwargs: The benchmarking implementation's kwargs.
Returns:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
"""
inferred_device = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
if inferred_device is None:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
@time_and_count
def benchmark_cpu(
self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
) -> float:
"""Benchmark the CPU callable, `_callable`, and return the median runtime,
in milliseconds.
Arguments:
- _callable: The CPU callable to benchmark.
Keyword Arguments:
- warmup: Optionally, the duration, in milliseconds, to run `_callable`
before benchmarking starts.
- rep: Optionally, the duration, in milliseconds, to run `_callable`
during benchmarking.
Returns:
- The median runtime of `_callable`, in milliseconds.
"""
def run_for(ms: int) -> list[float]:
timings = []
run_start_t = time.perf_counter()
while True:
start_t = time.perf_counter()
_callable()
end_t = time.perf_counter()
timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND)
if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms:
break
return timings
run_for(warmup)
return median(run_for(rep))
@time_and_count
def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
raise NotImplementedError
class TritonBenchmarker(Benchmarker):
@cached_property
def triton_do_bench(self: Self) -> Callable[..., Any]:
"""Lazily import Triton's `do_bench`."""
try:
from triton.testing import do_bench
except ImportError as e:
raise NotImplementedError("requires Triton") from e
return do_bench
@time_and_count
def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
"""Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
Arguments:
- _callable: The GPU callable to benchmark.
Keyword Arguments:
- quantiles: Optionally, a tuple of floats denoting the requested quantiles.
- return_mode: Optionally, the requested return mode. Currently, Triton's
`do_bench` supports min, max, mean, and median return modes.
- **kwargs: Additional kwargs passed to Triton's `do_bench`.
Returns:
- The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified,
this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified,
this is the requested return mode. Otherwise, this is the median.
"""
do_bench_params = inspect.signature(self.triton_do_bench).parameters
for kwarg in list(kwargs.keys()):
if kwarg not in do_bench_params:
del kwargs[kwarg]
if "quantiles" in kwargs:
return self.triton_do_bench(_callable, **kwargs)[0]
elif "return_mode" in kwargs:
return self.triton_do_bench(_callable, **kwargs)
return self.triton_do_bench(_callable, **kwargs, return_mode="median")
class InductorBenchmarker(TritonBenchmarker):
@cached_property
def L2_cache_size(self: Self) -> int:
"""Get the L2 cache size, in bytes, of the current device."""
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
return props.L2_cache_size
def get_event_pairs(
self: Self, iters: int
) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]:
"""Get `iters` pairs of CUDA events."""
return [
(
torch.cuda.Event(enable_timing=True),
torch.cuda.Event(enable_timing=True),
)
for _ in range(iters)
]
def get_event_pairs_min_timing(
self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]]
) -> float:
"""Get the minimum timing, in milliseconds, for a group of CUDA event pairs."""
return min(
[
start_event.elapsed_time(end_event)
for start_event, end_event in event_pairs
]
)
@time_and_count
def benchmark_gpu(
self: Self,
_callable: Callable[[], Any],
estimation_iters: int = 5,
memory_warmup_iters: int = 100,
benchmark_iters: int = 100,
max_benchmark_duration: int = 25,
**kwargs: Any,
) -> float:
"""Benchmark a GPU callable using a custom benchmarking implementation.
Arguments:
- _callable: The callable to benchmark.
Keyword Arguments:
- estimation_iters: Optionally, the number of iterations to run `_callable`
during runtime estimation.
- memory_warmup_iters: Optionally, the number of iterations to flush the L2
cache before starting benchmarking.
- benchmark_iters: Optionally, the number of iterations to run `_callable`
during the benchmarking.
- max_benchmark_duration: Optionally, the maximum duration of the benchmarking,
in milliseconds. An estimated duration is calculated based on the values
of `memory_warmup_iters` and `benchmark_iters`, along with the estimated
runtime of `_callable` and various other factors, and we then shrink
`benchmark_iters` to fit in the allotted maximum duration.
- **kwargs: Additional kwargs that may be passed to the fallback.
Returns:
- The minimum runtime of `_callable`, in milliseconds.
"""
# we don't want any outside errors propagating into benchmarking
torch.cuda.synchronize()
# warmup `_callable` (and catches any failures in the process)
_callable()
torch.cuda.synchronize()
# see https://github.com/triton-lang/triton/pull/840 for why `dtype=torch.int`
buffer = torch.empty(self.L2_cache_size // 4, dtype=torch.int, device="cuda")
buffer.zero_()
# estimate the runtime of `_callable`
event_pairs = self.get_event_pairs(estimation_iters)
for start_event, end_event in event_pairs:
buffer.zero_()
start_event.record()
_callable()
end_event.record()
torch.cuda.synchronize()
estimated_timing = self.get_event_pairs_min_timing(event_pairs)
# adjust `benchmark_iters` to fit in the maximum benchmarking duration
benchmark_iters = max(
min(benchmark_iters, int(max_benchmark_duration // estimated_timing)), 1
)
# do the memory warmup
for _ in range(memory_warmup_iters):
buffer.zero_()
# benchmark `_callable`
event_pairs = self.get_event_pairs(benchmark_iters)
for start_event, end_event in event_pairs:
buffer.zero_()
start_event.record()
_callable()
end_event.record()
torch.cuda.synchronize()
benchmarked_timing = self.get_event_pairs_min_timing(event_pairs)
# explicitly delete the buffer, sometimes helps memory
# footprint metrics in OSS Inductor performance benchmarks
del buffer
# return the minimum of `estimated_timing` and `benchmarked_timing`,
# we just want the minimum timing overall so we might as well check both
return min(estimated_timing, benchmarked_timing)
benchmarker = (
InductorBenchmarker() if use_experimental_benchmarker else TritonBenchmarker()
)
|