|
|
""" |
|
|
Artist Style Embedding - Model Architecture |
|
|
EVA02-Large based Multi-branch Style Encoder |
|
|
""" |
|
|
from typing import Dict, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import timm |
|
|
|
|
|
|
|
|
class EVA02Encoder(nn.Module): |
|
|
""" |
|
|
EVA02-Large backbone encoder |
|
|
Pre-trained on CLIP, excellent for style features |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pretrained: bool = True, |
|
|
output_dim: int = 1024, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.backbone = timm.create_model( |
|
|
"eva02_large_patch14_clip_224", |
|
|
pretrained=pretrained, |
|
|
num_classes=0, |
|
|
) |
|
|
|
|
|
|
|
|
self.feature_dim = 1024 |
|
|
|
|
|
if self.feature_dim != output_dim: |
|
|
self.proj = nn.Sequential( |
|
|
nn.Linear(self.feature_dim, output_dim), |
|
|
nn.LayerNorm(output_dim), |
|
|
nn.GELU(), |
|
|
) |
|
|
else: |
|
|
self.proj = nn.Identity() |
|
|
|
|
|
self.output_dim = output_dim |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
features = self.backbone(x) |
|
|
return self.proj(features) |
|
|
|
|
|
|
|
|
class GatedFusion(nn.Module): |
|
|
"""Gated attention fusion for multi-branch features""" |
|
|
|
|
|
def __init__(self, input_dim: int, num_branches: int = 3): |
|
|
super().__init__() |
|
|
|
|
|
self.num_branches = num_branches |
|
|
|
|
|
self.gate = nn.Sequential( |
|
|
nn.Linear(input_dim * num_branches, input_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(input_dim, num_branches), |
|
|
nn.Softmax(dim=-1), |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
B, N, D = features.shape |
|
|
|
|
|
concat_features = features.view(B, -1) |
|
|
gates = self.gate(concat_features) |
|
|
|
|
|
if mask is not None: |
|
|
gates = gates * mask.float() |
|
|
gates = gates / (gates.sum(dim=-1, keepdim=True) + 1e-8) |
|
|
|
|
|
gates = gates.unsqueeze(-1) |
|
|
fused = (features * gates).sum(dim=1) |
|
|
|
|
|
return fused |
|
|
|
|
|
|
|
|
class StyleEmbeddingHead(nn.Module): |
|
|
"""Final embedding projection head""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
embedding_dim: int = 512, |
|
|
hidden_dim: int = 1024, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, embedding_dim), |
|
|
) |
|
|
|
|
|
self.final_norm = nn.LayerNorm(embedding_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.mlp(x) |
|
|
x = self.final_norm(x) |
|
|
x = F.normalize(x, p=2, dim=-1) |
|
|
return x |
|
|
|
|
|
|
|
|
class MultiBranchStyleEncoder(nn.Module): |
|
|
""" |
|
|
Multi-branch style encoder with separate EVA02-Large backbones |
|
|
- Full image branch |
|
|
- Face crop branch |
|
|
- Eye crop branch |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: int = 512, |
|
|
hidden_dim: int = 1024, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.shared_backbone = EVA02Encoder(pretrained=True, output_dim=hidden_dim) |
|
|
|
|
|
|
|
|
self.fusion = GatedFusion(hidden_dim, num_branches=3) |
|
|
|
|
|
|
|
|
self.embedding_head = StyleEmbeddingHead( |
|
|
hidden_dim, embedding_dim, hidden_dim, dropout |
|
|
) |
|
|
|
|
|
self.embedding_dim = embedding_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
full: torch.Tensor, |
|
|
face: torch.Tensor, |
|
|
eye: torch.Tensor, |
|
|
has_face: torch.Tensor, |
|
|
has_eye: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
B = full.shape[0] |
|
|
device = full.device |
|
|
|
|
|
|
|
|
full_features = self.shared_backbone(full) |
|
|
face_features = self.shared_backbone(face) * has_face.unsqueeze(-1) |
|
|
eye_features = self.shared_backbone(eye) * has_eye.unsqueeze(-1) |
|
|
|
|
|
|
|
|
stacked = torch.stack([full_features, face_features, eye_features], dim=1) |
|
|
mask = torch.stack([ |
|
|
torch.ones(B, device=device, dtype=torch.bool), |
|
|
has_face, |
|
|
has_eye, |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
fused = self.fusion(stacked, mask) |
|
|
|
|
|
|
|
|
embeddings = self.embedding_head(fused) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def get_backbone_params(self): |
|
|
"""Returns parameters of the single shared backbone""" |
|
|
|
|
|
return self.shared_backbone.parameters() |
|
|
|
|
|
def get_head_params(self): |
|
|
"""Returns parameters of all heads and fusion layers""" |
|
|
params = [] |
|
|
|
|
|
params.extend(self.fusion.parameters()) |
|
|
|
|
|
params.extend(self.embedding_head.parameters()) |
|
|
|
|
|
|
|
|
return params |
|
|
|
|
|
def freeze_backbone(self): |
|
|
"""Freezes the single shared backbone""" |
|
|
for param in self.get_backbone_params(): |
|
|
param.requires_grad = False |
|
|
self.shared_backbone.eval() |
|
|
|
|
|
def unfreeze_backbone(self): |
|
|
"""Unfreezes the single shared backbone""" |
|
|
for param in self.get_backbone_params(): |
|
|
param.requires_grad = True |
|
|
self.shared_backbone.train() |
|
|
|
|
|
|
|
|
class ArtistStyleModel(nn.Module): |
|
|
""" |
|
|
Complete model: Multi-branch Encoder + ArcFace Head |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_classes: int, |
|
|
embedding_dim: int = 512, |
|
|
hidden_dim: int = 1024, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
|
|
|
self.encoder = MultiBranchStyleEncoder( |
|
|
embedding_dim=embedding_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
|
|
|
self.arcface_weight = nn.Parameter( |
|
|
torch.FloatTensor(num_classes, embedding_dim) |
|
|
) |
|
|
nn.init.xavier_uniform_(self.arcface_weight) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
full: torch.Tensor, |
|
|
face: torch.Tensor, |
|
|
eye: torch.Tensor, |
|
|
has_face: torch.Tensor, |
|
|
has_eye: torch.Tensor, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
embeddings = self.encoder(full, face, eye, has_face, has_eye) |
|
|
|
|
|
|
|
|
normalized_weights = F.normalize(self.arcface_weight, p=2, dim=1) |
|
|
cosine = F.linear(embeddings, normalized_weights) |
|
|
|
|
|
return { |
|
|
'embeddings': embeddings, |
|
|
'cosine': cosine, |
|
|
} |
|
|
|
|
|
def get_embeddings( |
|
|
self, |
|
|
full: torch.Tensor, |
|
|
face: torch.Tensor, |
|
|
eye: torch.Tensor, |
|
|
has_face: torch.Tensor, |
|
|
has_eye: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
return self.encoder(full, face, eye, has_face, has_eye) |
|
|
|
|
|
|
|
|
def create_model(config, num_classes: int) -> ArtistStyleModel: |
|
|
"""Create model from config""" |
|
|
return ArtistStyleModel( |
|
|
num_classes=num_classes, |
|
|
embedding_dim=config.model.embedding_dim, |
|
|
hidden_dim=config.model.hidden_dim, |
|
|
dropout=config.model.dropout, |
|
|
) |
|
|
|