""" src/evaluate.py Three jobs: 1. Per-label metrics table (precision, recall, F1, AP) on a given split. 2. Per-label threshold tuning — find the threshold that maximises F1 for each label individually on the *val* split, save as thresholds.json. This replaces the naive global threshold=0.5 used during training. 3. Confusion image grids — for the 3 labels with worst F1, save 3x3 grids of false positives and false negatives so failures are visually obvious. Why per-label thresholds? 0.5 is optimal only when the positive class is ~50% and precision/recall matter equally. Neither is true here: rare labels like "foggy" or "tunnel" will be predicted with low confidence, so their optimal threshold is lower. Usage: python -m src.evaluate --checkpoint experiments/checkpoints/baseline_best.pt python -m src.evaluate --checkpoint --split val --tune-thresholds """ import argparse import json import logging from pathlib import Path import matplotlib matplotlib.use("Agg") # no GUI needed; must be set before importing pyplot import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from PIL import Image from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score from torch.utils.data import DataLoader from tqdm import tqdm from src.config import DATA_PROCESSED, LABELS, NUM_LABELS, SEED from src.dataset import BDDMultiLabelDataset, get_transforms from src.model import build_model from src.utils import get_device, set_seed logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") log = logging.getLogger(__name__) CONFUSION_DIR = Path("experiments/confusion_grids") THRESHOLDS_PATH = DATA_PROCESSED / "thresholds.json" # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- @torch.no_grad() def run_inference(model: torch.nn.Module, split: str, device: torch.device, batch_size: int = 64) -> tuple[np.ndarray, np.ndarray]: """ Run model on a full split. Returns: probs float32 array (N, NUM_LABELS) — post-sigmoid probabilities targets int array (N, NUM_LABELS) — ground truth binary labels """ ds = BDDMultiLabelDataset(split) loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) all_probs, all_targets = [], [] model.eval() for imgs, labels in tqdm(loader, desc=f" inference [{split}]", leave=False): imgs = imgs.to(device) logits = model(imgs) probs = torch.sigmoid(logits).cpu().numpy() all_probs.append(probs) all_targets.append(labels.numpy()) return np.vstack(all_probs).astype(np.float32), np.vstack(all_targets).astype(int) # --------------------------------------------------------------------------- # Threshold tuning # --------------------------------------------------------------------------- def tune_thresholds(probs: np.ndarray, targets: np.ndarray, candidates: np.ndarray = None) -> dict[str, float]: """ For each label, sweep candidate thresholds and pick the one with highest F1. Returns a dict {label_name: best_threshold}. """ if candidates is None: candidates = np.arange(0.1, 0.91, 0.05) thresholds = {} for i, label in enumerate(LABELS): best_t, best_f1 = 0.5, 0.0 for t in candidates: preds = (probs[:, i] >= t).astype(int) f1 = f1_score(targets[:, i], preds, zero_division=0) if f1 > best_f1: best_f1, best_t = f1, float(t) thresholds[label] = round(best_t, 2) return thresholds def load_thresholds(fallback: float = 0.5) -> dict[str, float]: """Load saved thresholds, or return a dict of fallback=0.5 for all labels.""" if THRESHOLDS_PATH.exists(): with open(THRESHOLDS_PATH) as f: return json.load(f) return {label: fallback for label in LABELS} # --------------------------------------------------------------------------- # Metrics # --------------------------------------------------------------------------- def compute_metrics(probs: np.ndarray, targets: np.ndarray, thresholds: dict[str, float]) -> pd.DataFrame: """ Per-label precision, recall, F1, AP using per-label thresholds. Returns a DataFrame sorted by F1 ascending (worst labels first). """ rows = [] for i, label in enumerate(LABELS): t = thresholds.get(label, 0.5) preds = (probs[:, i] >= t).astype(int) rows.append({ "label": label, "threshold": t, "precision": round(precision_score(targets[:, i], preds, zero_division=0), 4), "recall": round(recall_score(targets[:, i], preds, zero_division=0), 4), "f1": round(f1_score(targets[:, i], preds, zero_division=0), 4), "ap": round(average_precision_score(targets[:, i], probs[:, i]) if targets[:, i].sum() > 0 else 0.0, 4), "n_positive": int(targets[:, i].sum()), }) df = pd.DataFrame(rows).sort_values("f1") micro_f1 = f1_score(targets, (probs >= 0.5).astype(int), average="micro", zero_division=0) macro_f1 = f1_score(targets, (probs >= 0.5).astype(int), average="macro", zero_division=0) log.info("Micro-F1: %.4f | Macro-F1: %.4f", micro_f1, macro_f1) return df # --------------------------------------------------------------------------- # Confusion image grids # --------------------------------------------------------------------------- def _load_thumb(path: str, size: int = 160) -> np.ndarray: img = Image.open(path).convert("RGB").resize((size, size)) return np.array(img) def save_confusion_grid(image_paths: list[str], title: str, out_path: Path, grid: int = 3) -> None: """Save a grid x grid mosaic of images to out_path as PNG.""" n = min(grid * grid, len(image_paths)) if n == 0: return fig, axes = plt.subplots(grid, grid, figsize=(grid * 2.5, grid * 2.5)) fig.suptitle(title, fontsize=10, y=1.01) for idx, ax in enumerate(axes.flat): ax.axis("off") if idx < n: ax.imshow(_load_thumb(image_paths[idx])) plt.tight_layout() out_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, dpi=100, bbox_inches="tight") plt.close(fig) log.info("Saved confusion grid: %s", out_path) def save_confusion_grids(probs: np.ndarray, targets: np.ndarray, thresholds: dict[str, float], split: str, n_worst: int = 3) -> None: """ For the `n_worst` labels by F1, save false-positive and false-negative image grids to experiments/confusion_grids/. """ metrics_df = compute_metrics(probs, targets, thresholds) worst_labels = metrics_df.head(n_worst)["label"].tolist() ds = BDDMultiLabelDataset(split) image_paths = ds.df["image_path"].tolist() for label in worst_labels: i = LABELS.index(label) t = thresholds.get(label, 0.5) pred = (probs[:, i] >= t).astype(int) true = targets[:, i] fp_idx = np.where((pred == 1) & (true == 0))[0] fn_idx = np.where((pred == 0) & (true == 1))[0] # sort by confidence so the most confident errors are shown first fp_idx = fp_idx[np.argsort(probs[fp_idx, i])[::-1]] fn_idx = fn_idx[np.argsort(probs[fn_idx, i])] fp_paths = [image_paths[j] for j in fp_idx[:9]] fn_paths = [image_paths[j] for j in fn_idx[:9]] save_confusion_grid( fp_paths, f"False Positives — {label} (predicted {label}, actually not)", CONFUSION_DIR / f"{label}_false_positives.png", ) save_confusion_grid( fn_paths, f"False Negatives — {label} (missed {label}, actually present)", CONFUSION_DIR / f"{label}_false_negatives.png", ) # --------------------------------------------------------------------------- # Full evaluation pipeline # --------------------------------------------------------------------------- def evaluate(checkpoint: str, split: str = "test", tune: bool = False) -> pd.DataFrame: set_seed(SEED) device = get_device() model = build_model().to(device) model.load_state_dict(torch.load(checkpoint, map_location=device)) log.info("Loaded checkpoint: %s", checkpoint) # --- inference --- probs, targets = run_inference(model, split, device) # --- thresholds --- if tune or not THRESHOLDS_PATH.exists(): log.info("Tuning per-label thresholds on val split...") val_probs, val_targets = run_inference(model, "val", device) thresholds = tune_thresholds(val_probs, val_targets) THRESHOLDS_PATH.parent.mkdir(parents=True, exist_ok=True) with open(THRESHOLDS_PATH, "w") as f: json.dump(thresholds, f, indent=2) log.info("Saved thresholds to %s", THRESHOLDS_PATH) else: thresholds = load_thresholds() # --- metrics --- metrics_df = compute_metrics(probs, targets, thresholds) print("\n" + metrics_df.to_string(index=False)) out_csv = Path("experiments") / f"metrics_{split}.csv" out_csv.parent.mkdir(parents=True, exist_ok=True) metrics_df.to_csv(out_csv, index=False) log.info("Saved metrics to %s", out_csv) # --- confusion grids for 3 worst labels --- save_confusion_grids(probs, targets, thresholds, split, n_worst=3) return metrics_df # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate multi-label road scene model") parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint file") parser.add_argument("--split", default="test", choices=["train", "val", "test"]) parser.add_argument( "--tune-thresholds", action="store_true", help="Re-run threshold tuning on val split even if thresholds.json exists", ) args = parser.parse_args() evaluate(args.checkpoint, args.split, tune=args.tune_thresholds)