import torch import torch.nn as nn from torchvision.models import mobilenet_v2 # The latent dimension must match the output of the MobileNetV2 features # before the final classifier, which is 1280. LATENT_DIM = 1280 # --- Helper Functions for Manual MobileNetV2 Reconstruction --- # Utility functions to build the inverted residual blocks manually def _make_divisible(v, divisor, min_value=None): if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that 'current value' is not less than 90% of 'new_v'. if new_v < 0.9 * v: new_v += divisor return new_v class ConvBNReLU(nn.Sequential): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=nn.BatchNorm2d): padding = (kernel_size - 1) // 2 super().__init__( nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), norm_layer(out_planes), nn.ReLU6(inplace=True) ) class SqueezeExcitation(nn.Module): def __init__(self, input_channels, squeeze_factor=4): super().__init__() squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) self.avgpool = nn.AdaptiveAvgPool2d(1) # These keys match the checkpoint: 'spatial_encoder.blocks.0.0.se.conv_reduce.weight', etc. self.conv_reduce = nn.Conv2d(input_channels, squeeze_channels, 1, bias=True) self.conv_expand = nn.Conv2d(squeeze_channels, input_channels, 1, bias=True) def forward(self, x): scale = self.avgpool(x) scale = self.conv_reduce(scale) scale = nn.ReLU(inplace=True)(scale) scale = self.conv_expand(scale) scale = nn.Sigmoid()(scale) return x * scale class InvertedResidual(nn.Module): def __init__(self, in_chs, out_chs, stride, expand_ratio, se_layer=None): super().__init__() hidden_dim = in_chs * expand_ratio self.use_res_connect = stride == 1 and in_chs == out_chs norm_layer = nn.BatchNorm2d # Assume standard BatchNorm # Blocks are internally labeled to match the checkpoint keys: 'conv_pw', 'bn1', etc. # Checkpoint key example: 'spatial_encoder.blocks.1.0.conv_pw.weight' layers = [] if expand_ratio != 1: # Point-wise expansion layers.extend([ nn.Conv2d(in_chs, hidden_dim, 1, 1, 0, bias=False), # conv_pw norm_layer(hidden_dim), # bn1 nn.ReLU6(inplace=True), ]) self.conv_pw = nn.Sequential(*layers[:2]) # conv_pw and bn1 # Depth-wise convolution self.conv_dw = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), # conv_dw norm_layer(hidden_dim), # bn2 nn.ReLU6(inplace=True) ) # Squeeze-and-Excitation self.se = se_layer(hidden_dim) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = nn.Sequential( nn.Conv2d(hidden_dim, out_chs, 1, 1, 0, bias=False), norm_layer(out_chs) # bn3 ) def forward(self, x): if self.use_res_connect: # Residual connection return x + self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x)))) else: return self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x)))) # --- MAIN DEEPSVDD CLASS USING CUSTOM MOBILELNETV2 STRUCTURE --- class DeepSVDD(nn.Module): """ Deep SVDD model with manually reconstructed MobileNetV3-like structure to match the checkpoint's layer names (conv_stem, blocks.X.Y.conv_pw, etc.). """ def __init__(self, latent_dim=LATENT_DIM): super().__init__() norm_layer = nn.BatchNorm2d # MobileNetV2/V3 Configuration based on standard feature maps (inverted residual blocks) inverted_residual_setting = [ # t, c, n, s, se [1, 16, 1, 1, False], # Output 16x32x32 [6, 24, 2, 2, False], # Output 24x16x16, stride 2 [6, 32, 3, 2, False], # Output 32x8x8, stride 2 [6, 64, 4, 2, True], # Output 64x4x4, stride 2, SE included [6, 96, 3, 1, True], # Output 96x4x4, SE included [6, 160, 3, 2, True], # Output 160x2x2, stride 2, SE included [6, 320, 1, 1, True], # Output 320x2x2, SE included ] # First layer (Matches 'spatial_encoder.conv_stem.weight') input_channel = 32 self.conv_stem = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = norm_layer(input_channel) # Inverted Residual Blocks (Matches 'spatial_encoder.blocks...') blocks = nn.ModuleList() current_in_channels = input_channel for t, c, n, s, se in inverted_residual_setting: out_channel = _make_divisible(c * 1.0, 8) # Assume width multiplier 1.0 se_layer = SqueezeExcitation if se else None # First block in sequence can have stride > 1 blocks.append(InvertedResidual(current_in_channels, out_channel, s, t, se_layer)) current_in_channels = out_channel # Remaining n-1 blocks have stride 1 for i in range(n - 1): blocks.append(InvertedResidual(current_in_channels, out_channel, 1, t, se_layer)) current_in_channels = out_channel # Final Convolution before pooling (Matches 'spatial_encoder.conv_head.weight') output_channel = 1280 self.conv_head = nn.Conv2d(current_in_channels, output_channel, 1, 1, 0, bias=False) self.bn2 = norm_layer(output_channel) # Combine all parts into the spatial_encoder sequential module self.spatial_encoder = nn.Sequential( ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer), # conv_stem/bn1 *blocks, nn.Sequential( self.conv_head, self.bn2 ) ) # Final layers for SVDD self.avgpool = nn.AdaptiveAvgPool2d(1) def forward(self, x): # The sequential container has internal numeric indexing (0, 1, 2...) # but its internal components have the named keys (conv_stem, blocks...) # that match the checkpoint. x = self.spatial_encoder(x) x = self.avgpool(x) x = torch.flatten(x, 1) return x