Lorg0n's picture
Upload model.py with huggingface_hub
bca76eb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict
class TextFieldAttention(nn.Module):
"""Calculates a weighted sum of text field embeddings."""
def __init__(self, num_fields: int, field_dim: int):
super().__init__()
self.attn = nn.Linear(field_dim, 1, bias=False)
self.num_fields = num_fields
def forward(self, fields: torch.Tensor):
scores = self.attn(fields)
weights = F.softmax(scores, dim=1)
weighted_sum = (fields * weights).sum(dim=1)
return weighted_sum, weights.squeeze(-1)
class GenreSelfAttention(nn.Module):
"""Calculates a weighted sum of genres based only on the genres themselves."""
def __init__(self, genre_dim: int):
super().__init__()
self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
scores = self.attn_scorer(genre_embeds)
scores.masked_fill_(mask == 0, -1e9)
weights = F.softmax(scores, dim=1)
weighted_sum = (genre_embeds * weights).sum(dim=1)
return weighted_sum
class ModalityAttention(nn.Module):
"""
Calculates a weighted sum of vectors from different modalities (text, genres, etc.),
allowing the model to dynamically determine their importance.
"""
def __init__(self, num_modalities: int, modality_dim: int):
super().__init__()
self.attn_scorer = nn.Linear(modality_dim, 1, bias=False)
self.num_modalities = num_modalities
def forward(self, modalities: torch.Tensor):
scores = self.attn_scorer(modalities)
weights = F.softmax(scores, dim=1)
weighted_sum = (modalities * weights).sum(dim=1)
return weighted_sum, weights.squeeze(-1)
class AnimeEmbeddingModel(nn.Module):
"""
Main model v13.
"""
def __init__(self,
vocab_sizes: Dict[str, int],
embedding_dims: Dict[str, int] = None,
dropout_rate: float = 0.3,
text_embedding_size: int = 384,
final_embedding_dim: int = 512):
super().__init__()
if embedding_dims is None:
embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32}
self.embedding_dims = embedding_dims
self.final_embedding_dim = final_embedding_dim
self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim)
self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim)
other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical']
self.other_projector = nn.Linear(other_dim, final_embedding_dim)
self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim)
self.encoder = nn.Sequential(
nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
nn.Linear(768, final_embedding_dim),
)
self.text_scale = nn.Parameter(torch.tensor(1.0))
self.genre_scale = nn.Parameter(torch.tensor(1.0))
self.other_scale = nn.Parameter(torch.tensor(1.0))
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
text_fields = torch.stack([
batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
batch['precomputed_ua_title'], batch['precomputed_en_title'],
batch['precomputed_original_title'], batch['precomputed_alternate_names'],
], dim=1)
text_vector_raw, _ = self.text_field_attention(text_fields)
genre_embeds_raw = self.genre_embedding(batch['genres'])
genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))
studio_emb = self.studio_embedding(batch['studio'])
type_emb = self.type_embedding(batch['type'])
numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
text_vector_proj = self.text_projector(text_vector_raw)
genre_vector_proj = self.genre_projector(genre_vector_raw)
other_vector_proj = self.other_projector(other_vector_parts)
modalities = torch.stack([
F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale,
F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale,
F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale,
], dim=1)
combined, _ = self.modality_attention(modalities)
embedding_logits = self.encoder(combined)
embedding = torch.tanh(embedding_logits)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding