ufc-predictor / handler.py
zjpiazza's picture
Added healthcheck to handler
e601cc1
import tensorflow as tf
from typing import Dict, Any
from predictor import FightPredictor
from config import MODEL_PATH
from image_service import get_fighter_image_url
import os
class CustomLSTM(tf.keras.layers.LSTM):
"""Custom LSTM layer that removes time_major from kwargs"""
@classmethod
def from_config(cls, config):
# Remove time_major if present in config
config.pop('time_major', None)
return super().from_config(config)
class EndpointHandler:
def __init__(self, model_path: str):
# Register the custom layer
tf.keras.utils.get_custom_objects()['LSTM'] = CustomLSTM
full_path = os.path.join(model_path, MODEL_PATH)
# Load model using path from config
self.model = tf.keras.models.load_model(str(full_path))
# Initialize predictor
self.predictor = FightPredictor(self.model)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Custom inference handler for UFC predictions.
Args:
data: A dictionary containing the input data.
Expected keys: "fighter1" and "fighter2", or "random_matchup".
Returns:
A dictionary with prediction results and fighter images.
"""
inputs = data.get("inputs", {})
ping = inputs.get("ping")
random_matchup = inputs.get("random_matchup")
if ping:
return {"pong": True}
if not random_matchup:
fighter1_name = inputs.get("fighter1")
fighter2_name = inputs.get("fighter2")
if not fighter1_name or not fighter2_name:
return {"error": "Both 'fighter1' and 'fighter2' must be provided when not using random_matchup"}
else:
fighter1_name = None
fighter2_name = None
# Get prediction using FightPredictor
result = self.predictor.get_prediction(fighter1_name, fighter2_name, verbose=True)
if result is None:
return {"error": "Prediction failed"}
f1_data, f2_data, details = result
# Get fighter images
f1_image = get_fighter_image_url(f1_data['name'])
f2_image = get_fighter_image_url(f2_data['name'])
return {
"fighter1": {
"name": f1_data['name'],
"image_url": f1_image,
"probability": f1_data['prob'],
"details": {
"form_score": float(f1_data['form_score']),
"total_fights": f1_data['total_fights']
}
},
"fighter2": {
"name": f2_data['name'],
"image_url": f2_image,
"probability": f2_data['prob'],
"details": {
"form_score": float(f2_data['form_score']),
"total_fights": f2_data['total_fights']
}
},
"details": {
"age_difference": float(details['age_difference'])
}
}