""" Isengard - User Tower Neural network that encodes a user's wine preferences from their reviewed wines. Uses attention-weighted aggregation of wine embeddings based on user ratings. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from .config import ( EMBEDDING_DIM, USER_VECTOR_DIM, HIDDEN_DIM, ) class UserTower(nn.Module): """ Isengard: Encodes user preferences from their reviewed wines. Architecture: 1. Rating-weighted attention over wine embeddings 2. MLP: 768 → 256 → 128 3. L2 normalization to unit sphere Input: wine_embeddings: (batch, num_wines, 768) - embeddings of reviewed wines ratings: (batch, num_wines) - user ratings for each wine mask: (batch, num_wines) - optional mask for padding Output: user_vector: (batch, 128) - normalized user embedding """ def __init__( self, embedding_dim: int = EMBEDDING_DIM, hidden_dim: int = HIDDEN_DIM, output_dim: int = USER_VECTOR_DIM, ): super().__init__() self.embedding_dim = embedding_dim self.output_dim = output_dim # MLP layers self.fc1 = nn.Linear(embedding_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) # Dropout for regularization self.dropout = nn.Dropout(0.1) def forward( self, wine_embeddings: torch.Tensor, ratings: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass through the user tower. Args: wine_embeddings: (batch, num_wines, embedding_dim) ratings: (batch, num_wines) - raw ratings (1-5 scale) mask: (batch, num_wines) - 1 for valid wines, 0 for padding Returns: user_vector: (batch, output_dim) - L2 normalized """ # Convert ratings to attention weights # Higher ratings = more attention # Shift ratings to be positive and scale attention_weights = (ratings - 2.5) / 2.5 # Normalize: 1→-0.6, 5→1.0 attention_weights = F.softmax(attention_weights, dim=-1) # Apply mask if provided if mask is not None: attention_weights = attention_weights * mask # Re-normalize after masking attention_weights = attention_weights / ( attention_weights.sum(dim=-1, keepdim=True) + 1e-8 ) # Weighted aggregation: (batch, num_wines) @ (batch, num_wines, embed_dim) # Result: (batch, embed_dim) aggregated = torch.bmm( attention_weights.unsqueeze(1), # (batch, 1, num_wines) wine_embeddings, # (batch, num_wines, embed_dim) ).squeeze(1) # (batch, embed_dim) # MLP projection x = F.relu(self.fc1(aggregated)) x = self.dropout(x) user_vector = self.fc2(x) # L2 normalize to unit sphere user_vector = F.normalize(user_vector, p=2, dim=-1) return user_vector