|
|
import tensorflow as tf |
|
|
from typing import Dict, Any |
|
|
from predictor import FightPredictor |
|
|
from config import MODEL_PATH |
|
|
from image_service import get_fighter_image_url |
|
|
import os |
|
|
|
|
|
class CustomLSTM(tf.keras.layers.LSTM): |
|
|
"""Custom LSTM layer that removes time_major from kwargs""" |
|
|
@classmethod |
|
|
def from_config(cls, config): |
|
|
|
|
|
config.pop('time_major', None) |
|
|
return super().from_config(config) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
|
|
|
def __init__(self, model_path: str): |
|
|
|
|
|
tf.keras.utils.get_custom_objects()['LSTM'] = CustomLSTM |
|
|
|
|
|
full_path = os.path.join(model_path, MODEL_PATH) |
|
|
|
|
|
self.model = tf.keras.models.load_model(str(full_path)) |
|
|
|
|
|
|
|
|
self.predictor = FightPredictor(self.model) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Custom inference handler for UFC predictions. |
|
|
|
|
|
Args: |
|
|
data: A dictionary containing the input data. |
|
|
Expected keys: "fighter1" and "fighter2", or "random_matchup". |
|
|
|
|
|
Returns: |
|
|
A dictionary with prediction results and fighter images. |
|
|
""" |
|
|
inputs = data.get("inputs", {}) |
|
|
ping = inputs.get("ping") |
|
|
random_matchup = inputs.get("random_matchup") |
|
|
|
|
|
if ping: |
|
|
return {"pong": True} |
|
|
|
|
|
if not random_matchup: |
|
|
fighter1_name = inputs.get("fighter1") |
|
|
fighter2_name = inputs.get("fighter2") |
|
|
if not fighter1_name or not fighter2_name: |
|
|
return {"error": "Both 'fighter1' and 'fighter2' must be provided when not using random_matchup"} |
|
|
else: |
|
|
fighter1_name = None |
|
|
fighter2_name = None |
|
|
|
|
|
|
|
|
result = self.predictor.get_prediction(fighter1_name, fighter2_name, verbose=True) |
|
|
|
|
|
if result is None: |
|
|
return {"error": "Prediction failed"} |
|
|
|
|
|
f1_data, f2_data, details = result |
|
|
|
|
|
|
|
|
f1_image = get_fighter_image_url(f1_data['name']) |
|
|
f2_image = get_fighter_image_url(f2_data['name']) |
|
|
|
|
|
return { |
|
|
"fighter1": { |
|
|
"name": f1_data['name'], |
|
|
"image_url": f1_image, |
|
|
"probability": f1_data['prob'], |
|
|
"details": { |
|
|
"form_score": float(f1_data['form_score']), |
|
|
"total_fights": f1_data['total_fights'] |
|
|
} |
|
|
}, |
|
|
"fighter2": { |
|
|
"name": f2_data['name'], |
|
|
"image_url": f2_image, |
|
|
"probability": f2_data['prob'], |
|
|
"details": { |
|
|
"form_score": float(f2_data['form_score']), |
|
|
"total_fights": f2_data['total_fights'] |
|
|
} |
|
|
}, |
|
|
"details": { |
|
|
"age_difference": float(details['age_difference']) |
|
|
} |
|
|
} |
|
|
|