import tensorflow as tf from typing import Dict, Any from predictor import FightPredictor from config import MODEL_PATH from image_service import get_fighter_image_url import os class CustomLSTM(tf.keras.layers.LSTM): """Custom LSTM layer that removes time_major from kwargs""" @classmethod def from_config(cls, config): # Remove time_major if present in config config.pop('time_major', None) return super().from_config(config) class EndpointHandler: def __init__(self, model_path: str): # Register the custom layer tf.keras.utils.get_custom_objects()['LSTM'] = CustomLSTM full_path = os.path.join(model_path, MODEL_PATH) # Load model using path from config self.model = tf.keras.models.load_model(str(full_path)) # Initialize predictor self.predictor = FightPredictor(self.model) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Custom inference handler for UFC predictions. Args: data: A dictionary containing the input data. Expected keys: "fighter1" and "fighter2", or "random_matchup". Returns: A dictionary with prediction results and fighter images. """ inputs = data.get("inputs", {}) ping = inputs.get("ping") random_matchup = inputs.get("random_matchup") if ping: return {"pong": True} if not random_matchup: fighter1_name = inputs.get("fighter1") fighter2_name = inputs.get("fighter2") if not fighter1_name or not fighter2_name: return {"error": "Both 'fighter1' and 'fighter2' must be provided when not using random_matchup"} else: fighter1_name = None fighter2_name = None # Get prediction using FightPredictor result = self.predictor.get_prediction(fighter1_name, fighter2_name, verbose=True) if result is None: return {"error": "Prediction failed"} f1_data, f2_data, details = result # Get fighter images f1_image = get_fighter_image_url(f1_data['name']) f2_image = get_fighter_image_url(f2_data['name']) return { "fighter1": { "name": f1_data['name'], "image_url": f1_image, "probability": f1_data['prob'], "details": { "form_score": float(f1_data['form_score']), "total_fights": f1_data['total_fights'] } }, "fighter2": { "name": f2_data['name'], "image_url": f2_image, "probability": f2_data['prob'], "details": { "form_score": float(f2_data['form_score']), "total_fights": f2_data['total_fights'] } }, "details": { "age_difference": float(details['age_difference']) } }