github-actions[bot]
deploy: backend bundle from 9c864b98f64c05462a27b71841ae97fb4451e449
c61ba70
"""
CardAuthModel - Multi-Purpose Deep Learning Model.
Three head groups:
- Head A (PokemonClassifierHead): Pokemon vs Non-Pokemon (ResNet50)
- Head B (BackAuthHead): Genuine vs counterfeit back (ResNet50)
- Head C (EmbeddingHead x6): Deep SVDD for front anomaly detection
SVDD centers stored as register_buffer (saved in state_dict, not gradient-trained).
"""
import torch
import torch.nn as nn
from typing import Dict, Optional, List
from .backbone import ResNet50Backbone
from .efficientnet import EfficientNetB7Backbone
from .heads import (
PokemonClassifierHead,
BackAuthHead,
EmbeddingHead,
SVDD_HEAD_CONFIG,
create_svdd_heads,
get_head_weights,
)
from ..utils.logger import get_logger
from ..utils.config import config
logger = get_logger(__name__)
class CardAuthModel(nn.Module):
"""
Multi-purpose deep learning model for card authentication.
Architecture:
Input image (B, 3, 224, 224)
|-- ResNet50 -> 2048-dim
| |-- pokemon_head (Head A) -> P(pokemon)
| |-- back_auth_head (Head B) -> P(genuine_back)
| |-- SVDD heads (Head C): primary, edge_inspector,
| | texture, hologram, historical -> 128-dim each
|-- EfficientNet-B7 -> 2560-dim
|-- SVDD head: print_quality -> 128-dim
SVDD output: weighted 1/(1+dist) scores
"""
def __init__(
self,
pretrained: bool = True,
freeze_early: bool = True,
head_weights: Optional[Dict[str, float]] = None,
embed_dim: int = None,
):
super().__init__()
if embed_dim is None:
embed_dim = config.DL_SVDD_EMBEDDING_DIM
# Backbones
self.resnet = ResNet50Backbone(pretrained=pretrained, freeze_early=freeze_early)
self.efficientnet = EfficientNetB7Backbone(pretrained=pretrained, freeze_early=freeze_early)
# Head A: Pokemon classifier
self.pokemon_head = PokemonClassifierHead(in_dim=self.resnet.output_dim)
# Head B: Back authenticator
self.back_auth_head = BackAuthHead(in_dim=self.resnet.output_dim)
# Head C: SVDD embedding heads (6 heads for component_scores)
self.svdd_heads = create_svdd_heads(
resnet_dim=self.resnet.output_dim,
efficientnet_dim=self.efficientnet.output_dim,
embed_dim=embed_dim,
)
# Head weights for final SVDD prediction
self.head_weights = head_weights or get_head_weights()
self.embed_dim = embed_dim
# Register SVDD centers as buffers (not trained by gradient)
for name in SVDD_HEAD_CONFIG:
self.register_buffer(
f"center_{name}",
torch.zeros(embed_dim),
)
# Track whether centers have been initialized
self.register_buffer("centers_initialized", torch.tensor(False))
resnet_params = self.resnet.get_trainable_params()
efn_params = self.efficientnet.get_trainable_params()
logger.info(
f"CardAuthModel initialized: "
f"ResNet50 ({resnet_params['trainable']:,} trainable), "
f"EfficientNet-B7 ({efn_params['trainable']:,} trainable), "
f"Head A (pokemon), Head B (back_auth), 6 SVDD heads"
)
def get_center(self, name: str) -> torch.Tensor:
"""Get SVDD center for a named head."""
return getattr(self, f"center_{name}")
def set_center(self, name: str, center: torch.Tensor):
"""Set SVDD center for a named head."""
getattr(self, f"center_{name}").copy_(center)
@torch.no_grad()
def initialize_centers(self, dataloader, device: torch.device = None):
"""
Initialize SVDD centers by computing mean embeddings on authentic front data.
Only uses samples where is_authentic=1 AND is_back=0 (authentic fronts).
Counterfeits and back images are excluded to avoid polluting centers.
Args:
dataloader: DataLoader yielding (images, metadata)
device: Compute device
"""
if device is None:
device = next(self.parameters()).device
self.eval()
embeddings_accum = {name: [] for name in SVDD_HEAD_CONFIG}
total_samples = 0
for batch in dataloader:
if isinstance(batch, (list, tuple)) and len(batch) == 2:
images, metadata = batch
else:
images = batch
metadata = None
images = images.to(device)
# Filter to authentic front images only
if metadata is not None:
is_authentic = metadata.get("is_authentic", torch.ones(images.size(0)))
is_back = metadata.get("is_back", torch.zeros(images.size(0)))
mask = (is_authentic == 1) & (is_back == 0)
if not mask.any():
continue
images = images[mask]
resnet_features = self.resnet(images)
efficientnet_features = self.efficientnet(images)
for name, head in self.svdd_heads.items():
if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7":
emb = head(efficientnet_features)
else:
emb = head(resnet_features)
embeddings_accum[name].append(emb.cpu())
total_samples += images.size(0)
for name in SVDD_HEAD_CONFIG:
if len(embeddings_accum[name]) == 0:
logger.warning(
f"No authentic front embeddings for head '{name}', keeping zero center"
)
continue
all_emb = torch.cat(embeddings_accum[name], dim=0)
center = all_emb.mean(dim=0)
self.set_center(name, center.to(device))
self.centers_initialized.fill_(True)
logger.info(
f"SVDD centers initialized from {total_samples} authentic front samples "
f"({len(embeddings_accum['primary'])} batches)"
)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass through all heads.
Args:
x: Input tensor (B, 3, 224, 224)
Returns:
Dict with:
'pokemon_score': P(pokemon) (B, 1)
'back_score': P(genuine_back) (B, 1)
'embeddings': Dict[name, (B, embed_dim)]
'distances': Dict[name, (B,)] - ||f(x) - c||^2
'svdd_scores': Dict[name, (B,)] - 1/(1+dist) normalized [0,1]
'prediction': Weighted SVDD score (B, 1)
'head_outputs': Alias for svdd_scores as (B, 1) tensors
"""
resnet_features = self.resnet(x)
efficientnet_features = self.efficientnet(x)
# Head A: Pokemon classifier
pokemon_score = self.pokemon_head(resnet_features)
# Head B: Back authenticator
back_score = self.back_auth_head(resnet_features)
# Head C: SVDD embeddings
embeddings = {}
distances = {}
svdd_scores = {}
for name, head in self.svdd_heads.items():
if SVDD_HEAD_CONFIG[name]["backbone"] == "efficientnet_b7":
emb = head(efficientnet_features)
else:
emb = head(resnet_features)
embeddings[name] = emb
center = self.get_center(name)
dist = torch.sum((emb - center.unsqueeze(0)) ** 2, dim=1)
distances[name] = dist
score = 1.0 / (1.0 + dist)
svdd_scores[name] = score
# Weighted SVDD prediction
batch_size = x.size(0)
weighted_sum = torch.zeros(batch_size, device=x.device)
for name, score in svdd_scores.items():
weighted_sum = weighted_sum + self.head_weights[name] * score
# head_outputs: backward-compatible dict of (B, 1) tensors
head_outputs = {
name: score.unsqueeze(1) for name, score in svdd_scores.items()
}
return {
"pokemon_score": pokemon_score,
"back_score": back_score,
"embeddings": embeddings,
"distances": distances,
"svdd_scores": svdd_scores,
"prediction": weighted_sum.unsqueeze(1),
"head_outputs": head_outputs,
}
def get_total_params(self) -> Dict[str, int]:
"""Get total parameter counts."""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
return {
"trainable": trainable,
"frozen": total - trainable,
"total": total,
}
def get_param_groups(self, backbone_lr: float = 1e-4, head_lr: float = 1e-3):
"""
Get parameter groups with discriminative (layer-wise) learning rates.
3 groups:
- Early trainable backbone layers (layer3/block6): backbone_lr * 0.1
- Late trainable backbone layers (layer4/block7+): backbone_lr
- Head parameters: head_lr
Args:
backbone_lr: Learning rate for late backbone layers
head_lr: Learning rate for head parameters
Returns:
List of parameter group dicts for optimizer
"""
resnet_groups = self.resnet.get_layer_groups() # [layer3, layer4]
efn_groups = self.efficientnet.get_layer_groups() # [block6, block7+]
early_backbone_params = resnet_groups[0] + efn_groups[0]
late_backbone_params = resnet_groups[1] + efn_groups[1]
head_params = (
list(self.pokemon_head.parameters())
+ list(self.back_auth_head.parameters())
+ list(self.svdd_heads.parameters())
)
groups = []
if early_backbone_params:
groups.append({"params": early_backbone_params, "lr": backbone_lr * 0.1})
if late_backbone_params:
groups.append({"params": late_backbone_params, "lr": backbone_lr})
groups.append({"params": head_params, "lr": head_lr})
return groups
# Backward-compatible alias
CardAuthDLModel = CardAuthModel