File size: 10,250 Bytes
0ba6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61ba70
 
 
 
 
 
0ba6002
 
c61ba70
0ba6002
 
 
 
 
c61ba70
 
 
 
 
 
0ba6002
 
 
 
 
 
c61ba70
 
 
 
 
 
 
 
0ba6002
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""
CardAuthModel - Multi-Purpose Deep Learning Model.

Three head groups:
- Head A (PokemonClassifierHead): Pokemon vs Non-Pokemon (ResNet50)
- Head B (BackAuthHead): Genuine vs counterfeit back (ResNet50)
- Head C (EmbeddingHead x6): Deep SVDD for front anomaly detection

SVDD centers stored as register_buffer (saved in state_dict, not gradient-trained).
"""

import torch
import torch.nn as nn
from typing import Dict, Optional, List

from .backbone import ResNet50Backbone
from .efficientnet import EfficientNetB7Backbone
from .heads import (
    PokemonClassifierHead,
    BackAuthHead,
    EmbeddingHead,
    SVDD_HEAD_CONFIG,
    create_svdd_heads,
    get_head_weights,
)
from ..utils.logger import get_logger
from ..utils.config import config

logger = get_logger(__name__)


class CardAuthModel(nn.Module):
    """
    Multi-purpose deep learning model for card authentication.

    Architecture:
        Input image (B, 3, 224, 224)
            |-- ResNet50 -> 2048-dim
            |       |-- pokemon_head (Head A) -> P(pokemon)
            |       |-- back_auth_head (Head B) -> P(genuine_back)
            |       |-- SVDD heads (Head C): primary, edge_inspector,
            |       |       texture, hologram, historical -> 128-dim each
            |-- EfficientNet-B7 -> 2560-dim
                    |-- SVDD head: print_quality -> 128-dim

        SVDD output: weighted 1/(1+dist) scores
    """

    def __init__(
        self,
        pretrained: bool = True,
        freeze_early: bool = True,
        head_weights: Optional[Dict[str, float]] = None,
        embed_dim: int = None,
    ):
        super().__init__()

        if embed_dim is None:
            embed_dim = config.DL_SVDD_EMBEDDING_DIM

        # Backbones
        self.resnet = ResNet50Backbone(pretrained=pretrained, freeze_early=freeze_early)
        self.efficientnet = EfficientNetB7Backbone(pretrained=pretrained, freeze_early=freeze_early)

        # Head A: Pokemon classifier
        self.pokemon_head = PokemonClassifierHead(in_dim=self.resnet.output_dim)

        # Head B: Back authenticator
        self.back_auth_head = BackAuthHead(in_dim=self.resnet.output_dim)

        # Head C: SVDD embedding heads (6 heads for component_scores)
        self.svdd_heads = create_svdd_heads(
            resnet_dim=self.resnet.output_dim,
            efficientnet_dim=self.efficientnet.output_dim,
            embed_dim=embed_dim,
        )

        # Head weights for final SVDD prediction
        self.head_weights = head_weights or get_head_weights()
        self.embed_dim = embed_dim

        # Register SVDD centers as buffers (not trained by gradient)
        for name in SVDD_HEAD_CONFIG:
            self.register_buffer(
                f"center_{name}",
                torch.zeros(embed_dim),
            )

        # Track whether centers have been initialized
        self.register_buffer("centers_initialized", torch.tensor(False))

        resnet_params = self.resnet.get_trainable_params()
        efn_params = self.efficientnet.get_trainable_params()
        logger.info(
            f"CardAuthModel initialized: "
            f"ResNet50 ({resnet_params['trainable']:,} trainable), "
            f"EfficientNet-B7 ({efn_params['trainable']:,} trainable), "
            f"Head A (pokemon), Head B (back_auth), 6 SVDD heads"
        )

    def get_center(self, name: str) -> torch.Tensor:
        """Get SVDD center for a named head."""
        return getattr(self, f"center_{name}")

    def set_center(self, name: str, center: torch.Tensor):
        """Set SVDD center for a named head."""
        getattr(self, f"center_{name}").copy_(center)

    @torch.no_grad()
    def initialize_centers(self, dataloader, device: torch.device = None):
        """
        Initialize SVDD centers by computing mean embeddings on authentic front data.

        Only uses samples where is_authentic=1 AND is_back=0 (authentic fronts).
        Counterfeits and back images are excluded to avoid polluting centers.

        Args:
            dataloader: DataLoader yielding (images, metadata)
            device: Compute device
        """
        if device is None:
            device = next(self.parameters()).device

        self.eval()
        embeddings_accum = {name: [] for name in SVDD_HEAD_CONFIG}
        total_samples = 0

        for batch in dataloader:
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                images, metadata = batch
            else:
                images = batch
                metadata = None

            images = images.to(device)

            # Filter to authentic front images only
            if metadata is not None:
                is_authentic = metadata.get("is_authentic", torch.ones(images.size(0)))
                is_back = metadata.get("is_back", torch.zeros(images.size(0)))
                mask = (is_authentic == 1) & (is_back == 0)
                if not mask.any():
                    continue
                images = images[mask]

            resnet_features = self.resnet(images)
            efficientnet_features = self.efficientnet(images)

            for name, head in self.svdd_heads.items():
                if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7":
                    emb = head(efficientnet_features)
                else:
                    emb = head(resnet_features)
                embeddings_accum[name].append(emb.cpu())

            total_samples += images.size(0)

        for name in SVDD_HEAD_CONFIG:
            if len(embeddings_accum[name]) == 0:
                logger.warning(
                    f"No authentic front embeddings for head '{name}', keeping zero center"
                )
                continue
            all_emb = torch.cat(embeddings_accum[name], dim=0)
            center = all_emb.mean(dim=0)
            self.set_center(name, center.to(device))

        self.centers_initialized.fill_(True)
        logger.info(
            f"SVDD centers initialized from {total_samples} authentic front samples "
            f"({len(embeddings_accum['primary'])} batches)"
        )

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass through all heads.

        Args:
            x: Input tensor (B, 3, 224, 224)

        Returns:
            Dict with:
                'pokemon_score': P(pokemon) (B, 1)
                'back_score': P(genuine_back) (B, 1)
                'embeddings': Dict[name, (B, embed_dim)]
                'distances': Dict[name, (B,)] - ||f(x) - c||^2
                'svdd_scores': Dict[name, (B,)] - 1/(1+dist) normalized [0,1]
                'prediction': Weighted SVDD score (B, 1)
                'head_outputs': Alias for svdd_scores as (B, 1) tensors
        """
        resnet_features = self.resnet(x)
        efficientnet_features = self.efficientnet(x)

        # Head A: Pokemon classifier
        pokemon_score = self.pokemon_head(resnet_features)

        # Head B: Back authenticator
        back_score = self.back_auth_head(resnet_features)

        # Head C: SVDD embeddings
        embeddings = {}
        distances = {}
        svdd_scores = {}

        for name, head in self.svdd_heads.items():
            if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7":
                emb = head(efficientnet_features)
            else:
                emb = head(resnet_features)

            embeddings[name] = emb

            center = self.get_center(name)
            dist = torch.sum((emb - center.unsqueeze(0)) ** 2, dim=1)
            distances[name] = dist

            score = 1.0 / (1.0 + dist)
            svdd_scores[name] = score

        # Weighted SVDD prediction
        batch_size = x.size(0)
        weighted_sum = torch.zeros(batch_size, device=x.device)
        for name, score in svdd_scores.items():
            weighted_sum = weighted_sum + self.head_weights[name] * score

        # head_outputs: backward-compatible dict of (B, 1) tensors
        head_outputs = {
            name: score.unsqueeze(1) for name, score in svdd_scores.items()
        }

        return {
            "pokemon_score": pokemon_score,
            "back_score": back_score,
            "embeddings": embeddings,
            "distances": distances,
            "svdd_scores": svdd_scores,
            "prediction": weighted_sum.unsqueeze(1),
            "head_outputs": head_outputs,
        }

    def get_total_params(self) -> Dict[str, int]:
        """Get total parameter counts."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        return {
            "trainable": trainable,
            "frozen": total - trainable,
            "total": total,
        }

    def get_param_groups(self, backbone_lr: float = 1e-4, head_lr: float = 1e-3):
        """
        Get parameter groups with discriminative (layer-wise) learning rates.

        3 groups:
        - Early trainable backbone layers (layer3/block6): backbone_lr * 0.1
        - Late trainable backbone layers (layer4/block7+): backbone_lr
        - Head parameters: head_lr

        Args:
            backbone_lr: Learning rate for late backbone layers
            head_lr: Learning rate for head parameters

        Returns:
            List of parameter group dicts for optimizer
        """
        resnet_groups = self.resnet.get_layer_groups()   # [layer3, layer4]
        efn_groups = self.efficientnet.get_layer_groups() # [block6, block7+]

        early_backbone_params = resnet_groups[0] + efn_groups[0]
        late_backbone_params = resnet_groups[1] + efn_groups[1]

        head_params = (
            list(self.pokemon_head.parameters())
            + list(self.back_auth_head.parameters())
            + list(self.svdd_heads.parameters())
        )

        groups = []
        if early_backbone_params:
            groups.append({"params": early_backbone_params, "lr": backbone_lr * 0.1})
        if late_backbone_params:
            groups.append({"params": late_backbone_params, "lr": backbone_lr})
        groups.append({"params": head_params, "lr": head_lr})

        return groups


# Backward-compatible alias
CardAuthDLModel = CardAuthModel