| # Copyright © 2023-2024 Apple Inc. | |
| import time | |
| import mlx.core as mx | |
| def time_fn(fn, *args, **kwargs): | |
| msg = kwargs.pop("msg", None) | |
| if msg: | |
| print(f"Timing {msg} ...", end=" ") | |
| else: | |
| print(f"Timing {fn.__name__} ...", end=" ") | |
| # warmup | |
| for _ in range(5): | |
| mx.eval(fn(*args, **kwargs)) | |
| num_iters = 100 | |
| tic = time.perf_counter() | |
| for _ in range(num_iters): | |
| x = mx.eval(fn(*args, **kwargs)) | |
| toc = time.perf_counter() | |
| msec = 1e3 * (toc - tic) / num_iters | |
| print(f"{msec:.5f} msec") | |
| def measure_runtime(fn, **kwargs): | |
| # Warmup | |
| for _ in range(5): | |
| fn(**kwargs) | |
| tic = time.time() | |
| iters = 100 | |
| for _ in range(iters): | |
| fn(**kwargs) | |
| return (time.time() - tic) * 1000 / iters | |