File size: 5,431 Bytes
fb84cfa
 
 
 
 
 
bca76eb
fb84cfa
 
 
 
bca76eb
fb84cfa
 
 
 
 
 
 
bca76eb
fb84cfa
 
 
bca76eb
fb84cfa
 
 
 
 
 
 
bca76eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb84cfa
bca76eb
 
 
 
 
 
 
 
 
fb84cfa
bca76eb
 
 
 
fb84cfa
bca76eb
 
fb84cfa
 
 
 
bca76eb
 
 
 
 
 
fb84cfa
 
bca76eb
 
fb84cfa
bca76eb
fb84cfa
bca76eb
fb84cfa
bca76eb
fb84cfa
 
 
bca76eb
fb84cfa
 
 
 
 
 
bca76eb
 
fb84cfa
bca76eb
 
fb84cfa
 
 
 
bca76eb
 
 
 
 
 
 
 
 
 
 
 
 
fb84cfa
 
 
bca76eb
fb84cfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict

class TextFieldAttention(nn.Module):
    """Calculates a weighted sum of text field embeddings."""
    def __init__(self, num_fields: int, field_dim: int):
        super().__init__()
        self.attn = nn.Linear(field_dim, 1, bias=False)
        self.num_fields = num_fields

    def forward(self, fields: torch.Tensor):
        scores = self.attn(fields)
        weights = F.softmax(scores, dim=1)
        weighted_sum = (fields * weights).sum(dim=1)
        return weighted_sum, weights.squeeze(-1)

class GenreSelfAttention(nn.Module):
    """Calculates a weighted sum of genres based only on the genres themselves."""
    def __init__(self, genre_dim: int):
        super().__init__()
        self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)

    def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
        scores = self.attn_scorer(genre_embeds)
        scores.masked_fill_(mask == 0, -1e9)
        weights = F.softmax(scores, dim=1)
        weighted_sum = (genre_embeds * weights).sum(dim=1)
        return weighted_sum

class ModalityAttention(nn.Module):
    """
    Calculates a weighted sum of vectors from different modalities (text, genres, etc.),
    allowing the model to dynamically determine their importance.
    """
    def __init__(self, num_modalities: int, modality_dim: int):
        super().__init__()
        self.attn_scorer = nn.Linear(modality_dim, 1, bias=False)
        self.num_modalities = num_modalities

    def forward(self, modalities: torch.Tensor):
        scores = self.attn_scorer(modalities)
        weights = F.softmax(scores, dim=1)
        weighted_sum = (modalities * weights).sum(dim=1)
        return weighted_sum, weights.squeeze(-1)

class AnimeEmbeddingModel(nn.Module):
    """
    Main model v13.
    """
    def __init__(self,
                 vocab_sizes: Dict[str, int],
                 embedding_dims: Dict[str, int] = None,
                 dropout_rate: float = 0.3,
                 text_embedding_size: int = 384,
                 final_embedding_dim: int = 512):
        super().__init__()

        if embedding_dims is None:
            embedding_dims = {'genre': 128, 'studio': 64, 'type': 16, 'numerical': 32}

        self.embedding_dims = embedding_dims
        self.final_embedding_dim = final_embedding_dim

        self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
        self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
        self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
        self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])

        self.text_projector = nn.Linear(text_embedding_size, final_embedding_dim)
        self.genre_projector = nn.Linear(embedding_dims['genre'], final_embedding_dim)
        other_dim = embedding_dims['studio'] + embedding_dims['type'] + embedding_dims['numerical']
        self.other_projector = nn.Linear(other_dim, final_embedding_dim)

        self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
        self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
        self.modality_attention = ModalityAttention(num_modalities=3, modality_dim=final_embedding_dim)

        self.encoder = nn.Sequential(
            nn.Linear(final_embedding_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
            nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
            nn.Linear(768, final_embedding_dim),
        )

        self.text_scale = nn.Parameter(torch.tensor(1.0))
        self.genre_scale = nn.Parameter(torch.tensor(1.0))
        self.other_scale = nn.Parameter(torch.tensor(1.0))

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        text_fields = torch.stack([
            batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
            batch['precomputed_ua_title'], batch['precomputed_en_title'],
            batch['precomputed_original_title'], batch['precomputed_alternate_names'],
        ], dim=1)
        text_vector_raw, _ = self.text_field_attention(text_fields)

        genre_embeds_raw = self.genre_embedding(batch['genres'])
        genre_vector_raw = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))

        studio_emb = self.studio_embedding(batch['studio'])
        type_emb = self.type_embedding(batch['type'])
        numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
        other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)

        text_vector_proj = self.text_projector(text_vector_raw)
        genre_vector_proj = self.genre_projector(genre_vector_raw)
        other_vector_proj = self.other_projector(other_vector_parts)

        modalities = torch.stack([
            F.normalize(text_vector_proj, p=2, dim=1) * self.text_scale,
            F.normalize(genre_vector_proj, p=2, dim=1) * self.genre_scale,
            F.normalize(other_vector_proj, p=2, dim=1) * self.other_scale,
        ], dim=1)

        combined, _ = self.modality_attention(modalities)

        embedding_logits = self.encoder(combined)
        embedding = torch.tanh(embedding_logits)
        embedding = F.normalize(embedding, p=2, dim=1)

        return embedding