File size: 3,107 Bytes
35bad1b
851013c
35bad1b
 
3f27d87
35bad1b
 
 
 
 
 
 
 
 
 
851013c
 
35bad1b
a554e4b
35bad1b
 
 
 
 
 
 
 
 
851013c
 
 
35bad1b
 
851013c
 
3084e65
851013c
 
3f27d87
851013c
3084e65
e601cc1
3084e65
e601cc1
 
 
3084e65
 
 
 
 
 
 
 
 
 
00d2af5
 
3084e65
00d2af5
3084e65
00d2af5
 
 
3084e65
00d2af5
 
3f27d87
3084e65
 
00d2af5
3084e65
00d2af5
3f27d87
00d2af5
 
3f27d87
3084e65
 
00d2af5
3084e65
00d2af5
3f27d87
00d2af5
 
3f27d87
3084e65
 
00d2af5
3f27d87
3084e65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import tensorflow as tf
from typing import Dict, Any
from predictor import FightPredictor
from config import MODEL_PATH
from image_service import get_fighter_image_url
import os

class CustomLSTM(tf.keras.layers.LSTM):
    """Custom LSTM layer that removes time_major from kwargs"""
    @classmethod
    def from_config(cls, config):
        # Remove time_major if present in config
        config.pop('time_major', None)
        return super().from_config(config)


class EndpointHandler:
    
    def __init__(self, model_path: str):
        # Register the custom layer
        tf.keras.utils.get_custom_objects()['LSTM'] = CustomLSTM
        
        full_path = os.path.join(model_path, MODEL_PATH)
        # Load model using path from config
        self.model = tf.keras.models.load_model(str(full_path))
        
        # Initialize predictor
        self.predictor = FightPredictor(self.model)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Custom inference handler for UFC predictions.

        Args:
            data: A dictionary containing the input data. 
                  Expected keys: "fighter1" and "fighter2", or "random_matchup".

        Returns:
            A dictionary with prediction results and fighter images.
        """
        inputs = data.get("inputs", {})
        ping = inputs.get("ping")
        random_matchup = inputs.get("random_matchup")

        if ping:
            return {"pong": True}
        
        if not random_matchup:
            fighter1_name = inputs.get("fighter1")
            fighter2_name = inputs.get("fighter2")
            if not fighter1_name or not fighter2_name:
                return {"error": "Both 'fighter1' and 'fighter2' must be provided when not using random_matchup"}
        else:
            fighter1_name = None
            fighter2_name = None

        # Get prediction using FightPredictor
        result = self.predictor.get_prediction(fighter1_name, fighter2_name, verbose=True)
        
        if result is None:
            return {"error": "Prediction failed"}
            
        f1_data, f2_data, details = result
        
        # Get fighter images
        f1_image = get_fighter_image_url(f1_data['name'])
        f2_image = get_fighter_image_url(f2_data['name'])

        return {
            "fighter1": {
                "name": f1_data['name'],
                "image_url": f1_image,
                "probability": f1_data['prob'],
                "details": {
                    "form_score": float(f1_data['form_score']),
                    "total_fights": f1_data['total_fights']
                }
            },
            "fighter2": {
                "name": f2_data['name'],
                "image_url": f2_image,
                "probability": f2_data['prob'],
                "details": {
                    "form_score": float(f2_data['form_score']),
                    "total_fights": f2_data['total_fights']
                }
            },
            "details": {
                "age_difference": float(details['age_difference'])
            }
        }