| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Dict | |
| class TextFieldAttention(nn.Module): | |
| """Calculates a weighted sum of text field embeddings.""" | |
| def __init__(self, num_fields: int, field_dim: int): | |
| super().__init__() | |
| self.attn = nn.Linear(field_dim, 1, bias=False) | |
| self.num_fields = num_fields | |
| def forward(self, fields: torch.Tensor): | |
| scores = self.attn(fields) | |
| weights = F.softmax(scores, dim=1) | |
| weighted_sum = (fields * weights).sum(dim=1) | |
| return weighted_sum, weights.squeeze(-1) | |
| class GenreSelfAttention(nn.Module): | |
| """Calculates a weighted sum of genres based only on the genres themselves.""" | |
| def __init__(self, genre_dim: int): | |
| super().__init__() | |
| self.attn_scorer = nn.Linear(genre_dim, 1, bias=False) | |
| def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor): | |
| scores = self.attn_scorer(genre_embeds) | |
| scores.masked_fill_(mask == 0, -1e9) | |
| weights = F.softmax(scores, dim=1) | |
| weighted_sum = (genre_embeds * weights).sum(dim=1) | |
| return weighted_sum | |
| class ModalityAttention(nn.Module): | |
| """ | |
| Calculates a weighted sum of vectors from different modalities (text, genres, etc.), | |
| allowing the model to dynamically determine their importance. | |
| """ | |
| def __init__(self, num_modalities: int, modality_dim: int): | |
| super().__init__() | |
| self.attn_scorer = nn.Linear(modality_dim, 1, bias=False) | |
| self.num_modalities = num_modalities | |
| def forward(self, modalities: torch.Tensor): | |
| scores = self.attn_scorer(modalities) | |
| weights = F.softmax(scores, dim=1) | |
| weighted_sum = (modalities * weights).sum(dim=1) | |
| return weighted_sum, weights.squeeze(-1) | |
| class AnimeEmbeddingModel(nn.Module): | |
| """ | |
| Main model v13. | |
| """ | |
| def __init__(self, | |
| vocab_sizes: Dict[str, int], | |
| embedding_dims: Dict[str, int] = None, | |
| dropout_rate: float = 0.3, | |
| text_embedding_size: int = 384, | |
| final_embedding_dim: int = 512): | |
| super().__init__() | |
| if embedding_dims is None: | |
| embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32} | |
| self.embedding_dims = embedding_dims | |
| self.final_embedding_dim = final_embedding_dim | |
| self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0) | |
| self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio']) | |
| self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type']) | |
| self.numerical_layer = nn.Linear(6, embedding_dims['numerical']) | |
| self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim) | |
| self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim) | |
| other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical'] | |
| self.other_projector = nn.Linear(other_dim, final_embedding_dim) | |
| self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size) | |
| self.genre_attention = GenreSelfAttention(embedding_dims['genre']) | |
| self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim) | |
| self.encoder = nn.Sequential( | |
| nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024), | |
| nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768), | |
| nn.Linear(768, final_embedding_dim), | |
| ) | |
| self.text_scale = nn.Parameter(torch.tensor(1.0)) | |
| self.genre_scale = nn.Parameter(torch.tensor(1.0)) | |
| self.other_scale = nn.Parameter(torch.tensor(1.0)) | |
| def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| text_fields = torch.stack([ | |
| batch['precomputed_ua_desc'], batch['precomputed_en_desc'], | |
| batch['precomputed_ua_title'], batch['precomputed_en_title'], | |
| batch['precomputed_original_title'], batch['precomputed_alternate_names'], | |
| ], dim=1) | |
| text_vector_raw, _ = self.text_field_attention(text_fields) | |
| genre_embeds_raw = self.genre_embedding(batch['genres']) | |
| genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1)) | |
| studio_emb = self.studio_embedding(batch['studio']) | |
| type_emb = self.type_embedding(batch['type']) | |
| numerical_emb = F.relu(self.numerical_layer(batch['numerical'])) | |
| other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1) | |
| text_vector_proj = self.text_projector(text_vector_raw) | |
| genre_vector_proj = self.genre_projector(genre_vector_raw) | |
| other_vector_proj = self.other_projector(other_vector_parts) | |
| modalities = torch.stack([ | |
| F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale, | |
| F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale, | |
| F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale, | |
| ], dim=1) | |
| combined, _ = self.modality_attention(modalities) | |
| embedding_logits = self.encoder(combined) | |
| embedding = torch.tanh(embedding_logits) | |
| embedding = F.normalize(embedding, p=2, dim=1) | |
| return embedding |