two-tower-recommender / user_tower.py
swirl's picture
Upload user_tower.py with huggingface_hub
d820920 verified
"""
Isengard - User Tower
Neural network that encodes a user's wine preferences from their reviewed wines.
Uses attention-weighted aggregation of wine embeddings based on user ratings.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .config import (
EMBEDDING_DIM,
USER_VECTOR_DIM,
HIDDEN_DIM,
)
class UserTower(nn.Module):
"""
Isengard: Encodes user preferences from their reviewed wines.
Architecture:
1. Rating-weighted attention over wine embeddings
2. MLP: 768 → 256 → 128
3. L2 normalization to unit sphere
Input:
wine_embeddings: (batch, num_wines, 768) - embeddings of reviewed wines
ratings: (batch, num_wines) - user ratings for each wine
mask: (batch, num_wines) - optional mask for padding
Output:
user_vector: (batch, 128) - normalized user embedding
"""
def __init__(
self,
embedding_dim: int = EMBEDDING_DIM,
hidden_dim: int = HIDDEN_DIM,
output_dim: int = USER_VECTOR_DIM,
):
super().__init__()
self.embedding_dim = embedding_dim
self.output_dim = output_dim
# MLP layers
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
# Dropout for regularization
self.dropout = nn.Dropout(0.1)
def forward(
self,
wine_embeddings: torch.Tensor,
ratings: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass through the user tower.
Args:
wine_embeddings: (batch, num_wines, embedding_dim)
ratings: (batch, num_wines) - raw ratings (1-5 scale)
mask: (batch, num_wines) - 1 for valid wines, 0 for padding
Returns:
user_vector: (batch, output_dim) - L2 normalized
"""
# Convert ratings to attention weights
# Higher ratings = more attention
# Shift ratings to be positive and scale
attention_weights = (ratings - 2.5) / 2.5 # Normalize: 1→-0.6, 5→1.0
attention_weights = F.softmax(attention_weights, dim=-1)
# Apply mask if provided
if mask is not None:
attention_weights = attention_weights * mask
# Re-normalize after masking
attention_weights = attention_weights / (
attention_weights.sum(dim=-1, keepdim=True) + 1e-8
)
# Weighted aggregation: (batch, num_wines) @ (batch, num_wines, embed_dim)
# Result: (batch, embed_dim)
aggregated = torch.bmm(
attention_weights.unsqueeze(1), # (batch, 1, num_wines)
wine_embeddings, # (batch, num_wines, embed_dim)
).squeeze(1) # (batch, embed_dim)
# MLP projection
x = F.relu(self.fc1(aggregated))
x = self.dropout(x)
user_vector = self.fc2(x)
# L2 normalize to unit sphere
user_vector = F.normalize(user_vector, p=2, dim=-1)
return user_vector