Martinacap02's picture
Init deploy branch for HF Space
f7d11f7
from __future__ import annotations
import argparse
import json
from pathlib import Path
import dagshub
from imblearn.over_sampling import RandomOverSampler
import joblib
from loguru import logger
import mlflow
import pandas as pd
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from predicting_outcomes_in_heart_failure.config import (
CONFIG_DT,
CONFIG_LR,
CONFIG_RF,
DATASET_NAME,
EXPERIMENT_NAME,
MODELS_DIR,
N_SPLITS,
PROCESSED_DATA_DIR,
RANDOM_STATE,
REPO_NAME,
REPO_OWNER,
REPORTS_DIR,
SCORING,
TARGET_COL,
VALID_MODELS,
VALID_VARIANTS,
)
REFIT = "f1"
def load_split(path: Path) -> pd.DataFrame:
if not path.exists():
logger.error(f"Missing split file: {path}. Run split_data.py first.")
raise FileNotFoundError(path)
df = pd.read_csv(path)
logger.info(f"Loaded {path} (rows={len(df)}, cols={df.shape[1]})")
return df
def apply_random_oversampling(
X: pd.DataFrame,
y: pd.Series,
model_name: str,
variant: str,
):
"""Apply RandomOverSampler to balance classes in the training set."""
logger.info(f"[{variant} | {model_name}] Applying RandomOverSampler on training data...")
# Log original class distribution
orig_counts = y.value_counts().to_dict()
logger.info(f"[{variant} | {model_name}] Original class distribution: {orig_counts}")
ros = RandomOverSampler(random_state=RANDOM_STATE)
X_res, y_res = ros.fit_resample(X, y)
# Log resampled class distribution
res_counts = y_res.value_counts().to_dict()
logger.info(f"[{variant} | {model_name}] Resampled class distribution: {res_counts}")
logger.success(f"[{variant} | {model_name}] RandomOverSampler applied successfully.")
return X_res, y_res
def get_model_and_grid(model_name: str):
"""Return estimator and parameter grid for the selected model."""
if model_name == "decision_tree":
from sklearn.tree import DecisionTreeClassifier
estimator = DecisionTreeClassifier(random_state=RANDOM_STATE)
param_grid = CONFIG_DT
return estimator, param_grid
elif model_name == "logreg":
from sklearn.linear_model import LogisticRegression
estimator = LogisticRegression(max_iter=500, random_state=RANDOM_STATE)
param_grid = CONFIG_LR
return estimator, param_grid
elif model_name == "random_forest":
from sklearn.ensemble import RandomForestClassifier
estimator = RandomForestClassifier(random_state=RANDOM_STATE)
param_grid = CONFIG_RF
return estimator, param_grid
else:
raise ValueError(f"Unknown model_name: {model_name}")
def run_grid_search(
estimator,
param_grid,
X_train,
y_train,
model_name: str,
variant: str,
reports_dir: Path,
):
"""Run GridSearchCV for the specified model and log CV results."""
cv = StratifiedKFold(
n_splits=N_SPLITS,
shuffle=True,
random_state=RANDOM_STATE,
)
grid = GridSearchCV(
estimator=estimator,
param_grid=param_grid,
scoring=SCORING,
refit=REFIT,
cv=cv,
n_jobs=-1,
verbose=1,
return_train_score=True,
)
logger.info(f"[{variant} | {model_name}] Starting GridSearchCV …")
grid.fit(X_train, y_train)
logger.success(f"[{variant} | {model_name}] GridSearchCV completed.")
logger.info(f"[{variant} | {model_name}] Best params ({REFIT}): {grid.best_params_}")
logger.info(f"[{variant} | {model_name}] Best CV {REFIT}: {grid.best_score_:.4f}")
cv_results_path = reports_dir / "cv_results.csv"
df = pd.DataFrame(grid.cv_results_)
df.to_csv(cv_results_path, index=False)
mlflow.log_artifact(str(cv_results_path))
return grid.best_estimator_, grid, grid.best_params_
def save_artifacts(
model,
grid,
X_train,
model_name: str,
variant: str,
model_dir: Path,
reports_dir: Path,
) -> None:
"""Save model, parameters, and metadata to disk and MLflow."""
model_dir.mkdir(parents=True, exist_ok=True)
reports_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / f"{model_name}.joblib"
joblib.dump(model, model_path)
logger.success(f"[{variant} | {model_name}] Saved model → {model_path}")
out = {
"model_name": model_name,
"data_variant": variant,
"cv": {
"refit": REFIT,
"best_score": getattr(grid, "best_score_", None),
"best_params": getattr(grid, "best_params_", None),
"scoring": list(SCORING.keys()),
"n_splits": N_SPLITS,
"random_state": RANDOM_STATE,
},
"features": list(X_train.columns),
}
cv_params_path = reports_dir / "cv_parameters.json"
with open(cv_params_path, "w", encoding="utf-8") as f:
json.dump(out, f, indent=4)
mlflow.log_artifact(str(cv_params_path))
logger.success(f"[{variant} | {model_name}] Saved artifacts.")
def train(model_name: str, variant: str):
"""Train a model for a specific dataset variant and log results to MLflow."""
experiment_name = f"{EXPERIMENT_NAME}_{variant}"
if not mlflow.get_experiment_by_name(experiment_name):
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
train_path = PROCESSED_DATA_DIR / variant / "train.csv"
run_name = f"{model_name}_{variant}"
logger.info(f"=== Training started (model={model_name}, variant={variant}) ===")
with mlflow.start_run(run_name=run_name):
train_df = load_split(train_path)
rawdata = mlflow.data.from_pandas(train_df, name=f"{DATASET_NAME}_{variant}")
mlflow.log_input(rawdata, context="training")
X_train = train_df.drop(columns=[TARGET_COL])
y_train = train_df[TARGET_COL].astype(int)
X_train, y_train = apply_random_oversampling(
X_train,
y_train,
model_name=model_name,
variant=variant,
)
estimator, param_grid = get_model_and_grid(model_name)
mlflow.set_tag("estimator_name", estimator.__class__.__name__)
mlflow.set_tag("data_variant", variant)
mlflow.log_param("data_variant", variant)
model_dir = MODELS_DIR / variant
reports_dir = REPORTS_DIR / variant / model_name
reports_dir.mkdir(parents=True, exist_ok=True)
best_model, grid, params = run_grid_search(
estimator,
param_grid,
X_train,
y_train,
model_name=model_name,
variant=variant,
reports_dir=reports_dir,
)
mlflow.log_params(params)
save_artifacts(
best_model,
grid,
X_train,
model_name=model_name,
variant=variant,
model_dir=model_dir,
reports_dir=reports_dir,
)
logger.success(f"=== Training completed (model={model_name}, variant={variant}) ===")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--variant",
type=str,
choices=VALID_VARIANTS,
required=True,
help="Data variant to use: all, female, male, or nosex.",
)
parser.add_argument(
"--model",
type=str,
choices=VALID_MODELS,
required=True,
help="Model to train: logreg, random_forest, or decision_tree.",
)
args = parser.parse_args()
dagshub.init(repo_owner=REPO_OWNER, repo_name=REPO_NAME, mlflow=True)
train(args.model, args.variant)
if __name__ == "__main__":
main()