# src/evaluation/eval_accuracy.py import argparse from collections import defaultdict import numpy as np from tqdm import tqdm from sklearn.metrics import accuracy_score, classification_report from torchvision.datasets import OxfordIIITPet from src.registry import get_model import torch def load_test_dataset(data_root: str): """ Load Oxford-IIIT Pet test split without transforms, so we get PIL images. Targets will be integer class indices (0..36). """ dataset = OxfordIIITPet( root=data_root, split="test", target_types="category", transform=None, # we want raw PIL here ) return dataset def load_model_direct(model_id: str): """ Workaround loader that bypasses registry and constructs models using their actual existing constructor signatures. Modify only the paths here if needed. """ if model_id == "lr_raw": from src.inference.lr_model import LRModel # Adjust to match your actual LRModel __init__ return LRModel("checkpoints/lr_model.joblib", "configs/labels.json") elif model_id == "svm_raw": from src.inference.svm_model import SVMModel return SVMModel("checkpoints/svm_model.joblib", "configs/labels.json") elif model_id == "resnet_pt_lr": from src.inference.resnet_pt_lr_model import ResNetPTLRModel # If these require device or not, match your working constructor return ResNetPTLRModel( ckpt_path="checkpoints/resnet_pt_lr_head.joblib", labels_path="configs/labels.json", ) elif model_id == "resnet_pt_svm": from src.inference.resnet_pt_svm_model import ResNetPTSVMModel return ResNetPTSVMModel( ckpt_path="checkpoints/resnet_pt_svm_head.joblib", labels_path="configs/labels.json", ) else: raise ValueError(f"Unsupported model_id: {model_id}") def evaluate_model_on_dataset(model_id: str, data_root: str): """ Evaluate a single model (by id from registry) on the Oxford-IIIT Pet test split. Uses model.predict(PIL.Image, top_k=5) API. Returns a dict with: - top1_acc - top5_acc - report_dict (per-class and aggregate metrics) """ print(f"\n=== Evaluating model: {model_id} ===") dataset = load_test_dataset(data_root) model = load_model_direct(model_id) y_true = [] y_pred_top1 = [] top5_correct = 0 for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"): img, target = dataset[idx] # img: PIL.Image, target: int # Try to call with top_k; if the model doesn't support it, fall back gracefully try: result = model.predict(img, top_k=5) except TypeError: # Older / simpler API: predict(img) without top_k result = model.predict(img) # Top-1 prediction (must exist) pred_id = int(result.get("class_id")) y_true.append(int(target)) y_pred_top1.append(pred_id) # Try to get top_k list; if not present, create a synthetic one using only top-1 top_k = result.get("top_k") if not top_k: # Fallback: just treat the top-1 prediction as the only candidate. # This means Top-5 == Top-1 for such models, which is acceptable as a workaround. cname = result.get("class_name", "") top_k = [{ "class_id": pred_id, "class_name": cname, "probability": 1.0 }] # Top-5 correct? (GT in top_k list) if any(int(entry.get("class_id")) == int(target) for entry in top_k): top5_correct += 1 y_true = np.array(y_true) y_pred_top1 = np.array(y_pred_top1) n = len(y_true) # Overall Top-1 accuracy top1_acc = accuracy_score(y_true, y_pred_top1) # Overall Top-5 accuracy top5_acc = top5_correct / float(n) # Detailed precision/recall/F1 per class + aggregate report = classification_report( y_true, y_pred_top1, digits=4, output_dict=True # gives a nice dict we can log/inspect ) print(f"Top-1 accuracy ({model_id}): {top1_acc:.4f}") print(f"Top-5 accuracy ({model_id}): {top5_acc:.4f}") print("\nMacro avg (from classification_report):") print(report["macro avg"]) print("\nWeighted avg (from classification_report):") print(report["weighted avg"]) return { "model_id": model_id, "top1_acc": top1_acc, "top5_acc": top5_acc, "report": report, } def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data-root", type=str, default="data/oxford-iiit-pet", help="Root directory of Oxford-IIIT Pet dataset.", ) args = parser.parse_args() # List all models you want to evaluate model_ids = [ "lr_raw", "svm_raw", "resnet_pt_lr", "resnet_pt_svm", ] all_results = [] for mid in model_ids: res = evaluate_model_on_dataset(mid, args.data_root) all_results.append(res) # Print a compact summary table at the end print("\n===== Summary (Top-1 & Top-5) =====") print(f"{'Model':25s} {'Top-1':>8s} {'Top-5':>8s}") print("-" * 50) for res in all_results: name = res["model_id"] t1 = res["top1_acc"] t5 = res["top5_acc"] print(f"{name:25s} {t1:8.4f} {t5:8.4f}") if __name__ == "__main__": # Make sure torch doesn't spawn too many threads on some systems torch.set_num_threads(4) main()