two-tower-recommender / handler.py
swirl's picture
Upload handler.py with huggingface_hub
ea9cf67 verified
"""
HuggingFace Inference Endpoint Handler
Custom handler for the Two-Tower recommendation model.
This file is required for deploying to HuggingFace Inference Endpoints.
See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler
Input format:
{
"inputs": {
"user_wines": [
{"embedding": [768 floats], "rating": 4.5},
...
],
"candidate_wine": {
"embedding": [768 floats],
"color": "red",
"type": "still",
"style": "Classic",
"climate_type": "continental",
"climate_band": "cool",
"vintage_band": "medium"
}
}
}
OR for batch scoring:
{
"inputs": {
"user_wines": [...],
"candidate_wines": [...] # Multiple candidates
}
}
Output format:
{
"score": 75.5 # Single wine
}
OR
{
"scores": [75.5, 82.3, ...] # Batch
}
"""
import torch
from typing import Dict, List, Any
# Categorical feature vocabularies for one-hot encoding
CATEGORICAL_VOCABS = {
"color": ["red", "white", "rosé", "orange", "sparkling"],
"type": ["still", "sparkling", "fortified", "dessert"],
"style": [
"Classic",
"Natural",
"Organic",
"Biodynamic",
"Conventional",
"Pet-Nat",
"Orange",
"Skin-Contact",
"Amphora",
"Traditional",
],
"climate_type": ["cool", "moderate", "warm", "hot"],
"climate_band": ["cool", "moderate", "warm", "hot"],
"vintage_band": ["young", "developing", "mature", "non_vintage"],
}
class EndpointHandler:
"""
Custom handler for HuggingFace Inference Endpoints.
Loads the Two-Tower model and handles inference requests.
"""
def __init__(self, path: str = ""):
"""
Initialize the handler.
Args:
path: Path to the model directory (provided by HF Inference Endpoints)
"""
from model import TwoTowerModel
# Load model from the checkpoint
if path:
self.model = TwoTowerModel.from_pretrained(path)
else:
self.model = TwoTowerModel.from_pretrained("swirl/two-tower-recommender")
self.model.eval()
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
print(f"Two-Tower model loaded on {self.device}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle inference request.
Args:
data: Request payload with "inputs" key
Returns:
Response with "score" or "scores" key
"""
inputs = data.get("inputs", data)
# Get user wines
user_wines = inputs.get("user_wines", [])
if not user_wines:
return {"error": "No user_wines provided"}
# Check for single or batch candidate
if "candidate_wine" in inputs:
# Single wine scoring
return self._score_single(user_wines, inputs["candidate_wine"])
elif "candidate_wines" in inputs:
# Batch scoring
return self._score_batch(user_wines, inputs["candidate_wines"])
else:
return {"error": "No candidate_wine or candidate_wines provided"}
def _score_single(
self, user_wines: List[Dict[str, Any]], candidate_wine: Dict[str, Any]
) -> Dict[str, float]:
"""Score a single candidate wine."""
with torch.no_grad():
# Prepare user data
user_embeddings, user_ratings, user_mask = self._prepare_user_data(
user_wines
)
# Prepare candidate data
wine_embedding, wine_categorical = self._prepare_wine_data(candidate_wine)
# Forward pass
score = self.model(
user_embeddings,
user_ratings,
wine_embedding,
wine_categorical,
user_mask,
)
return {"score": float(score.item())}
def _score_batch(
self, user_wines: List[Dict[str, Any]], candidate_wines: List[Dict[str, Any]]
) -> Dict[str, List[float]]:
"""Score multiple candidate wines."""
with torch.no_grad():
# Prepare user data (same for all candidates)
user_embeddings, user_ratings, user_mask = self._prepare_user_data(
user_wines
)
# Get user embedding once
user_vector = self.model.get_user_embedding(
user_embeddings, user_ratings, user_mask
)
# Score each candidate
scores = []
for wine in candidate_wines:
wine_embedding, wine_categorical = self._prepare_wine_data(wine)
wine_vector = self.model.get_wine_embedding(
wine_embedding, wine_categorical
)
score = self.model.score_from_embeddings(user_vector, wine_vector)
scores.append(float(score.item()))
return {"scores": scores}
def _prepare_user_data(self, user_wines: List[Dict[str, Any]]) -> tuple:
"""
Prepare user wine data for model input.
Returns:
user_embeddings: (1, num_wines, 768)
user_ratings: (1, num_wines)
user_mask: (1, num_wines)
"""
embeddings = []
ratings = []
for wine in user_wines:
embedding = wine.get("embedding", [0.0] * 768)
rating = wine.get("rating", 3.0)
embeddings.append(embedding)
ratings.append(rating)
# Convert to tensors with batch dimension
user_embeddings = torch.tensor(
[embeddings], dtype=torch.float32, device=self.device
)
user_ratings = torch.tensor([ratings], dtype=torch.float32, device=self.device)
# Create mask (all 1s since no padding)
user_mask = torch.ones(
1, len(user_wines), dtype=torch.float32, device=self.device
)
return user_embeddings, user_ratings, user_mask
def _prepare_wine_data(self, wine: Dict[str, Any]) -> tuple:
"""
Prepare wine data for model input.
Returns:
wine_embedding: (1, 768)
wine_categorical: (1, categorical_dim)
"""
# Get embedding
embedding = wine.get("embedding", [0.0] * 768)
wine_embedding = torch.tensor(
[embedding], dtype=torch.float32, device=self.device
)
# Build one-hot categorical encoding
categorical = self._encode_categorical(wine)
wine_categorical = torch.tensor(
[categorical], dtype=torch.float32, device=self.device
)
return wine_embedding, wine_categorical
def _encode_categorical(self, wine: Dict[str, Any]) -> List[float]:
"""
One-hot encode categorical features.
Args:
wine: Wine dict with categorical features
Returns:
List of floats (one-hot encoded)
"""
encoding = []
for feature, vocab in CATEGORICAL_VOCABS.items():
value = wine.get(feature)
one_hot = [0.0] * len(vocab)
if value and value in vocab:
idx = vocab.index(value)
one_hot[idx] = 1.0
encoding.extend(one_hot)
return encoding
# For local testing
if __name__ == "__main__":
# Test the handler
handler = EndpointHandler()
# Mock request
test_data = {
"inputs": {
"user_wines": [
{"embedding": [0.1] * 768, "rating": 4.5},
{"embedding": [0.2] * 768, "rating": 3.0},
],
"candidate_wine": {
"embedding": [0.15] * 768,
"color": "red",
"type": "still",
},
}
}
result = handler(test_data)
print(f"Score: {result}")