""" Mordor - Wine Tower Neural network that encodes wine characteristics from embedding + categorical features. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict from .config import ( EMBEDDING_DIM, WINE_VECTOR_DIM, HIDDEN_DIM, CATEGORICAL_ENCODING_DIM, ) class WineTower(nn.Module): """ Mordor: Encodes wine characteristics from embedding and metadata. Architecture: 1. Concatenate wine embedding + categorical one-hot encoding 2. MLP: (768 + 31) → 256 → 128 3. L2 normalization to unit sphere Input: wine_embedding: (batch, 768) - google-text-embedding-004 vector categorical_features: (batch, 31) - one-hot encoded categoricals Output: wine_vector: (batch, 128) - normalized wine embedding """ def __init__( self, embedding_dim: int = EMBEDDING_DIM, categorical_dim: int = CATEGORICAL_ENCODING_DIM, hidden_dim: int = HIDDEN_DIM, output_dim: int = WINE_VECTOR_DIM, ): super().__init__() self.embedding_dim = embedding_dim self.categorical_dim = categorical_dim self.output_dim = output_dim # Input dimension: embedding + categorical input_dim = embedding_dim + categorical_dim # MLP layers self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) # Dropout for regularization self.dropout = nn.Dropout(0.1) def forward( self, wine_embedding: torch.Tensor, categorical_features: torch.Tensor, ) -> torch.Tensor: """ Forward pass through the wine tower. Args: wine_embedding: (batch, embedding_dim) categorical_features: (batch, categorical_dim) - one-hot encoded Returns: wine_vector: (batch, output_dim) - L2 normalized """ # Concatenate embedding and categorical features x = torch.cat([wine_embedding, categorical_features], dim=-1) # MLP projection x = F.relu(self.fc1(x)) x = self.dropout(x) wine_vector = self.fc2(x) # L2 normalize to unit sphere wine_vector = F.normalize(wine_vector, p=2, dim=-1) return wine_vector def encode_categorical_features(wine_data: Dict) -> torch.Tensor: """ Convert wine metadata dict to one-hot encoded tensor. Args: wine_data: Dict with keys: color, type, style, climate_type, climate_band, vintage_band Returns: Tensor of shape (categorical_dim,) with one-hot encoding """ from .config import CATEGORICAL_VOCAB_SIZES, CATEGORICAL_FEATURES # Vocabulary mappings (could be loaded from config) vocab_maps = { "color": { "red": 0, "white": 1, "rosé": 2, "rose": 2, "orange": 3, "sparkling": 4, }, "type": {"still": 0, "sparkling": 1, "fortified": 2, "dessert": 3}, "style": { "natural": 0, "organic": 1, "biodynamic": 2, "conventional": 3, "sustainable": 4, "vegan": 5, "other": 6, }, "climate_type": {"cool": 0, "moderate": 1, "warm": 2, "hot": 3}, "climate_band": {"cool": 0, "moderate": 1, "warm": 2, "hot": 3}, "vintage_band": {"young": 0, "developing": 1, "mature": 2, "non_vintage": 3}, } encoded = [] for feature in CATEGORICAL_FEATURES: vocab_size = CATEGORICAL_VOCAB_SIZES[feature] one_hot = torch.zeros(vocab_size) value = wine_data.get(feature) if value and feature in vocab_maps: value_lower = str(value).lower() if value_lower in vocab_maps[feature]: idx = vocab_maps[feature][value_lower] one_hot[idx] = 1.0 encoded.append(one_hot) return torch.cat(encoded, dim=0)