""" Classification and Embedding Heads for Card Authentication. Three head types: - PokemonClassifierHead (Head A): Pokemon vs Non-Pokemon binary classifier - BackAuthHead (Head B): Genuine vs counterfeit back pattern classifier - EmbeddingHead (Head C): Deep SVDD embedding for front anomaly detection Six SVDD embedding heads provide component_scores for backward compatibility. """ import torch import torch.nn as nn from typing import Dict class PokemonClassifierHead(nn.Module): """ Head A: Pokemon vs Non-Pokemon binary classifier. Architecture: Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3) -> Linear(256 -> 1) -> Sigmoid """ def __init__(self, in_dim: int, name: str = "pokemon_head"): super().__init__() self.name = name self.classifier = nn.Sequential( nn.Linear(in_dim, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Returns P(pokemon) in [0, 1], shape (B, 1).""" return self.classifier(x) class BackAuthHead(nn.Module): """ Head B: Genuine vs counterfeit back pattern classifier. Architecture: Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3) -> Linear(256 -> 1) -> Sigmoid """ def __init__(self, in_dim: int, name: str = "back_auth_head"): super().__init__() self.name = name self.classifier = nn.Sequential( nn.Linear(in_dim, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Returns P(genuine_back) in [0, 1], shape (B, 1).""" return self.classifier(x) class EmbeddingHead(nn.Module): """ Head C: Deep SVDD embedding head for front anomaly detection. No sigmoid. No bias in final layer (Ruff et al. 2018). Architecture: Linear(in_dim -> 512) -> BatchNorm -> ReLU -> Linear(512 -> 128) -> BatchNorm -> ReLU -> Linear(128 -> embed_dim, bias=False) """ def __init__(self, in_dim: int, embed_dim: int = 128, name: str = "embedding"): super().__init__() self.name = name self.embed_dim = embed_dim self.encoder = nn.Sequential( nn.Linear(in_dim, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Linear(128, embed_dim, bias=False), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Returns embedding (B, embed_dim).""" return self.encoder(x) # SVDD head configuration: name -> (weight, backbone_source) # These provide component_scores for backward compatibility with 6-head UI SVDD_HEAD_CONFIG: Dict[str, Dict] = { "primary": { "weight": 0.25, "backbone": "resnet50", "description": "Overall authenticity assessment", }, "print_quality": { "weight": 0.25, "backbone": "efficientnet_b7", "description": "Print patterns, color consistency", }, "edge_inspector": { "weight": 0.15, "backbone": "resnet50", "description": "Edge cutting, border quality", }, "texture": { "weight": 0.15, "backbone": "resnet50", "description": "Surface texture, micro-patterns", }, "hologram": { "weight": 0.10, "backbone": "resnet50", "description": "Hologram/foil patterns (limited by data)", }, "historical": { "weight": 0.10, "backbone": "resnet50", "description": "Similarity patterns (limited by data)", }, } # Backward-compatible alias HEAD_CONFIG = SVDD_HEAD_CONFIG def create_svdd_heads( resnet_dim: int = 2048, efficientnet_dim: int = 2560, embed_dim: int = 128, ) -> nn.ModuleDict: """ Create all 6 SVDD embedding heads. Args: resnet_dim: ResNet50 output dimension efficientnet_dim: EfficientNet-B7 output dimension embed_dim: SVDD embedding dimension Returns: ModuleDict with all embedding heads """ heads = nn.ModuleDict() for name, cfg in SVDD_HEAD_CONFIG.items(): in_dim = efficientnet_dim if cfg["backbone"] == "efficientnet_b7" else resnet_dim heads[name] = EmbeddingHead(in_dim=in_dim, embed_dim=embed_dim, name=name) return heads def get_head_weights() -> Dict[str, float]: """Get the weighting for each SVDD head in the final prediction.""" return {name: cfg["weight"] for name, cfg in SVDD_HEAD_CONFIG.items()}