Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/evaluation/eval_accuracy.py
import argparse
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report
from torchvision.datasets import OxfordIIITPet
from src.registry import get_model
import torch
def load_test_dataset(data_root: str):
"""
Load Oxford-IIIT Pet test split without transforms, so we get PIL images.
Targets will be integer class indices (0..36).
"""
dataset = OxfordIIITPet(
root=data_root,
split="test",
target_types="category",
transform=None, # we want raw PIL here
)
return dataset
def load_model_direct(model_id: str):
"""
Workaround loader that bypasses registry and constructs models
using their actual existing constructor signatures.
Modify only the paths here if needed.
"""
if model_id == "lr_raw":
from src.inference.lr_model import LRModel
# Adjust to match your actual LRModel __init__
return LRModel("checkpoints/lr_model.joblib", "configs/labels.json")
elif model_id == "svm_raw":
from src.inference.svm_model import SVMModel
return SVMModel("checkpoints/svm_model.joblib", "configs/labels.json")
elif model_id == "resnet_pt_lr":
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
# If these require device or not, match your working constructor
return ResNetPTLRModel(
ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
labels_path="configs/labels.json",
)
elif model_id == "resnet_pt_svm":
from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
return ResNetPTSVMModel(
ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
labels_path="configs/labels.json",
)
else:
raise ValueError(f"Unsupported model_id: {model_id}")
def evaluate_model_on_dataset(model_id: str, data_root: str):
"""
Evaluate a single model (by id from registry) on the Oxford-IIIT Pet test split.
Uses model.predict(PIL.Image, top_k=5) API.
Returns a dict with:
- top1_acc
- top5_acc
- report_dict (per-class and aggregate metrics)
"""
print(f"\n=== Evaluating model: {model_id} ===")
dataset = load_test_dataset(data_root)
model = load_model_direct(model_id)
y_true = []
y_pred_top1 = []
top5_correct = 0
for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
img, target = dataset[idx] # img: PIL.Image, target: int
# Try to call with top_k; if the model doesn't support it, fall back gracefully
try:
result = model.predict(img, top_k=5)
except TypeError:
# Older / simpler API: predict(img) without top_k
result = model.predict(img)
# Top-1 prediction (must exist)
pred_id = int(result.get("class_id"))
y_true.append(int(target))
y_pred_top1.append(pred_id)
# Try to get top_k list; if not present, create a synthetic one using only top-1
top_k = result.get("top_k")
if not top_k:
# Fallback: just treat the top-1 prediction as the only candidate.
# This means Top-5 == Top-1 for such models, which is acceptable as a workaround.
cname = result.get("class_name", "")
top_k = [{
"class_id": pred_id,
"class_name": cname,
"probability": 1.0
}]
# Top-5 correct? (GT in top_k list)
if any(int(entry.get("class_id")) == int(target) for entry in top_k):
top5_correct += 1
y_true = np.array(y_true)
y_pred_top1 = np.array(y_pred_top1)
n = len(y_true)
# Overall Top-1 accuracy
top1_acc = accuracy_score(y_true, y_pred_top1)
# Overall Top-5 accuracy
top5_acc = top5_correct / float(n)
# Detailed precision/recall/F1 per class + aggregate
report = classification_report(
y_true,
y_pred_top1,
digits=4,
output_dict=True # gives a nice dict we can log/inspect
)
print(f"Top-1 accuracy ({model_id}): {top1_acc:.4f}")
print(f"Top-5 accuracy ({model_id}): {top5_acc:.4f}")
print("\nMacro avg (from classification_report):")
print(report["macro avg"])
print("\nWeighted avg (from classification_report):")
print(report["weighted avg"])
return {
"model_id": model_id,
"top1_acc": top1_acc,
"top5_acc": top5_acc,
"report": report,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory of Oxford-IIIT Pet dataset.",
)
args = parser.parse_args()
# List all models you want to evaluate
model_ids = [
"lr_raw",
"svm_raw",
"resnet_pt_lr",
"resnet_pt_svm",
]
all_results = []
for mid in model_ids:
res = evaluate_model_on_dataset(mid, args.data_root)
all_results.append(res)
# Print a compact summary table at the end
print("\n===== Summary (Top-1 & Top-5) =====")
print(f"{'Model':25s} {'Top-1':>8s} {'Top-5':>8s}")
print("-" * 50)
for res in all_results:
name = res["model_id"]
t1 = res["top1_acc"]
t5 = res["top5_acc"]
print(f"{name:25s} {t1:8.4f} {t5:8.4f}")
if __name__ == "__main__":
# Make sure torch doesn't spawn too many threads on some systems
torch.set_num_threads(4)
main()