File size: 3,380 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import math
import time

import mlx.core as mx
import numpy as np
import torch

N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)


def bench(f, a, b):
    for i in range(N_warmup):
        f(a, b)

    s = time.perf_counter_ns()
    for i in range(N_iter_bench):
        f(a, b)
    e = time.perf_counter_ns()
    return (e - s) * 1e-9


def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
    def mx_conv_3D(a, b):
        ys = []
        for i in range(N_iter_func):
            y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
            ys.append(y)
        mx.eval(ys)
        return ys

    return mx_conv_3D


def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
    @torch.no_grad()
    def pt_conv_3D(a, b):
        ys = []
        for i in range(N_iter_func):
            y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
            ys.append(y)
        return ys

    return pt_conv_3D


def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
    scale = 1.0 / math.sqrt(kD * kH * kW * C)
    a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
    b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
        np_dtype
    )

    a_mx = mx.array(a_np)
    b_mx = mx.array(b_np)

    a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
    b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")

    f_mx = make_mx_conv_3D(strides, padding, groups)
    f_pt = make_pt_conv_3D(strides, padding, groups)

    time_torch = bench(f_pt, a_pt, b_pt)
    time_mlx = bench(f_mx, a_mx, b_mx)

    out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
    out_pt = torch.conv3d(
        a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
    )
    out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
    out_pt = out_pt.numpy(force=True)

    atol = 2e-5 if np_dtype == np.float32 else 1e-4

    if not np.allclose(out_pt, out_mx, atol=atol):
        print(
            f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
        )

    return time_mlx, time_torch


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run conv benchmarks")

    dtypes = ("float32",)
    shapes = (
        (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
        (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
    )

    for dtype in dtypes:
        print(
            "(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%"
        )
        for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
            np_dtype = getattr(np, dtype)
            time_mlx, time_torch = bench_shape(
                N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
            )
            diff = time_torch / time_mlx - 1.0

            print(
                f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
            )
            if time_mlx >= 2.0 * time_torch:
                print("ATTENTION ^^^^^^^")