stevelohwc's picture
hfspace: add backend API runtime bundle for Space deployment
0ba6002
"""
Classification and Embedding Heads for Card Authentication.
Three head types:
- PokemonClassifierHead (Head A): Pokemon vs Non-Pokemon binary classifier
- BackAuthHead (Head B): Genuine vs counterfeit back pattern classifier
- EmbeddingHead (Head C): Deep SVDD embedding for front anomaly detection
Six SVDD embedding heads provide component_scores for backward compatibility.
"""
import torch
import torch.nn as nn
from typing import Dict
class PokemonClassifierHead(nn.Module):
"""
Head A: Pokemon vs Non-Pokemon binary classifier.
Architecture:
Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3)
-> Linear(256 -> 1) -> Sigmoid
"""
def __init__(self, in_dim: int, name: str = "pokemon_head"):
super().__init__()
self.name = name
self.classifier = nn.Sequential(
nn.Linear(in_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. Returns P(pokemon) in [0, 1], shape (B, 1)."""
return self.classifier(x)
class BackAuthHead(nn.Module):
"""
Head B: Genuine vs counterfeit back pattern classifier.
Architecture:
Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3)
-> Linear(256 -> 1) -> Sigmoid
"""
def __init__(self, in_dim: int, name: str = "back_auth_head"):
super().__init__()
self.name = name
self.classifier = nn.Sequential(
nn.Linear(in_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. Returns P(genuine_back) in [0, 1], shape (B, 1)."""
return self.classifier(x)
class EmbeddingHead(nn.Module):
"""
Head C: Deep SVDD embedding head for front anomaly detection.
No sigmoid. No bias in final layer (Ruff et al. 2018).
Architecture:
Linear(in_dim -> 512) -> BatchNorm -> ReLU
-> Linear(512 -> 128) -> BatchNorm -> ReLU
-> Linear(128 -> embed_dim, bias=False)
"""
def __init__(self, in_dim: int, embed_dim: int = 128, name: str = "embedding"):
super().__init__()
self.name = name
self.embed_dim = embed_dim
self.encoder = nn.Sequential(
nn.Linear(in_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, embed_dim, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. Returns embedding (B, embed_dim)."""
return self.encoder(x)
# SVDD head configuration: name -> (weight, backbone_source)
# These provide component_scores for backward compatibility with 6-head UI
SVDD_HEAD_CONFIG: Dict[str, Dict] = {
"primary": {
"weight": 0.25,
"backbone": "resnet50",
"description": "Overall authenticity assessment",
},
"print_quality": {
"weight": 0.25,
"backbone": "efficientnet_b7",
"description": "Print patterns, color consistency",
},
"edge_inspector": {
"weight": 0.15,
"backbone": "resnet50",
"description": "Edge cutting, border quality",
},
"texture": {
"weight": 0.15,
"backbone": "resnet50",
"description": "Surface texture, micro-patterns",
},
"hologram": {
"weight": 0.10,
"backbone": "resnet50",
"description": "Hologram/foil patterns (limited by data)",
},
"historical": {
"weight": 0.10,
"backbone": "resnet50",
"description": "Similarity patterns (limited by data)",
},
}
# Backward-compatible alias
HEAD_CONFIG = SVDD_HEAD_CONFIG
def create_svdd_heads(
resnet_dim: int = 2048,
efficientnet_dim: int = 2560,
embed_dim: int = 128,
) -> nn.ModuleDict:
"""
Create all 6 SVDD embedding heads.
Args:
resnet_dim: ResNet50 output dimension
efficientnet_dim: EfficientNet-B7 output dimension
embed_dim: SVDD embedding dimension
Returns:
ModuleDict with all embedding heads
"""
heads = nn.ModuleDict()
for name, cfg in SVDD_HEAD_CONFIG.items():
in_dim = efficientnet_dim if cfg["backbone"] == "efficientnet_b7" else resnet_dim
heads[name] = EmbeddingHead(in_dim=in_dim, embed_dim=embed_dim, name=name)
return heads
def get_head_weights() -> Dict[str, float]:
"""Get the weighting for each SVDD head in the final prediction."""
return {name: cfg["weight"] for name, cfg in SVDD_HEAD_CONFIG.items()}