""" CardAuthModel - Multi-Purpose Deep Learning Model. Three head groups: - Head A (PokemonClassifierHead): Pokemon vs Non-Pokemon (ResNet50) - Head B (BackAuthHead): Genuine vs counterfeit back (ResNet50) - Head C (EmbeddingHead x6): Deep SVDD for front anomaly detection SVDD centers stored as register_buffer (saved in state_dict, not gradient-trained). """ import torch import torch.nn as nn from typing import Dict, Optional, List from .backbone import ResNet50Backbone from .efficientnet import EfficientNetB7Backbone from .heads import ( PokemonClassifierHead, BackAuthHead, EmbeddingHead, SVDD_HEAD_CONFIG, create_svdd_heads, get_head_weights, ) from ..utils.logger import get_logger from ..utils.config import config logger = get_logger(__name__) class CardAuthModel(nn.Module): """ Multi-purpose deep learning model for card authentication. Architecture: Input image (B, 3, 224, 224) |-- ResNet50 -> 2048-dim | |-- pokemon_head (Head A) -> P(pokemon) | |-- back_auth_head (Head B) -> P(genuine_back) | |-- SVDD heads (Head C): primary, edge_inspector, | | texture, hologram, historical -> 128-dim each |-- EfficientNet-B7 -> 2560-dim |-- SVDD head: print_quality -> 128-dim SVDD output: weighted 1/(1+dist) scores """ def __init__( self, pretrained: bool = True, freeze_early: bool = True, head_weights: Optional[Dict[str, float]] = None, embed_dim: int = None, ): super().__init__() if embed_dim is None: embed_dim = config.DL_SVDD_EMBEDDING_DIM # Backbones self.resnet = ResNet50Backbone(pretrained=pretrained, freeze_early=freeze_early) self.efficientnet = EfficientNetB7Backbone(pretrained=pretrained, freeze_early=freeze_early) # Head A: Pokemon classifier self.pokemon_head = PokemonClassifierHead(in_dim=self.resnet.output_dim) # Head B: Back authenticator self.back_auth_head = BackAuthHead(in_dim=self.resnet.output_dim) # Head C: SVDD embedding heads (6 heads for component_scores) self.svdd_heads = create_svdd_heads( resnet_dim=self.resnet.output_dim, efficientnet_dim=self.efficientnet.output_dim, embed_dim=embed_dim, ) # Head weights for final SVDD prediction self.head_weights = head_weights or get_head_weights() self.embed_dim = embed_dim # Register SVDD centers as buffers (not trained by gradient) for name in SVDD_HEAD_CONFIG: self.register_buffer( f"center_{name}", torch.zeros(embed_dim), ) # Track whether centers have been initialized self.register_buffer("centers_initialized", torch.tensor(False)) resnet_params = self.resnet.get_trainable_params() efn_params = self.efficientnet.get_trainable_params() logger.info( f"CardAuthModel initialized: " f"ResNet50 ({resnet_params['trainable']:,} trainable), " f"EfficientNet-B7 ({efn_params['trainable']:,} trainable), " f"Head A (pokemon), Head B (back_auth), 6 SVDD heads" ) def get_center(self, name: str) -> torch.Tensor: """Get SVDD center for a named head.""" return getattr(self, f"center_{name}") def set_center(self, name: str, center: torch.Tensor): """Set SVDD center for a named head.""" getattr(self, f"center_{name}").copy_(center) @torch.no_grad() def initialize_centers(self, dataloader, device: torch.device = None): """ Initialize SVDD centers by computing mean embeddings on authentic front data. Only uses samples where is_authentic=1 AND is_back=0 (authentic fronts). Counterfeits and back images are excluded to avoid polluting centers. Args: dataloader: DataLoader yielding (images, metadata) device: Compute device """ if device is None: device = next(self.parameters()).device self.eval() embeddings_accum = {name: [] for name in SVDD_HEAD_CONFIG} total_samples = 0 for batch in dataloader: if isinstance(batch, (list, tuple)) and len(batch) == 2: images, metadata = batch else: images = batch metadata = None images = images.to(device) # Filter to authentic front images only if metadata is not None: is_authentic = metadata.get("is_authentic", torch.ones(images.size(0))) is_back = metadata.get("is_back", torch.zeros(images.size(0))) mask = (is_authentic == 1) & (is_back == 0) if not mask.any(): continue images = images[mask] resnet_features = self.resnet(images) efficientnet_features = self.efficientnet(images) for name, head in self.svdd_heads.items(): if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7": emb = head(efficientnet_features) else: emb = head(resnet_features) embeddings_accum[name].append(emb.cpu()) total_samples += images.size(0) for name in SVDD_HEAD_CONFIG: if len(embeddings_accum[name]) == 0: logger.warning( f"No authentic front embeddings for head '{name}', keeping zero center" ) continue all_emb = torch.cat(embeddings_accum[name], dim=0) center = all_emb.mean(dim=0) self.set_center(name, center.to(device)) self.centers_initialized.fill_(True) logger.info( f"SVDD centers initialized from {total_samples} authentic front samples " f"({len(embeddings_accum['primary'])} batches)" ) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass through all heads. Args: x: Input tensor (B, 3, 224, 224) Returns: Dict with: 'pokemon_score': P(pokemon) (B, 1) 'back_score': P(genuine_back) (B, 1) 'embeddings': Dict[name, (B, embed_dim)] 'distances': Dict[name, (B,)] - ||f(x) - c||^2 'svdd_scores': Dict[name, (B,)] - 1/(1+dist) normalized [0,1] 'prediction': Weighted SVDD score (B, 1) 'head_outputs': Alias for svdd_scores as (B, 1) tensors """ resnet_features = self.resnet(x) efficientnet_features = self.efficientnet(x) # Head A: Pokemon classifier pokemon_score = self.pokemon_head(resnet_features) # Head B: Back authenticator back_score = self.back_auth_head(resnet_features) # Head C: SVDD embeddings embeddings = {} distances = {} svdd_scores = {} for name, head in self.svdd_heads.items(): if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7": emb = head(efficientnet_features) else: emb = head(resnet_features) embeddings[name] = emb center = self.get_center(name) dist = torch.sum((emb - center.unsqueeze(0)) ** 2, dim=1) distances[name] = dist score = 1.0 / (1.0 + dist) svdd_scores[name] = score # Weighted SVDD prediction batch_size = x.size(0) weighted_sum = torch.zeros(batch_size, device=x.device) for name, score in svdd_scores.items(): weighted_sum = weighted_sum + self.head_weights[name] * score # head_outputs: backward-compatible dict of (B, 1) tensors head_outputs = { name: score.unsqueeze(1) for name, score in svdd_scores.items() } return { "pokemon_score": pokemon_score, "back_score": back_score, "embeddings": embeddings, "distances": distances, "svdd_scores": svdd_scores, "prediction": weighted_sum.unsqueeze(1), "head_outputs": head_outputs, } def get_total_params(self) -> Dict[str, int]: """Get total parameter counts.""" trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) total = sum(p.numel() for p in self.parameters()) return { "trainable": trainable, "frozen": total - trainable, "total": total, } def get_param_groups(self, backbone_lr: float = 1e-4, head_lr: float = 1e-3): """ Get parameter groups with discriminative (layer-wise) learning rates. 3 groups: - Early trainable backbone layers (layer3/block6): backbone_lr * 0.1 - Late trainable backbone layers (layer4/block7+): backbone_lr - Head parameters: head_lr Args: backbone_lr: Learning rate for late backbone layers head_lr: Learning rate for head parameters Returns: List of parameter group dicts for optimizer """ resnet_groups = self.resnet.get_layer_groups() # [layer3, layer4] efn_groups = self.efficientnet.get_layer_groups() # [block6, block7+] early_backbone_params = resnet_groups[0] + efn_groups[0] late_backbone_params = resnet_groups[1] + efn_groups[1] head_params = ( list(self.pokemon_head.parameters()) + list(self.back_auth_head.parameters()) + list(self.svdd_heads.parameters()) ) groups = [] if early_backbone_params: groups.append({"params": early_backbone_params, "lr": backbone_lr * 0.1}) if late_backbone_params: groups.append({"params": late_backbone_params, "lr": backbone_lr}) groups.append({"params": head_params, "lr": head_lr}) return groups # Backward-compatible alias CardAuthDLModel = CardAuthModel