| """ |
| Minimal GoogLeNet (Inception V1) in MLX, up to inception4e. |
| Loads weights from a torchvision-exported npz (see export_googlenet_npz.py). |
| """ |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| import numpy as np |
|
|
|
|
| def _conv_bn(in_ch, out_ch, kernel_size, stride=1, padding=0): |
| return nn.Sequential( |
| nn.Conv2d( |
| in_ch, |
| out_ch, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| bias=False, |
| ), |
| nn.BatchNorm(out_ch, eps=1e-3, momentum=0.1), |
| nn.ReLU(), |
| ) |
|
|
|
|
| class Inception(nn.Module): |
| def __init__(self, in_ch, ch1, ch3r, ch3, ch5r, ch5, pool_proj): |
| super().__init__() |
| self.branch1 = _conv_bn(in_ch, ch1, 1) |
|
|
| self.branch2_1 = _conv_bn(in_ch, ch3r, 1) |
| self.branch2_2 = _conv_bn(ch3r, ch3, 3, padding=1) |
|
|
| self.branch3_1 = _conv_bn(in_ch, ch5r, 1) |
| |
| self.branch3_2 = _conv_bn(ch5r, ch5, 3, padding=1) |
|
|
| self.branch4_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) |
| self.branch4_2 = _conv_bn(in_ch, pool_proj, 1) |
|
|
| def __call__(self, x): |
| b1 = self.branch1(x) |
| b2 = self.branch2_2(self.branch2_1(x)) |
| b3 = self.branch3_2(self.branch3_1(x)) |
| b4 = self.branch4_2(self.branch4_pool(x)) |
| return mx.concatenate([b1, b2, b3, b4], axis=-1) |
|
|
|
|
| class GoogLeNet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = _conv_bn(3, 64, 7, stride=2, padding=3) |
| self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) |
|
|
| self.conv2 = _conv_bn(64, 64, 1) |
| self.conv3 = _conv_bn(64, 192, 3, padding=1) |
| self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) |
|
|
| self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) |
| self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) |
| self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) |
|
|
| self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) |
| self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) |
| self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) |
| self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) |
| self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) |
| self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
| self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) |
| self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) |
|
|
| def forward_with_endpoints(self, x): |
| endpoints = {} |
| x = self.conv1(x) |
| x = self.maxpool1(x) |
|
|
| x = self.conv2(x) |
| x = self.conv3(x) |
| x = self.maxpool2(x) |
|
|
| x = self.inception3a(x) |
| endpoints["inception3a"] = x |
| x = self.inception3b(x) |
| endpoints["inception3b"] = x |
| x = self.maxpool3(x) |
|
|
| x = self.inception4a(x) |
| endpoints["inception4a"] = x |
| x = self.inception4b(x) |
| endpoints["inception4b"] = x |
| x = self.inception4c(x) |
| endpoints["inception4c"] = x |
| x = self.inception4d(x) |
| endpoints["inception4d"] = x |
| x = self.inception4e(x) |
| endpoints["inception4e"] = x |
| x = self.maxpool4(x) |
|
|
| x = self.inception5a(x) |
| endpoints["inception5a"] = x |
| x = self.inception5b(x) |
| endpoints["inception5b"] = x |
| return x, endpoints |
|
|
| def __call__(self, x): |
| _, endpoints = self.forward_with_endpoints(x) |
| return endpoints |
|
|
| def load_npz(self, path: str): |
| data = np.load(path) |
|
|
| def load_weight(key, target_module, param_name="weight", transpose=False): |
| |
| if key in data: |
| w = data[key] |
| |
| elif f"{key}_int8" in data: |
| w_int8 = data[f"{key}_int8"] |
| scale = data[f"{key}_scale"] |
| |
| w = w_int8.astype(scale.dtype) * scale |
| else: |
| raise ValueError(f"Missing key {key} (or {key}_int8) in npz") |
|
|
| |
| if transpose and w.ndim == 4: |
| w = np.transpose(w, (0, 2, 3, 1)) |
| |
| |
| target_module[param_name] = mx.array(w) |
|
|
| def load_conv_bn(prefix, seq_mod: nn.Sequential): |
| conv = seq_mod.layers[0] |
| bn = seq_mod.layers[1] |
| |
| load_weight(f"{prefix}.conv.weight", conv, transpose=True) |
| |
| load_weight(f"{prefix}.bn.weight", bn) |
| load_weight(f"{prefix}.bn.bias", bn, param_name="bias") |
| load_weight(f"{prefix}.bn.running_mean", bn, param_name="running_mean") |
| load_weight(f"{prefix}.bn.running_var", bn, param_name="running_var") |
|
|
| load_conv_bn("conv1", self.conv1) |
| load_conv_bn("conv2", self.conv2) |
| load_conv_bn("conv3", self.conv3) |
|
|
| def load_inception(prefix, module: Inception): |
| load_conv_bn(f"{prefix}.branch1", module.branch1) |
| load_conv_bn(f"{prefix}.branch2.0", module.branch2_1) |
| load_conv_bn(f"{prefix}.branch2.1", module.branch2_2) |
| load_conv_bn(f"{prefix}.branch3.0", module.branch3_1) |
| load_conv_bn(f"{prefix}.branch3.1", module.branch3_2) |
| load_conv_bn(f"{prefix}.branch4.1", module.branch4_2) |
|
|
| load_inception("inception3a", self.inception3a) |
| load_inception("inception3b", self.inception3b) |
| load_inception("inception4a", self.inception4a) |
| load_inception("inception4b", self.inception4b) |
| load_inception("inception4c", self.inception4c) |
| load_inception("inception4d", self.inception4d) |
| load_inception("inception4e", self.inception4e) |
| load_inception("inception5a", self.inception5a) |
| load_inception("inception5b", self.inception5b) |
|
|