File size: 4,908 Bytes
0ba6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Classification and Embedding Heads for Card Authentication.

Three head types:
- PokemonClassifierHead (Head A): Pokemon vs Non-Pokemon binary classifier
- BackAuthHead (Head B): Genuine vs counterfeit back pattern classifier
- EmbeddingHead (Head C): Deep SVDD embedding for front anomaly detection

Six SVDD embedding heads provide component_scores for backward compatibility.
"""

import torch
import torch.nn as nn
from typing import Dict


class PokemonClassifierHead(nn.Module):
    """
    Head A: Pokemon vs Non-Pokemon binary classifier.

    Architecture:
        Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3)
        -> Linear(256 -> 1) -> Sigmoid
    """

    def __init__(self, in_dim: int, name: str = "pokemon_head"):
        super().__init__()
        self.name = name
        self.classifier = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass. Returns P(pokemon) in [0, 1], shape (B, 1)."""
        return self.classifier(x)


class BackAuthHead(nn.Module):
    """
    Head B: Genuine vs counterfeit back pattern classifier.

    Architecture:
        Linear(in_dim -> 256) -> BatchNorm -> ReLU -> Dropout(0.3)
        -> Linear(256 -> 1) -> Sigmoid
    """

    def __init__(self, in_dim: int, name: str = "back_auth_head"):
        super().__init__()
        self.name = name
        self.classifier = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass. Returns P(genuine_back) in [0, 1], shape (B, 1)."""
        return self.classifier(x)


class EmbeddingHead(nn.Module):
    """
    Head C: Deep SVDD embedding head for front anomaly detection.

    No sigmoid. No bias in final layer (Ruff et al. 2018).

    Architecture:
        Linear(in_dim -> 512) -> BatchNorm -> ReLU
        -> Linear(512 -> 128) -> BatchNorm -> ReLU
        -> Linear(128 -> embed_dim, bias=False)
    """

    def __init__(self, in_dim: int, embed_dim: int = 128, name: str = "embedding"):
        super().__init__()
        self.name = name
        self.embed_dim = embed_dim
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, embed_dim, bias=False),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass. Returns embedding (B, embed_dim)."""
        return self.encoder(x)


# SVDD head configuration: name -> (weight, backbone_source)
# These provide component_scores for backward compatibility with 6-head UI
SVDD_HEAD_CONFIG: Dict[str, Dict] = {
    "primary": {
        "weight": 0.25,
        "backbone": "resnet50",
        "description": "Overall authenticity assessment",
    },
    "print_quality": {
        "weight": 0.25,
        "backbone": "efficientnet_b7",
        "description": "Print patterns, color consistency",
    },
    "edge_inspector": {
        "weight": 0.15,
        "backbone": "resnet50",
        "description": "Edge cutting, border quality",
    },
    "texture": {
        "weight": 0.15,
        "backbone": "resnet50",
        "description": "Surface texture, micro-patterns",
    },
    "hologram": {
        "weight": 0.10,
        "backbone": "resnet50",
        "description": "Hologram/foil patterns (limited by data)",
    },
    "historical": {
        "weight": 0.10,
        "backbone": "resnet50",
        "description": "Similarity patterns (limited by data)",
    },
}

# Backward-compatible alias
HEAD_CONFIG = SVDD_HEAD_CONFIG


def create_svdd_heads(
    resnet_dim: int = 2048,
    efficientnet_dim: int = 2560,
    embed_dim: int = 128,
) -> nn.ModuleDict:
    """
    Create all 6 SVDD embedding heads.

    Args:
        resnet_dim: ResNet50 output dimension
        efficientnet_dim: EfficientNet-B7 output dimension
        embed_dim: SVDD embedding dimension

    Returns:
        ModuleDict with all embedding heads
    """
    heads = nn.ModuleDict()
    for name, cfg in SVDD_HEAD_CONFIG.items():
        in_dim = efficientnet_dim if cfg["backbone"] == "efficientnet_b7" else resnet_dim
        heads[name] = EmbeddingHead(in_dim=in_dim, embed_dim=embed_dim, name=name)
    return heads


def get_head_weights() -> Dict[str, float]:
    """Get the weighting for each SVDD head in the final prediction."""
    return {name: cfg["weight"] for name, cfg in SVDD_HEAD_CONFIG.items()}