Spaces:
Runtime error
Runtime error
File size: 2,607 Bytes
68aba54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """Two-tower neural recommender.
User tower: user_id βββΊ Embedding βββΊ MLP βββΊ L2-normalized user_vec
Track tower: (track_id, artist_id) βββΊ Embeddings βββΊ MLP βββΊ L2-normalized track_vec
Score = dot(user_vec, track_vec). Trained with InfoNCE / in-batch negatives.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
def mlp(in_dim: int, hidden: int, out_dim: int, dropout: float = 0.1) -> nn.Sequential:
return nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
class UserTower(nn.Module):
def __init__(self, n_users: int, embed_dim: int = 64, out_dim: int = 128):
super().__init__()
self.user_emb = nn.Embedding(n_users, embed_dim)
self.proj = mlp(embed_dim, 256, out_dim)
def forward(self, user_idx: torch.Tensor) -> torch.Tensor:
x = self.user_emb(user_idx)
return F.normalize(self.proj(x), dim=-1)
class TrackTower(nn.Module):
def __init__(self, n_tracks: int, n_artists: int, embed_dim: int = 64, out_dim: int = 128):
super().__init__()
self.track_emb = nn.Embedding(n_tracks, embed_dim)
self.artist_emb = nn.Embedding(n_artists, embed_dim)
self.proj = mlp(embed_dim * 2, 256, out_dim)
def forward(self, track_idx: torch.Tensor, artist_idx: torch.Tensor) -> torch.Tensor:
x = torch.cat([self.track_emb(track_idx), self.artist_emb(artist_idx)], dim=-1)
return F.normalize(self.proj(x), dim=-1)
class TwoTower(nn.Module):
def __init__(self, n_users: int, n_tracks: int, n_artists: int, out_dim: int = 128):
super().__init__()
self.user_tower = UserTower(n_users, out_dim=out_dim)
self.track_tower = TrackTower(n_tracks, n_artists, out_dim=out_dim)
self.log_temp = nn.Parameter(torch.tensor(0.0)) # learnable temperature, like CLIP
def forward(
self, user_idx: torch.Tensor, track_idx: torch.Tensor, artist_idx: torch.Tensor
) -> torch.Tensor:
u = self.user_tower(user_idx) # (B, D)
t = self.track_tower(track_idx, artist_idx) # (B, D)
logits = (u @ t.T) * self.log_temp.exp() # (B, B)
return logits
def info_nce_loss(logits: torch.Tensor) -> torch.Tensor:
"""Symmetric InfoNCE β diagonal is the positive pair, off-diagonal are in-batch negatives."""
targets = torch.arange(logits.size(0), device=logits.device)
return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
|