File size: 4,139 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import time

import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch


def bench_mlx(steps: int = 20) -> float:
    mx.set_default_device(mx.cpu)

    class BenchNetMLX(mlx.nn.Module):
        # simple encoder-decoder net

        def __init__(self, in_channels, hidden_channels=32):
            super().__init__()

            self.net = mlx.nn.Sequential(
                mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
                mlx.nn.ReLU(),
                mlx.nn.Conv2d(
                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
                ),
                mlx.nn.ReLU(),
                mlx.nn.ConvTranspose2d(
                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
                ),
                mlx.nn.ReLU(),
                mlx.nn.ConvTranspose2d(
                    hidden_channels, in_channels, kernel_size=3, padding=1
                ),
            )

        def __call__(self, input):
            return self.net(input)

    benchNet = BenchNetMLX(3)
    mx.eval(benchNet.parameters())
    optim = opt.Adam(learning_rate=1e-3)

    inputs = mx.random.normal([10, 256, 256, 3])

    params = benchNet.parameters()
    optim.init(params)

    state = [benchNet.state, optim.state]

    def loss_fn(params, image):
        benchNet.update(params)
        pred_image = benchNet(image)
        return (pred_image - image).abs().mean()

    def step(params, image):
        loss, grads = mx.value_and_grad(loss_fn)(params, image)
        optim.update(benchNet, grads)
        return loss

    total_time = 0.0
    print("MLX:")
    for i in range(steps):
        start_time = time.perf_counter()

        step(benchNet.parameters(), inputs)
        mx.eval(state)
        end_time = time.perf_counter()

        print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
        total_time += (end_time - start_time) * 1000

    return total_time


def bench_torch(steps: int = 20) -> float:
    device = torch.device("cpu")

    class BenchNetTorch(torch.nn.Module):
        # simple encoder-decoder net

        def __init__(self, in_channels, hidden_channels=32):
            super().__init__()

            self.net = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(
                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
                ),
                torch.nn.ReLU(),
                torch.nn.ConvTranspose2d(
                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
                ),
                torch.nn.ReLU(),
                torch.nn.ConvTranspose2d(
                    hidden_channels, in_channels, kernel_size=3, padding=1
                ),
            )

        def forward(self, input):
            return self.net(input)

    benchNet = BenchNetTorch(3).to(device)
    optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)

    inputs = torch.randn(10, 3, 256, 256, device=device)

    def loss_fn(pred_image, image):
        return (pred_image - image).abs().mean()

    total_time = 0.0
    print("PyTorch:")
    for i in range(steps):
        start_time = time.perf_counter()

        optim.zero_grad()
        pred_image = benchNet(inputs)
        loss = loss_fn(pred_image, inputs)
        loss.backward()
        optim.step()

        end_time = time.perf_counter()

        print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
        total_time += (end_time - start_time) * 1000

    return total_time


def main():
    steps = 20
    time_mlx = bench_mlx(steps)
    time_torch = bench_torch(steps)

    print(f"average time of MLX:     {time_mlx/steps:9.2f} ms")
    print(f"total time of MLX:       {time_mlx:9.2f} ms")
    print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
    print(f"total time of PyTorch:   {time_torch:9.2f} ms")

    diff = time_torch / time_mlx - 1.0
    print(f"torch/mlx diff: {100. * diff:+5.2f}%")


if __name__ == "__main__":
    main()