Spaces:
Runtime error
Runtime error
| """ | |
| Classification and Embedding Heads for Card Authentication. | |
| Three head types: | |
| - PokemonClassifierHead (Head A): Pokemon vs Non-Pokemon binary classifier | |
| - BackAuthHead (Head B): Genuine vs counterfeit back pattern classifier | |
| - EmbeddingHead (Head C): Deep SVDD embedding for front anomaly detection | |
| Six SVDD embedding heads provide component_scores for backward compatibility. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Dict | |
| class PokemonClassifierHead(nn.Module): | |
| """ | |
| Head A: Pokemon vs Non-Pokemon binary classifier. | |
| Architecture: | |
| Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3) | |
| -> Linear(256 -> 1) -> Sigmoid | |
| """ | |
| def __init__(self, in_dim: int, name: str = "pokemon_head"): | |
| super().__init__() | |
| self.name = name | |
| self.classifier = nn.Sequential( | |
| nn.Linear(in_dim, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass. Returns P(pokemon) in [0, 1], shape (B, 1).""" | |
| return self.classifier(x) | |
| class BackAuthHead(nn.Module): | |
| """ | |
| Head B: Genuine vs counterfeit back pattern classifier. | |
| Architecture: | |
| Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3) | |
| -> Linear(256 -> 1) -> Sigmoid | |
| """ | |
| def __init__(self, in_dim: int, name: str = "back_auth_head"): | |
| super().__init__() | |
| self.name = name | |
| self.classifier = nn.Sequential( | |
| nn.Linear(in_dim, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass. Returns P(genuine_back) in [0, 1], shape (B, 1).""" | |
| return self.classifier(x) | |
| class EmbeddingHead(nn.Module): | |
| """ | |
| Head C: Deep SVDD embedding head for front anomaly detection. | |
| No sigmoid. No bias in final layer (Ruff et al. 2018). | |
| Architecture: | |
| Linear(in_dim -> 512) -> BatchNorm -> ReLU | |
| -> Linear(512 -> 128) -> BatchNorm -> ReLU | |
| -> Linear(128 -> embed_dim, bias=False) | |
| """ | |
| def __init__(self, in_dim: int, embed_dim: int = 128, name: str = "embedding"): | |
| super().__init__() | |
| self.name = name | |
| self.embed_dim = embed_dim | |
| self.encoder = nn.Sequential( | |
| nn.Linear(in_dim, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(512, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(128, embed_dim, bias=False), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass. Returns embedding (B, embed_dim).""" | |
| return self.encoder(x) | |
| # SVDD head configuration: name -> (weight, backbone_source) | |
| # These provide component_scores for backward compatibility with 6-head UI | |
| SVDD_HEAD_CONFIG: Dict[str, Dict] = { | |
| "primary": { | |
| "weight": 0.25, | |
| "backbone": "resnet50", | |
| "description": "Overall authenticity assessment", | |
| }, | |
| "print_quality": { | |
| "weight": 0.25, | |
| "backbone": "efficientnet_b7", | |
| "description": "Print patterns, color consistency", | |
| }, | |
| "edge_inspector": { | |
| "weight": 0.15, | |
| "backbone": "resnet50", | |
| "description": "Edge cutting, border quality", | |
| }, | |
| "texture": { | |
| "weight": 0.15, | |
| "backbone": "resnet50", | |
| "description": "Surface texture, micro-patterns", | |
| }, | |
| "hologram": { | |
| "weight": 0.10, | |
| "backbone": "resnet50", | |
| "description": "Hologram/foil patterns (limited by data)", | |
| }, | |
| "historical": { | |
| "weight": 0.10, | |
| "backbone": "resnet50", | |
| "description": "Similarity patterns (limited by data)", | |
| }, | |
| } | |
| # Backward-compatible alias | |
| HEAD_CONFIG = SVDD_HEAD_CONFIG | |
| def create_svdd_heads( | |
| resnet_dim: int = 2048, | |
| efficientnet_dim: int = 2560, | |
| embed_dim: int = 128, | |
| ) -> nn.ModuleDict: | |
| """ | |
| Create all 6 SVDD embedding heads. | |
| Args: | |
| resnet_dim: ResNet50 output dimension | |
| efficientnet_dim: EfficientNet-B7 output dimension | |
| embed_dim: SVDD embedding dimension | |
| Returns: | |
| ModuleDict with all embedding heads | |
| """ | |
| heads = nn.ModuleDict() | |
| for name, cfg in SVDD_HEAD_CONFIG.items(): | |
| in_dim = efficientnet_dim if cfg["backbone"] == "efficientnet_b7" else resnet_dim | |
| heads[name] = EmbeddingHead(in_dim=in_dim, embed_dim=embed_dim, name=name) | |
| return heads | |
| def get_head_weights() -> Dict[str, float]: | |
| """Get the weighting for each SVDD head in the final prediction.""" | |
| return {name: cfg["weight"] for name, cfg in SVDD_HEAD_CONFIG.items()} | |