Spaces:
Runtime error
Runtime error
| """Two-tower neural recommender. | |
| User tower: user_id βββΊ Embedding βββΊ MLP βββΊ L2-normalized user_vec | |
| Track tower: (track_id, artist_id) βββΊ Embeddings βββΊ MLP βββΊ L2-normalized track_vec | |
| Score = dot(user_vec, track_vec). Trained with InfoNCE / in-batch negatives. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def mlp(in_dim: int, hidden: int, out_dim: int, dropout: float = 0.1) -> nn.Sequential: | |
| return nn.Sequential( | |
| nn.Linear(in_dim, hidden), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden, out_dim), | |
| ) | |
| class UserTower(nn.Module): | |
| def __init__(self, n_users: int, embed_dim: int = 64, out_dim: int = 128): | |
| super().__init__() | |
| self.user_emb = nn.Embedding(n_users, embed_dim) | |
| self.proj = mlp(embed_dim, 256, out_dim) | |
| def forward(self, user_idx: torch.Tensor) -> torch.Tensor: | |
| x = self.user_emb(user_idx) | |
| return F.normalize(self.proj(x), dim=-1) | |
| class TrackTower(nn.Module): | |
| def __init__(self, n_tracks: int, n_artists: int, embed_dim: int = 64, out_dim: int = 128): | |
| super().__init__() | |
| self.track_emb = nn.Embedding(n_tracks, embed_dim) | |
| self.artist_emb = nn.Embedding(n_artists, embed_dim) | |
| self.proj = mlp(embed_dim * 2, 256, out_dim) | |
| def forward(self, track_idx: torch.Tensor, artist_idx: torch.Tensor) -> torch.Tensor: | |
| x = torch.cat([self.track_emb(track_idx), self.artist_emb(artist_idx)], dim=-1) | |
| return F.normalize(self.proj(x), dim=-1) | |
| class TwoTower(nn.Module): | |
| def __init__(self, n_users: int, n_tracks: int, n_artists: int, out_dim: int = 128): | |
| super().__init__() | |
| self.user_tower = UserTower(n_users, out_dim=out_dim) | |
| self.track_tower = TrackTower(n_tracks, n_artists, out_dim=out_dim) | |
| self.log_temp = nn.Parameter(torch.tensor(0.0)) # learnable temperature, like CLIP | |
| def forward( | |
| self, user_idx: torch.Tensor, track_idx: torch.Tensor, artist_idx: torch.Tensor | |
| ) -> torch.Tensor: | |
| u = self.user_tower(user_idx) # (B, D) | |
| t = self.track_tower(track_idx, artist_idx) # (B, D) | |
| logits = (u @ t.T) * self.log_temp.exp() # (B, B) | |
| return logits | |
| def info_nce_loss(logits: torch.Tensor) -> torch.Tensor: | |
| """Symmetric InfoNCE β diagonal is the positive pair, off-diagonal are in-batch negatives.""" | |
| targets = torch.arange(logits.size(0), device=logits.device) | |
| return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets)) | |