File size: 3,562 Bytes
3c8f058 bdff6f4 3c8f058 bdff6f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """
VGG19 in MLX with endpoints for common DeepDream layers.
Loads weights from a torchvision-exported npz (see export_vgg19_npz.py).
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
def _conv(in_ch, out_ch, kernel_size=3, padding=1):
return nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, bias=True)
class VGG19(nn.Module):
def __init__(self):
super().__init__()
# Mirrors torchvision.models.vgg19(features) layout
self.layers = [
_conv(3, 64), # 0 conv1_1
nn.ReLU(),
_conv(64, 64), # 2 conv1_2
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
_conv(64, 128), # 5 conv2_1
nn.ReLU(),
_conv(128, 128), # 7 conv2_2
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
_conv(128, 256), # 10 conv3_1
nn.ReLU(),
_conv(256, 256), # 12 conv3_2
nn.ReLU(),
_conv(256, 256), # 14 conv3_3
nn.ReLU(),
_conv(256, 256), # 16 conv3_4
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
_conv(256, 512), # 19 conv4_1
nn.ReLU(),
_conv(512, 512), # 21 conv4_2
nn.ReLU(),
_conv(512, 512), # 23 conv4_3
nn.ReLU(),
_conv(512, 512), # 25 conv4_4
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
_conv(512, 512), # 28 conv5_1
nn.ReLU(),
_conv(512, 512), # 30 conv5_2
nn.ReLU(),
_conv(512, 512), # 32 conv5_3
nn.ReLU(),
_conv(512, 512), # 34 conv5_4
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
]
self.endpoint_indices = {
"relu1_2": 3,
"relu2_2": 8,
"relu3_2": 13,
"relu3_3": 15,
"relu3_4": 17,
"relu4_1": 20,
"relu4_2": 22,
"relu4_3": 24,
"relu4_4": 26,
"relu5_1": 29,
"relu5_2": 31,
"relu5_3": 33,
"relu5_4": 35,
}
def forward_with_endpoints(self, x):
endpoints = {}
for idx, layer in enumerate(self.layers):
x = layer(x)
for name, i in self.endpoint_indices.items():
if idx == i:
endpoints[name] = x
return x, endpoints
def __call__(self, x):
_, endpoints = self.forward_with_endpoints(x)
return endpoints
def load_npz(self, path: str):
data = np.load(path)
def load_weight(key, transpose=False):
if key in data:
w = data[key]
elif f"{key}_int8" in data:
w_int8 = data[f"{key}_int8"]
scale = data[f"{key}_scale"]
w = w_int8.astype(scale.dtype) * scale
else:
raise ValueError(f"Missing key {key} in npz")
if transpose and w.ndim == 4:
w = np.transpose(w, (0, 2, 3, 1))
return mx.array(w)
conv_indices = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
for idx in conv_indices:
conv = self.layers[idx]
weight_key = f"features.{idx}.weight"
bias_key = f"features.{idx}.bias"
conv.weight = load_weight(weight_key, transpose=True)
conv.bias = load_weight(bias_key)
|