File size: 5,810 Bytes
52dd1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# src/evaluation/eval_accuracy.py

import argparse
from collections import defaultdict

import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report

from torchvision.datasets import OxfordIIITPet

from src.registry import get_model
import torch


def load_test_dataset(data_root: str):
    """

    Load Oxford-IIIT Pet test split without transforms, so we get PIL images.

    Targets will be integer class indices (0..36).

    """
    dataset = OxfordIIITPet(
        root=data_root,
        split="test",
        target_types="category",
        transform=None,      # we want raw PIL here
    )
    return dataset

def load_model_direct(model_id: str):
    """

    Workaround loader that bypasses registry and constructs models

    using their actual existing constructor signatures.

    Modify only the paths here if needed.

    """
    if model_id == "lr_raw":
        from src.inference.lr_model import LRModel
        # Adjust to match your actual LRModel __init__
        return LRModel("checkpoints/lr_model.joblib", "configs/labels.json")

    elif model_id == "svm_raw":
        from src.inference.svm_model import SVMModel
        return SVMModel("checkpoints/svm_model.joblib", "configs/labels.json")

    elif model_id == "resnet_pt_lr":
        from src.inference.resnet_pt_lr_model import ResNetPTLRModel
        # If these require device or not, match your working constructor
        return ResNetPTLRModel(
            ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
            labels_path="configs/labels.json",
        )

    elif model_id == "resnet_pt_svm":
        from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
        return ResNetPTSVMModel(
            ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
            labels_path="configs/labels.json",
        )

    else:
        raise ValueError(f"Unsupported model_id: {model_id}")

def evaluate_model_on_dataset(model_id: str, data_root: str):
    """

    Evaluate a single model (by id from registry) on the Oxford-IIIT Pet test split.

    Uses model.predict(PIL.Image, top_k=5) API.



    Returns a dict with:

        - top1_acc

        - top5_acc

        - report_dict (per-class and aggregate metrics)

    """
    print(f"\n=== Evaluating model: {model_id} ===")

    dataset = load_test_dataset(data_root)
    model = load_model_direct(model_id)

    y_true = []
    y_pred_top1 = []
    top5_correct = 0

    for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
        img, target = dataset[idx]  # img: PIL.Image, target: int
    
        # Try to call with top_k; if the model doesn't support it, fall back gracefully
        try:
            result = model.predict(img, top_k=5)
        except TypeError:
            # Older / simpler API: predict(img) without top_k
            result = model.predict(img)
    
        # Top-1 prediction (must exist)
        pred_id = int(result.get("class_id"))
        y_true.append(int(target))
        y_pred_top1.append(pred_id)
    
        # Try to get top_k list; if not present, create a synthetic one using only top-1
        top_k = result.get("top_k")
        if not top_k:
            # Fallback: just treat the top-1 prediction as the only candidate.
            # This means Top-5 == Top-1 for such models, which is acceptable as a workaround.
            cname = result.get("class_name", "")
            top_k = [{
                "class_id": pred_id,
                "class_name": cname,
                "probability": 1.0
            }]

    # Top-5 correct? (GT in top_k list)
    if any(int(entry.get("class_id")) == int(target) for entry in top_k):
        top5_correct += 1


    y_true = np.array(y_true)
    y_pred_top1 = np.array(y_pred_top1)
    n = len(y_true)

    # Overall Top-1 accuracy
    top1_acc = accuracy_score(y_true, y_pred_top1)

    # Overall Top-5 accuracy
    top5_acc = top5_correct / float(n)

    # Detailed precision/recall/F1 per class + aggregate
    report = classification_report(
        y_true,
        y_pred_top1,
        digits=4,
        output_dict=True  # gives a nice dict we can log/inspect
    )

    print(f"Top-1 accuracy ({model_id}): {top1_acc:.4f}")
    print(f"Top-5 accuracy ({model_id}): {top5_acc:.4f}")
    print("\nMacro avg (from classification_report):")
    print(report["macro avg"])
    print("\nWeighted avg (from classification_report):")
    print(report["weighted avg"])

    return {
        "model_id": model_id,
        "top1_acc": top1_acc,
        "top5_acc": top5_acc,
        "report": report,
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data-root",
        type=str,
        default="data/oxford-iiit-pet",
        help="Root directory of Oxford-IIIT Pet dataset.",
    )
    args = parser.parse_args()

    # List all models you want to evaluate
    model_ids = [
        "lr_raw",
        "svm_raw",
        "resnet_pt_lr",
        "resnet_pt_svm",
    ]

    all_results = []

    for mid in model_ids:
        res = evaluate_model_on_dataset(mid, args.data_root)
        all_results.append(res)

    # Print a compact summary table at the end
    print("\n===== Summary (Top-1 & Top-5) =====")
    print(f"{'Model':25s}  {'Top-1':>8s}  {'Top-5':>8s}")
    print("-" * 50)
    for res in all_results:
        name = res["model_id"]
        t1 = res["top1_acc"]
        t5 = res["top5_acc"]
        print(f"{name:25s}  {t1:8.4f}  {t5:8.4f}")


if __name__ == "__main__":
    # Make sure torch doesn't spawn too many threads on some systems
    torch.set_num_threads(4)
    main()