File size: 2,731 Bytes
2dd52ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
AlexNet in MLX with endpoints for relu1, relu2, relu3, relu4, relu5.
Loads weights from a torchvision-exported npz.
"""

import mlx.core as mx
import mlx.nn as nn
import numpy as np


def _conv(in_ch, out_ch, kernel_size, stride=1, padding=0):
    return nn.Conv2d(
        in_ch,
        out_ch,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        bias=True,
    )


class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [
            _conv(3, 64, kernel_size=11, stride=4, padding=2),  # 0
            nn.ReLU(),  # 1 (relu1)
            nn.MaxPool2d(kernel_size=3, stride=2),  # 2
            _conv(64, 192, kernel_size=5, padding=2),  # 3
            nn.ReLU(),  # 4 (relu2)
            nn.MaxPool2d(kernel_size=3, stride=2),  # 5
            _conv(192, 384, kernel_size=3, padding=1),  # 6
            nn.ReLU(),  # 7 (relu3)
            _conv(384, 256, kernel_size=3, padding=1),  # 8
            nn.ReLU(),  # 9 (relu4)
            _conv(256, 256, kernel_size=3, padding=1),  # 10
            nn.ReLU(),  # 11 (relu5)
            nn.MaxPool2d(kernel_size=3, stride=2),  # 12
        ]

        self.endpoint_indices = {
            "relu1": 1,
            "relu2": 4,
            "relu3": 7,
            "relu4": 9,
            "relu5": 11,
        }

    def forward_with_endpoints(self, x):
        endpoints = {}
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            for name, i in self.endpoint_indices.items():
                if idx == i:
                    endpoints[name] = x
        return x, endpoints

    def __call__(self, x):
        _, endpoints = self.forward_with_endpoints(x)
        return endpoints

    def load_npz(self, path: str):
        data = np.load(path)

        def load_weight(key, transpose=False):
            if key in data:
                w = data[key]
            elif f"{key}_int8" in data:
                w_int8 = data[f"{key}_int8"]
                scale = data[f"{key}_scale"]
                w = w_int8.astype(scale.dtype) * scale
            else:
                raise ValueError(f"Missing key {key} in npz")
            
            if transpose and w.ndim == 4:
                w = np.transpose(w, (0, 2, 3, 1))
            return mx.array(w)

        # Map layer indices to 'features.X' in standard torchvision keys
        conv_indices = [0, 3, 6, 8, 10]
        
        for idx in conv_indices:
            conv = self.layers[idx]
            weight_key = f"features.{idx}.weight"
            bias_key = f"features.{idx}.bias"
            
            conv.weight = load_weight(weight_key, transpose=True)
            conv.bias = load_weight(bias_key)