|
|
""" |
|
|
Isengard - User Tower |
|
|
|
|
|
Neural network that encodes a user's wine preferences from their reviewed wines. |
|
|
Uses attention-weighted aggregation of wine embeddings based on user ratings. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional |
|
|
|
|
|
from .config import ( |
|
|
EMBEDDING_DIM, |
|
|
USER_VECTOR_DIM, |
|
|
HIDDEN_DIM, |
|
|
) |
|
|
|
|
|
|
|
|
class UserTower(nn.Module): |
|
|
""" |
|
|
Isengard: Encodes user preferences from their reviewed wines. |
|
|
|
|
|
Architecture: |
|
|
1. Rating-weighted attention over wine embeddings |
|
|
2. MLP: 768 → 256 → 128 |
|
|
3. L2 normalization to unit sphere |
|
|
|
|
|
Input: |
|
|
wine_embeddings: (batch, num_wines, 768) - embeddings of reviewed wines |
|
|
ratings: (batch, num_wines) - user ratings for each wine |
|
|
mask: (batch, num_wines) - optional mask for padding |
|
|
|
|
|
Output: |
|
|
user_vector: (batch, 128) - normalized user embedding |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: int = EMBEDDING_DIM, |
|
|
hidden_dim: int = HIDDEN_DIM, |
|
|
output_dim: int = USER_VECTOR_DIM, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.embedding_dim = embedding_dim |
|
|
self.output_dim = output_dim |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(embedding_dim, hidden_dim) |
|
|
self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
wine_embeddings: torch.Tensor, |
|
|
ratings: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the user tower. |
|
|
|
|
|
Args: |
|
|
wine_embeddings: (batch, num_wines, embedding_dim) |
|
|
ratings: (batch, num_wines) - raw ratings (1-5 scale) |
|
|
mask: (batch, num_wines) - 1 for valid wines, 0 for padding |
|
|
|
|
|
Returns: |
|
|
user_vector: (batch, output_dim) - L2 normalized |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
attention_weights = (ratings - 2.5) / 2.5 |
|
|
attention_weights = F.softmax(attention_weights, dim=-1) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
attention_weights = attention_weights * mask |
|
|
|
|
|
attention_weights = attention_weights / ( |
|
|
attention_weights.sum(dim=-1, keepdim=True) + 1e-8 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
aggregated = torch.bmm( |
|
|
attention_weights.unsqueeze(1), |
|
|
wine_embeddings, |
|
|
).squeeze(1) |
|
|
|
|
|
|
|
|
x = F.relu(self.fc1(aggregated)) |
|
|
x = self.dropout(x) |
|
|
user_vector = self.fc2(x) |
|
|
|
|
|
|
|
|
user_vector = F.normalize(user_vector, p=2, dim=-1) |
|
|
|
|
|
return user_vector |
|
|
|