| """ |
| src/evaluate.py |
| |
| Three jobs: |
| 1. Per-label metrics table (precision, recall, F1, AP) on a given split. |
| 2. Per-label threshold tuning — find the threshold that maximises F1 for |
| each label individually on the *val* split, save as thresholds.json. |
| This replaces the naive global threshold=0.5 used during training. |
| 3. Confusion image grids — for the 3 labels with worst F1, save 3x3 grids |
| of false positives and false negatives so failures are visually obvious. |
| |
| Why per-label thresholds? |
| 0.5 is optimal only when the positive class is ~50% and precision/recall |
| matter equally. Neither is true here: rare labels like "foggy" or "tunnel" |
| will be predicted with low confidence, so their optimal threshold is lower. |
| |
| Usage: |
| python -m src.evaluate --checkpoint experiments/checkpoints/baseline_best.pt |
| python -m src.evaluate --checkpoint <path> --split val --tune-thresholds |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image |
| from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from src.config import DATA_PROCESSED, LABELS, NUM_LABELS, SEED |
| from src.dataset import BDDMultiLabelDataset, get_transforms |
| from src.model import build_model |
| from src.utils import get_device, set_seed |
|
|
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") |
| log = logging.getLogger(__name__) |
|
|
| CONFUSION_DIR = Path("experiments/confusion_grids") |
| THRESHOLDS_PATH = DATA_PROCESSED / "thresholds.json" |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def run_inference(model: torch.nn.Module, split: str, device: torch.device, |
| batch_size: int = 64) -> tuple[np.ndarray, np.ndarray]: |
| """ |
| Run model on a full split. |
| |
| Returns: |
| probs float32 array (N, NUM_LABELS) — post-sigmoid probabilities |
| targets int array (N, NUM_LABELS) — ground truth binary labels |
| """ |
| ds = BDDMultiLabelDataset(split) |
| loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
| all_probs, all_targets = [], [] |
| model.eval() |
| for imgs, labels in tqdm(loader, desc=f" inference [{split}]", leave=False): |
| imgs = imgs.to(device) |
| logits = model(imgs) |
| probs = torch.sigmoid(logits).cpu().numpy() |
| all_probs.append(probs) |
| all_targets.append(labels.numpy()) |
|
|
| return np.vstack(all_probs).astype(np.float32), np.vstack(all_targets).astype(int) |
|
|
|
|
| |
| |
| |
|
|
| def tune_thresholds(probs: np.ndarray, targets: np.ndarray, |
| candidates: np.ndarray = None) -> dict[str, float]: |
| """ |
| For each label, sweep candidate thresholds and pick the one with highest F1. |
| |
| Returns a dict {label_name: best_threshold}. |
| """ |
| if candidates is None: |
| candidates = np.arange(0.1, 0.91, 0.05) |
|
|
| thresholds = {} |
| for i, label in enumerate(LABELS): |
| best_t, best_f1 = 0.5, 0.0 |
| for t in candidates: |
| preds = (probs[:, i] >= t).astype(int) |
| f1 = f1_score(targets[:, i], preds, zero_division=0) |
| if f1 > best_f1: |
| best_f1, best_t = f1, float(t) |
| thresholds[label] = round(best_t, 2) |
| return thresholds |
|
|
|
|
| def load_thresholds(fallback: float = 0.5) -> dict[str, float]: |
| """Load saved thresholds, or return a dict of fallback=0.5 for all labels.""" |
| if THRESHOLDS_PATH.exists(): |
| with open(THRESHOLDS_PATH) as f: |
| return json.load(f) |
| return {label: fallback for label in LABELS} |
|
|
|
|
| |
| |
| |
|
|
| def compute_metrics(probs: np.ndarray, targets: np.ndarray, |
| thresholds: dict[str, float]) -> pd.DataFrame: |
| """ |
| Per-label precision, recall, F1, AP using per-label thresholds. |
| Returns a DataFrame sorted by F1 ascending (worst labels first). |
| """ |
| rows = [] |
| for i, label in enumerate(LABELS): |
| t = thresholds.get(label, 0.5) |
| preds = (probs[:, i] >= t).astype(int) |
| rows.append({ |
| "label": label, |
| "threshold": t, |
| "precision": round(precision_score(targets[:, i], preds, zero_division=0), 4), |
| "recall": round(recall_score(targets[:, i], preds, zero_division=0), 4), |
| "f1": round(f1_score(targets[:, i], preds, zero_division=0), 4), |
| "ap": round(average_precision_score(targets[:, i], probs[:, i]) |
| if targets[:, i].sum() > 0 else 0.0, 4), |
| "n_positive": int(targets[:, i].sum()), |
| }) |
|
|
| df = pd.DataFrame(rows).sort_values("f1") |
| micro_f1 = f1_score(targets, (probs >= 0.5).astype(int), average="micro", zero_division=0) |
| macro_f1 = f1_score(targets, (probs >= 0.5).astype(int), average="macro", zero_division=0) |
| log.info("Micro-F1: %.4f | Macro-F1: %.4f", micro_f1, macro_f1) |
| return df |
|
|
|
|
| |
| |
| |
|
|
| def _load_thumb(path: str, size: int = 160) -> np.ndarray: |
| img = Image.open(path).convert("RGB").resize((size, size)) |
| return np.array(img) |
|
|
|
|
| def save_confusion_grid(image_paths: list[str], title: str, out_path: Path, |
| grid: int = 3) -> None: |
| """Save a grid x grid mosaic of images to out_path as PNG.""" |
| n = min(grid * grid, len(image_paths)) |
| if n == 0: |
| return |
| fig, axes = plt.subplots(grid, grid, figsize=(grid * 2.5, grid * 2.5)) |
| fig.suptitle(title, fontsize=10, y=1.01) |
| for idx, ax in enumerate(axes.flat): |
| ax.axis("off") |
| if idx < n: |
| ax.imshow(_load_thumb(image_paths[idx])) |
| plt.tight_layout() |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| plt.savefig(out_path, dpi=100, bbox_inches="tight") |
| plt.close(fig) |
| log.info("Saved confusion grid: %s", out_path) |
|
|
|
|
| def save_confusion_grids(probs: np.ndarray, targets: np.ndarray, |
| thresholds: dict[str, float], split: str, |
| n_worst: int = 3) -> None: |
| """ |
| For the `n_worst` labels by F1, save false-positive and false-negative |
| image grids to experiments/confusion_grids/. |
| """ |
| metrics_df = compute_metrics(probs, targets, thresholds) |
| worst_labels = metrics_df.head(n_worst)["label"].tolist() |
|
|
| ds = BDDMultiLabelDataset(split) |
| image_paths = ds.df["image_path"].tolist() |
|
|
| for label in worst_labels: |
| i = LABELS.index(label) |
| t = thresholds.get(label, 0.5) |
| pred = (probs[:, i] >= t).astype(int) |
| true = targets[:, i] |
|
|
| fp_idx = np.where((pred == 1) & (true == 0))[0] |
| fn_idx = np.where((pred == 0) & (true == 1))[0] |
|
|
| |
| fp_idx = fp_idx[np.argsort(probs[fp_idx, i])[::-1]] |
| fn_idx = fn_idx[np.argsort(probs[fn_idx, i])] |
|
|
| fp_paths = [image_paths[j] for j in fp_idx[:9]] |
| fn_paths = [image_paths[j] for j in fn_idx[:9]] |
|
|
| save_confusion_grid( |
| fp_paths, |
| f"False Positives — {label} (predicted {label}, actually not)", |
| CONFUSION_DIR / f"{label}_false_positives.png", |
| ) |
| save_confusion_grid( |
| fn_paths, |
| f"False Negatives — {label} (missed {label}, actually present)", |
| CONFUSION_DIR / f"{label}_false_negatives.png", |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def evaluate(checkpoint: str, split: str = "test", tune: bool = False) -> pd.DataFrame: |
| set_seed(SEED) |
| device = get_device() |
|
|
| model = build_model().to(device) |
| model.load_state_dict(torch.load(checkpoint, map_location=device)) |
| log.info("Loaded checkpoint: %s", checkpoint) |
|
|
| |
| probs, targets = run_inference(model, split, device) |
|
|
| |
| if tune or not THRESHOLDS_PATH.exists(): |
| log.info("Tuning per-label thresholds on val split...") |
| val_probs, val_targets = run_inference(model, "val", device) |
| thresholds = tune_thresholds(val_probs, val_targets) |
| THRESHOLDS_PATH.parent.mkdir(parents=True, exist_ok=True) |
| with open(THRESHOLDS_PATH, "w") as f: |
| json.dump(thresholds, f, indent=2) |
| log.info("Saved thresholds to %s", THRESHOLDS_PATH) |
| else: |
| thresholds = load_thresholds() |
|
|
| |
| metrics_df = compute_metrics(probs, targets, thresholds) |
| print("\n" + metrics_df.to_string(index=False)) |
|
|
| out_csv = Path("experiments") / f"metrics_{split}.csv" |
| out_csv.parent.mkdir(parents=True, exist_ok=True) |
| metrics_df.to_csv(out_csv, index=False) |
| log.info("Saved metrics to %s", out_csv) |
|
|
| |
| save_confusion_grids(probs, targets, thresholds, split, n_worst=3) |
|
|
| return metrics_df |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Evaluate multi-label road scene model") |
| parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint file") |
| parser.add_argument("--split", default="test", choices=["train", "val", "test"]) |
| parser.add_argument( |
| "--tune-thresholds", action="store_true", |
| help="Re-run threshold tuning on val split even if thresholds.json exists", |
| ) |
| args = parser.parse_args() |
| evaluate(args.checkpoint, args.split, tune=args.tune_thresholds) |
|
|