File size: 552 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch.utils.benchmark as benchmark


def benchmark_torch_function(f, *args, **kwargs):
    t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
    return t0.blocked_autorange().mean * 1e3  # Milliseconds


def benchmark_torch_function_with_result(f, *args, **kwargs):
    result = f(*args, **kwargs)
    t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
    time_in_ms = t0.blocked_autorange().mean * 1e3  # Milliseconds
    return time_in_ms, result