|
|
|
|
|
|
|
|
import argparse |
|
|
|
|
|
import mlx.core as mx |
|
|
import torch |
|
|
from time_utils import measure_runtime |
|
|
|
|
|
|
|
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): |
|
|
def scatter(dst, x, idx): |
|
|
dst[tuple(idx)] = x |
|
|
mx.eval(dst) |
|
|
|
|
|
idx = [] |
|
|
for idx_shape in idx_shapes: |
|
|
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape)) |
|
|
x = mx.random.normal(x_shape).astype(mx.float32) |
|
|
dst = mx.random.normal(dst_shape).astype(mx.float32) |
|
|
|
|
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx) |
|
|
print(f"MLX: {runtime:.3f}ms") |
|
|
|
|
|
|
|
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): |
|
|
def scatter(dst, x, idx, device): |
|
|
dst[tuple(idx)] = x |
|
|
if device == torch.device("mps"): |
|
|
torch.mps.synchronize() |
|
|
|
|
|
idx = [] |
|
|
for idx_shape in idx_shapes: |
|
|
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)) |
|
|
x = torch.randn(x_shape, dtype=torch.float32).to(device) |
|
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device) |
|
|
|
|
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device) |
|
|
print(f"PyTorch: {runtime:.3f}ms") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser("Gather benchmarks.") |
|
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.cpu: |
|
|
mx.set_default_device(mx.cpu) |
|
|
device = torch.device("cpu") |
|
|
else: |
|
|
device = torch.device("mps") |
|
|
|
|
|
dst_shapes = [ |
|
|
(10, 64), |
|
|
(100_000, 64), |
|
|
(1_000_000, 64), |
|
|
(100_000,), |
|
|
(200_000,), |
|
|
(20_000_000,), |
|
|
(10000, 64), |
|
|
(100, 64), |
|
|
(100, 10_000, 64), |
|
|
(10, 100, 100, 21), |
|
|
(1_000, 1_000, 10), |
|
|
] |
|
|
idx_shapes = [ |
|
|
[(1_000_000,)], |
|
|
[(1_000_000,)], |
|
|
[(100_000,)], |
|
|
[(1_000_000,)], |
|
|
[(20_000_000,)], |
|
|
[(20_000_000,)], |
|
|
[(1000000,)], |
|
|
[(10000000,)], |
|
|
[(1_000,)], |
|
|
[(10_000,)], |
|
|
[(1_000,), (1_000,)], |
|
|
] |
|
|
x_shapes = [ |
|
|
(1_000_000, 64), |
|
|
(1_000_000, 64), |
|
|
(100_000, 64), |
|
|
(1_000_000,), |
|
|
(20_000_000,), |
|
|
(20_000_000,), |
|
|
(1000000, 64), |
|
|
(10000000, 64), |
|
|
(1_000, 10_000, 64), |
|
|
(10_000, 100, 100, 21), |
|
|
(1_000, 10), |
|
|
] |
|
|
|
|
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): |
|
|
print("=" * 20) |
|
|
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}") |
|
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) |
|
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) |
|
|
|