ash12321's picture
Update model.py
7a7f907 verified
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
# The latent dimension must match the output of the MobileNetV2 features
# before the final classifier, which is 1280.
LATENT_DIM = 1280
# --- Helper Functions for Manual MobileNetV2 Reconstruction ---
# Utility functions to build the inverted residual blocks manually
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that 'current value' is not less than 90% of 'new_v'.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=nn.BatchNorm2d):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)
class SqueezeExcitation(nn.Module):
def __init__(self, input_channels, squeeze_factor=4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.avgpool = nn.AdaptiveAvgPool2d(1)
# These keys match the checkpoint: 'spatial_encoder.blocks.0.0.se.conv_reduce.weight', etc.
self.conv_reduce = nn.Conv2d(input_channels, squeeze_channels, 1, bias=True)
self.conv_expand = nn.Conv2d(squeeze_channels, input_channels, 1, bias=True)
def forward(self, x):
scale = self.avgpool(x)
scale = self.conv_reduce(scale)
scale = nn.ReLU(inplace=True)(scale)
scale = self.conv_expand(scale)
scale = nn.Sigmoid()(scale)
return x * scale
class InvertedResidual(nn.Module):
def __init__(self, in_chs, out_chs, stride, expand_ratio, se_layer=None):
super().__init__()
hidden_dim = in_chs * expand_ratio
self.use_res_connect = stride == 1 and in_chs == out_chs
norm_layer = nn.BatchNorm2d # Assume standard BatchNorm
# Blocks are internally labeled to match the checkpoint keys: 'conv_pw', 'bn1', etc.
# Checkpoint key example: 'spatial_encoder.blocks.1.0.conv_pw.weight'
layers = []
if expand_ratio != 1:
# Point-wise expansion
layers.extend([
nn.Conv2d(in_chs, hidden_dim, 1, 1, 0, bias=False), # conv_pw
norm_layer(hidden_dim), # bn1
nn.ReLU6(inplace=True),
])
self.conv_pw = nn.Sequential(*layers[:2]) # conv_pw and bn1
# Depth-wise convolution
self.conv_dw = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), # conv_dw
norm_layer(hidden_dim), # bn2
nn.ReLU6(inplace=True)
)
# Squeeze-and-Excitation
self.se = se_layer(hidden_dim) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = nn.Sequential(
nn.Conv2d(hidden_dim, out_chs, 1, 1, 0, bias=False),
norm_layer(out_chs) # bn3
)
def forward(self, x):
if self.use_res_connect:
# Residual connection
return x + self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x))))
else:
return self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x))))
# --- MAIN DEEPSVDD CLASS USING CUSTOM MOBILELNETV2 STRUCTURE ---
class DeepSVDD(nn.Module):
"""
Deep SVDD model with manually reconstructed MobileNetV3-like structure
to match the checkpoint's layer names (conv_stem, blocks.X.Y.conv_pw, etc.).
"""
def __init__(self, latent_dim=LATENT_DIM):
super().__init__()
norm_layer = nn.BatchNorm2d
# MobileNetV2/V3 Configuration based on standard feature maps (inverted residual blocks)
inverted_residual_setting = [
# t, c, n, s, se
[1, 16, 1, 1, False], # Output 16x32x32
[6, 24, 2, 2, False], # Output 24x16x16, stride 2
[6, 32, 3, 2, False], # Output 32x8x8, stride 2
[6, 64, 4, 2, True], # Output 64x4x4, stride 2, SE included
[6, 96, 3, 1, True], # Output 96x4x4, SE included
[6, 160, 3, 2, True], # Output 160x2x2, stride 2, SE included
[6, 320, 1, 1, True], # Output 320x2x2, SE included
]
# First layer (Matches 'spatial_encoder.conv_stem.weight')
input_channel = 32
self.conv_stem = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = norm_layer(input_channel)
# Inverted Residual Blocks (Matches 'spatial_encoder.blocks...')
blocks = nn.ModuleList()
current_in_channels = input_channel
for t, c, n, s, se in inverted_residual_setting:
out_channel = _make_divisible(c * 1.0, 8) # Assume width multiplier 1.0
se_layer = SqueezeExcitation if se else None
# First block in sequence can have stride > 1
blocks.append(InvertedResidual(current_in_channels, out_channel, s, t, se_layer))
current_in_channels = out_channel
# Remaining n-1 blocks have stride 1
for i in range(n - 1):
blocks.append(InvertedResidual(current_in_channels, out_channel, 1, t, se_layer))
current_in_channels = out_channel
# Final Convolution before pooling (Matches 'spatial_encoder.conv_head.weight')
output_channel = 1280
self.conv_head = nn.Conv2d(current_in_channels, output_channel, 1, 1, 0, bias=False)
self.bn2 = norm_layer(output_channel)
# Combine all parts into the spatial_encoder sequential module
self.spatial_encoder = nn.Sequential(
ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer), # conv_stem/bn1
*blocks,
nn.Sequential(
self.conv_head,
self.bn2
)
)
# Final layers for SVDD
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
# The sequential container has internal numeric indexing (0, 1, 2...)
# but its internal components have the named keys (conv_stem, blocks...)
# that match the checkpoint.
x = self.spatial_encoder(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x