File size: 3,506 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import math
import time

import mlx.core as mx
import numpy as np
import torch

N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)


def bench(f, a, b):
    for i in range(N_warmup):
        f(a, b)

    s = time.perf_counter_ns()
    for i in range(N_iter_bench):
        f(a, b)
    e = time.perf_counter_ns()
    return (e - s) * 1e-9


def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
    def mx_conv_3D(a, b):
        ys = []
        for i in range(N_iter_func):
            y = mx.conv_transpose3d(
                a, b, stride=strides, padding=padding, groups=groups
            )
            ys.append(y)
        mx.eval(ys)
        return ys

    return mx_conv_3D


def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
    @torch.no_grad()
    def pt_conv_3D(a, b):
        ys = []
        for i in range(N_iter_func):
            y = torch.conv_transpose3d(
                a, b, stride=strides, padding=padding, groups=groups
            )
            ys.append(y)
        return ys

    return pt_conv_3D


def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
    scale = 1.0 / math.sqrt(kD * kH * kW * C)
    a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
    b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
        np_dtype
    )

    a_mx = mx.array(a_np)
    b_mx = mx.array(b_np)

    a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
    b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")

    f_mx = make_mx_conv_3D(strides, padding, groups)
    f_pt = make_pt_conv_3D(strides, padding, groups)

    time_torch = bench(f_pt, a_pt, b_pt)
    time_mlx = bench(f_mx, a_mx, b_mx)

    out_mx = mx.conv_transpose3d(
        a_mx, b_mx, stride=strides, padding=padding, groups=groups
    )
    out_pt = torch.conv_transpose3d(
        a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
    )
    out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
    out_pt = out_pt.numpy(force=True)

    atol = 2e-5 if np_dtype == np.float32 else 1e-4

    if not np.allclose(out_pt, out_mx, atol=atol):
        print(
            f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
        )

    return time_mlx, time_torch


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run conv benchmarks")

    dtypes = ("float32",)
    shapes = (
        (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
        (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
    )

    for dtype in dtypes:
        print(
            "(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%"
        )
        for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
            np_dtype = getattr(np, dtype)
            time_mlx, time_torch = bench_shape(
                N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
            )
            diff = time_torch / time_mlx - 1.0

            print(
                f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
            )
            if time_mlx >= 2.0 * time_torch:
                print("ATTENTION ^^^^^^^")