Lorg0n commited on
Commit
bca76eb
·
verified ·
1 Parent(s): 55bb496

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +65 -16
model.py CHANGED
@@ -3,14 +3,13 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  from typing import Dict
5
 
6
- # ... Paste the full code for TextFieldAttention, GenreSelfAttention, and AnimeEmbeddingModel here ...
7
- # (The same code as in Cell 2 of the notebook)
8
-
9
  class TextFieldAttention(nn.Module):
 
10
  def __init__(self, num_fields: int, field_dim: int):
11
  super().__init__()
12
  self.attn = nn.Linear(field_dim, 1, bias=False)
13
  self.num_fields = num_fields
 
14
  def forward(self, fields: torch.Tensor):
15
  scores = self.attn(fields)
16
  weights = F.softmax(scores, dim=1)
@@ -18,9 +17,11 @@ class TextFieldAttention(nn.Module):
18
  return weighted_sum, weights.squeeze(-1)
19
 
20
  class GenreSelfAttention(nn.Module):
 
21
  def __init__(self, genre_dim: int):
22
  super().__init__()
23
  self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
 
24
  def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
25
  scores = self.attn_scorer(genre_embeds)
26
  scores.masked_fill_(mask == 0, -1e9)
@@ -28,46 +29,94 @@ class GenreSelfAttention(nn.Module):
28
  weighted_sum = (genre_embeds * weights).sum(dim=1)
29
  return weighted_sum
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class AnimeEmbeddingModel(nn.Module):
32
- def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int], dropout_rate: float = 0.3, text_embedding_size: int = 384):
 
 
 
 
 
 
 
 
33
  super().__init__()
 
 
 
 
34
  self.embedding_dims = embedding_dims
 
 
35
  self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
36
  self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
37
  self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
38
  self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
 
 
 
 
 
 
39
  self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
40
  self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
41
- total_dim = sum(embedding_dims.values())
 
42
  self.encoder = nn.Sequential(
43
- nn.Linear(total_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
44
  nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
45
- nn.Linear(768, 512),
46
  )
 
47
  self.text_scale = nn.Parameter(torch.tensor(1.0))
48
  self.genre_scale = nn.Parameter(torch.tensor(1.0))
49
  self.other_scale = nn.Parameter(torch.tensor(1.0))
 
50
  def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
51
  text_fields = torch.stack([
52
  batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
53
  batch['precomputed_ua_title'], batch['precomputed_en_title'],
54
  batch['precomputed_original_title'], batch['precomputed_alternate_names'],
55
  ], dim=1)
56
- text_vector, _ = self.text_field_attention(text_fields)
 
57
  genre_embeds_raw = self.genre_embedding(batch['genres'])
58
- genre_vector = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))
 
59
  studio_emb = self.studio_embedding(batch['studio'])
60
  type_emb = self.type_embedding(batch['type'])
61
  numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
62
  other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
63
- text_vector_norm = F.normalize(text_vector, p=2, dim=1)
64
- genre_vector_norm = F.normalize(genre_vector, p=2, dim=1)
65
- other_vector_norm = F.normalize(other_vector_parts, p=2, dim=1)
66
- scaled_text = text_vector_norm * self.text_scale
67
- scaled_genre = genre_vector_norm * self.genre_scale
68
- scaled_other = other_vector_norm * self.other_scale
69
- combined = torch.cat([scaled_text, scaled_genre, scaled_other], dim=1)
 
 
 
 
 
 
70
  embedding_logits = self.encoder(combined)
71
  embedding = torch.tanh(embedding_logits)
72
  embedding = F.normalize(embedding, p=2, dim=1)
 
73
  return embedding
 
3
  import torch.nn.functional as F
4
  from typing import Dict
5
 
 
 
 
6
  class TextFieldAttention(nn.Module):
7
+ """Calculates a weighted sum of text field embeddings."""
8
  def __init__(self, num_fields: int, field_dim: int):
9
  super().__init__()
10
  self.attn = nn.Linear(field_dim, 1, bias=False)
11
  self.num_fields = num_fields
12
+
13
  def forward(self, fields: torch.Tensor):
14
  scores = self.attn(fields)
15
  weights = F.softmax(scores, dim=1)
 
17
  return weighted_sum, weights.squeeze(-1)
18
 
19
  class GenreSelfAttention(nn.Module):
20
+ """Calculates a weighted sum of genres based only on the genres themselves."""
21
  def __init__(self, genre_dim: int):
22
  super().__init__()
23
  self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
24
+
25
  def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
26
  scores = self.attn_scorer(genre_embeds)
27
  scores.masked_fill_(mask == 0, -1e9)
 
29
  weighted_sum = (genre_embeds * weights).sum(dim=1)
30
  return weighted_sum
31
 
32
+ class ModalityAttention(nn.Module):
33
+ """
34
+ Calculates a weighted sum of vectors from different modalities (text, genres, etc.),
35
+ allowing the model to dynamically determine their importance.
36
+ """
37
+ def __init__(self, num_modalities: int, modality_dim: int):
38
+ super().__init__()
39
+ self.attn_scorer = nn.Linear(modality_dim, 1, bias=False)
40
+ self.num_modalities = num_modalities
41
+
42
+ def forward(self, modalities: torch.Tensor):
43
+ scores = self.attn_scorer(modalities)
44
+ weights = F.softmax(scores, dim=1)
45
+ weighted_sum = (modalities * weights).sum(dim=1)
46
+ return weighted_sum, weights.squeeze(-1)
47
+
48
  class AnimeEmbeddingModel(nn.Module):
49
+ """
50
+ Main model v13.
51
+ """
52
+ def __init__(self,
53
+ vocab_sizes: Dict[str, int],
54
+ embedding_dims: Dict[str, int] = None,
55
+ dropout_rate: float = 0.3,
56
+ text_embedding_size: int = 384,
57
+ final_embedding_dim: int = 512):
58
  super().__init__()
59
+
60
+ if embedding_dims is None:
61
+ embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32}
62
+
63
  self.embedding_dims = embedding_dims
64
+ self.final_embedding_dim = final_embedding_dim
65
+
66
  self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
67
  self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
68
  self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
69
  self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
70
+
71
+ self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim)
72
+ self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim)
73
+ other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical']
74
+ self.other_projector = nn.Linear(other_dim, final_embedding_dim)
75
+
76
  self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
77
  self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
78
+ self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim)
79
+
80
  self.encoder = nn.Sequential(
81
+ nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
82
  nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
83
+ nn.Linear(768, final_embedding_dim),
84
  )
85
+
86
  self.text_scale = nn.Parameter(torch.tensor(1.0))
87
  self.genre_scale = nn.Parameter(torch.tensor(1.0))
88
  self.other_scale = nn.Parameter(torch.tensor(1.0))
89
+
90
  def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
91
  text_fields = torch.stack([
92
  batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
93
  batch['precomputed_ua_title'], batch['precomputed_en_title'],
94
  batch['precomputed_original_title'], batch['precomputed_alternate_names'],
95
  ], dim=1)
96
+ text_vector_raw, _ = self.text_field_attention(text_fields)
97
+
98
  genre_embeds_raw = self.genre_embedding(batch['genres'])
99
+ genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))
100
+
101
  studio_emb = self.studio_embedding(batch['studio'])
102
  type_emb = self.type_embedding(batch['type'])
103
  numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
104
  other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
105
+
106
+ text_vector_proj = self.text_projector(text_vector_raw)
107
+ genre_vector_proj = self.genre_projector(genre_vector_raw)
108
+ other_vector_proj = self.other_projector(other_vector_parts)
109
+
110
+ modalities = torch.stack([
111
+ F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale,
112
+ F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale,
113
+ F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale,
114
+ ], dim=1)
115
+
116
+ combined, _ = self.modality_attention(modalities)
117
+
118
  embedding_logits = self.encoder(combined)
119
  embedding = torch.tanh(embedding_logits)
120
  embedding = F.normalize(embedding, p=2, dim=1)
121
+
122
  return embedding