budijuarto's picture
Upload src/egg_damage/compare_models.py
094ac5e verified
from __future__ import annotations
from pathlib import Path
from typing import Any
import pandas as pd
from .paths import ensure_dir
from .utils import model_file_size_mb, save_json
RANK_COLUMNS = ["f1", "roc_auc", "balanced_accuracy"]
def rank_models(metrics_df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame:
if metrics_df.empty:
return pd.DataFrame()
candidates = metrics_df[metrics_df["split"] == "test"].copy()
if candidates.empty:
candidates = metrics_df[metrics_df["split"] == "val"].copy()
for col in RANK_COLUMNS:
candidates[col] = pd.to_numeric(candidates[col], errors="coerce").fillna(0.0)
if "model_size_mb" not in candidates:
candidates["model_size_mb"] = candidates["model_path"].apply(model_file_size_mb)
candidates["model_size_mb"] = pd.to_numeric(candidates["model_size_mb"], errors="coerce").fillna(1e9)
if "avg_inference_ms" not in candidates:
candidates["avg_inference_ms"] = 1e9
candidates["avg_inference_ms"] = pd.to_numeric(candidates["avg_inference_ms"], errors="coerce").fillna(1e9)
leaderboard = candidates.sort_values(
by=["f1", "roc_auc", "balanced_accuracy", "model_size_mb", "avg_inference_ms"],
ascending=[False, False, False, True, True],
).reset_index(drop=True)
leaderboard.insert(0, "rank", range(1, len(leaderboard) + 1))
output_dir = ensure_dir(config["paths"]["output_dir"])
leaderboard_path = output_dir / "leaderboard.csv"
leaderboard.to_csv(leaderboard_path, index=False)
save_json(leaderboard.to_dict(orient="records"), output_dir / "leaderboard.json")
if not leaderboard.empty:
best = leaderboard.iloc[0].to_dict()
save_json(best, output_dir / "best_model.json")
return leaderboard
def load_best_model_record(config: dict[str, Any]) -> dict[str, Any]:
from .utils import load_json
path = Path(config["paths"]["output_dir"]) / "best_model.json"
if not path.exists():
raise FileNotFoundError(
f"Best model record not found at {path}. Run training and evaluation first."
)
record = load_json(path)
model_path = record.get("model_path")
if model_path and not Path(str(model_path)).exists():
candidate = Path(config["paths"]["model_dir"]) / Path(str(model_path)).name
if candidate.exists():
record["model_path"] = str(candidate)
return record