File size: 6,641 Bytes
0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 7a7f907 0d16dd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
# The latent dimension must match the output of the MobileNetV2 features
# before the final classifier, which is 1280.
LATENT_DIM = 1280
# --- Helper Functions for Manual MobileNetV2 Reconstruction ---
# Utility functions to build the inverted residual blocks manually
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that 'current value' is not less than 90% of 'new_v'.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=nn.BatchNorm2d):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)
class SqueezeExcitation(nn.Module):
def __init__(self, input_channels, squeeze_factor=4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.avgpool = nn.AdaptiveAvgPool2d(1)
# These keys match the checkpoint: 'spatial_encoder.blocks.0.0.se.conv_reduce.weight', etc.
self.conv_reduce = nn.Conv2d(input_channels, squeeze_channels, 1, bias=True)
self.conv_expand = nn.Conv2d(squeeze_channels, input_channels, 1, bias=True)
def forward(self, x):
scale = self.avgpool(x)
scale = self.conv_reduce(scale)
scale = nn.ReLU(inplace=True)(scale)
scale = self.conv_expand(scale)
scale = nn.Sigmoid()(scale)
return x * scale
class InvertedResidual(nn.Module):
def __init__(self, in_chs, out_chs, stride, expand_ratio, se_layer=None):
super().__init__()
hidden_dim = in_chs * expand_ratio
self.use_res_connect = stride == 1 and in_chs == out_chs
norm_layer = nn.BatchNorm2d # Assume standard BatchNorm
# Blocks are internally labeled to match the checkpoint keys: 'conv_pw', 'bn1', etc.
# Checkpoint key example: 'spatial_encoder.blocks.1.0.conv_pw.weight'
layers = []
if expand_ratio != 1:
# Point-wise expansion
layers.extend([
nn.Conv2d(in_chs, hidden_dim, 1, 1, 0, bias=False), # conv_pw
norm_layer(hidden_dim), # bn1
nn.ReLU6(inplace=True),
])
self.conv_pw = nn.Sequential(*layers[:2]) # conv_pw and bn1
# Depth-wise convolution
self.conv_dw = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), # conv_dw
norm_layer(hidden_dim), # bn2
nn.ReLU6(inplace=True)
)
# Squeeze-and-Excitation
self.se = se_layer(hidden_dim) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = nn.Sequential(
nn.Conv2d(hidden_dim, out_chs, 1, 1, 0, bias=False),
norm_layer(out_chs) # bn3
)
def forward(self, x):
if self.use_res_connect:
# Residual connection
return x + self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x))))
else:
return self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x))))
# --- MAIN DEEPSVDD CLASS USING CUSTOM MOBILELNETV2 STRUCTURE ---
class DeepSVDD(nn.Module):
"""
Deep SVDD model with manually reconstructed MobileNetV3-like structure
to match the checkpoint's layer names (conv_stem, blocks.X.Y.conv_pw, etc.).
"""
def __init__(self, latent_dim=LATENT_DIM):
super().__init__()
norm_layer = nn.BatchNorm2d
# MobileNetV2/V3 Configuration based on standard feature maps (inverted residual blocks)
inverted_residual_setting = [
# t, c, n, s, se
[1, 16, 1, 1, False], # Output 16x32x32
[6, 24, 2, 2, False], # Output 24x16x16, stride 2
[6, 32, 3, 2, False], # Output 32x8x8, stride 2
[6, 64, 4, 2, True], # Output 64x4x4, stride 2, SE included
[6, 96, 3, 1, True], # Output 96x4x4, SE included
[6, 160, 3, 2, True], # Output 160x2x2, stride 2, SE included
[6, 320, 1, 1, True], # Output 320x2x2, SE included
]
# First layer (Matches 'spatial_encoder.conv_stem.weight')
input_channel = 32
self.conv_stem = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(input_channel)
# Inverted Residual Blocks (Matches 'spatial_encoder.blocks...')
blocks = nn.ModuleList()
current_in_channels = input_channel
for t, c, n, s, se in inverted_residual_setting:
out_channel = _make_divisible(c * 1.0, 8) # Assume width multiplier 1.0
se_layer = SqueezeExcitation if se else None
# First block in sequence can have stride > 1
blocks.append(InvertedResidual(current_in_channels, out_channel, s, t, se_layer))
current_in_channels = out_channel
# Remaining n-1 blocks have stride 1
for i in range(n - 1):
blocks.append(InvertedResidual(current_in_channels, out_channel, 1, t, se_layer))
current_in_channels = out_channel
# Final Convolution before pooling (Matches 'spatial_encoder.conv_head.weight')
output_channel = 1280
self.conv_head = nn.Conv2d(current_in_channels, output_channel, 1, 1, 0, bias=False)
self.bn2 = norm_layer(output_channel)
# Combine all parts into the spatial_encoder sequential module
self.spatial_encoder = nn.Sequential(
ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer), # conv_stem/bn1
*blocks,
nn.Sequential(
self.conv_head,
self.bn2
)
)
# Final layers for SVDD
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
# The sequential container has internal numeric indexing (0, 1, 2...)
# but its internal components have the named keys (conv_stem, blocks...)
# that match the checkpoint.
x = self.spatial_encoder(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x |