Upload model.py with huggingface_hub
Browse files
model.py
CHANGED
|
@@ -3,14 +3,13 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from typing import Dict
|
| 5 |
|
| 6 |
-
# ... Paste the full code for TextFieldAttention, GenreSelfAttention, and AnimeEmbeddingModel here ...
|
| 7 |
-
# (The same code as in Cell 2 of the notebook)
|
| 8 |
-
|
| 9 |
class TextFieldAttention(nn.Module):
|
|
|
|
| 10 |
def __init__(self, num_fields: int, field_dim: int):
|
| 11 |
super().__init__()
|
| 12 |
self.attn = nn.Linear(field_dim, 1, bias=False)
|
| 13 |
self.num_fields = num_fields
|
|
|
|
| 14 |
def forward(self, fields: torch.Tensor):
|
| 15 |
scores = self.attn(fields)
|
| 16 |
weights = F.softmax(scores, dim=1)
|
|
@@ -18,9 +17,11 @@ class TextFieldAttention(nn.Module):
|
|
| 18 |
return weighted_sum, weights.squeeze(-1)
|
| 19 |
|
| 20 |
class GenreSelfAttention(nn.Module):
|
|
|
|
| 21 |
def __init__(self, genre_dim: int):
|
| 22 |
super().__init__()
|
| 23 |
self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
|
|
|
|
| 24 |
def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
|
| 25 |
scores = self.attn_scorer(genre_embeds)
|
| 26 |
scores.masked_fill_(mask == 0, -1e9)
|
|
@@ -28,46 +29,94 @@ class GenreSelfAttention(nn.Module):
|
|
| 28 |
weighted_sum = (genre_embeds * weights).sum(dim=1)
|
| 29 |
return weighted_sum
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class AnimeEmbeddingModel(nn.Module):
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self.embedding_dims = embedding_dims
|
|
|
|
|
|
|
| 35 |
self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
|
| 36 |
self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
|
| 37 |
self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
|
| 38 |
self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
|
| 40 |
self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
|
| 41 |
-
|
|
|
|
| 42 |
self.encoder = nn.Sequential(
|
| 43 |
-
nn.Linear(
|
| 44 |
nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
|
| 45 |
-
nn.Linear(768,
|
| 46 |
)
|
|
|
|
| 47 |
self.text_scale = nn.Parameter(torch.tensor(1.0))
|
| 48 |
self.genre_scale = nn.Parameter(torch.tensor(1.0))
|
| 49 |
self.other_scale = nn.Parameter(torch.tensor(1.0))
|
|
|
|
| 50 |
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 51 |
text_fields = torch.stack([
|
| 52 |
batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
|
| 53 |
batch['precomputed_ua_title'], batch['precomputed_en_title'],
|
| 54 |
batch['precomputed_original_title'], batch['precomputed_alternate_names'],
|
| 55 |
], dim=1)
|
| 56 |
-
|
|
|
|
| 57 |
genre_embeds_raw = self.genre_embedding(batch['genres'])
|
| 58 |
-
|
|
|
|
| 59 |
studio_emb = self.studio_embedding(batch['studio'])
|
| 60 |
type_emb = self.type_embedding(batch['type'])
|
| 61 |
numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
|
| 62 |
other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
embedding_logits = self.encoder(combined)
|
| 71 |
embedding = torch.tanh(embedding_logits)
|
| 72 |
embedding = F.normalize(embedding, p=2, dim=1)
|
|
|
|
| 73 |
return embedding
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from typing import Dict
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
class TextFieldAttention(nn.Module):
|
| 7 |
+
"""Calculates a weighted sum of text field embeddings."""
|
| 8 |
def __init__(self, num_fields: int, field_dim: int):
|
| 9 |
super().__init__()
|
| 10 |
self.attn = nn.Linear(field_dim, 1, bias=False)
|
| 11 |
self.num_fields = num_fields
|
| 12 |
+
|
| 13 |
def forward(self, fields: torch.Tensor):
|
| 14 |
scores = self.attn(fields)
|
| 15 |
weights = F.softmax(scores, dim=1)
|
|
|
|
| 17 |
return weighted_sum, weights.squeeze(-1)
|
| 18 |
|
| 19 |
class GenreSelfAttention(nn.Module):
|
| 20 |
+
"""Calculates a weighted sum of genres based only on the genres themselves."""
|
| 21 |
def __init__(self, genre_dim: int):
|
| 22 |
super().__init__()
|
| 23 |
self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
|
| 24 |
+
|
| 25 |
def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
|
| 26 |
scores = self.attn_scorer(genre_embeds)
|
| 27 |
scores.masked_fill_(mask == 0, -1e9)
|
|
|
|
| 29 |
weighted_sum = (genre_embeds * weights).sum(dim=1)
|
| 30 |
return weighted_sum
|
| 31 |
|
| 32 |
+
class ModalityAttention(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Calculates a weighted sum of vectors from different modalities (text, genres, etc.),
|
| 35 |
+
allowing the model to dynamically determine their importance.
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, num_modalities: int, modality_dim: int):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.attn_scorer = nn.Linear(modality_dim, 1, bias=False)
|
| 40 |
+
self.num_modalities = num_modalities
|
| 41 |
+
|
| 42 |
+
def forward(self, modalities: torch.Tensor):
|
| 43 |
+
scores = self.attn_scorer(modalities)
|
| 44 |
+
weights = F.softmax(scores, dim=1)
|
| 45 |
+
weighted_sum = (modalities * weights).sum(dim=1)
|
| 46 |
+
return weighted_sum, weights.squeeze(-1)
|
| 47 |
+
|
| 48 |
class AnimeEmbeddingModel(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
Main model v13.
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self,
|
| 53 |
+
vocab_sizes: Dict[str, int],
|
| 54 |
+
embedding_dims: Dict[str, int] = None,
|
| 55 |
+
dropout_rate: float = 0.3,
|
| 56 |
+
text_embedding_size: int = 384,
|
| 57 |
+
final_embedding_dim: int = 512):
|
| 58 |
super().__init__()
|
| 59 |
+
|
| 60 |
+
if embedding_dims is None:
|
| 61 |
+
embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32}
|
| 62 |
+
|
| 63 |
self.embedding_dims = embedding_dims
|
| 64 |
+
self.final_embedding_dim = final_embedding_dim
|
| 65 |
+
|
| 66 |
self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
|
| 67 |
self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
|
| 68 |
self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
|
| 69 |
self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
|
| 70 |
+
|
| 71 |
+
self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim)
|
| 72 |
+
self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim)
|
| 73 |
+
other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical']
|
| 74 |
+
self.other_projector = nn.Linear(other_dim, final_embedding_dim)
|
| 75 |
+
|
| 76 |
self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
|
| 77 |
self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
|
| 78 |
+
self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim)
|
| 79 |
+
|
| 80 |
self.encoder = nn.Sequential(
|
| 81 |
+
nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
|
| 82 |
nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
|
| 83 |
+
nn.Linear(768, final_embedding_dim),
|
| 84 |
)
|
| 85 |
+
|
| 86 |
self.text_scale = nn.Parameter(torch.tensor(1.0))
|
| 87 |
self.genre_scale = nn.Parameter(torch.tensor(1.0))
|
| 88 |
self.other_scale = nn.Parameter(torch.tensor(1.0))
|
| 89 |
+
|
| 90 |
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 91 |
text_fields = torch.stack([
|
| 92 |
batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
|
| 93 |
batch['precomputed_ua_title'], batch['precomputed_en_title'],
|
| 94 |
batch['precomputed_original_title'], batch['precomputed_alternate_names'],
|
| 95 |
], dim=1)
|
| 96 |
+
text_vector_raw, _ = self.text_field_attention(text_fields)
|
| 97 |
+
|
| 98 |
genre_embeds_raw = self.genre_embedding(batch['genres'])
|
| 99 |
+
genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))
|
| 100 |
+
|
| 101 |
studio_emb = self.studio_embedding(batch['studio'])
|
| 102 |
type_emb = self.type_embedding(batch['type'])
|
| 103 |
numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
|
| 104 |
other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
|
| 105 |
+
|
| 106 |
+
text_vector_proj = self.text_projector(text_vector_raw)
|
| 107 |
+
genre_vector_proj = self.genre_projector(genre_vector_raw)
|
| 108 |
+
other_vector_proj = self.other_projector(other_vector_parts)
|
| 109 |
+
|
| 110 |
+
modalities = torch.stack([
|
| 111 |
+
F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale,
|
| 112 |
+
F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale,
|
| 113 |
+
F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale,
|
| 114 |
+
], dim=1)
|
| 115 |
+
|
| 116 |
+
combined, _ = self.modality_attention(modalities)
|
| 117 |
+
|
| 118 |
embedding_logits = self.encoder(combined)
|
| 119 |
embedding = torch.tanh(embedding_logits)
|
| 120 |
embedding = F.normalize(embedding, p=2, dim=1)
|
| 121 |
+
|
| 122 |
return embedding
|