|
|
""" |
|
|
VGG16 in MLX with endpoints for relu1_2, relu2_2, relu3_3, relu4_2, relu4_3, |
|
|
relu5_2, relu5_3. Loads weights from a torchvision-exported npz |
|
|
(see export_vgg16_npz.py). |
|
|
""" |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def _conv(in_ch, out_ch, kernel_size=3, padding=1): |
|
|
return nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, bias=True) |
|
|
|
|
|
|
|
|
class VGG16(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.layers = [ |
|
|
_conv(3, 64), |
|
|
nn.ReLU(), |
|
|
_conv(64, 64), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
_conv(64, 128), |
|
|
nn.ReLU(), |
|
|
_conv(128, 128), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
_conv(128, 256), |
|
|
nn.ReLU(), |
|
|
_conv(256, 256), |
|
|
nn.ReLU(), |
|
|
_conv(256, 256), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
_conv(256, 512), |
|
|
nn.ReLU(), |
|
|
_conv(512, 512), |
|
|
nn.ReLU(), |
|
|
_conv(512, 512), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
_conv(512, 512), |
|
|
nn.ReLU(), |
|
|
_conv(512, 512), |
|
|
nn.ReLU(), |
|
|
_conv(512, 512), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
] |
|
|
|
|
|
|
|
|
self.endpoint_indices = { |
|
|
"relu1_2": 3, |
|
|
"relu2_2": 8, |
|
|
"relu3_3": 15, |
|
|
"relu4_1": 18, |
|
|
"relu4_2": 20, |
|
|
"relu4_3": 22, |
|
|
"relu5_1": 25, |
|
|
"relu5_2": 27, |
|
|
"relu5_3": 29, |
|
|
} |
|
|
|
|
|
def forward_with_endpoints(self, x): |
|
|
endpoints = {} |
|
|
for idx, layer in enumerate(self.layers): |
|
|
x = layer(x) |
|
|
for name, i in self.endpoint_indices.items(): |
|
|
if idx == i: |
|
|
endpoints[name] = x |
|
|
return x, endpoints |
|
|
|
|
|
def __call__(self, x): |
|
|
_, endpoints = self.forward_with_endpoints(x) |
|
|
return endpoints |
|
|
|
|
|
def load_npz(self, path: str): |
|
|
data = np.load(path) |
|
|
|
|
|
def load_weight(key, transpose=False): |
|
|
if key in data: |
|
|
w = data[key] |
|
|
elif f"{key}_int8" in data: |
|
|
w_int8 = data[f"{key}_int8"] |
|
|
scale = data[f"{key}_scale"] |
|
|
w = w_int8.astype(scale.dtype) * scale |
|
|
else: |
|
|
raise ValueError(f"Missing key {key} in npz") |
|
|
|
|
|
if transpose and w.ndim == 4: |
|
|
w = np.transpose(w, (0, 2, 3, 1)) |
|
|
return mx.array(w) |
|
|
|
|
|
conv_indices = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28] |
|
|
for idx in conv_indices: |
|
|
conv = self.layers[idx] |
|
|
weight_key = f"features.{idx}.weight" |
|
|
bias_key = f"features.{idx}.bias" |
|
|
|
|
|
conv.weight = load_weight(weight_key, transpose=True) |
|
|
conv.bias = load_weight(bias_key) |
|
|
|