| | import time |
| |
|
| | import mlx.core as mx |
| | import mlx.nn |
| | import mlx.optimizers as opt |
| | import torch |
| |
|
| |
|
| | def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float: |
| | mx.set_default_device(mx.cpu) |
| |
|
| | class BenchNetMLX(mlx.nn.Module): |
| | |
| |
|
| | def __init__(self, in_channels, hidden_channels=16): |
| | super().__init__() |
| |
|
| | self.net = mlx.nn.Sequential( |
| | mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), |
| | mlx.nn.ReLU(), |
| | mlx.nn.Conv3d( |
| | hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 |
| | ), |
| | mlx.nn.ReLU(), |
| | mlx.nn.ConvTranspose3d( |
| | 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 |
| | ), |
| | mlx.nn.ReLU(), |
| | mlx.nn.ConvTranspose3d( |
| | hidden_channels, in_channels, kernel_size=3, padding=1 |
| | ), |
| | ) |
| |
|
| | def __call__(self, input): |
| | return self.net(input) |
| |
|
| | benchNet = BenchNetMLX(3) |
| | mx.eval(benchNet.parameters()) |
| | optim = opt.Adam(learning_rate=1e-3) |
| |
|
| | inputs = mx.random.normal(shape) |
| |
|
| | params = benchNet.parameters() |
| | optim.init(params) |
| |
|
| | state = [benchNet.state, optim.state] |
| |
|
| | def loss_fn(params, image): |
| | benchNet.update(params) |
| | pred_image = benchNet(image) |
| | return (pred_image - image).abs().mean() |
| |
|
| | def step(params, image): |
| | loss, grads = mx.value_and_grad(loss_fn)(params, image) |
| | optim.update(benchNet, grads) |
| | return loss |
| |
|
| | total_time = 0.0 |
| | print("MLX:") |
| | for i in range(steps): |
| | start_time = time.perf_counter() |
| |
|
| | step(benchNet.parameters(), inputs) |
| | mx.eval(state) |
| | end_time = time.perf_counter() |
| |
|
| | print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") |
| | total_time += (end_time - start_time) * 1000 |
| |
|
| | return total_time |
| |
|
| |
|
| | def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float: |
| | device = torch.device("cpu") |
| |
|
| | class BenchNetTorch(torch.nn.Module): |
| | |
| |
|
| | def __init__(self, in_channels, hidden_channels=16): |
| | super().__init__() |
| |
|
| | self.net = torch.nn.Sequential( |
| | torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), |
| | torch.nn.ReLU(), |
| | torch.nn.Conv3d( |
| | hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 |
| | ), |
| | torch.nn.ReLU(), |
| | torch.nn.ConvTranspose3d( |
| | 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 |
| | ), |
| | torch.nn.ReLU(), |
| | torch.nn.ConvTranspose3d( |
| | hidden_channels, in_channels, kernel_size=3, padding=1 |
| | ), |
| | ) |
| |
|
| | def forward(self, input): |
| | return self.net(input) |
| |
|
| | benchNet = BenchNetTorch(3).to(device) |
| | optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) |
| |
|
| | inputs = torch.randn(*shape, device=device) |
| |
|
| | def loss_fn(pred_image, image): |
| | return (pred_image - image).abs().mean() |
| |
|
| | total_time = 0.0 |
| | print("PyTorch:") |
| | for i in range(steps): |
| | start_time = time.perf_counter() |
| |
|
| | optim.zero_grad() |
| | pred_image = benchNet(inputs) |
| | loss = loss_fn(pred_image, inputs) |
| | loss.backward() |
| | optim.step() |
| |
|
| | end_time = time.perf_counter() |
| |
|
| | print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") |
| | total_time += (end_time - start_time) * 1000 |
| |
|
| | return total_time |
| |
|
| |
|
| | def main(): |
| | steps = 10 |
| | time_mlx = bench_mlx(steps) |
| | time_torch = bench_torch(steps) |
| |
|
| | print(f"average time of MLX: {time_mlx/steps:9.2f} ms") |
| | print(f"total time of MLX: {time_mlx:9.2f} ms") |
| | print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") |
| | print(f"total time of PyTorch: {time_torch:9.2f} ms") |
| |
|
| | diff = time_torch / time_mlx - 1.0 |
| | print(f"torch/mlx diff: {100. * diff:+5.2f}%") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|