swirl commited on
Commit
ea9cf67
·
verified ·
1 Parent(s): eac3898

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +280 -0
handler.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Inference Endpoint Handler
3
+
4
+ Custom handler for the Two-Tower recommendation model.
5
+ This file is required for deploying to HuggingFace Inference Endpoints.
6
+
7
+ See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler
8
+
9
+ Input format:
10
+ {
11
+ "inputs": {
12
+ "user_wines": [
13
+ {"embedding": [768 floats], "rating": 4.5},
14
+ ...
15
+ ],
16
+ "candidate_wine": {
17
+ "embedding": [768 floats],
18
+ "color": "red",
19
+ "type": "still",
20
+ "style": "Classic",
21
+ "climate_type": "continental",
22
+ "climate_band": "cool",
23
+ "vintage_band": "medium"
24
+ }
25
+ }
26
+ }
27
+
28
+ OR for batch scoring:
29
+ {
30
+ "inputs": {
31
+ "user_wines": [...],
32
+ "candidate_wines": [...] # Multiple candidates
33
+ }
34
+ }
35
+
36
+ Output format:
37
+ {
38
+ "score": 75.5 # Single wine
39
+ }
40
+ OR
41
+ {
42
+ "scores": [75.5, 82.3, ...] # Batch
43
+ }
44
+ """
45
+
46
+ import torch
47
+ from typing import Dict, List, Any
48
+
49
+ # Categorical feature vocabularies for one-hot encoding
50
+ CATEGORICAL_VOCABS = {
51
+ "color": ["red", "white", "rosé", "orange", "sparkling"],
52
+ "type": ["still", "sparkling", "fortified", "dessert"],
53
+ "style": [
54
+ "Classic",
55
+ "Natural",
56
+ "Organic",
57
+ "Biodynamic",
58
+ "Conventional",
59
+ "Pet-Nat",
60
+ "Orange",
61
+ "Skin-Contact",
62
+ "Amphora",
63
+ "Traditional",
64
+ ],
65
+ "climate_type": ["cool", "moderate", "warm", "hot"],
66
+ "climate_band": ["cool", "moderate", "warm", "hot"],
67
+ "vintage_band": ["young", "developing", "mature", "non_vintage"],
68
+ }
69
+
70
+
71
+ class EndpointHandler:
72
+ """
73
+ Custom handler for HuggingFace Inference Endpoints.
74
+
75
+ Loads the Two-Tower model and handles inference requests.
76
+ """
77
+
78
+ def __init__(self, path: str = ""):
79
+ """
80
+ Initialize the handler.
81
+
82
+ Args:
83
+ path: Path to the model directory (provided by HF Inference Endpoints)
84
+ """
85
+ from model import TwoTowerModel
86
+
87
+ # Load model from the checkpoint
88
+ if path:
89
+ self.model = TwoTowerModel.from_pretrained(path)
90
+ else:
91
+ self.model = TwoTowerModel.from_pretrained("swirl/two-tower-recommender")
92
+
93
+ self.model.eval()
94
+
95
+ # Move to GPU if available
96
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+ self.model.to(self.device)
98
+
99
+ print(f"Two-Tower model loaded on {self.device}")
100
+
101
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
102
+ """
103
+ Handle inference request.
104
+
105
+ Args:
106
+ data: Request payload with "inputs" key
107
+
108
+ Returns:
109
+ Response with "score" or "scores" key
110
+ """
111
+ inputs = data.get("inputs", data)
112
+
113
+ # Get user wines
114
+ user_wines = inputs.get("user_wines", [])
115
+
116
+ if not user_wines:
117
+ return {"error": "No user_wines provided"}
118
+
119
+ # Check for single or batch candidate
120
+ if "candidate_wine" in inputs:
121
+ # Single wine scoring
122
+ return self._score_single(user_wines, inputs["candidate_wine"])
123
+ elif "candidate_wines" in inputs:
124
+ # Batch scoring
125
+ return self._score_batch(user_wines, inputs["candidate_wines"])
126
+ else:
127
+ return {"error": "No candidate_wine or candidate_wines provided"}
128
+
129
+ def _score_single(
130
+ self, user_wines: List[Dict[str, Any]], candidate_wine: Dict[str, Any]
131
+ ) -> Dict[str, float]:
132
+ """Score a single candidate wine."""
133
+ with torch.no_grad():
134
+ # Prepare user data
135
+ user_embeddings, user_ratings, user_mask = self._prepare_user_data(
136
+ user_wines
137
+ )
138
+
139
+ # Prepare candidate data
140
+ wine_embedding, wine_categorical = self._prepare_wine_data(candidate_wine)
141
+
142
+ # Forward pass
143
+ score = self.model(
144
+ user_embeddings,
145
+ user_ratings,
146
+ wine_embedding,
147
+ wine_categorical,
148
+ user_mask,
149
+ )
150
+
151
+ return {"score": float(score.item())}
152
+
153
+ def _score_batch(
154
+ self, user_wines: List[Dict[str, Any]], candidate_wines: List[Dict[str, Any]]
155
+ ) -> Dict[str, List[float]]:
156
+ """Score multiple candidate wines."""
157
+ with torch.no_grad():
158
+ # Prepare user data (same for all candidates)
159
+ user_embeddings, user_ratings, user_mask = self._prepare_user_data(
160
+ user_wines
161
+ )
162
+
163
+ # Get user embedding once
164
+ user_vector = self.model.get_user_embedding(
165
+ user_embeddings, user_ratings, user_mask
166
+ )
167
+
168
+ # Score each candidate
169
+ scores = []
170
+ for wine in candidate_wines:
171
+ wine_embedding, wine_categorical = self._prepare_wine_data(wine)
172
+ wine_vector = self.model.get_wine_embedding(
173
+ wine_embedding, wine_categorical
174
+ )
175
+ score = self.model.score_from_embeddings(user_vector, wine_vector)
176
+ scores.append(float(score.item()))
177
+
178
+ return {"scores": scores}
179
+
180
+ def _prepare_user_data(self, user_wines: List[Dict[str, Any]]) -> tuple:
181
+ """
182
+ Prepare user wine data for model input.
183
+
184
+ Returns:
185
+ user_embeddings: (1, num_wines, 768)
186
+ user_ratings: (1, num_wines)
187
+ user_mask: (1, num_wines)
188
+ """
189
+ embeddings = []
190
+ ratings = []
191
+
192
+ for wine in user_wines:
193
+ embedding = wine.get("embedding", [0.0] * 768)
194
+ rating = wine.get("rating", 3.0)
195
+
196
+ embeddings.append(embedding)
197
+ ratings.append(rating)
198
+
199
+ # Convert to tensors with batch dimension
200
+ user_embeddings = torch.tensor(
201
+ [embeddings], dtype=torch.float32, device=self.device
202
+ )
203
+ user_ratings = torch.tensor([ratings], dtype=torch.float32, device=self.device)
204
+
205
+ # Create mask (all 1s since no padding)
206
+ user_mask = torch.ones(
207
+ 1, len(user_wines), dtype=torch.float32, device=self.device
208
+ )
209
+
210
+ return user_embeddings, user_ratings, user_mask
211
+
212
+ def _prepare_wine_data(self, wine: Dict[str, Any]) -> tuple:
213
+ """
214
+ Prepare wine data for model input.
215
+
216
+ Returns:
217
+ wine_embedding: (1, 768)
218
+ wine_categorical: (1, categorical_dim)
219
+ """
220
+ # Get embedding
221
+ embedding = wine.get("embedding", [0.0] * 768)
222
+ wine_embedding = torch.tensor(
223
+ [embedding], dtype=torch.float32, device=self.device
224
+ )
225
+
226
+ # Build one-hot categorical encoding
227
+ categorical = self._encode_categorical(wine)
228
+ wine_categorical = torch.tensor(
229
+ [categorical], dtype=torch.float32, device=self.device
230
+ )
231
+
232
+ return wine_embedding, wine_categorical
233
+
234
+ def _encode_categorical(self, wine: Dict[str, Any]) -> List[float]:
235
+ """
236
+ One-hot encode categorical features.
237
+
238
+ Args:
239
+ wine: Wine dict with categorical features
240
+
241
+ Returns:
242
+ List of floats (one-hot encoded)
243
+ """
244
+ encoding = []
245
+
246
+ for feature, vocab in CATEGORICAL_VOCABS.items():
247
+ value = wine.get(feature)
248
+ one_hot = [0.0] * len(vocab)
249
+
250
+ if value and value in vocab:
251
+ idx = vocab.index(value)
252
+ one_hot[idx] = 1.0
253
+
254
+ encoding.extend(one_hot)
255
+
256
+ return encoding
257
+
258
+
259
+ # For local testing
260
+ if __name__ == "__main__":
261
+ # Test the handler
262
+ handler = EndpointHandler()
263
+
264
+ # Mock request
265
+ test_data = {
266
+ "inputs": {
267
+ "user_wines": [
268
+ {"embedding": [0.1] * 768, "rating": 4.5},
269
+ {"embedding": [0.2] * 768, "rating": 3.0},
270
+ ],
271
+ "candidate_wine": {
272
+ "embedding": [0.15] * 768,
273
+ "color": "red",
274
+ "type": "still",
275
+ },
276
+ }
277
+ }
278
+
279
+ result = handler(test_data)
280
+ print(f"Score: {result}")