Spaces:
Build error
Build error
Commit ·
eb615ca
1
Parent(s): f972c61
Add k-fold cross-validation to prediction pipeline
Browse filesIntroduces a --kfold argument to main.py to enable 3-fold cross-validation. Implements run_kfold_cv in pipeline.py, using event-based splits and MLflow for experiment tracking and model registration. Refactors imports and typing for consistency, and moves configuration constants to config.py for better modularity.
- src/predict/main.py +13 -2
- src/predict/models.py +3 -2
- src/predict/pipeline.py +69 -0
- src/predict/preprocess.py +8 -9
- src/predict/utils.py +2 -8
src/predict/main.py
CHANGED
|
@@ -55,6 +55,11 @@ def main():
|
|
| 55 |
default=False,
|
| 56 |
help="Force retrain all models even if no new data is available."
|
| 57 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
args = parser.parse_args()
|
| 59 |
|
| 60 |
# Handle conflicting arguments
|
|
@@ -75,9 +80,15 @@ def main():
|
|
| 75 |
use_existing_models=use_existing_models,
|
| 76 |
force_retrain=force_retrain
|
| 77 |
)
|
| 78 |
-
|
| 79 |
try:
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
except FileNotFoundError as e:
|
| 82 |
print(f"Error: {e}")
|
| 83 |
print("Please ensure the required data files exist. You may need to run the scraping and ELO analysis first.")
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
default=False,
|
| 56 |
help="Force retrain all models even if no new data is available."
|
| 57 |
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
'--kfold',
|
| 60 |
+
action='store_true',
|
| 61 |
+
help='Run 3-fold CV instead of standard split.'
|
| 62 |
+
)
|
| 63 |
args = parser.parse_args()
|
| 64 |
|
| 65 |
# Handle conflicting arguments
|
|
|
|
| 80 |
use_existing_models=use_existing_models,
|
| 81 |
force_retrain=force_retrain
|
| 82 |
)
|
|
|
|
| 83 |
try:
|
| 84 |
+
if args.kfold:
|
| 85 |
+
cv_results = pipeline.run_kfold_cv(k=3, holdout_events=1)
|
| 86 |
+
print(cv_results)
|
| 87 |
+
else:
|
| 88 |
+
pipeline.run(detailed_report=(args.report == 'detailed'))
|
| 89 |
except FileNotFoundError as e:
|
| 90 |
print(f"Error: {e}")
|
| 91 |
print("Please ensure the required data files exist. You may need to run the scraping and ELO analysis first.")
|
| 92 |
+
|
| 93 |
+
if __name__ == '__main__':
|
| 94 |
+
main()
|
src/predict/models.py
CHANGED
|
@@ -12,7 +12,8 @@ from lightgbm import LGBMClassifier
|
|
| 12 |
from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
|
| 13 |
from ..config import FIGHTERS_CSV_PATH
|
| 14 |
from .preprocess import preprocess_for_ml, _get_fighter_history_stats
|
| 15 |
-
from .utils import calculate_age, prepare_fighters_data
|
|
|
|
| 16 |
|
| 17 |
class BaseModel(ABC):
|
| 18 |
"""
|
|
@@ -87,7 +88,7 @@ class BaseMLModel(BaseModel):
|
|
| 87 |
self.fighters_df = None
|
| 88 |
self.fighter_histories = {}
|
| 89 |
|
| 90 |
-
def train(self, train_fights:
|
| 91 |
"""
|
| 92 |
Trains the machine learning model. This involves loading fighter data,
|
| 93 |
pre-calculating histories, and fitting the model on the preprocessed data.
|
|
|
|
| 12 |
from ..analysis.elo import process_fights_for_elo, INITIAL_ELO
|
| 13 |
from ..config import FIGHTERS_CSV_PATH
|
| 14 |
from .preprocess import preprocess_for_ml, _get_fighter_history_stats
|
| 15 |
+
from .utils import calculate_age, prepare_fighters_data
|
| 16 |
+
from .config import DEFAULT_ELO
|
| 17 |
|
| 18 |
class BaseModel(ABC):
|
| 19 |
"""
|
|
|
|
| 88 |
self.fighters_df = None
|
| 89 |
self.fighter_histories = {}
|
| 90 |
|
| 91 |
+
def train(self, train_fights: list[dict[str, any]]) -> None:
|
| 92 |
"""
|
| 93 |
Trains the machine learning model. This involves loading fighter data,
|
| 94 |
pre-calculating histories, and fitting the model on the preprocessed data.
|
src/predict/pipeline.py
CHANGED
|
@@ -25,6 +25,9 @@ import json
|
|
| 25 |
import joblib
|
| 26 |
from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH, MODELS_DIR, LAST_EVENT_JSON_PATH
|
| 27 |
from .models import BaseModel
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
class PredictionPipeline:
|
| 30 |
"""
|
|
@@ -248,6 +251,72 @@ class PredictionPipeline:
|
|
| 248 |
if should_retrain:
|
| 249 |
self._train_and_save_models()
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
def update_models_if_new_data(self):
|
| 252 |
"""
|
| 253 |
Checks for new data and retrains/saves all models on the full dataset if needed.
|
|
|
|
| 25 |
import joblib
|
| 26 |
from ..config import FIGHTS_CSV_PATH, MODEL_RESULTS_PATH, MODELS_DIR, LAST_EVENT_JSON_PATH
|
| 27 |
from .models import BaseModel
|
| 28 |
+
from sklearn.model_selection import KFold
|
| 29 |
+
import mlflow
|
| 30 |
+
import mlflow.sklearn
|
| 31 |
|
| 32 |
class PredictionPipeline:
|
| 33 |
"""
|
|
|
|
| 251 |
if should_retrain:
|
| 252 |
self._train_and_save_models()
|
| 253 |
|
| 254 |
+
def run_kfold_cv(self, k: int = 3, holdout_events: int = 1):
|
| 255 |
+
"""Performs k-fold cross-validation where each fold is a set of events.
|
| 256 |
+
Within each fold, we keep the last *holdout_events* for testing."""
|
| 257 |
+
fights = self._load_fights()
|
| 258 |
+
|
| 259 |
+
# Build an ordered list of unique events
|
| 260 |
+
event_list = list(OrderedDict.fromkeys(f['event_name'] for f in fights))
|
| 261 |
+
|
| 262 |
+
# Initialize KFold splitter on events
|
| 263 |
+
kf = KFold(n_splits=k, shuffle=True, random_state=42)
|
| 264 |
+
|
| 265 |
+
all_fold_metrics = []
|
| 266 |
+
for fold_idx, (train_event_idx, test_event_idx) in enumerate(kf.split(event_list), start=1):
|
| 267 |
+
train_events = [event_list[i] for i in train_event_idx]
|
| 268 |
+
|
| 269 |
+
# Collect fights that belong to the training events
|
| 270 |
+
fold_fights = [f for f in fights if f['event_name'] in train_events]
|
| 271 |
+
|
| 272 |
+
# Inside this fold, reserve the last `holdout_events` events for testing
|
| 273 |
+
fold_events_ordered = list(OrderedDict.fromkeys(f['event_name'] for f in fold_fights))
|
| 274 |
+
test_events = fold_events_ordered[-holdout_events:]
|
| 275 |
+
|
| 276 |
+
train_set = [f for f in fold_fights if f['event_name'] not in test_events]
|
| 277 |
+
test_set = [f for f in fold_fights if f['event_name'] in test_events]
|
| 278 |
+
|
| 279 |
+
# Start an MLflow run for the current fold
|
| 280 |
+
mlflow.set_experiment("UFC_KFold_CV")
|
| 281 |
+
with mlflow.start_run(run_name=f"fold_{fold_idx}"):
|
| 282 |
+
# Log meta information about the fold
|
| 283 |
+
mlflow.log_param("fold", fold_idx)
|
| 284 |
+
mlflow.log_param("train_events", len(train_events))
|
| 285 |
+
mlflow.log_param("test_events", holdout_events)
|
| 286 |
+
|
| 287 |
+
fold_results = {}
|
| 288 |
+
for model in self.models:
|
| 289 |
+
model_name = model.__class__.__name__
|
| 290 |
+
|
| 291 |
+
# Train and evaluate
|
| 292 |
+
model.train(train_set)
|
| 293 |
+
correct = 0
|
| 294 |
+
total_fights = 0
|
| 295 |
+
for fight in test_set:
|
| 296 |
+
if fight['winner'] not in ["Draw", "NC", ""]:
|
| 297 |
+
prediction = model.predict(fight)
|
| 298 |
+
if prediction.get('winner') == fight['winner']:
|
| 299 |
+
correct += 1
|
| 300 |
+
total_fights += 1
|
| 301 |
+
|
| 302 |
+
acc = correct / total_fights if total_fights > 0 else 0.0
|
| 303 |
+
fold_results[model_name] = acc
|
| 304 |
+
|
| 305 |
+
# Log metrics and register model to appear in MLflow Models tab
|
| 306 |
+
mlflow.log_metric(f"accuracy_{model_name}", acc)
|
| 307 |
+
mlflow.log_metric(f"total_fights_{model_name}", total_fights)
|
| 308 |
+
|
| 309 |
+
# Register the model with MLflow to appear in Models tab
|
| 310 |
+
mlflow.sklearn.log_model(
|
| 311 |
+
model,
|
| 312 |
+
f"model_{model_name}",
|
| 313 |
+
registered_model_name=f"{model_name}_UFC_Model"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
all_fold_metrics.append(fold_results)
|
| 317 |
+
|
| 318 |
+
return all_fold_metrics
|
| 319 |
+
|
| 320 |
def update_models_if_new_data(self):
|
| 321 |
"""
|
| 322 |
Checks for new data and retrains/saves all models on the full dataset if needed.
|
src/predict/preprocess.py
CHANGED
|
@@ -1,22 +1,21 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import os
|
| 3 |
from datetime import datetime
|
| 4 |
-
from typing import Dict, List, Tuple, Any, Optional
|
| 5 |
-
from ..config import FIGHTERS_CSV_PATH
|
| 6 |
from .utils import (
|
| 7 |
parse_round_time_to_seconds, parse_striking_stats, to_int_safe,
|
| 8 |
-
calculate_age, prepare_fighters_data
|
| 9 |
)
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
def _get_fighter_history_stats(
|
| 14 |
fighter_name: str,
|
| 15 |
current_fight_date: datetime,
|
| 16 |
-
fighter_history:
|
| 17 |
fighters_df: pd.DataFrame,
|
| 18 |
n: int = N_FIGHTS_HISTORY
|
| 19 |
-
) ->
|
| 20 |
"""
|
| 21 |
Calculates performance statistics for a fighter based on their last n fights.
|
| 22 |
"""
|
|
@@ -82,9 +81,9 @@ def _get_fighter_history_stats(
|
|
| 82 |
}
|
| 83 |
|
| 84 |
def preprocess_for_ml(
|
| 85 |
-
fights_to_process:
|
| 86 |
fighters_csv_path: str
|
| 87 |
-
) ->
|
| 88 |
"""
|
| 89 |
Transforms raw fight and fighter data into a feature matrix (X) and target vector (y)
|
| 90 |
suitable for a binary classification machine learning model.
|
|
@@ -135,8 +134,8 @@ def preprocess_for_ml(
|
|
| 135 |
if isinstance(f2_stats, pd.DataFrame): f2_stats = f2_stats.iloc[0]
|
| 136 |
|
| 137 |
# Calculate ages for both fighters
|
| 138 |
-
f1_age =
|
| 139 |
-
f2_age =
|
| 140 |
|
| 141 |
# Get historical stats for both fighters
|
| 142 |
f1_hist_stats = _get_fighter_history_stats(f1_name, fight['date_obj'], fighter_histories.get(f1_name, []), fighters_prepared)
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import os
|
| 3 |
from datetime import datetime
|
|
|
|
|
|
|
| 4 |
from .utils import (
|
| 5 |
parse_round_time_to_seconds, parse_striking_stats, to_int_safe,
|
| 6 |
+
calculate_age, prepare_fighters_data
|
| 7 |
)
|
| 8 |
+
from .config import DEFAULT_ELO, N_FIGHTS_HISTORY
|
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
def _get_fighter_history_stats(
|
| 13 |
fighter_name: str,
|
| 14 |
current_fight_date: datetime,
|
| 15 |
+
fighter_history: list[dict[str, any]],
|
| 16 |
fighters_df: pd.DataFrame,
|
| 17 |
n: int = N_FIGHTS_HISTORY
|
| 18 |
+
) -> dict[str, float]:
|
| 19 |
"""
|
| 20 |
Calculates performance statistics for a fighter based on their last n fights.
|
| 21 |
"""
|
|
|
|
| 81 |
}
|
| 82 |
|
| 83 |
def preprocess_for_ml(
|
| 84 |
+
fights_to_process: list[dict[str, any]],
|
| 85 |
fighters_csv_path: str
|
| 86 |
+
) -> tuple[pd.DataFrame, pd.Series, pd.DataFrame]:
|
| 87 |
"""
|
| 88 |
Transforms raw fight and fighter data into a feature matrix (X) and target vector (y)
|
| 89 |
suitable for a binary classification machine learning model.
|
|
|
|
| 134 |
if isinstance(f2_stats, pd.DataFrame): f2_stats = f2_stats.iloc[0]
|
| 135 |
|
| 136 |
# Calculate ages for both fighters
|
| 137 |
+
f1_age = calculate_age(f1_stats.get('dob'), fight['event_date'])
|
| 138 |
+
f2_age = calculate_age(f2_stats.get('dob'), fight['event_date'])
|
| 139 |
|
| 140 |
# Get historical stats for both fighters
|
| 141 |
f1_hist_stats = _get_fighter_history_stats(f1_name, fight['date_obj'], fighter_histories.get(f1_name, []), fighters_prepared)
|
src/predict/utils.py
CHANGED
|
@@ -1,14 +1,8 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
-
import os
|
| 3 |
from datetime import datetime
|
| 4 |
-
from typing import Optional,
|
| 5 |
|
| 6 |
-
|
| 7 |
-
DEFAULT_ELO = 1500
|
| 8 |
-
DEFAULT_AGE = 0
|
| 9 |
-
DEFAULT_FIGHT_TIME = 0
|
| 10 |
-
DEFAULT_ROUNDS_DURATION = 5 * 60 # 5 minutes per round
|
| 11 |
-
N_FIGHTS_HISTORY = 5
|
| 12 |
|
| 13 |
def clean_numeric_column(series: pd.Series) -> pd.Series:
|
| 14 |
"""A helper to clean string columns into numbers, handling errors."""
|
|
|
|
| 1 |
import pandas as pd
|
|
|
|
| 2 |
from datetime import datetime
|
| 3 |
+
from typing import Optional, Any
|
| 4 |
|
| 5 |
+
from .config import DEFAULT_ROUNDS_DURATION
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def clean_numeric_column(series: pd.Series) -> pd.Series:
|
| 8 |
"""A helper to clean string columns into numbers, handling errors."""
|