budijuarto's picture
Upload src/egg_damage/evaluate.py
ceffaf1 verified
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
classification_report,
confusion_matrix,
precision_score,
recall_score,
roc_auc_score,
f1_score,
)
from .compare_models import rank_models
from .config import load_config
from .data_discovery import CANONICAL_LABELS, ID_TO_LABEL, prepare_data
from .paths import ensure_dir
from .reporting import (
plot_calibration,
plot_combined_roc,
plot_confusion_matrix,
plot_metric_bars,
plot_precision_recall_curve_single,
plot_roc_curve_single,
plot_sample_grid,
write_markdown_report,
)
from .utils import elapsed_ms, get_logger, model_file_size_mb, save_json, timer
LOGGER = get_logger(__name__)
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: np.ndarray) -> dict[str, float | None]:
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp) if (tn + fp) else 0.0
sensitivity = tp / (tp + fn) if (tp + fn) else 0.0
try:
roc_auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) == 2 else None
except ValueError:
roc_auc = None
return {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, zero_division=0)),
"recall": float(recall_score(y_true, y_pred, zero_division=0)),
"f1": float(f1_score(y_true, y_pred, zero_division=0)),
"balanced_accuracy": float(balanced_accuracy_score(y_true, y_pred)),
"roc_auc": None if roc_auc is None else float(roc_auc),
"specificity": float(specificity),
"sensitivity": float(sensitivity),
}
def prediction_frame(
split_df: pd.DataFrame,
y_pred: np.ndarray,
y_prob: np.ndarray,
model_name: str,
split: str,
) -> pd.DataFrame:
out = split_df[["filepath", "label", "label_id", "split"]].copy().reset_index(drop=True)
out["model_name"] = model_name
out["eval_split"] = split
out["y_true"] = out["label_id"].astype(int)
out["y_pred"] = y_pred.astype(int)
out["prob_damaged"] = y_prob.astype(float)
out["pred_label"] = out["y_pred"].map(ID_TO_LABEL)
out["confidence"] = np.where(out["y_pred"] == 1, out["prob_damaged"], 1.0 - out["prob_damaged"])
out["is_correct"] = out["y_true"] == out["y_pred"]
return out
def save_prediction_outputs(
pred_df: pd.DataFrame,
metrics: dict[str, Any],
config: dict[str, Any],
model_name: str,
split: str,
) -> None:
output_dir = Path(config["paths"]["output_dir"])
pred_dir = ensure_dir(output_dir / "predictions")
plots_dir = ensure_dir(output_dir / "plots")
reports_dir = ensure_dir(output_dir / "reports")
safe = model_name.replace("/", "_")
pred_df.to_csv(pred_dir / f"{safe}_{split}_predictions.csv", index=False)
y_true = pred_df["y_true"].to_numpy()
y_pred = pred_df["y_pred"].to_numpy()
y_prob = pred_df["prob_damaged"].to_numpy()
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
plot_confusion_matrix(
cm,
plots_dir / f"confusion_matrix_{safe}_{split}.png",
f"{model_name} {split} Confusion Matrix",
CANONICAL_LABELS,
)
plot_roc_curve_single(y_true, y_prob, plots_dir / f"roc_{safe}_{split}.png", f"{model_name} {split} ROC")
if config["evaluation"].get("save_precision_recall_curve", True):
plot_precision_recall_curve_single(
y_true, y_prob, plots_dir / f"precision_recall_{safe}_{split}.png", f"{model_name} {split} PR"
)
if config["evaluation"].get("save_calibration_plot", False):
plot_calibration(y_true, y_prob, plots_dir / f"calibration_{safe}_{split}.png", f"{model_name} {split}")
report = classification_report(
y_true,
y_pred,
labels=[0, 1],
target_names=list(CANONICAL_LABELS),
zero_division=0,
output_dict=True,
)
with (reports_dir / f"classification_report_{safe}_{split}.json").open("w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
with (reports_dir / f"classification_report_{safe}_{split}.txt").open("w", encoding="utf-8") as f:
f.write(
classification_report(
y_true,
y_pred,
labels=[0, 1],
target_names=list(CANONICAL_LABELS),
zero_division=0,
)
)
save_json(metrics, reports_dir / f"metrics_{safe}_{split}.json")
def evaluate_classical(
model_path: Path,
splits_df: pd.DataFrame,
config: dict[str, Any],
) -> tuple[list[dict[str, Any]], list[pd.DataFrame]]:
from .classical_features import extract_feature_matrix
bundle = joblib.load(model_path)
pipeline = bundle["pipeline"]
metadata = bundle["metadata"]
model_name = metadata["model_name"]
feature_type = metadata["feature_type"]
rows: list[dict[str, Any]] = []
pred_frames: list[pd.DataFrame] = []
for split in ["val", "test"]:
split_df = splits_df[splits_df["split"] == split].reset_index(drop=True)
if split_df.empty:
continue
start = timer()
x, y_true, _ = extract_feature_matrix(split_df, feature_type, config, balance_train=False)
y_prob = pipeline.predict_proba(x)[:, 1]
y_pred = (y_prob >= float(config["evaluation"].get("threshold", 0.5))).astype(int)
avg_ms = elapsed_ms(start, len(split_df))
pred_df = prediction_frame(split_df, y_pred, y_prob, model_name, split)
metrics = compute_metrics(y_true, y_pred, y_prob)
row = {
"model_name": model_name,
"model_type": "classical",
"feature_type": feature_type,
"split": split,
"training_curves": "N/A",
"model_path": str(model_path),
"model_size_mb": model_file_size_mb(model_path),
"avg_inference_ms": avg_ms,
**metrics,
}
save_prediction_outputs(pred_df, row, config, model_name, split)
rows.append(row)
pred_frames.append(pred_df)
return rows, pred_frames
def evaluate_deep_learning(
model_path: Path,
splits_df: pd.DataFrame,
config: dict[str, Any],
) -> tuple[list[dict[str, Any]], list[pd.DataFrame]]:
import torch
from torch.utils.data import DataLoader
from .augmentations import build_eval_transform
from .dataset import EggImageDataset
from .dl_models import create_model, load_torch_checkpoint
checkpoint = load_torch_checkpoint(model_path, map_location="cpu")
model_key = checkpoint["model_key"]
model_name = checkpoint.get("model_name", model_key)
model = create_model(model_key, checkpoint.get("config", config), pretrained=False)
model.load_state_dict(checkpoint["state_dict"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
rows: list[dict[str, Any]] = []
pred_frames: list[pd.DataFrame] = []
batch_size = int(config["training"].get("batch_size", 16))
num_workers = int(config["training"].get("num_workers", 0))
for split in ["val", "test"]:
split_df = splits_df[splits_df["split"] == split].reset_index(drop=True)
if split_df.empty:
continue
loader = DataLoader(
EggImageDataset(split_df, transform=build_eval_transform(config)),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=bool(config["training"].get("pin_memory", True) and device.type == "cuda"),
)
all_prob: list[float] = []
all_pred: list[int] = []
all_true: list[int] = []
start = timer()
with torch.no_grad():
for images, labels, _ in loader:
images = images.to(device, non_blocking=True)
logits = model(images)
probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
pred = (probs >= float(config["evaluation"].get("threshold", 0.5))).astype(int)
all_prob.extend(probs.astype(float).tolist())
all_pred.extend(pred.astype(int).tolist())
all_true.extend(labels.numpy().astype(int).tolist())
avg_ms = elapsed_ms(start, len(split_df))
y_true = np.asarray(all_true, dtype=int)
y_pred = np.asarray(all_pred, dtype=int)
y_prob = np.asarray(all_prob, dtype=float)
pred_df = prediction_frame(split_df, y_pred, y_prob, model_name, split)
metrics = compute_metrics(y_true, y_pred, y_prob)
row = {
"model_name": model_name,
"model_type": "deep_learning",
"model_key": model_key,
"split": split,
"training_curves": str(Path(config["paths"]["output_dir"]) / "histories" / f"{model_key}_history.csv"),
"model_path": str(model_path),
"model_size_mb": model_file_size_mb(model_path),
"avg_inference_ms": avg_ms,
**metrics,
}
save_prediction_outputs(pred_df, row, config, model_name, split)
rows.append(row)
pred_frames.append(pred_df)
return rows, pred_frames
def find_model_files(config: dict[str, Any]) -> list[Path]:
model_dir = Path(config["paths"]["model_dir"])
files = sorted(model_dir.glob("*.joblib")) + sorted(model_dir.glob("*.pt"))
return [path for path in files if path.is_file()]
def evaluate_all(config: dict[str, Any]) -> pd.DataFrame:
split_csv = Path(config["paths"]["split_csv"])
splits_df = pd.read_csv(split_csv) if split_csv.exists() else prepare_data(config)
output_dir = ensure_dir(config["paths"]["output_dir"])
all_metrics: list[dict[str, Any]] = []
all_predictions: list[pd.DataFrame] = []
for model_path in find_model_files(config):
try:
LOGGER.info("Evaluating model %s", model_path)
if model_path.suffix == ".joblib":
rows, preds = evaluate_classical(model_path, splits_df, config)
elif model_path.suffix == ".pt":
rows, preds = evaluate_deep_learning(model_path, splits_df, config)
else:
continue
all_metrics.extend(rows)
all_predictions.extend(preds)
except Exception as exc:
LOGGER.exception("Skipping %s because evaluation failed: %s", model_path, exc)
metrics_df = pd.DataFrame(all_metrics)
if metrics_df.empty:
raise RuntimeError("No trained models were evaluated. Train models before running evaluation.")
metrics_df.to_csv(output_dir / "metrics_summary.csv", index=False)
save_json(metrics_df.to_dict(orient="records"), output_dir / "metrics_summary.json")
test_predictions = [(df["model_name"].iloc[0], df) for df in all_predictions if df["eval_split"].iloc[0] == "test"]
if test_predictions:
plot_combined_roc(test_predictions, output_dir / "plots" / "combined_roc_test.png")
plot_metric_bars(metrics_df, output_dir / "plots" / "metrics_bar_comparison.png")
misclassified = pd.concat(
[df[(df["eval_split"] == "test") & (~df["is_correct"])] for df in all_predictions],
ignore_index=True,
) if all_predictions else pd.DataFrame()
misclassified.to_csv(output_dir / "misclassified_samples.csv", index=False)
leaderboard = rank_models(metrics_df, config)
if not leaderboard.empty:
best_name = leaderboard.iloc[0]["model_name"]
best_preds = [df for name, df in test_predictions if name == best_name]
if best_preds:
best_df = best_preds[0].sort_values("confidence", ascending=False)
n = int(config["evaluation"].get("sample_grid_count", 12))
plot_sample_grid(
best_df[best_df["is_correct"]],
output_dir / "plots" / f"sample_predictions_correct_{best_name}.png",
f"{best_name}: Correct Test Predictions",
max_images=n,
)
plot_sample_grid(
best_df[~best_df["is_correct"]].sort_values("confidence", ascending=False),
output_dir / "plots" / f"sample_predictions_misclassified_{best_name}.png",
f"{best_name}: Misclassified Test Predictions",
max_images=n,
)
try:
if config.get("explainability", {}).get("enabled", True) and not leaderboard.empty:
from .explainability import save_gradcam_examples_for_best
save_gradcam_examples_for_best(config, splits_df, leaderboard)
except Exception as exc:
LOGGER.warning("Explainability generation skipped: %s", exc)
write_markdown_report(
config,
splits_df,
metrics_df,
leaderboard,
misclassified,
output_dir / "reports" / "model_report.md",
)
LOGGER.info("Saved metrics summary and report under %s", output_dir)
return metrics_df
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate all trained egg damage models.")
parser.add_argument("--config", default="configs/default.yaml")
args = parser.parse_args()
config = load_config(args.config)
evaluate_all(config)
if __name__ == "__main__":
main()