|
|
"""Example of Timer and Compare APIs: |
|
|
|
|
|
$ python -m examples.compare |
|
|
""" |
|
|
|
|
|
import pickle |
|
|
import sys |
|
|
import time |
|
|
|
|
|
import torch |
|
|
|
|
|
import torch.utils.benchmark as benchmark_utils |
|
|
|
|
|
|
|
|
class FauxTorch(object): |
|
|
"""Emulate different versions of pytorch. |
|
|
|
|
|
In normal circumstances this would be done with multiple processes |
|
|
writing serialized measurements, but this simplifies that model to |
|
|
make the example clearer. |
|
|
""" |
|
|
def __init__(self, real_torch, extra_ns_per_element): |
|
|
self._real_torch = real_torch |
|
|
self._extra_ns_per_element = extra_ns_per_element |
|
|
|
|
|
def extra_overhead(self, result): |
|
|
|
|
|
|
|
|
numel = int(result.numel()) |
|
|
if numel > 5000: |
|
|
time.sleep(numel * self._extra_ns_per_element * 1e-9) |
|
|
return result |
|
|
|
|
|
def add(self, *args, **kwargs): |
|
|
return self.extra_overhead(self._real_torch.add(*args, **kwargs)) |
|
|
|
|
|
def mul(self, *args, **kwargs): |
|
|
return self.extra_overhead(self._real_torch.mul(*args, **kwargs)) |
|
|
|
|
|
def cat(self, *args, **kwargs): |
|
|
return self.extra_overhead(self._real_torch.cat(*args, **kwargs)) |
|
|
|
|
|
def matmul(self, *args, **kwargs): |
|
|
return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) |
|
|
|
|
|
|
|
|
def main(): |
|
|
tasks = [ |
|
|
("add", "add", "torch.add(x, y)"), |
|
|
("add", "add (extra +0)", "torch.add(x, y + zero)"), |
|
|
] |
|
|
|
|
|
serialized_results = [] |
|
|
repeats = 2 |
|
|
timers = [ |
|
|
benchmark_utils.Timer( |
|
|
stmt=stmt, |
|
|
globals={ |
|
|
"torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), |
|
|
"x": torch.ones((size, 4)), |
|
|
"y": torch.ones((1, 4)), |
|
|
"zero": torch.zeros(()), |
|
|
}, |
|
|
label=label, |
|
|
sub_label=sub_label, |
|
|
description=f"size: {size}", |
|
|
env=branch, |
|
|
num_threads=num_threads, |
|
|
) |
|
|
for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] |
|
|
for label, sub_label, stmt in tasks |
|
|
for size in [1, 10, 100, 1000, 10000, 50000] |
|
|
for num_threads in [1, 4] |
|
|
] |
|
|
|
|
|
for i, timer in enumerate(timers * repeats): |
|
|
serialized_results.append(pickle.dumps( |
|
|
timer.blocked_autorange(min_run_time=0.05) |
|
|
)) |
|
|
print(f"\r{i + 1} / {len(timers) * repeats}", end="") |
|
|
sys.stdout.flush() |
|
|
print() |
|
|
|
|
|
comparison = benchmark_utils.Compare([ |
|
|
pickle.loads(i) for i in serialized_results |
|
|
]) |
|
|
|
|
|
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") |
|
|
comparison.print() |
|
|
|
|
|
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") |
|
|
comparison.trim_significant_figures() |
|
|
comparison.colorize() |
|
|
comparison.print() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|