Spaces:
Runtime error
Runtime error
File size: 4,908 Bytes
0ba6002 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
Classification and Embedding Heads for Card Authentication.
Three head types:
- PokemonClassifierHead (Head A): Pokemon vs Non-Pokemon binary classifier
- BackAuthHead (Head B): Genuine vs counterfeit back pattern classifier
- EmbeddingHead (Head C): Deep SVDD embedding for front anomaly detection
Six SVDD embedding heads provide component_scores for backward compatibility.
"""
import torch
import torch.nn as nn
from typing import Dict
class PokemonClassifierHead(nn.Module):
"""
Head A: Pokemon vs Non-Pokemon binary classifier.
Architecture:
Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3)
-> Linear(256 -> 1) -> Sigmoid
"""
def __init__(self, in_dim: int, name: str = "pokemon_head"):
super().__init__()
self.name = name
self.classifier = nn.Sequential(
nn.Linear(in_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. Returns P(pokemon) in [0, 1], shape (B, 1)."""
return self.classifier(x)
class BackAuthHead(nn.Module):
"""
Head B: Genuine vs counterfeit back pattern classifier.
Architecture:
Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3)
-> Linear(256 -> 1) -> Sigmoid
"""
def __init__(self, in_dim: int, name: str = "back_auth_head"):
super().__init__()
self.name = name
self.classifier = nn.Sequential(
nn.Linear(in_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. Returns P(genuine_back) in [0, 1], shape (B, 1)."""
return self.classifier(x)
class EmbeddingHead(nn.Module):
"""
Head C: Deep SVDD embedding head for front anomaly detection.
No sigmoid. No bias in final layer (Ruff et al. 2018).
Architecture:
Linear(in_dim -> 512) -> BatchNorm -> ReLU
-> Linear(512 -> 128) -> BatchNorm -> ReLU
-> Linear(128 -> embed_dim, bias=False)
"""
def __init__(self, in_dim: int, embed_dim: int = 128, name: str = "embedding"):
super().__init__()
self.name = name
self.embed_dim = embed_dim
self.encoder = nn.Sequential(
nn.Linear(in_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, embed_dim, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass. Returns embedding (B, embed_dim)."""
return self.encoder(x)
# SVDD head configuration: name -> (weight, backbone_source)
# These provide component_scores for backward compatibility with 6-head UI
SVDD_HEAD_CONFIG: Dict[str, Dict] = {
"primary": {
"weight": 0.25,
"backbone": "resnet50",
"description": "Overall authenticity assessment",
},
"print_quality": {
"weight": 0.25,
"backbone": "efficientnet_b7",
"description": "Print patterns, color consistency",
},
"edge_inspector": {
"weight": 0.15,
"backbone": "resnet50",
"description": "Edge cutting, border quality",
},
"texture": {
"weight": 0.15,
"backbone": "resnet50",
"description": "Surface texture, micro-patterns",
},
"hologram": {
"weight": 0.10,
"backbone": "resnet50",
"description": "Hologram/foil patterns (limited by data)",
},
"historical": {
"weight": 0.10,
"backbone": "resnet50",
"description": "Similarity patterns (limited by data)",
},
}
# Backward-compatible alias
HEAD_CONFIG = SVDD_HEAD_CONFIG
def create_svdd_heads(
resnet_dim: int = 2048,
efficientnet_dim: int = 2560,
embed_dim: int = 128,
) -> nn.ModuleDict:
"""
Create all 6 SVDD embedding heads.
Args:
resnet_dim: ResNet50 output dimension
efficientnet_dim: EfficientNet-B7 output dimension
embed_dim: SVDD embedding dimension
Returns:
ModuleDict with all embedding heads
"""
heads = nn.ModuleDict()
for name, cfg in SVDD_HEAD_CONFIG.items():
in_dim = efficientnet_dim if cfg["backbone"] == "efficientnet_b7" else resnet_dim
heads[name] = EmbeddingHead(in_dim=in_dim, embed_dim=embed_dim, name=name)
return heads
def get_head_weights() -> Dict[str, float]:
"""Get the weighting for each SVDD head in the final prediction."""
return {name: cfg["weight"] for name, cfg in SVDD_HEAD_CONFIG.items()}
|