File size: 5,431 Bytes
fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa bca76eb fb84cfa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict
class TextFieldAttention(nn.Module):
"""Calculates a weighted sum of text field embeddings."""
def __init__(self, num_fields: int, field_dim: int):
super().__init__()
self.attn = nn.Linear(field_dim, 1, bias=False)
self.num_fields = num_fields
def forward(self, fields: torch.Tensor):
scores = self.attn(fields)
weights = F.softmax(scores, dim=1)
weighted_sum = (fields * weights).sum(dim=1)
return weighted_sum, weights.squeeze(-1)
class GenreSelfAttention(nn.Module):
"""Calculates a weighted sum of genres based only on the genres themselves."""
def __init__(self, genre_dim: int):
super().__init__()
self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
scores = self.attn_scorer(genre_embeds)
scores.masked_fill_(mask == 0, -1e9)
weights = F.softmax(scores, dim=1)
weighted_sum = (genre_embeds * weights).sum(dim=1)
return weighted_sum
class ModalityAttention(nn.Module):
"""
Calculates a weighted sum of vectors from different modalities (text, genres, etc.),
allowing the model to dynamically determine their importance.
"""
def __init__(self, num_modalities: int, modality_dim: int):
super().__init__()
self.attn_scorer = nn.Linear(modality_dim, 1, bias=False)
self.num_modalities = num_modalities
def forward(self, modalities: torch.Tensor):
scores = self.attn_scorer(modalities)
weights = F.softmax(scores, dim=1)
weighted_sum = (modalities * weights).sum(dim=1)
return weighted_sum, weights.squeeze(-1)
class AnimeEmbeddingModel(nn.Module):
"""
Main model v13.
"""
def __init__(self,
vocab_sizes: Dict[str, int],
embedding_dims: Dict[str, int] = None,
dropout_rate: float = 0.3,
text_embedding_size: int = 384,
final_embedding_dim: int = 512):
super().__init__()
if embedding_dims is None:
embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32}
self.embedding_dims = embedding_dims
self.final_embedding_dim = final_embedding_dim
self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim)
self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim)
other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical']
self.other_projector = nn.Linear(other_dim, final_embedding_dim)
self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim)
self.encoder = nn.Sequential(
nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
nn.Linear(768, final_embedding_dim),
)
self.text_scale = nn.Parameter(torch.tensor(1.0))
self.genre_scale = nn.Parameter(torch.tensor(1.0))
self.other_scale = nn.Parameter(torch.tensor(1.0))
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
text_fields = torch.stack([
batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
batch['precomputed_ua_title'], batch['precomputed_en_title'],
batch['precomputed_original_title'], batch['precomputed_alternate_names'],
], dim=1)
text_vector_raw, _ = self.text_field_attention(text_fields)
genre_embeds_raw = self.genre_embedding(batch['genres'])
genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))
studio_emb = self.studio_embedding(batch['studio'])
type_emb = self.type_embedding(batch['type'])
numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
text_vector_proj = self.text_projector(text_vector_raw)
genre_vector_proj = self.genre_projector(genre_vector_raw)
other_vector_proj = self.other_projector(other_vector_parts)
modalities = torch.stack([
F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale,
F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale,
F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale,
], dim=1)
combined, _ = self.modality_attention(modalities)
embedding_logits = self.encoder(combined)
embedding = torch.tanh(embedding_logits)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding |