File size: 3,557 Bytes
6424951
 
 
 
 
 
 
 
 
 
 
ec9c511
 
6424951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Machine learning model loading and prediction with caching."""

import logging
from pathlib import Path

import numpy as np
import streamlit as st
from tensorflow.keras.models import Model, load_model

logger = logging.getLogger("streamlit_nba")

# Default model path relative to the project root
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent.parent / "winner.keras"


class ModelLoadError(Exception):
    """Raised when model loading fails."""

    pass


@st.cache_resource
def get_winner_model(model_path: str | Path = DEFAULT_MODEL_PATH) -> Model:
    """Load and cache the winner prediction model.

    Uses Streamlit's cache_resource to ensure model is only loaded once
    per session, improving performance significantly.

    Args:
        model_path: Path to the Keras model file

    Returns:
        Loaded Keras model

    Raises:
        ModelLoadError: If model cannot be loaded
    """
    path = Path(model_path)
    if not path.exists():
        logger.error(f"Model file not found: {path}")
        raise ModelLoadError(f"Model file not found: {path}")

    try:
        logger.info(f"Loading model from {path}")
        model = load_model(str(path))
        logger.info("Model loaded successfully")
        return model
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise ModelLoadError(f"Failed to load model: {e}") from e


def predict_winner(combined_stats: np.ndarray) -> tuple[float, int]:
    """Predict game winner from combined team stats.

    Args:
        combined_stats: Numpy array of shape (1, 100) containing
            home team stats followed by away team stats

    Returns:
        Tuple of (probability, prediction) where:
            - probability: Float between 0-1 (sigmoid output)
            - prediction: 0 (away wins) or 1 (home wins)

    Raises:
        ModelLoadError: If model cannot be loaded
        ValueError: If input shape is invalid
    """
    if combined_stats.shape != (1, 100):
        raise ValueError(
            f"Expected input shape (1, 100), got {combined_stats.shape}"
        )

    model = get_winner_model()
    sigmoid_output = model.predict(combined_stats, verbose=0)
    probability = float(sigmoid_output[0][0])
    prediction = int(np.round(probability))

    logger.info(f"Prediction: probability={probability:.4f}, winner={prediction}")
    return probability, prediction


def analyze_team_stats(
    home_stats: list[list[float]], away_stats: list[list[float]]
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Prepare team stats for model prediction.

    Flattens per-player stats into team-level arrays suitable for
    the prediction model.

    Args:
        home_stats: List of stat lists for each home player
        away_stats: List of stat lists for each away player

    Returns:
        Tuple of (home_array, away_array, combined_array) where:
            - home_array: Shape (1, 50) - home team flattened stats
            - away_array: Shape (1, 50) - away team flattened stats
            - combined_array: Shape (1, 100) - both teams for prediction
    """
    home_flat: list[float] = []
    away_flat: list[float] = []

    for player_stats in home_stats:
        home_flat.extend(player_stats)

    for player_stats in away_stats:
        away_flat.extend(player_stats)

    home_array = np.array(home_flat).reshape(1, -1)
    away_array = np.array(away_flat).reshape(1, -1)
    combined_array = np.array(home_flat + away_flat).reshape(1, -1)

    return home_array, away_array, combined_array