import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict class TextFieldAttention(nn.Module): """Calculates a weighted sum of text field embeddings.""" def __init__(self, num_fields: int, field_dim: int): super().__init__() self.attn = nn.Linear(field_dim, 1, bias=False) self.num_fields = num_fields def forward(self, fields: torch.Tensor): scores = self.attn(fields) weights = F.softmax(scores, dim=1) weighted_sum = (fields * weights).sum(dim=1) return weighted_sum, weights.squeeze(-1) class GenreSelfAttention(nn.Module): """Calculates a weighted sum of genres based only on the genres themselves.""" def __init__(self, genre_dim: int): super().__init__() self.attn_scorer = nn.Linear(genre_dim, 1, bias=False) def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor): scores = self.attn_scorer(genre_embeds) scores.masked_fill_(mask == 0, -1e9) weights = F.softmax(scores, dim=1) weighted_sum = (genre_embeds * weights).sum(dim=1) return weighted_sum class ModalityAttention(nn.Module): """ Calculates a weighted sum of vectors from different modalities (text, genres, etc.), allowing the model to dynamically determine their importance. """ def __init__(self, num_modalities: int, modality_dim: int): super().__init__() self.attn_scorer = nn.Linear(modality_dim, 1, bias=False) self.num_modalities = num_modalities def forward(self, modalities: torch.Tensor): scores = self.attn_scorer(modalities) weights = F.softmax(scores, dim=1) weighted_sum = (modalities * weights).sum(dim=1) return weighted_sum, weights.squeeze(-1) class AnimeEmbeddingModel(nn.Module): """ Main model v13. """ def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int] = None, dropout_rate: float = 0.3, text_embedding_size: int = 384, final_embedding_dim: int = 512): super().__init__() if embedding_dims is None: embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32} self.embedding_dims = embedding_dims self.final_embedding_dim = final_embedding_dim self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0) self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio']) self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type']) self.numerical_layer = nn.Linear(6, embedding_dims['numerical']) self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim) self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim) other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical'] self.other_projector = nn.Linear(other_dim, final_embedding_dim) self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size) self.genre_attention = GenreSelfAttention(embedding_dims['genre']) self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim) self.encoder = nn.Sequential( nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024), nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768), nn.Linear(768, final_embedding_dim), ) self.text_scale = nn.Parameter(torch.tensor(1.0)) self.genre_scale = nn.Parameter(torch.tensor(1.0)) self.other_scale = nn.Parameter(torch.tensor(1.0)) def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: text_fields = torch.stack([ batch['precomputed_ua_desc'], batch['precomputed_en_desc'], batch['precomputed_ua_title'], batch['precomputed_en_title'], batch['precomputed_original_title'], batch['precomputed_alternate_names'], ], dim=1) text_vector_raw, _ = self.text_field_attention(text_fields) genre_embeds_raw = self.genre_embedding(batch['genres']) genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1)) studio_emb = self.studio_embedding(batch['studio']) type_emb = self.type_embedding(batch['type']) numerical_emb = F.relu(self.numerical_layer(batch['numerical'])) other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1) text_vector_proj = self.text_projector(text_vector_raw) genre_vector_proj = self.genre_projector(genre_vector_raw) other_vector_proj = self.other_projector(other_vector_parts) modalities = torch.stack([ F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale, F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale, F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale, ], dim=1) combined, _ = self.modality_attention(modalities) embedding_logits = self.encoder(combined) embedding = torch.tanh(embedding_logits) embedding = F.normalize(embedding, p=2, dim=1) return embedding