kishore-9's picture
Add road scene classifier app
9466fff
"""
src/evaluate.py
Three jobs:
1. Per-label metrics table (precision, recall, F1, AP) on a given split.
2. Per-label threshold tuning — find the threshold that maximises F1 for
each label individually on the *val* split, save as thresholds.json.
This replaces the naive global threshold=0.5 used during training.
3. Confusion image grids — for the 3 labels with worst F1, save 3x3 grids
of false positives and false negatives so failures are visually obvious.
Why per-label thresholds?
0.5 is optimal only when the positive class is ~50% and precision/recall
matter equally. Neither is true here: rare labels like "foggy" or "tunnel"
will be predicted with low confidence, so their optimal threshold is lower.
Usage:
python -m src.evaluate --checkpoint experiments/checkpoints/baseline_best.pt
python -m src.evaluate --checkpoint <path> --split val --tune-thresholds
"""
import argparse
import json
import logging
from pathlib import Path
import matplotlib
matplotlib.use("Agg") # no GUI needed; must be set before importing pyplot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image
from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.config import DATA_PROCESSED, LABELS, NUM_LABELS, SEED
from src.dataset import BDDMultiLabelDataset, get_transforms
from src.model import build_model
from src.utils import get_device, set_seed
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger(__name__)
CONFUSION_DIR = Path("experiments/confusion_grids")
THRESHOLDS_PATH = DATA_PROCESSED / "thresholds.json"
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
@torch.no_grad()
def run_inference(model: torch.nn.Module, split: str, device: torch.device,
batch_size: int = 64) -> tuple[np.ndarray, np.ndarray]:
"""
Run model on a full split.
Returns:
probs float32 array (N, NUM_LABELS) — post-sigmoid probabilities
targets int array (N, NUM_LABELS) — ground truth binary labels
"""
ds = BDDMultiLabelDataset(split)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
all_probs, all_targets = [], []
model.eval()
for imgs, labels in tqdm(loader, desc=f" inference [{split}]", leave=False):
imgs = imgs.to(device)
logits = model(imgs)
probs = torch.sigmoid(logits).cpu().numpy()
all_probs.append(probs)
all_targets.append(labels.numpy())
return np.vstack(all_probs).astype(np.float32), np.vstack(all_targets).astype(int)
# ---------------------------------------------------------------------------
# Threshold tuning
# ---------------------------------------------------------------------------
def tune_thresholds(probs: np.ndarray, targets: np.ndarray,
candidates: np.ndarray = None) -> dict[str, float]:
"""
For each label, sweep candidate thresholds and pick the one with highest F1.
Returns a dict {label_name: best_threshold}.
"""
if candidates is None:
candidates = np.arange(0.1, 0.91, 0.05)
thresholds = {}
for i, label in enumerate(LABELS):
best_t, best_f1 = 0.5, 0.0
for t in candidates:
preds = (probs[:, i] >= t).astype(int)
f1 = f1_score(targets[:, i], preds, zero_division=0)
if f1 > best_f1:
best_f1, best_t = f1, float(t)
thresholds[label] = round(best_t, 2)
return thresholds
def load_thresholds(fallback: float = 0.5) -> dict[str, float]:
"""Load saved thresholds, or return a dict of fallback=0.5 for all labels."""
if THRESHOLDS_PATH.exists():
with open(THRESHOLDS_PATH) as f:
return json.load(f)
return {label: fallback for label in LABELS}
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
def compute_metrics(probs: np.ndarray, targets: np.ndarray,
thresholds: dict[str, float]) -> pd.DataFrame:
"""
Per-label precision, recall, F1, AP using per-label thresholds.
Returns a DataFrame sorted by F1 ascending (worst labels first).
"""
rows = []
for i, label in enumerate(LABELS):
t = thresholds.get(label, 0.5)
preds = (probs[:, i] >= t).astype(int)
rows.append({
"label": label,
"threshold": t,
"precision": round(precision_score(targets[:, i], preds, zero_division=0), 4),
"recall": round(recall_score(targets[:, i], preds, zero_division=0), 4),
"f1": round(f1_score(targets[:, i], preds, zero_division=0), 4),
"ap": round(average_precision_score(targets[:, i], probs[:, i])
if targets[:, i].sum() > 0 else 0.0, 4),
"n_positive": int(targets[:, i].sum()),
})
df = pd.DataFrame(rows).sort_values("f1")
micro_f1 = f1_score(targets, (probs >= 0.5).astype(int), average="micro", zero_division=0)
macro_f1 = f1_score(targets, (probs >= 0.5).astype(int), average="macro", zero_division=0)
log.info("Micro-F1: %.4f | Macro-F1: %.4f", micro_f1, macro_f1)
return df
# ---------------------------------------------------------------------------
# Confusion image grids
# ---------------------------------------------------------------------------
def _load_thumb(path: str, size: int = 160) -> np.ndarray:
img = Image.open(path).convert("RGB").resize((size, size))
return np.array(img)
def save_confusion_grid(image_paths: list[str], title: str, out_path: Path,
grid: int = 3) -> None:
"""Save a grid x grid mosaic of images to out_path as PNG."""
n = min(grid * grid, len(image_paths))
if n == 0:
return
fig, axes = plt.subplots(grid, grid, figsize=(grid * 2.5, grid * 2.5))
fig.suptitle(title, fontsize=10, y=1.01)
for idx, ax in enumerate(axes.flat):
ax.axis("off")
if idx < n:
ax.imshow(_load_thumb(image_paths[idx]))
plt.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, dpi=100, bbox_inches="tight")
plt.close(fig)
log.info("Saved confusion grid: %s", out_path)
def save_confusion_grids(probs: np.ndarray, targets: np.ndarray,
thresholds: dict[str, float], split: str,
n_worst: int = 3) -> None:
"""
For the `n_worst` labels by F1, save false-positive and false-negative
image grids to experiments/confusion_grids/.
"""
metrics_df = compute_metrics(probs, targets, thresholds)
worst_labels = metrics_df.head(n_worst)["label"].tolist()
ds = BDDMultiLabelDataset(split)
image_paths = ds.df["image_path"].tolist()
for label in worst_labels:
i = LABELS.index(label)
t = thresholds.get(label, 0.5)
pred = (probs[:, i] >= t).astype(int)
true = targets[:, i]
fp_idx = np.where((pred == 1) & (true == 0))[0]
fn_idx = np.where((pred == 0) & (true == 1))[0]
# sort by confidence so the most confident errors are shown first
fp_idx = fp_idx[np.argsort(probs[fp_idx, i])[::-1]]
fn_idx = fn_idx[np.argsort(probs[fn_idx, i])]
fp_paths = [image_paths[j] for j in fp_idx[:9]]
fn_paths = [image_paths[j] for j in fn_idx[:9]]
save_confusion_grid(
fp_paths,
f"False Positives — {label} (predicted {label}, actually not)",
CONFUSION_DIR / f"{label}_false_positives.png",
)
save_confusion_grid(
fn_paths,
f"False Negatives — {label} (missed {label}, actually present)",
CONFUSION_DIR / f"{label}_false_negatives.png",
)
# ---------------------------------------------------------------------------
# Full evaluation pipeline
# ---------------------------------------------------------------------------
def evaluate(checkpoint: str, split: str = "test", tune: bool = False) -> pd.DataFrame:
set_seed(SEED)
device = get_device()
model = build_model().to(device)
model.load_state_dict(torch.load(checkpoint, map_location=device))
log.info("Loaded checkpoint: %s", checkpoint)
# --- inference ---
probs, targets = run_inference(model, split, device)
# --- thresholds ---
if tune or not THRESHOLDS_PATH.exists():
log.info("Tuning per-label thresholds on val split...")
val_probs, val_targets = run_inference(model, "val", device)
thresholds = tune_thresholds(val_probs, val_targets)
THRESHOLDS_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(THRESHOLDS_PATH, "w") as f:
json.dump(thresholds, f, indent=2)
log.info("Saved thresholds to %s", THRESHOLDS_PATH)
else:
thresholds = load_thresholds()
# --- metrics ---
metrics_df = compute_metrics(probs, targets, thresholds)
print("\n" + metrics_df.to_string(index=False))
out_csv = Path("experiments") / f"metrics_{split}.csv"
out_csv.parent.mkdir(parents=True, exist_ok=True)
metrics_df.to_csv(out_csv, index=False)
log.info("Saved metrics to %s", out_csv)
# --- confusion grids for 3 worst labels ---
save_confusion_grids(probs, targets, thresholds, split, n_worst=3)
return metrics_df
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate multi-label road scene model")
parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint file")
parser.add_argument("--split", default="test", choices=["train", "val", "test"])
parser.add_argument(
"--tune-thresholds", action="store_true",
help="Re-run threshold tuning on val split even if thresholds.json exists",
)
args = parser.parse_args()
evaluate(args.checkpoint, args.split, tune=args.tune_thresholds)