File size: 2,607 Bytes
68aba54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Two-tower neural recommender.

User tower:  user_id ──► Embedding ──► MLP ──► L2-normalized user_vec
Track tower: (track_id, artist_id) ──► Embeddings ──► MLP ──► L2-normalized track_vec

Score = dot(user_vec, track_vec). Trained with InfoNCE / in-batch negatives.
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


def mlp(in_dim: int, hidden: int, out_dim: int, dropout: float = 0.1) -> nn.Sequential:
    return nn.Sequential(
        nn.Linear(in_dim, hidden),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, out_dim),
    )


class UserTower(nn.Module):
    def __init__(self, n_users: int, embed_dim: int = 64, out_dim: int = 128):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, embed_dim)
        self.proj = mlp(embed_dim, 256, out_dim)

    def forward(self, user_idx: torch.Tensor) -> torch.Tensor:
        x = self.user_emb(user_idx)
        return F.normalize(self.proj(x), dim=-1)


class TrackTower(nn.Module):
    def __init__(self, n_tracks: int, n_artists: int, embed_dim: int = 64, out_dim: int = 128):
        super().__init__()
        self.track_emb = nn.Embedding(n_tracks, embed_dim)
        self.artist_emb = nn.Embedding(n_artists, embed_dim)
        self.proj = mlp(embed_dim * 2, 256, out_dim)

    def forward(self, track_idx: torch.Tensor, artist_idx: torch.Tensor) -> torch.Tensor:
        x = torch.cat([self.track_emb(track_idx), self.artist_emb(artist_idx)], dim=-1)
        return F.normalize(self.proj(x), dim=-1)


class TwoTower(nn.Module):
    def __init__(self, n_users: int, n_tracks: int, n_artists: int, out_dim: int = 128):
        super().__init__()
        self.user_tower = UserTower(n_users, out_dim=out_dim)
        self.track_tower = TrackTower(n_tracks, n_artists, out_dim=out_dim)
        self.log_temp = nn.Parameter(torch.tensor(0.0))  # learnable temperature, like CLIP

    def forward(
        self, user_idx: torch.Tensor, track_idx: torch.Tensor, artist_idx: torch.Tensor
    ) -> torch.Tensor:
        u = self.user_tower(user_idx)  # (B, D)
        t = self.track_tower(track_idx, artist_idx)  # (B, D)
        logits = (u @ t.T) * self.log_temp.exp()  # (B, B)
        return logits


def info_nce_loss(logits: torch.Tensor) -> torch.Tensor:
    """Symmetric InfoNCE β€” diagonal is the positive pair, off-diagonal are in-batch negatives."""
    targets = torch.arange(logits.size(0), device=logits.device)
    return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))