File size: 3,107 Bytes
35bad1b 851013c 35bad1b 3f27d87 35bad1b 851013c 35bad1b a554e4b 35bad1b 851013c 35bad1b 851013c 3084e65 851013c 3f27d87 851013c 3084e65 e601cc1 3084e65 e601cc1 3084e65 00d2af5 3084e65 00d2af5 3084e65 00d2af5 3084e65 00d2af5 3f27d87 3084e65 00d2af5 3084e65 00d2af5 3f27d87 00d2af5 3f27d87 3084e65 00d2af5 3084e65 00d2af5 3f27d87 00d2af5 3f27d87 3084e65 00d2af5 3f27d87 3084e65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import tensorflow as tf
from typing import Dict, Any
from predictor import FightPredictor
from config import MODEL_PATH
from image_service import get_fighter_image_url
import os
class CustomLSTM(tf.keras.layers.LSTM):
"""Custom LSTM layer that removes time_major from kwargs"""
@classmethod
def from_config(cls, config):
# Remove time_major if present in config
config.pop('time_major', None)
return super().from_config(config)
class EndpointHandler:
def __init__(self, model_path: str):
# Register the custom layer
tf.keras.utils.get_custom_objects()['LSTM'] = CustomLSTM
full_path = os.path.join(model_path, MODEL_PATH)
# Load model using path from config
self.model = tf.keras.models.load_model(str(full_path))
# Initialize predictor
self.predictor = FightPredictor(self.model)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Custom inference handler for UFC predictions.
Args:
data: A dictionary containing the input data.
Expected keys: "fighter1" and "fighter2", or "random_matchup".
Returns:
A dictionary with prediction results and fighter images.
"""
inputs = data.get("inputs", {})
ping = inputs.get("ping")
random_matchup = inputs.get("random_matchup")
if ping:
return {"pong": True}
if not random_matchup:
fighter1_name = inputs.get("fighter1")
fighter2_name = inputs.get("fighter2")
if not fighter1_name or not fighter2_name:
return {"error": "Both 'fighter1' and 'fighter2' must be provided when not using random_matchup"}
else:
fighter1_name = None
fighter2_name = None
# Get prediction using FightPredictor
result = self.predictor.get_prediction(fighter1_name, fighter2_name, verbose=True)
if result is None:
return {"error": "Prediction failed"}
f1_data, f2_data, details = result
# Get fighter images
f1_image = get_fighter_image_url(f1_data['name'])
f2_image = get_fighter_image_url(f2_data['name'])
return {
"fighter1": {
"name": f1_data['name'],
"image_url": f1_image,
"probability": f1_data['prob'],
"details": {
"form_score": float(f1_data['form_score']),
"total_fights": f1_data['total_fights']
}
},
"fighter2": {
"name": f2_data['name'],
"image_url": f2_image,
"probability": f2_data['prob'],
"details": {
"form_score": float(f2_data['form_score']),
"total_fights": f2_data['total_fights']
}
},
"details": {
"age_difference": float(details['age_difference'])
}
}
|