swirl commited on
Commit
b9b6890
·
verified ·
1 Parent(s): d820920

Upload wine_tower.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. wine_tower.py +142 -0
wine_tower.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mordor - Wine Tower
3
+
4
+ Neural network that encodes wine characteristics from embedding + categorical features.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Dict
11
+
12
+ from .config import (
13
+ EMBEDDING_DIM,
14
+ WINE_VECTOR_DIM,
15
+ HIDDEN_DIM,
16
+ CATEGORICAL_ENCODING_DIM,
17
+ )
18
+
19
+
20
+ class WineTower(nn.Module):
21
+ """
22
+ Mordor: Encodes wine characteristics from embedding and metadata.
23
+
24
+ Architecture:
25
+ 1. Concatenate wine embedding + categorical one-hot encoding
26
+ 2. MLP: (768 + 31) → 256 → 128
27
+ 3. L2 normalization to unit sphere
28
+
29
+ Input:
30
+ wine_embedding: (batch, 768) - google-text-embedding-004 vector
31
+ categorical_features: (batch, 31) - one-hot encoded categoricals
32
+
33
+ Output:
34
+ wine_vector: (batch, 128) - normalized wine embedding
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int = EMBEDDING_DIM,
40
+ categorical_dim: int = CATEGORICAL_ENCODING_DIM,
41
+ hidden_dim: int = HIDDEN_DIM,
42
+ output_dim: int = WINE_VECTOR_DIM,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.embedding_dim = embedding_dim
47
+ self.categorical_dim = categorical_dim
48
+ self.output_dim = output_dim
49
+
50
+ # Input dimension: embedding + categorical
51
+ input_dim = embedding_dim + categorical_dim
52
+
53
+ # MLP layers
54
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
55
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
56
+
57
+ # Dropout for regularization
58
+ self.dropout = nn.Dropout(0.1)
59
+
60
+ def forward(
61
+ self,
62
+ wine_embedding: torch.Tensor,
63
+ categorical_features: torch.Tensor,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Forward pass through the wine tower.
67
+
68
+ Args:
69
+ wine_embedding: (batch, embedding_dim)
70
+ categorical_features: (batch, categorical_dim) - one-hot encoded
71
+
72
+ Returns:
73
+ wine_vector: (batch, output_dim) - L2 normalized
74
+ """
75
+ # Concatenate embedding and categorical features
76
+ x = torch.cat([wine_embedding, categorical_features], dim=-1)
77
+
78
+ # MLP projection
79
+ x = F.relu(self.fc1(x))
80
+ x = self.dropout(x)
81
+ wine_vector = self.fc2(x)
82
+
83
+ # L2 normalize to unit sphere
84
+ wine_vector = F.normalize(wine_vector, p=2, dim=-1)
85
+
86
+ return wine_vector
87
+
88
+
89
+ def encode_categorical_features(wine_data: Dict) -> torch.Tensor:
90
+ """
91
+ Convert wine metadata dict to one-hot encoded tensor.
92
+
93
+ Args:
94
+ wine_data: Dict with keys: color, type, style, climate_type,
95
+ climate_band, vintage_band
96
+
97
+ Returns:
98
+ Tensor of shape (categorical_dim,) with one-hot encoding
99
+ """
100
+ from .config import CATEGORICAL_VOCAB_SIZES, CATEGORICAL_FEATURES
101
+
102
+ # Vocabulary mappings (could be loaded from config)
103
+ vocab_maps = {
104
+ "color": {
105
+ "red": 0,
106
+ "white": 1,
107
+ "rosé": 2,
108
+ "rose": 2,
109
+ "orange": 3,
110
+ "sparkling": 4,
111
+ },
112
+ "type": {"still": 0, "sparkling": 1, "fortified": 2, "dessert": 3},
113
+ "style": {
114
+ "natural": 0,
115
+ "organic": 1,
116
+ "biodynamic": 2,
117
+ "conventional": 3,
118
+ "sustainable": 4,
119
+ "vegan": 5,
120
+ "other": 6,
121
+ },
122
+ "climate_type": {"cool": 0, "moderate": 1, "warm": 2, "hot": 3},
123
+ "climate_band": {"cool": 0, "moderate": 1, "warm": 2, "hot": 3},
124
+ "vintage_band": {"young": 0, "developing": 1, "mature": 2, "non_vintage": 3},
125
+ }
126
+
127
+ encoded = []
128
+
129
+ for feature in CATEGORICAL_FEATURES:
130
+ vocab_size = CATEGORICAL_VOCAB_SIZES[feature]
131
+ one_hot = torch.zeros(vocab_size)
132
+
133
+ value = wine_data.get(feature)
134
+ if value and feature in vocab_maps:
135
+ value_lower = str(value).lower()
136
+ if value_lower in vocab_maps[feature]:
137
+ idx = vocab_maps[feature][value_lower]
138
+ one_hot[idx] = 1.0
139
+
140
+ encoded.append(one_hot)
141
+
142
+ return torch.cat(encoded, dim=0)