loopback / model.py
DanielRegaladoCardoso's picture
Upload model.py with huggingface_hub
68aba54 verified
Raw
History Blame Contribute Delete
2.61 kB
"""Two-tower neural recommender.
User tower: user_id ──► Embedding ──► MLP ──► L2-normalized user_vec
Track tower: (track_id, artist_id) ──► Embeddings ──► MLP ──► L2-normalized track_vec
Score = dot(user_vec, track_vec). Trained with InfoNCE / in-batch negatives.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
def mlp(in_dim: int, hidden: int, out_dim: int, dropout: float = 0.1) -> nn.Sequential:
return nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
class UserTower(nn.Module):
def __init__(self, n_users: int, embed_dim: int = 64, out_dim: int = 128):
super().__init__()
self.user_emb = nn.Embedding(n_users, embed_dim)
self.proj = mlp(embed_dim, 256, out_dim)
def forward(self, user_idx: torch.Tensor) -> torch.Tensor:
x = self.user_emb(user_idx)
return F.normalize(self.proj(x), dim=-1)
class TrackTower(nn.Module):
def __init__(self, n_tracks: int, n_artists: int, embed_dim: int = 64, out_dim: int = 128):
super().__init__()
self.track_emb = nn.Embedding(n_tracks, embed_dim)
self.artist_emb = nn.Embedding(n_artists, embed_dim)
self.proj = mlp(embed_dim * 2, 256, out_dim)
def forward(self, track_idx: torch.Tensor, artist_idx: torch.Tensor) -> torch.Tensor:
x = torch.cat([self.track_emb(track_idx), self.artist_emb(artist_idx)], dim=-1)
return F.normalize(self.proj(x), dim=-1)
class TwoTower(nn.Module):
def __init__(self, n_users: int, n_tracks: int, n_artists: int, out_dim: int = 128):
super().__init__()
self.user_tower = UserTower(n_users, out_dim=out_dim)
self.track_tower = TrackTower(n_tracks, n_artists, out_dim=out_dim)
self.log_temp = nn.Parameter(torch.tensor(0.0)) # learnable temperature, like CLIP
def forward(
self, user_idx: torch.Tensor, track_idx: torch.Tensor, artist_idx: torch.Tensor
) -> torch.Tensor:
u = self.user_tower(user_idx) # (B, D)
t = self.track_tower(track_idx, artist_idx) # (B, D)
logits = (u @ t.T) * self.log_temp.exp() # (B, B)
return logits
def info_nce_loss(logits: torch.Tensor) -> torch.Tensor:
"""Symmetric InfoNCE β€” diagonal is the positive pair, off-diagonal are in-batch negatives."""
targets = torch.arange(logits.size(0), device=logits.device)
return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))