File size: 3,290 Bytes
3c8f058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdff6f4
 
 
 
 
 
 
 
 
 
 
 
 
3c8f058
 
 
 
 
 
bdff6f4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
VGG16 in MLX with endpoints for relu1_2, relu2_2, relu3_3, relu4_2, relu4_3,
relu5_2, relu5_3. Loads weights from a torchvision-exported npz
(see export_vgg16_npz.py).
"""

import mlx.core as mx
import mlx.nn as nn
import numpy as np


def _conv(in_ch, out_ch, kernel_size=3, padding=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, bias=True)


class VGG16(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [
            _conv(3, 64),  # 0 conv1_1
            nn.ReLU(),
            _conv(64, 64),  # 2 conv1_2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            _conv(64, 128),  # 5 conv2_1
            nn.ReLU(),
            _conv(128, 128),  # 7 conv2_2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            _conv(128, 256),  # 10 conv3_1
            nn.ReLU(),
            _conv(256, 256),  # 12 conv3_2
            nn.ReLU(),
            _conv(256, 256),  # 14 conv3_3
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            _conv(256, 512),  # 17 conv4_1
            nn.ReLU(),
            _conv(512, 512),  # 19 conv4_2
            nn.ReLU(),
            _conv(512, 512),  # 21 conv4_3
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            _conv(512, 512),  # 24 conv5_1
            nn.ReLU(),
            _conv(512, 512),  # 26 conv5_2
            nn.ReLU(),
            _conv(512, 512),  # 28 conv5_3
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        ]

        # Layer indices in self.layers corresponding to named endpoints
        self.endpoint_indices = {
            "relu1_2": 3,
            "relu2_2": 8,
            "relu3_3": 15,
            "relu4_1": 18,
            "relu4_2": 20,
            "relu4_3": 22,
            "relu5_1": 25,
            "relu5_2": 27,
            "relu5_3": 29,
        }

    def forward_with_endpoints(self, x):
        endpoints = {}
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            for name, i in self.endpoint_indices.items():
                if idx == i:
                    endpoints[name] = x
        return x, endpoints

    def __call__(self, x):
        _, endpoints = self.forward_with_endpoints(x)
        return endpoints

    def load_npz(self, path: str):
        data = np.load(path)

        def load_weight(key, transpose=False):
            if key in data:
                w = data[key]
            elif f"{key}_int8" in data:
                w_int8 = data[f"{key}_int8"]
                scale = data[f"{key}_scale"]
                w = w_int8.astype(scale.dtype) * scale
            else:
                raise ValueError(f"Missing key {key} in npz")
            
            if transpose and w.ndim == 4:
                w = np.transpose(w, (0, 2, 3, 1))
            return mx.array(w)

        conv_indices = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
        for idx in conv_indices:
            conv = self.layers[idx]
            weight_key = f"features.{idx}.weight"
            bias_key = f"features.{idx}.bias"
            
            conv.weight = load_weight(weight_key, transpose=True)
            conv.bias = load_weight(bias_key)