# src/evaluation/eval_confusion.py import argparse from pathlib import Path import json import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix from tqdm import tqdm # Reuse the same dataset + model loading logic as eval_accuracy.py from src.evaluation.eval_accuracy import load_test_dataset, load_model_direct def load_class_names(labels_path: str = "configs/labels.json"): """ Try to load class names from labels.json. This is written to be robust to a few likely formats: - List: ["Abyssinian", "American Bulldog", ...] - Dict with string keys: {"0": "Abyssinian", "1": "American Bulldog", ...} - Dict with 'id_to_label': {"id_to_label": {"0": "Abyssinian", ...}} If anything goes wrong, returns None and we’ll just use numeric class IDs on the axes. """ try: with open(labels_path, "r") as f: data = json.load(f) except FileNotFoundError: print(f"[WARN] labels file not found at {labels_path}, using numeric IDs.") return None except json.JSONDecodeError: print(f"[WARN] Could not parse {labels_path}, using numeric IDs.") return None # Case 1: simple list if isinstance(data, list): return data # Case 2: dict with 'id_to_label' if isinstance(data, dict) and "id_to_label" in data: id_to_label = data["id_to_label"] # sort by integer key keys = sorted(id_to_label.keys(), key=lambda k: int(k)) return [id_to_label[k] for k in keys] # Case 3: dict mapping "0" -> "Abyssinian" if isinstance(data, dict): try: keys = sorted(data.keys(), key=lambda k: int(k)) return [data[k] for k in keys] except Exception: pass print(f"[WARN] Unrecognized labels.json format, using numeric IDs.") return None def collect_predictions(model_id: str, data_root: str): """ Run the given model across the Oxford-IIIT Pet test split and collect: - y_true: ground-truth integer class indices - y_pred: top-1 predicted class indices Uses the same model API as eval_accuracy.py: model.predict(PIL, top_k=5) """ print(f"\n=== Collecting predictions for model: {model_id} ===") dataset = load_test_dataset(data_root) model = load_model_direct(model_id) y_true = [] y_pred = [] for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"): img, target = dataset[idx] # img: PIL.Image, target: int # Same predict logic as eval_accuracy (support with/without top_k) try: result = model.predict(img, top_k=5) except TypeError: result = model.predict(img) pred_id = int(result.get("class_id")) y_true.append(int(target)) y_pred.append(pred_id) y_true = np.array(y_true) y_pred = np.array(y_pred) print(f" Collected {len(y_true)} predictions.") return y_true, y_pred def plot_confusion_matrix( cm: np.ndarray, class_names, title: str, save_path: Path, normalize: bool = True, ): """ Plot and save a confusion matrix. If normalize=True, each row (true class) is normalized to sum to 1. If class_names is None, we just use numeric indices on axes. """ if normalize: cm = cm.astype("float") row_sums = cm.sum(axis=1, keepdims=True) cm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0) num_classes = cm.shape[0] plt.figure(figsize=(12, 10)) im = plt.imshow(cm, interpolation="nearest", cmap="viridis") plt.title(title) plt.colorbar(im, fraction=0.046, pad=0.04) if class_names is not None and len(class_names) == num_classes: tick_labels = class_names else: tick_labels = list(range(num_classes)) plt.xticks( ticks=np.arange(num_classes), labels=tick_labels, rotation=90, fontsize=6, ) plt.yticks( ticks=np.arange(num_classes), labels=tick_labels, fontsize=6, ) plt.xlabel("Predicted class") plt.ylabel("True class") plt.tight_layout() plt.savefig(save_path, dpi=300) plt.close() print(f" Saved confusion matrix plot to: {save_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data-root", type=str, default="data/oxford-iiit-pet", help="Root directory of Oxford-IIIT Pet dataset.", ) parser.add_argument( "--labels-path", type=str, default="configs/labels.json", help="Path to labels.json (for axis names).", ) parser.add_argument( "--out-dir", type=str, default="outputs/confusion_matrices", help="Directory to save confusion matrices and plots.", ) args = parser.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) # Same set of models as eval_accuracy model_ids = [ "lr_raw", "svm_raw", "resnet_pt_lr", "resnet_pt_svm", ] class_names = load_class_names(args.labels_path) # y_true is identical for all models (same test split, same indexing), # but for clarity we recompute per model; confusion_matrix only needs # consistent labels (0..36) which we enforce below. for model_id in model_ids: y_true, y_pred = collect_predictions(model_id, args.data_root) # Define a fixed label ordering (0..max) to get 37x37 num_classes = int(y_true.max()) + 1 labels = list(range(num_classes)) cm = confusion_matrix(y_true, y_pred, labels=labels) # Save raw matrix for future analysis npy_path = out_dir / f"cm_{model_id}.npy" np.save(npy_path, cm) print(f" Saved raw confusion matrix to: {npy_path}") # Save a normalized plot png_path = out_dir / f"cm_{model_id}.png" title = f"Confusion Matrix ({model_id})" plot_confusion_matrix(cm, class_names, title, png_path, normalize=True) if __name__ == "__main__": main()