Spaces:
Sleeping
Sleeping
| """ | |
| CardAuthModel - Multi-Purpose Deep Learning Model. | |
| Three head groups: | |
| - Head A (PokemonClassifierHead): Pokemon vs Non-Pokemon (ResNet50) | |
| - Head B (BackAuthHead): Genuine vs counterfeit back (ResNet50) | |
| - Head C (EmbeddingHead x6): Deep SVDD for front anomaly detection | |
| SVDD centers stored as register_buffer (saved in state_dict, not gradient-trained). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Dict, Optional, List | |
| from .backbone import ResNet50Backbone | |
| from .efficientnet import EfficientNetB7Backbone | |
| from .heads import ( | |
| PokemonClassifierHead, | |
| BackAuthHead, | |
| EmbeddingHead, | |
| SVDD_HEAD_CONFIG, | |
| create_svdd_heads, | |
| get_head_weights, | |
| ) | |
| from ..utils.logger import get_logger | |
| from ..utils.config import config | |
| logger = get_logger(__name__) | |
| class CardAuthModel(nn.Module): | |
| """ | |
| Multi-purpose deep learning model for card authentication. | |
| Architecture: | |
| Input image (B, 3, 224, 224) | |
| |-- ResNet50 -> 2048-dim | |
| | |-- pokemon_head (Head A) -> P(pokemon) | |
| | |-- back_auth_head (Head B) -> P(genuine_back) | |
| | |-- SVDD heads (Head C): primary, edge_inspector, | |
| | | texture, hologram, historical -> 128-dim each | |
| |-- EfficientNet-B7 -> 2560-dim | |
| |-- SVDD head: print_quality -> 128-dim | |
| SVDD output: weighted 1/(1+dist) scores | |
| """ | |
| def __init__( | |
| self, | |
| pretrained: bool = True, | |
| freeze_early: bool = True, | |
| head_weights: Optional[Dict[str, float]] = None, | |
| embed_dim: int = None, | |
| ): | |
| super().__init__() | |
| if embed_dim is None: | |
| embed_dim = config.DL_SVDD_EMBEDDING_DIM | |
| # Backbones | |
| self.resnet = ResNet50Backbone(pretrained=pretrained, freeze_early=freeze_early) | |
| self.efficientnet = EfficientNetB7Backbone(pretrained=pretrained, freeze_early=freeze_early) | |
| # Head A: Pokemon classifier | |
| self.pokemon_head = PokemonClassifierHead(in_dim=self.resnet.output_dim) | |
| # Head B: Back authenticator | |
| self.back_auth_head = BackAuthHead(in_dim=self.resnet.output_dim) | |
| # Head C: SVDD embedding heads (6 heads for component_scores) | |
| self.svdd_heads = create_svdd_heads( | |
| resnet_dim=self.resnet.output_dim, | |
| efficientnet_dim=self.efficientnet.output_dim, | |
| embed_dim=embed_dim, | |
| ) | |
| # Head weights for final SVDD prediction | |
| self.head_weights = head_weights or get_head_weights() | |
| self.embed_dim = embed_dim | |
| # Register SVDD centers as buffers (not trained by gradient) | |
| for name in SVDD_HEAD_CONFIG: | |
| self.register_buffer( | |
| f"center_{name}", | |
| torch.zeros(embed_dim), | |
| ) | |
| # Track whether centers have been initialized | |
| self.register_buffer("centers_initialized", torch.tensor(False)) | |
| resnet_params = self.resnet.get_trainable_params() | |
| efn_params = self.efficientnet.get_trainable_params() | |
| logger.info( | |
| f"CardAuthModel initialized: " | |
| f"ResNet50 ({resnet_params['trainable']:,} trainable), " | |
| f"EfficientNet-B7 ({efn_params['trainable']:,} trainable), " | |
| f"Head A (pokemon), Head B (back_auth), 6 SVDD heads" | |
| ) | |
| def get_center(self, name: str) -> torch.Tensor: | |
| """Get SVDD center for a named head.""" | |
| return getattr(self, f"center_{name}") | |
| def set_center(self, name: str, center: torch.Tensor): | |
| """Set SVDD center for a named head.""" | |
| getattr(self, f"center_{name}").copy_(center) | |
| def initialize_centers(self, dataloader, device: torch.device = None): | |
| """ | |
| Initialize SVDD centers by computing mean embeddings on authentic front data. | |
| Only uses samples where is_authentic=1 AND is_back=0 (authentic fronts). | |
| Counterfeits and back images are excluded to avoid polluting centers. | |
| Args: | |
| dataloader: DataLoader yielding (images, metadata) | |
| device: Compute device | |
| """ | |
| if device is None: | |
| device = next(self.parameters()).device | |
| self.eval() | |
| embeddings_accum = {name: [] for name in SVDD_HEAD_CONFIG} | |
| total_samples = 0 | |
| for batch in dataloader: | |
| if isinstance(batch, (list, tuple)) and len(batch) == 2: | |
| images, metadata = batch | |
| else: | |
| images = batch | |
| metadata = None | |
| images = images.to(device) | |
| # Filter to authentic front images only | |
| if metadata is not None: | |
| is_authentic = metadata.get("is_authentic", torch.ones(images.size(0))) | |
| is_back = metadata.get("is_back", torch.zeros(images.size(0))) | |
| mask = (is_authentic == 1) & (is_back == 0) | |
| if not mask.any(): | |
| continue | |
| images = images[mask] | |
| resnet_features = self.resnet(images) | |
| efficientnet_features = self.efficientnet(images) | |
| for name, head in self.svdd_heads.items(): | |
| if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7": | |
| emb = head(efficientnet_features) | |
| else: | |
| emb = head(resnet_features) | |
| embeddings_accum[name].append(emb.cpu()) | |
| total_samples += images.size(0) | |
| for name in SVDD_HEAD_CONFIG: | |
| if len(embeddings_accum[name]) == 0: | |
| logger.warning( | |
| f"No authentic front embeddings for head '{name}', keeping zero center" | |
| ) | |
| continue | |
| all_emb = torch.cat(embeddings_accum[name], dim=0) | |
| center = all_emb.mean(dim=0) | |
| self.set_center(name, center.to(device)) | |
| self.centers_initialized.fill_(True) | |
| logger.info( | |
| f"SVDD centers initialized from {total_samples} authentic front samples " | |
| f"({len(embeddings_accum['primary'])} batches)" | |
| ) | |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass through all heads. | |
| Args: | |
| x: Input tensor (B, 3, 224, 224) | |
| Returns: | |
| Dict with: | |
| 'pokemon_score': P(pokemon) (B, 1) | |
| 'back_score': P(genuine_back) (B, 1) | |
| 'embeddings': Dict[name, (B, embed_dim)] | |
| 'distances': Dict[name, (B,)] - ||f(x) - c||^2 | |
| 'svdd_scores': Dict[name, (B,)] - 1/(1+dist) normalized [0,1] | |
| 'prediction': Weighted SVDD score (B, 1) | |
| 'head_outputs': Alias for svdd_scores as (B, 1) tensors | |
| """ | |
| resnet_features = self.resnet(x) | |
| efficientnet_features = self.efficientnet(x) | |
| # Head A: Pokemon classifier | |
| pokemon_score = self.pokemon_head(resnet_features) | |
| # Head B: Back authenticator | |
| back_score = self.back_auth_head(resnet_features) | |
| # Head C: SVDD embeddings | |
| embeddings = {} | |
| distances = {} | |
| svdd_scores = {} | |
| for name, head in self.svdd_heads.items(): | |
| if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7": | |
| emb = head(efficientnet_features) | |
| else: | |
| emb = head(resnet_features) | |
| embeddings[name] = emb | |
| center = self.get_center(name) | |
| dist = torch.sum((emb - center.unsqueeze(0)) ** 2, dim=1) | |
| distances[name] = dist | |
| score = 1.0 / (1.0 + dist) | |
| svdd_scores[name] = score | |
| # Weighted SVDD prediction | |
| batch_size = x.size(0) | |
| weighted_sum = torch.zeros(batch_size, device=x.device) | |
| for name, score in svdd_scores.items(): | |
| weighted_sum = weighted_sum + self.head_weights[name] * score | |
| # head_outputs: backward-compatible dict of (B, 1) tensors | |
| head_outputs = { | |
| name: score.unsqueeze(1) for name, score in svdd_scores.items() | |
| } | |
| return { | |
| "pokemon_score": pokemon_score, | |
| "back_score": back_score, | |
| "embeddings": embeddings, | |
| "distances": distances, | |
| "svdd_scores": svdd_scores, | |
| "prediction": weighted_sum.unsqueeze(1), | |
| "head_outputs": head_outputs, | |
| } | |
| def get_total_params(self) -> Dict[str, int]: | |
| """Get total parameter counts.""" | |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in self.parameters()) | |
| return { | |
| "trainable": trainable, | |
| "frozen": total - trainable, | |
| "total": total, | |
| } | |
| def get_param_groups(self, backbone_lr: float = 1e-4, head_lr: float = 1e-3): | |
| """ | |
| Get parameter groups with discriminative (layer-wise) learning rates. | |
| 3 groups: | |
| - Early trainable backbone layers (layer3/block6): backbone_lr * 0.1 | |
| - Late trainable backbone layers (layer4/block7+): backbone_lr | |
| - Head parameters: head_lr | |
| Args: | |
| backbone_lr: Learning rate for late backbone layers | |
| head_lr: Learning rate for head parameters | |
| Returns: | |
| List of parameter group dicts for optimizer | |
| """ | |
| resnet_groups = self.resnet.get_layer_groups() # [layer3, layer4] | |
| efn_groups = self.efficientnet.get_layer_groups() # [block6, block7+] | |
| early_backbone_params = resnet_groups[0] + efn_groups[0] | |
| late_backbone_params = resnet_groups[1] + efn_groups[1] | |
| head_params = ( | |
| list(self.pokemon_head.parameters()) | |
| + list(self.back_auth_head.parameters()) | |
| + list(self.svdd_heads.parameters()) | |
| ) | |
| groups = [] | |
| if early_backbone_params: | |
| groups.append({"params": early_backbone_params, "lr": backbone_lr * 0.1}) | |
| if late_backbone_params: | |
| groups.append({"params": late_backbone_params, "lr": backbone_lr}) | |
| groups.append({"params": head_params, "lr": head_lr}) | |
| return groups | |
| # Backward-compatible alias | |
| CardAuthDLModel = CardAuthModel | |