Spaces:
Sleeping
Sleeping
| # src/evaluation/eval_confusion.py | |
| import argparse | |
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import confusion_matrix | |
| from tqdm import tqdm | |
| # Reuse the same dataset + model loading logic as eval_accuracy.py | |
| from src.evaluation.eval_accuracy import load_test_dataset, load_model_direct | |
| def load_class_names(labels_path: str = "configs/labels.json"): | |
| """ | |
| Try to load class names from labels.json. | |
| This is written to be robust to a few likely formats: | |
| - List: ["Abyssinian", "American Bulldog", ...] | |
| - Dict with string keys: {"0": "Abyssinian", "1": "American Bulldog", ...} | |
| - Dict with 'id_to_label': {"id_to_label": {"0": "Abyssinian", ...}} | |
| If anything goes wrong, returns None and we’ll just use numeric class IDs on the axes. | |
| """ | |
| try: | |
| with open(labels_path, "r") as f: | |
| data = json.load(f) | |
| except FileNotFoundError: | |
| print(f"[WARN] labels file not found at {labels_path}, using numeric IDs.") | |
| return None | |
| except json.JSONDecodeError: | |
| print(f"[WARN] Could not parse {labels_path}, using numeric IDs.") | |
| return None | |
| # Case 1: simple list | |
| if isinstance(data, list): | |
| return data | |
| # Case 2: dict with 'id_to_label' | |
| if isinstance(data, dict) and "id_to_label" in data: | |
| id_to_label = data["id_to_label"] | |
| # sort by integer key | |
| keys = sorted(id_to_label.keys(), key=lambda k: int(k)) | |
| return [id_to_label[k] for k in keys] | |
| # Case 3: dict mapping "0" -> "Abyssinian" | |
| if isinstance(data, dict): | |
| try: | |
| keys = sorted(data.keys(), key=lambda k: int(k)) | |
| return [data[k] for k in keys] | |
| except Exception: | |
| pass | |
| print(f"[WARN] Unrecognized labels.json format, using numeric IDs.") | |
| return None | |
| def collect_predictions(model_id: str, data_root: str): | |
| """ | |
| Run the given model across the Oxford-IIIT Pet test split and collect: | |
| - y_true: ground-truth integer class indices | |
| - y_pred: top-1 predicted class indices | |
| Uses the same model API as eval_accuracy.py: model.predict(PIL, top_k=5) | |
| """ | |
| print(f"\n=== Collecting predictions for model: {model_id} ===") | |
| dataset = load_test_dataset(data_root) | |
| model = load_model_direct(model_id) | |
| y_true = [] | |
| y_pred = [] | |
| for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"): | |
| img, target = dataset[idx] # img: PIL.Image, target: int | |
| # Same predict logic as eval_accuracy (support with/without top_k) | |
| try: | |
| result = model.predict(img, top_k=5) | |
| except TypeError: | |
| result = model.predict(img) | |
| pred_id = int(result.get("class_id")) | |
| y_true.append(int(target)) | |
| y_pred.append(pred_id) | |
| y_true = np.array(y_true) | |
| y_pred = np.array(y_pred) | |
| print(f" Collected {len(y_true)} predictions.") | |
| return y_true, y_pred | |
| def plot_confusion_matrix( | |
| cm: np.ndarray, | |
| class_names, | |
| title: str, | |
| save_path: Path, | |
| normalize: bool = True, | |
| ): | |
| """ | |
| Plot and save a confusion matrix. | |
| If normalize=True, each row (true class) is normalized to sum to 1. | |
| If class_names is None, we just use numeric indices on axes. | |
| """ | |
| if normalize: | |
| cm = cm.astype("float") | |
| row_sums = cm.sum(axis=1, keepdims=True) | |
| cm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0) | |
| num_classes = cm.shape[0] | |
| plt.figure(figsize=(12, 10)) | |
| im = plt.imshow(cm, interpolation="nearest", cmap="viridis") | |
| plt.title(title) | |
| plt.colorbar(im, fraction=0.046, pad=0.04) | |
| if class_names is not None and len(class_names) == num_classes: | |
| tick_labels = class_names | |
| else: | |
| tick_labels = list(range(num_classes)) | |
| plt.xticks( | |
| ticks=np.arange(num_classes), | |
| labels=tick_labels, | |
| rotation=90, | |
| fontsize=6, | |
| ) | |
| plt.yticks( | |
| ticks=np.arange(num_classes), | |
| labels=tick_labels, | |
| fontsize=6, | |
| ) | |
| plt.xlabel("Predicted class") | |
| plt.ylabel("True class") | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300) | |
| plt.close() | |
| print(f" Saved confusion matrix plot to: {save_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--data-root", | |
| type=str, | |
| default="data/oxford-iiit-pet", | |
| help="Root directory of Oxford-IIIT Pet dataset.", | |
| ) | |
| parser.add_argument( | |
| "--labels-path", | |
| type=str, | |
| default="configs/labels.json", | |
| help="Path to labels.json (for axis names).", | |
| ) | |
| parser.add_argument( | |
| "--out-dir", | |
| type=str, | |
| default="outputs/confusion_matrices", | |
| help="Directory to save confusion matrices and plots.", | |
| ) | |
| args = parser.parse_args() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Same set of models as eval_accuracy | |
| model_ids = [ | |
| "lr_raw", | |
| "svm_raw", | |
| "resnet_pt_lr", | |
| "resnet_pt_svm", | |
| ] | |
| class_names = load_class_names(args.labels_path) | |
| # y_true is identical for all models (same test split, same indexing), | |
| # but for clarity we recompute per model; confusion_matrix only needs | |
| # consistent labels (0..36) which we enforce below. | |
| for model_id in model_ids: | |
| y_true, y_pred = collect_predictions(model_id, args.data_root) | |
| # Define a fixed label ordering (0..max) to get 37x37 | |
| num_classes = int(y_true.max()) + 1 | |
| labels = list(range(num_classes)) | |
| cm = confusion_matrix(y_true, y_pred, labels=labels) | |
| # Save raw matrix for future analysis | |
| npy_path = out_dir / f"cm_{model_id}.npy" | |
| np.save(npy_path, cm) | |
| print(f" Saved raw confusion matrix to: {npy_path}") | |
| # Save a normalized plot | |
| png_path = out_dir / f"cm_{model_id}.png" | |
| title = f"Confusion Matrix ({model_id})" | |
| plot_confusion_matrix(cm, class_names, title, png_path, normalize=True) | |
| if __name__ == "__main__": | |
| main() | |