File size: 6,641 Bytes
0d16dd2
 
 
 
 
 
 
 
7a7f907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d16dd2
 
7a7f907
 
0d16dd2
 
 
7a7f907
0d16dd2
7a7f907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d16dd2
7a7f907
 
 
0d16dd2
7a7f907
 
 
 
 
 
 
 
 
 
 
 
0d16dd2
7a7f907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d16dd2
 
7a7f907
 
 
0d16dd2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2

# The latent dimension must match the output of the MobileNetV2 features
# before the final classifier, which is 1280.
LATENT_DIM = 1280

# --- Helper Functions for Manual MobileNetV2 Reconstruction ---

# Utility functions to build the inverted residual blocks manually
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that 'current value' is not less than 90% of 'new_v'.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=nn.BatchNorm2d):
        padding = (kernel_size - 1) // 2
        super().__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            norm_layer(out_planes),
            nn.ReLU6(inplace=True)
        )

class SqueezeExcitation(nn.Module):
    def __init__(self, input_channels, squeeze_factor=4):
        super().__init__()
        squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # These keys match the checkpoint: 'spatial_encoder.blocks.0.0.se.conv_reduce.weight', etc.
        self.conv_reduce = nn.Conv2d(input_channels, squeeze_channels, 1, bias=True)
        self.conv_expand = nn.Conv2d(squeeze_channels, input_channels, 1, bias=True)
        
    def forward(self, x):
        scale = self.avgpool(x)
        scale = self.conv_reduce(scale)
        scale = nn.ReLU(inplace=True)(scale)
        scale = self.conv_expand(scale)
        scale = nn.Sigmoid()(scale)
        return x * scale

class InvertedResidual(nn.Module):
    def __init__(self, in_chs, out_chs, stride, expand_ratio, se_layer=None):
        super().__init__()
        hidden_dim = in_chs * expand_ratio
        self.use_res_connect = stride == 1 and in_chs == out_chs
        norm_layer = nn.BatchNorm2d # Assume standard BatchNorm

        # Blocks are internally labeled to match the checkpoint keys: 'conv_pw', 'bn1', etc.
        # Checkpoint key example: 'spatial_encoder.blocks.1.0.conv_pw.weight'
        
        layers = []
        if expand_ratio != 1:
            # Point-wise expansion
            layers.extend([
                nn.Conv2d(in_chs, hidden_dim, 1, 1, 0, bias=False), # conv_pw
                norm_layer(hidden_dim), # bn1
                nn.ReLU6(inplace=True),
            ])
            self.conv_pw = nn.Sequential(*layers[:2]) # conv_pw and bn1
        
        # Depth-wise convolution
        self.conv_dw = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), # conv_dw
            norm_layer(hidden_dim), # bn2
            nn.ReLU6(inplace=True)
        )
        
        # Squeeze-and-Excitation
        self.se = se_layer(hidden_dim) if se_layer else nn.Identity()

        # Point-wise linear projection
        self.conv_pwl = nn.Sequential(
            nn.Conv2d(hidden_dim, out_chs, 1, 1, 0, bias=False),
            norm_layer(out_chs) # bn3
        )
        
    def forward(self, x):
        if self.use_res_connect:
            # Residual connection
            return x + self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x))))
        else:
            return self.conv_pwl(self.se(self.conv_dw(self.conv_pw(x))))

# --- MAIN DEEPSVDD CLASS USING CUSTOM MOBILELNETV2 STRUCTURE ---

class DeepSVDD(nn.Module):
    """
    Deep SVDD model with manually reconstructed MobileNetV3-like structure 
    to match the checkpoint's layer names (conv_stem, blocks.X.Y.conv_pw, etc.).
    """
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        norm_layer = nn.BatchNorm2d
        
        # MobileNetV2/V3 Configuration based on standard feature maps (inverted residual blocks)
        inverted_residual_setting = [
            # t, c, n, s, se
            [1, 16, 1, 1, False], # Output 16x32x32
            [6, 24, 2, 2, False], # Output 24x16x16, stride 2
            [6, 32, 3, 2, False], # Output 32x8x8, stride 2
            [6, 64, 4, 2, True],  # Output 64x4x4, stride 2, SE included
            [6, 96, 3, 1, True],  # Output 96x4x4, SE included
            [6, 160, 3, 2, True], # Output 160x2x2, stride 2, SE included
            [6, 320, 1, 1, True], # Output 320x2x2, SE included
        ]

        # First layer (Matches 'spatial_encoder.conv_stem.weight')
        input_channel = 32
        self.conv_stem = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = norm_layer(input_channel)
        
        # Inverted Residual Blocks (Matches 'spatial_encoder.blocks...')
        blocks = nn.ModuleList()
        current_in_channels = input_channel
        
        for t, c, n, s, se in inverted_residual_setting:
            out_channel = _make_divisible(c * 1.0, 8) # Assume width multiplier 1.0
            se_layer = SqueezeExcitation if se else None
            
            # First block in sequence can have stride > 1
            blocks.append(InvertedResidual(current_in_channels, out_channel, s, t, se_layer))
            current_in_channels = out_channel
            
            # Remaining n-1 blocks have stride 1
            for i in range(n - 1):
                blocks.append(InvertedResidual(current_in_channels, out_channel, 1, t, se_layer))
                current_in_channels = out_channel

        # Final Convolution before pooling (Matches 'spatial_encoder.conv_head.weight')
        output_channel = 1280
        self.conv_head = nn.Conv2d(current_in_channels, output_channel, 1, 1, 0, bias=False)
        self.bn2 = norm_layer(output_channel)
        
        # Combine all parts into the spatial_encoder sequential module
        self.spatial_encoder = nn.Sequential(
            ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer), # conv_stem/bn1
            *blocks,
            nn.Sequential(
                self.conv_head,
                self.bn2
            )
        )
        
        # Final layers for SVDD
        self.avgpool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        # The sequential container has internal numeric indexing (0, 1, 2...)
        # but its internal components have the named keys (conv_stem, blocks...) 
        # that match the checkpoint.
        x = self.spatial_encoder(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        
        return x