Shashwat98 commited on
Commit
52dd1ca
·
verified ·
1 Parent(s): fe0618f

Upload 37 files

Browse files
Files changed (37) hide show
  1. app.py +152 -0
  2. checkpoints/lr_model.joblib +3 -0
  3. checkpoints/resnet_pt_lr_head.joblib +3 -0
  4. checkpoints/resnet_pt_svm_head.joblib +3 -0
  5. checkpoints/svm_model.joblib +3 -0
  6. configs/labels.json +39 -0
  7. requirements.txt +7 -0
  8. src/__pycache__/registry.cpython-313.pyc +0 -0
  9. src/evaluation/__pycache__/eval_accuracy.cpython-313.pyc +0 -0
  10. src/evaluation/__pycache__/eval_confusion.cpython-313.pyc +0 -0
  11. src/evaluation/__pycache__/eval_tsne_umap.cpython-313.pyc +0 -0
  12. src/evaluation/eval_accuracy.py +184 -0
  13. src/evaluation/eval_confusion.py +206 -0
  14. src/evaluation/eval_tsne_umap.py +283 -0
  15. src/inference/__pycache__/lr_model.cpython-313.pyc +0 -0
  16. src/inference/__pycache__/resnet_pt_lr_model.cpython-313.pyc +0 -0
  17. src/inference/__pycache__/resnet_pt_svm_model.cpython-313.pyc +0 -0
  18. src/inference/__pycache__/svm_model.cpython-313.pyc +0 -0
  19. src/inference/__pycache__/test_resnet_pt_lr.cpython-313.pyc +0 -0
  20. src/inference/__pycache__/test_resnet_pt_svm.cpython-313.pyc +0 -0
  21. src/inference/base_model.py +30 -0
  22. src/inference/lr_model.py +63 -0
  23. src/inference/resnet_pt_lr_model.py +179 -0
  24. src/inference/resnet_pt_svm_model.py +174 -0
  25. src/inference/svm_model.py +115 -0
  26. src/inference/test_resnet_pt_lr.py +150 -0
  27. src/inference/test_resnet_pt_svm.py +143 -0
  28. src/registry.py +108 -0
  29. src/training/__pycache__/extract_resnet_features.cpython-313.pyc +0 -0
  30. src/training/__pycache__/train_resnet_pt_lr.cpython-313.pyc +0 -0
  31. src/training/__pycache__/train_resnet_pt_svm.cpython-313.pyc +0 -0
  32. src/training/__pycache__/train_svm.cpython-313.pyc +0 -0
  33. src/training/extract_resnet_features.py +183 -0
  34. src/training/train_lr.py +171 -0
  35. src/training/train_resnet_pt_lr.py +128 -0
  36. src/training/train_resnet_pt_svm.py +124 -0
  37. src/training/train_svm.py +177 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ui/app.py
2
+
3
+ import gradio as gr
4
+ from typing import Any, Dict, List
5
+
6
+ from src.registry import get_model_display_names, get_model
7
+
8
+ APP_TITLE = "PetRecog – Oxford-IIIT Pet Identification"
9
+ APP_DESC = (
10
+ "Upload a pet image, choose a model, and compare predictions across "
11
+ "classical (LR, SVM) and deep-feature (ResNet) models."
12
+ )
13
+ TOP_K_DEFAULT = 5
14
+
15
+
16
+ def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]:
17
+ """
18
+ Convert the model's top_k list of dicts into a 2D list suitable for gr.Dataframe.
19
+
20
+ Expected each entry in top_k to look like:
21
+ { 'class_id': int, 'class_name': str, 'probability': float }
22
+ """
23
+ rows = []
24
+ for rank, entry in enumerate(top_k, start=1):
25
+ class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}")
26
+ prob = entry.get("probability", 0.0)
27
+ rows.append([rank, class_name, round(float(prob) * 100.0, 2)])
28
+ return rows
29
+
30
+
31
+ def run_inference(model_id: str, image) -> Dict[str, Any]:
32
+ """
33
+ Wrapper called by Gradio.
34
+
35
+ Inputs:
36
+ - model_id: key from the registry
37
+ - image: PIL image object from gr.Image (type='pil')
38
+
39
+ Outputs (as a dict mapped to Gradio components in the UI):
40
+ - main_text: formatted prediction string
41
+ - topk_table: 2D list for gr.Dataframe
42
+ """
43
+ if image is None:
44
+ return {
45
+ "main_text": "⚠️ Please upload an image first.",
46
+ "topk_table": [],
47
+ }
48
+
49
+ # Get the model instance (lazy-loaded via registry)
50
+ model = get_model(model_id)
51
+
52
+ # All models follow the shared predict API:
53
+ # predict(PIL.Image, top_k=TOP_K_DEFAULT) -> {
54
+ # 'class_id', 'class_name', 'probabilities', 'top_k'
55
+ # }
56
+ result = model.predict(image, top_k=TOP_K_DEFAULT)
57
+
58
+ class_name = result.get("class_name", "Unknown")
59
+ class_id = result.get("class_id", "N/A")
60
+ top_k = result.get("top_k", [])
61
+
62
+ main_text = f"**Predicted Class:** {class_name} \n" f"**Class ID:** {class_id}"
63
+
64
+ table = format_topk_for_table(top_k)
65
+
66
+ return {
67
+ "main_text": main_text,
68
+ "topk_table": table,
69
+ }
70
+
71
+
72
+ def build_demo() -> gr.Blocks:
73
+ model_display_names = get_model_display_names()
74
+ # Gradio dropdown will show pretty display_name, but we need to map back to ids.
75
+ id_to_name = model_display_names
76
+ name_to_id = {v: k for k, v in id_to_name.items()}
77
+
78
+ default_display_name = next(iter(name_to_id.keys())) if name_to_id else None
79
+
80
+ with gr.Blocks(css="""
81
+ body { background: #fbead8; }
82
+ .noble-header { text-align: center; margin-bottom: 1.0rem; }
83
+ .noble-title { font-size: 2.0rem; font-weight: 800; color: #5b3b27; }
84
+ .noble-subtitle { font-size: 0.95rem; color: #7a5b45; }
85
+ """) as demo:
86
+ # Header
87
+ with gr.Row(elem_classes="noble-header"):
88
+ gr.Markdown(
89
+ f"### {APP_TITLE}\n{APP_DESC}",
90
+ elem_classes="noble-title"
91
+ )
92
+
93
+ with gr.Row():
94
+ # Left column: controls
95
+ with gr.Column(scale=1):
96
+ gr.Markdown("#### 1️⃣ Select Model & Upload Image")
97
+
98
+ model_dropdown = gr.Dropdown(
99
+ choices=list(name_to_id.keys()),
100
+ value=default_display_name,
101
+ label="Select Model",
102
+ )
103
+
104
+ image_input = gr.Image(
105
+ type="pil",
106
+ label="Upload your pet image (JPEG/PNG)",
107
+ )
108
+
109
+ run_button = gr.Button("Run Identification")
110
+
111
+ # Right column: output
112
+ with gr.Column(scale=1):
113
+ gr.Markdown("#### 2️⃣ Model Prediction")
114
+
115
+ main_output = gr.Markdown(
116
+ value="Prediction will appear here.",
117
+ label="Prediction",
118
+ )
119
+
120
+ topk_output = gr.Dataframe(
121
+ headers=["Rank", "Class Name", "Probability (%)"],
122
+ datatype=["number", "str", "number"],
123
+ col_count=(3, "fixed"),
124
+ label=f"Top-{TOP_K_DEFAULT} Predictions",
125
+ )
126
+
127
+ # Wiring: button click -> inference
128
+ def _gradio_infer(selected_display_name, img):
129
+ if selected_display_name is None:
130
+ return {
131
+ main_output: "⚠️ Please select a model.",
132
+ topk_output: [],
133
+ }
134
+ model_id = name_to_id[selected_display_name]
135
+ result = run_inference(model_id, img)
136
+ return {
137
+ main_output: result["main_text"],
138
+ topk_output: result["topk_table"],
139
+ }
140
+
141
+ run_button.click(
142
+ fn=_gradio_infer,
143
+ inputs=[model_dropdown, image_input],
144
+ outputs=[main_output, topk_output],
145
+ )
146
+
147
+ return demo
148
+
149
+
150
+ if __name__ == "__main__":
151
+ demo = build_demo()
152
+ demo.launch()
checkpoints/lr_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4aa8d4f7586fbbaf893e0ff269fdde75123a12eb66ee4175beb0e4cc26d5e8a
3
+ size 607515
checkpoints/resnet_pt_lr_head.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a91416a9e7522ac0626b96f0ba1903482ddd8fe3181accdcb8833a3494c55d94
3
+ size 77209
checkpoints/resnet_pt_svm_head.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9f950c784e7196d427c1165cfc18a3520c3212d2a52d5d8b96006591f80da6c
3
+ size 153001
checkpoints/svm_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f6e50cb4316ab0d9e6634d150c18fdb1302d25d88067c1d52c7a5deffe17ec6
3
+ size 1213818
configs/labels.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "Abyssinian",
3
+ "1": "American Bulldog",
4
+ "2": "American Pit Bull Terrier",
5
+ "3": "Basset Hound",
6
+ "4": "Beagle",
7
+ "5": "Bengal",
8
+ "6": "Birman",
9
+ "7": "Bombay",
10
+ "8": "Boxer",
11
+ "9": "British Shorthair",
12
+ "10": "Chihuahua",
13
+ "11": "Egyptian Mau",
14
+ "12": "English Cocker Spaniel",
15
+ "13": "English Setter",
16
+ "14": "German Shorthaired",
17
+ "15": "Great Pyrenees",
18
+ "16": "Havanese",
19
+ "17": "Japanese Chin",
20
+ "18": "Keeshond",
21
+ "19": "Leonberger",
22
+ "20": "Maine Coon",
23
+ "21": "Miniature Pinscher",
24
+ "22": "Newfoundland",
25
+ "23": "Persian",
26
+ "24": "Pomeranian",
27
+ "25": "Pug",
28
+ "26": "Ragdoll",
29
+ "27": "Russian Blue",
30
+ "28": "Saint Bernard",
31
+ "29": "Samoyed",
32
+ "30": "Scottish Terrier",
33
+ "31": "Shiba Inu",
34
+ "32": "Siamese",
35
+ "33": "Sphynx",
36
+ "34": "Staffordshire Bull Terrier",
37
+ "35": "Wheaten Terrier",
38
+ "36": "Yorkshire Terrier"
39
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0
2
+ torch>=2.0
3
+ torchvision>=0.15
4
+ numpy
5
+ scikit-learn
6
+ joblib
7
+ Pillow
src/__pycache__/registry.cpython-313.pyc ADDED
Binary file (4.49 kB). View file
 
src/evaluation/__pycache__/eval_accuracy.cpython-313.pyc ADDED
Binary file (6.02 kB). View file
 
src/evaluation/__pycache__/eval_confusion.cpython-313.pyc ADDED
Binary file (8.15 kB). View file
 
src/evaluation/__pycache__/eval_tsne_umap.cpython-313.pyc ADDED
Binary file (11.8 kB). View file
 
src/evaluation/eval_accuracy.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/evaluation/eval_accuracy.py
2
+
3
+ import argparse
4
+ from collections import defaultdict
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from sklearn.metrics import accuracy_score, classification_report
9
+
10
+ from torchvision.datasets import OxfordIIITPet
11
+
12
+ from src.registry import get_model
13
+ import torch
14
+
15
+
16
+ def load_test_dataset(data_root: str):
17
+ """
18
+ Load Oxford-IIIT Pet test split without transforms, so we get PIL images.
19
+ Targets will be integer class indices (0..36).
20
+ """
21
+ dataset = OxfordIIITPet(
22
+ root=data_root,
23
+ split="test",
24
+ target_types="category",
25
+ transform=None, # we want raw PIL here
26
+ )
27
+ return dataset
28
+
29
+ def load_model_direct(model_id: str):
30
+ """
31
+ Workaround loader that bypasses registry and constructs models
32
+ using their actual existing constructor signatures.
33
+ Modify only the paths here if needed.
34
+ """
35
+ if model_id == "lr_raw":
36
+ from src.inference.lr_model import LRModel
37
+ # Adjust to match your actual LRModel __init__
38
+ return LRModel("checkpoints/lr_model.joblib", "configs/labels.json")
39
+
40
+ elif model_id == "svm_raw":
41
+ from src.inference.svm_model import SVMModel
42
+ return SVMModel("checkpoints/svm_model.joblib", "configs/labels.json")
43
+
44
+ elif model_id == "resnet_pt_lr":
45
+ from src.inference.resnet_pt_lr_model import ResNetPTLRModel
46
+ # If these require device or not, match your working constructor
47
+ return ResNetPTLRModel(
48
+ ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
49
+ labels_path="configs/labels.json",
50
+ )
51
+
52
+ elif model_id == "resnet_pt_svm":
53
+ from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
54
+ return ResNetPTSVMModel(
55
+ ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
56
+ labels_path="configs/labels.json",
57
+ )
58
+
59
+ else:
60
+ raise ValueError(f"Unsupported model_id: {model_id}")
61
+
62
+ def evaluate_model_on_dataset(model_id: str, data_root: str):
63
+ """
64
+ Evaluate a single model (by id from registry) on the Oxford-IIIT Pet test split.
65
+ Uses model.predict(PIL.Image, top_k=5) API.
66
+
67
+ Returns a dict with:
68
+ - top1_acc
69
+ - top5_acc
70
+ - report_dict (per-class and aggregate metrics)
71
+ """
72
+ print(f"\n=== Evaluating model: {model_id} ===")
73
+
74
+ dataset = load_test_dataset(data_root)
75
+ model = load_model_direct(model_id)
76
+
77
+ y_true = []
78
+ y_pred_top1 = []
79
+ top5_correct = 0
80
+
81
+ for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
82
+ img, target = dataset[idx] # img: PIL.Image, target: int
83
+
84
+ # Try to call with top_k; if the model doesn't support it, fall back gracefully
85
+ try:
86
+ result = model.predict(img, top_k=5)
87
+ except TypeError:
88
+ # Older / simpler API: predict(img) without top_k
89
+ result = model.predict(img)
90
+
91
+ # Top-1 prediction (must exist)
92
+ pred_id = int(result.get("class_id"))
93
+ y_true.append(int(target))
94
+ y_pred_top1.append(pred_id)
95
+
96
+ # Try to get top_k list; if not present, create a synthetic one using only top-1
97
+ top_k = result.get("top_k")
98
+ if not top_k:
99
+ # Fallback: just treat the top-1 prediction as the only candidate.
100
+ # This means Top-5 == Top-1 for such models, which is acceptable as a workaround.
101
+ cname = result.get("class_name", "")
102
+ top_k = [{
103
+ "class_id": pred_id,
104
+ "class_name": cname,
105
+ "probability": 1.0
106
+ }]
107
+
108
+ # Top-5 correct? (GT in top_k list)
109
+ if any(int(entry.get("class_id")) == int(target) for entry in top_k):
110
+ top5_correct += 1
111
+
112
+
113
+ y_true = np.array(y_true)
114
+ y_pred_top1 = np.array(y_pred_top1)
115
+ n = len(y_true)
116
+
117
+ # Overall Top-1 accuracy
118
+ top1_acc = accuracy_score(y_true, y_pred_top1)
119
+
120
+ # Overall Top-5 accuracy
121
+ top5_acc = top5_correct / float(n)
122
+
123
+ # Detailed precision/recall/F1 per class + aggregate
124
+ report = classification_report(
125
+ y_true,
126
+ y_pred_top1,
127
+ digits=4,
128
+ output_dict=True # gives a nice dict we can log/inspect
129
+ )
130
+
131
+ print(f"Top-1 accuracy ({model_id}): {top1_acc:.4f}")
132
+ print(f"Top-5 accuracy ({model_id}): {top5_acc:.4f}")
133
+ print("\nMacro avg (from classification_report):")
134
+ print(report["macro avg"])
135
+ print("\nWeighted avg (from classification_report):")
136
+ print(report["weighted avg"])
137
+
138
+ return {
139
+ "model_id": model_id,
140
+ "top1_acc": top1_acc,
141
+ "top5_acc": top5_acc,
142
+ "report": report,
143
+ }
144
+
145
+
146
+ def main():
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument(
149
+ "--data-root",
150
+ type=str,
151
+ default="data/oxford-iiit-pet",
152
+ help="Root directory of Oxford-IIIT Pet dataset.",
153
+ )
154
+ args = parser.parse_args()
155
+
156
+ # List all models you want to evaluate
157
+ model_ids = [
158
+ "lr_raw",
159
+ "svm_raw",
160
+ "resnet_pt_lr",
161
+ "resnet_pt_svm",
162
+ ]
163
+
164
+ all_results = []
165
+
166
+ for mid in model_ids:
167
+ res = evaluate_model_on_dataset(mid, args.data_root)
168
+ all_results.append(res)
169
+
170
+ # Print a compact summary table at the end
171
+ print("\n===== Summary (Top-1 & Top-5) =====")
172
+ print(f"{'Model':25s} {'Top-1':>8s} {'Top-5':>8s}")
173
+ print("-" * 50)
174
+ for res in all_results:
175
+ name = res["model_id"]
176
+ t1 = res["top1_acc"]
177
+ t5 = res["top5_acc"]
178
+ print(f"{name:25s} {t1:8.4f} {t5:8.4f}")
179
+
180
+
181
+ if __name__ == "__main__":
182
+ # Make sure torch doesn't spawn too many threads on some systems
183
+ torch.set_num_threads(4)
184
+ main()
src/evaluation/eval_confusion.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/evaluation/eval_confusion.py
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ import json
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.metrics import confusion_matrix
10
+
11
+ from tqdm import tqdm
12
+
13
+ # Reuse the same dataset + model loading logic as eval_accuracy.py
14
+ from src.evaluation.eval_accuracy import load_test_dataset, load_model_direct
15
+
16
+
17
+ def load_class_names(labels_path: str = "configs/labels.json"):
18
+ """
19
+ Try to load class names from labels.json.
20
+
21
+ This is written to be robust to a few likely formats:
22
+ - List: ["Abyssinian", "American Bulldog", ...]
23
+ - Dict with string keys: {"0": "Abyssinian", "1": "American Bulldog", ...}
24
+ - Dict with 'id_to_label': {"id_to_label": {"0": "Abyssinian", ...}}
25
+
26
+ If anything goes wrong, returns None and we’ll just use numeric class IDs on the axes.
27
+ """
28
+ try:
29
+ with open(labels_path, "r") as f:
30
+ data = json.load(f)
31
+ except FileNotFoundError:
32
+ print(f"[WARN] labels file not found at {labels_path}, using numeric IDs.")
33
+ return None
34
+ except json.JSONDecodeError:
35
+ print(f"[WARN] Could not parse {labels_path}, using numeric IDs.")
36
+ return None
37
+
38
+ # Case 1: simple list
39
+ if isinstance(data, list):
40
+ return data
41
+
42
+ # Case 2: dict with 'id_to_label'
43
+ if isinstance(data, dict) and "id_to_label" in data:
44
+ id_to_label = data["id_to_label"]
45
+ # sort by integer key
46
+ keys = sorted(id_to_label.keys(), key=lambda k: int(k))
47
+ return [id_to_label[k] for k in keys]
48
+
49
+ # Case 3: dict mapping "0" -> "Abyssinian"
50
+ if isinstance(data, dict):
51
+ try:
52
+ keys = sorted(data.keys(), key=lambda k: int(k))
53
+ return [data[k] for k in keys]
54
+ except Exception:
55
+ pass
56
+
57
+ print(f"[WARN] Unrecognized labels.json format, using numeric IDs.")
58
+ return None
59
+
60
+
61
+ def collect_predictions(model_id: str, data_root: str):
62
+ """
63
+ Run the given model across the Oxford-IIIT Pet test split and collect:
64
+ - y_true: ground-truth integer class indices
65
+ - y_pred: top-1 predicted class indices
66
+
67
+ Uses the same model API as eval_accuracy.py: model.predict(PIL, top_k=5)
68
+ """
69
+ print(f"\n=== Collecting predictions for model: {model_id} ===")
70
+
71
+ dataset = load_test_dataset(data_root)
72
+ model = load_model_direct(model_id)
73
+
74
+ y_true = []
75
+ y_pred = []
76
+
77
+ for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
78
+ img, target = dataset[idx] # img: PIL.Image, target: int
79
+
80
+ # Same predict logic as eval_accuracy (support with/without top_k)
81
+ try:
82
+ result = model.predict(img, top_k=5)
83
+ except TypeError:
84
+ result = model.predict(img)
85
+
86
+ pred_id = int(result.get("class_id"))
87
+ y_true.append(int(target))
88
+ y_pred.append(pred_id)
89
+
90
+ y_true = np.array(y_true)
91
+ y_pred = np.array(y_pred)
92
+
93
+ print(f" Collected {len(y_true)} predictions.")
94
+ return y_true, y_pred
95
+
96
+
97
+ def plot_confusion_matrix(
98
+ cm: np.ndarray,
99
+ class_names,
100
+ title: str,
101
+ save_path: Path,
102
+ normalize: bool = True,
103
+ ):
104
+ """
105
+ Plot and save a confusion matrix.
106
+
107
+ If normalize=True, each row (true class) is normalized to sum to 1.
108
+ If class_names is None, we just use numeric indices on axes.
109
+ """
110
+ if normalize:
111
+ cm = cm.astype("float")
112
+ row_sums = cm.sum(axis=1, keepdims=True)
113
+ cm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0)
114
+
115
+ num_classes = cm.shape[0]
116
+
117
+ plt.figure(figsize=(12, 10))
118
+ im = plt.imshow(cm, interpolation="nearest", cmap="viridis")
119
+ plt.title(title)
120
+ plt.colorbar(im, fraction=0.046, pad=0.04)
121
+
122
+ if class_names is not None and len(class_names) == num_classes:
123
+ tick_labels = class_names
124
+ else:
125
+ tick_labels = list(range(num_classes))
126
+
127
+ plt.xticks(
128
+ ticks=np.arange(num_classes),
129
+ labels=tick_labels,
130
+ rotation=90,
131
+ fontsize=6,
132
+ )
133
+ plt.yticks(
134
+ ticks=np.arange(num_classes),
135
+ labels=tick_labels,
136
+ fontsize=6,
137
+ )
138
+
139
+ plt.xlabel("Predicted class")
140
+ plt.ylabel("True class")
141
+ plt.tight_layout()
142
+ plt.savefig(save_path, dpi=300)
143
+ plt.close()
144
+ print(f" Saved confusion matrix plot to: {save_path}")
145
+
146
+
147
+ def main():
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument(
150
+ "--data-root",
151
+ type=str,
152
+ default="data/oxford-iiit-pet",
153
+ help="Root directory of Oxford-IIIT Pet dataset.",
154
+ )
155
+ parser.add_argument(
156
+ "--labels-path",
157
+ type=str,
158
+ default="configs/labels.json",
159
+ help="Path to labels.json (for axis names).",
160
+ )
161
+ parser.add_argument(
162
+ "--out-dir",
163
+ type=str,
164
+ default="outputs/confusion_matrices",
165
+ help="Directory to save confusion matrices and plots.",
166
+ )
167
+ args = parser.parse_args()
168
+
169
+ out_dir = Path(args.out_dir)
170
+ out_dir.mkdir(parents=True, exist_ok=True)
171
+
172
+ # Same set of models as eval_accuracy
173
+ model_ids = [
174
+ "lr_raw",
175
+ "svm_raw",
176
+ "resnet_pt_lr",
177
+ "resnet_pt_svm",
178
+ ]
179
+
180
+ class_names = load_class_names(args.labels_path)
181
+
182
+ # y_true is identical for all models (same test split, same indexing),
183
+ # but for clarity we recompute per model; confusion_matrix only needs
184
+ # consistent labels (0..36) which we enforce below.
185
+ for model_id in model_ids:
186
+ y_true, y_pred = collect_predictions(model_id, args.data_root)
187
+
188
+ # Define a fixed label ordering (0..max) to get 37x37
189
+ num_classes = int(y_true.max()) + 1
190
+ labels = list(range(num_classes))
191
+
192
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
193
+
194
+ # Save raw matrix for future analysis
195
+ npy_path = out_dir / f"cm_{model_id}.npy"
196
+ np.save(npy_path, cm)
197
+ print(f" Saved raw confusion matrix to: {npy_path}")
198
+
199
+ # Save a normalized plot
200
+ png_path = out_dir / f"cm_{model_id}.png"
201
+ title = f"Confusion Matrix ({model_id})"
202
+ plot_confusion_matrix(cm, class_names, title, png_path, normalize=True)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
src/evaluation/eval_tsne_umap.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/evaluation/eval_tsne_umap.py
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import transforms as T, models
10
+
11
+ from tqdm import tqdm
12
+ import matplotlib.pyplot as plt
13
+ from sklearn.manifold import TSNE
14
+
15
+ # Reuse your test dataset loader from eval_accuracy
16
+ from src.evaluation.eval_accuracy import load_test_dataset
17
+
18
+
19
+ # Optional UMAP support
20
+ try:
21
+ import umap
22
+ HAS_UMAP = True
23
+ except ImportError:
24
+ HAS_UMAP = False
25
+ print("[INFO] umap-learn not installed; will skip UMAP and only run t-SNE.")
26
+
27
+
28
+ class ResNetFeatureExtractor(nn.Module):
29
+ """
30
+ Wraps a torchvision ResNet18 pretrained on ImageNet and
31
+ exposes a 512-d feature vector for each image.
32
+ """
33
+
34
+ def __init__(self, device="cuda"):
35
+ super().__init__()
36
+ # Use the modern weights API
37
+ backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
38
+ # Remove the final FC layer: keep everything up to avgpool
39
+ self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
40
+ self.feature_extractor.to(device)
41
+ self.feature_extractor.eval()
42
+ self.device = device
43
+
44
+ # Standard ImageNet normalization
45
+ self.transform = T.Compose([
46
+ T.Resize((224, 224)),
47
+ T.ToTensor(),
48
+ T.Normalize(
49
+ mean=[0.485, 0.456, 0.406],
50
+ std=[0.229, 0.224, 0.225],
51
+ ),
52
+ ])
53
+
54
+ @torch.no_grad()
55
+ def forward(self, pil_img):
56
+ """
57
+ pil_img: a single PIL.Image
58
+ returns: numpy array of shape (512,)
59
+ """
60
+ x = self.transform(pil_img).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
61
+ feat = self.feature_extractor(x) # (1, 512, 1, 1)
62
+ feat = feat.view(1, -1) # (1, 512)
63
+ return feat.squeeze(0).cpu().numpy()
64
+
65
+
66
+ def extract_features(data_root: str, max_samples: int = 2000, seed: int = 42):
67
+ """
68
+ Extract:
69
+ - Raw 64x64 grayscale flattened features (for LR/SVM-style space)
70
+ - ResNet18 pretrained 512-d features
71
+
72
+ Returns:
73
+ X_raw : (N, 4096)
74
+ X_resnet: (N, 512)
75
+ y : (N,)
76
+ """
77
+ print(f"[INFO] Loading test dataset from {data_root}")
78
+ dataset = load_test_dataset(data_root)
79
+ total = len(dataset)
80
+
81
+ # Optional subsampling for t-SNE / UMAP visualization
82
+ rng = np.random.default_rng(seed)
83
+ if max_samples is not None and max_samples < total:
84
+ indices = rng.choice(total, size=max_samples, replace=False)
85
+ indices = sorted(indices.tolist())
86
+ print(f"[INFO] Subsampling {len(indices)} / {total} test samples for visualization.")
87
+ else:
88
+ indices = list(range(total))
89
+ print(f"[INFO] Using all {total} test samples for visualization.")
90
+
91
+ device = "cuda" if torch.cuda.is_available() else "cpu"
92
+ print(f"[INFO] Using device: {device}")
93
+
94
+ # Raw feature pipeline: 64x64 grayscale + flatten
95
+ raw_transform = T.Compose([
96
+ T.Resize((64, 64)),
97
+ T.Grayscale(num_output_channels=1),
98
+ T.ToTensor(), # (1, 64, 64), values in [0,1]
99
+ ])
100
+
101
+ resnet_extractor = ResNetFeatureExtractor(device=device)
102
+
103
+ X_raw_list = []
104
+ X_resnet_list = []
105
+ y_list = []
106
+
107
+ for idx in tqdm(indices, desc="Extracting features"):
108
+ img, target = dataset[idx] # img: PIL.Image, target: int
109
+ y_list.append(int(target))
110
+
111
+ # Raw features
112
+ raw_tensor = raw_transform(img) # (1, 64, 64)
113
+ X_raw_list.append(raw_tensor.view(-1).numpy()) # (4096,)
114
+
115
+ # ResNet features
116
+ resnet_feat = resnet_extractor(img) # (512,)
117
+ X_resnet_list.append(resnet_feat)
118
+
119
+ X_raw = np.stack(X_raw_list, axis=0) # (N, 4096)
120
+ X_resnet = np.stack(X_resnet_list, axis=0) # (N, 512)
121
+ y = np.array(y_list, dtype=int)
122
+
123
+ print(f"[INFO] X_raw shape: {X_raw.shape}")
124
+ print(f"[INFO] X_resnet shape: {X_resnet.shape}")
125
+ print(f"[INFO] y shape: {y.shape}")
126
+
127
+ return X_raw, X_resnet, y
128
+
129
+
130
+ def run_tsne(X, y, out_path: Path, title: str, num_classes_to_label: int = 10):
131
+ """
132
+ Run t-SNE on feature matrix X and save a 2D scatter plot.
133
+ Points are colored by class label.
134
+ """
135
+ print(f"[INFO] Running t-SNE for {title} with shape {X.shape}")
136
+ tsne = TSNE(
137
+ n_components=2,
138
+ perplexity=30,
139
+ learning_rate="auto",
140
+ init="pca",
141
+ random_state=42,
142
+ )
143
+ X_2d = tsne.fit_transform(X)
144
+
145
+ # Plot
146
+ plt.figure(figsize=(10, 8))
147
+ scatter = plt.scatter(
148
+ X_2d[:, 0],
149
+ X_2d[:, 1],
150
+ c=y,
151
+ s=8,
152
+ alpha=0.7,
153
+ cmap="tab20",
154
+ )
155
+ plt.title(title)
156
+ plt.xticks([])
157
+ plt.yticks([])
158
+
159
+ # Optionally build a legend with a subset of classes to avoid clutter
160
+ unique_classes = np.unique(y)
161
+ if len(unique_classes) > num_classes_to_label:
162
+ chosen = unique_classes[:num_classes_to_label]
163
+ else:
164
+ chosen = unique_classes
165
+
166
+ # Create proxy artists for legend
167
+ handles = []
168
+ labels = []
169
+ for cls in chosen:
170
+ handles.append(plt.Line2D([], [], marker="o", linestyle="",
171
+ color=scatter.cmap(scatter.norm(cls))))
172
+ labels.append(f"Class {cls}")
173
+ plt.legend(handles, labels, title="Example classes", fontsize=8, loc="best")
174
+
175
+ plt.tight_layout()
176
+ plt.savefig(out_path, dpi=300)
177
+ plt.close()
178
+ print(f"[INFO] Saved t-SNE plot to {out_path}")
179
+
180
+
181
+ def run_umap(X, y, out_path: Path, title: str, num_classes_to_label: int = 10):
182
+ """
183
+ Run UMAP on feature matrix X and save a 2D scatter plot.
184
+ Only runs if umap-learn is installed.
185
+ """
186
+ if not HAS_UMAP:
187
+ print(f"[WARN] UMAP not available; skipping {title}")
188
+ return
189
+
190
+ print(f"[INFO] Running UMAP for {title} with shape {X.shape}")
191
+ reducer = umap.UMAP(
192
+ n_components=2,
193
+ n_neighbors=15,
194
+ min_dist=0.1,
195
+ random_state=42,
196
+ )
197
+ X_2d = reducer.fit_transform(X)
198
+
199
+ plt.figure(figsize=(10, 8))
200
+ scatter = plt.scatter(
201
+ X_2d[:, 0],
202
+ X_2d[:, 1],
203
+ c=y,
204
+ s=8,
205
+ alpha=0.7,
206
+ cmap="tab20",
207
+ )
208
+ plt.title(title)
209
+ plt.xticks([])
210
+ plt.yticks([])
211
+
212
+ unique_classes = np.unique(y)
213
+ if len(unique_classes) > num_classes_to_label:
214
+ chosen = unique_classes[:num_classes_to_label]
215
+ else:
216
+ chosen = unique_classes
217
+
218
+ handles = []
219
+ labels = []
220
+ for cls in chosen:
221
+ handles.append(plt.Line2D([], [], marker="o", linestyle="",
222
+ color=scatter.cmap(scatter.norm(cls))))
223
+ labels.append(f"Class {cls}")
224
+ plt.legend(handles, labels, title="Example classes", fontsize=8, loc="best")
225
+
226
+ plt.tight_layout()
227
+ plt.savefig(out_path, dpi=300)
228
+ plt.close()
229
+ print(f"[INFO] Saved UMAP plot to {out_path}")
230
+
231
+
232
+ def main():
233
+ parser = argparse.ArgumentParser()
234
+ parser.add_argument(
235
+ "--data-root",
236
+ type=str,
237
+ default="data/oxford-iiit-pet",
238
+ help="Root directory of Oxford-IIIT Pet dataset.",
239
+ )
240
+ parser.add_argument(
241
+ "--out-dir",
242
+ type=str,
243
+ default="outputs/feature_viz",
244
+ help="Directory to save t-SNE/UMAP plots.",
245
+ )
246
+ parser.add_argument(
247
+ "--max-samples",
248
+ type=int,
249
+ default=2000,
250
+ help="Max number of test samples to subsample for visualization (None = all).",
251
+ )
252
+ args = parser.parse_args()
253
+
254
+ out_dir = Path(args.out_dir)
255
+ out_dir.mkdir(parents=True, exist_ok=True)
256
+
257
+ # 1) Extract features
258
+ X_raw, X_resnet, y = extract_features(
259
+ data_root=args.data_root,
260
+ max_samples=args.max_samples,
261
+ seed=42,
262
+ )
263
+
264
+ # 2) t-SNE on raw features
265
+ tsne_raw_path = out_dir / "tsne_raw.png"
266
+ run_tsne(X_raw, y, tsne_raw_path, title="t-SNE: Raw 64x64 Grayscale Features")
267
+
268
+ # 3) t-SNE on ResNet features
269
+ tsne_resnet_path = out_dir / "tsne_resnet.png"
270
+ run_tsne(X_resnet, y, tsne_resnet_path, title="t-SNE: ResNet18 Pretrained Features")
271
+
272
+ # 4) Optional UMAP (if available)
273
+ umap_raw_path = out_dir / "umap_raw.png"
274
+ run_umap(X_raw, y, umap_raw_path, title="UMAP: Raw 64x64 Grayscale Features")
275
+
276
+ umap_resnet_path = out_dir / "umap_resnet.png"
277
+ run_umap(X_resnet, y, umap_resnet_path, title="UMAP: ResNet18 Pretrained Features")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # Keep torch threads manageable
282
+ torch.set_num_threads(4)
283
+ main()
src/inference/__pycache__/lr_model.cpython-313.pyc ADDED
Binary file (3.17 kB). View file
 
src/inference/__pycache__/resnet_pt_lr_model.cpython-313.pyc ADDED
Binary file (8.63 kB). View file
 
src/inference/__pycache__/resnet_pt_svm_model.cpython-313.pyc ADDED
Binary file (8.41 kB). View file
 
src/inference/__pycache__/svm_model.cpython-313.pyc ADDED
Binary file (5.3 kB). View file
 
src/inference/__pycache__/test_resnet_pt_lr.cpython-313.pyc ADDED
Binary file (5.2 kB). View file
 
src/inference/__pycache__/test_resnet_pt_svm.cpython-313.pyc ADDED
Binary file (5.1 kB). View file
 
src/inference/base_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference/base_model.py
2
+ from abc import ABC, abstractmethod
3
+ from typing import Dict, Any
4
+ from PIL import Image
5
+
6
+ class BaseModel(ABC):
7
+ """Common interface for all pet recognition models."""
8
+
9
+ def __init__(self, name: str, labels: Dict[int, str]):
10
+ self.name = name
11
+ self.labels = labels
12
+
13
+ @abstractmethod
14
+ def preprocess(self, image: Image.Image) -> Any:
15
+ """Convert PIL image → model input (tensor / numpy / feature vector)."""
16
+ pass
17
+
18
+ @abstractmethod
19
+ def predict(self, image: Image.Image) -> Dict[str, Any]:
20
+ """
21
+ Run full pipeline: preprocess → forward pass → postprocess.
22
+
23
+ Returns:
24
+ {
25
+ "class_id": int,
26
+ "class_name": str,
27
+ "probs": Dict[str, float], # optional, top-k
28
+ }
29
+ """
30
+ pass
src/inference/lr_model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import joblib
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+
7
+ class LRModel:
8
+ """
9
+ Inference pipeline for Logistic Regression model
10
+ trained on 64x64 grayscale flattened images.
11
+ """
12
+
13
+ def __init__(self, model_path: str, labels_path: str, image_size: int = 64):
14
+ self.model = joblib.load(model_path)
15
+ self.labels = self._load_labels(labels_path)
16
+ self.image_size = image_size
17
+
18
+ def _load_labels(self, labels_path):
19
+ with open(labels_path, "r") as f:
20
+ label_dict = json.load(f)
21
+
22
+ # Ensure keys are integer indices, not strings
23
+ label_dict = {int(k): v for k, v in label_dict.items()}
24
+ return label_dict
25
+
26
+ def preprocess(self, image: Image.Image) -> np.ndarray:
27
+ """
28
+ Preprocessing matching training:
29
+ - Resize to 64x64
30
+ - Grayscale
31
+ - Normalize to [0,1]
32
+ - Flatten to (1, D)
33
+ """
34
+ img = image.resize((self.image_size, self.image_size))
35
+ img = img.convert("L") # grayscale
36
+ arr = np.array(img, dtype=np.float32) / 255.0
37
+ arr = arr.reshape(1, -1) # shape: (1, D)
38
+ return arr
39
+
40
+ def predict(self, image: Image.Image):
41
+ """
42
+ Returns:
43
+ {
44
+ "class_id": int,
45
+ "class_name": str,
46
+ "probabilities": {class_name: prob, ...}
47
+ }
48
+ """
49
+ x = self.preprocess(image)
50
+ probs = self.model.predict_proba(x)[0]
51
+ class_id = int(np.argmax(probs))
52
+ class_name = self.labels[class_id]
53
+
54
+ # Build probability dict (optional)
55
+ prob_dict = {
56
+ self.labels[i]: float(probs[i]) for i in range(len(probs))
57
+ }
58
+
59
+ return {
60
+ "class_id": class_id,
61
+ "class_name": class_name,
62
+ "probabilities": prob_dict
63
+ }
src/inference/resnet_pt_lr_model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference/resnet_pt_lr_model.py
2
+
3
+ import os
4
+ import json
5
+ from typing import Dict, Any, List, Optional
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ import torch
11
+ from torchvision.models import resnet18, ResNet18_Weights
12
+ import joblib
13
+
14
+
15
+ class ResNetPTLRModel:
16
+ """
17
+ End-to-end inference wrapper:
18
+ - ResNet18 (pretrained on ImageNet) as frozen backbone
19
+ - Logistic Regression head trained on extracted features
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib",
25
+ labels_path: str = "configs/labels.json",
26
+ device: Optional[str] = None,
27
+ ):
28
+ assert os.path.exists(ckpt_path), f"ResNet PT + LR checkpoint not found: {ckpt_path}"
29
+ assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
30
+
31
+ # Decide device
32
+ if device is None:
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ else:
35
+ self.device = torch.device(device)
36
+
37
+ print(f"[ResNetPTLRModel] Using device: {self.device}")
38
+
39
+ # --- Load LR head ---
40
+ print(f"[ResNetPTLRModel] Loading LR head from {ckpt_path} ...")
41
+ payload = joblib.load(ckpt_path)
42
+
43
+ # payload was saved as dict in train_resnet_pt_lr.py
44
+ if isinstance(payload, dict) and "model" in payload:
45
+ self.lr_head = payload["model"]
46
+ self.feature_dim = int(payload.get("feature_dim", 512))
47
+ self.backbone_name = payload.get("backbone", "resnet18_imagenet")
48
+ self.saved_labels_path = payload.get("labels_path", labels_path)
49
+ else:
50
+ # Fallback if someone saved the raw model
51
+ self.lr_head = payload
52
+ self.feature_dim = None
53
+ self.backbone_name = "resnet18_imagenet"
54
+ self.saved_labels_path = labels_path
55
+
56
+ # --- Load labels mapping ---
57
+ labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
58
+ print(f"[ResNetPTLRModel] Loading labels from {labels_file} ...")
59
+ with open(labels_file, "r") as f:
60
+ id_to_name = json.load(f)
61
+
62
+ # ensure keys are ints
63
+ self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
64
+
65
+ # --- Build ResNet18 backbone + preprocess (same as in feature extraction) ---
66
+ print("[ResNetPTLRModel] Building ResNet18 backbone ...")
67
+ weights = ResNet18_Weights.DEFAULT
68
+ model = resnet18(weights=weights)
69
+
70
+ import torch.nn as nn
71
+ model.fc = nn.Identity()
72
+
73
+ model.to(self.device)
74
+ model.eval()
75
+
76
+ self.backbone = model
77
+ self.preprocess_tf = weights.transforms()
78
+
79
+ # Optional: check feature_dim consistency if available
80
+ if self.feature_dim is not None:
81
+ try:
82
+ test_input = torch.zeros(1, 3, 224, 224).to(self.device)
83
+ with torch.no_grad():
84
+ out = self.backbone(test_input)
85
+ actual_dim = out.shape[1]
86
+ if actual_dim != self.feature_dim:
87
+ print(
88
+ f"[ResNetPTLRModel][WARN] feature_dim mismatch: "
89
+ f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
90
+ )
91
+ except Exception as e:
92
+ print(f"[ResNetPTLRModel][WARN] could not verify feature_dim: {e}")
93
+
94
+ def preprocess(self, img: Image.Image) -> torch.Tensor:
95
+ """
96
+ Apply ImageNet-style transforms and return a (1, 3, H, W) tensor on device.
97
+ """
98
+ t = self.preprocess_tf(img) # (3, H, W)
99
+ if t.ndim == 3:
100
+ t = t.unsqueeze(0) # (1, 3, H, W)
101
+ return t.to(self.device)
102
+
103
+ @staticmethod
104
+ def _to_probabilities_from_logits(logits: np.ndarray) -> np.ndarray:
105
+ """
106
+ Convert raw scores/logits to probabilities using softmax.
107
+ """
108
+ logits = logits - np.max(logits)
109
+ exp = np.exp(logits)
110
+ return exp / np.sum(exp)
111
+
112
+ def _extract_features(self, img: Image.Image) -> np.ndarray:
113
+ """
114
+ Run a PIL image through the backbone and get a (1, D) numpy feature vector.
115
+ """
116
+ x = self.preprocess(img) # (1, 3, H, W)
117
+ with torch.no_grad():
118
+ feats = self.backbone(x) # (1, D)
119
+ feats_np = feats.cpu().numpy()
120
+ return feats_np # (1, D)
121
+
122
+ def predict(
123
+ self,
124
+ img: Image.Image,
125
+ top_k: int = 5,
126
+ ) -> Dict[str, Any]:
127
+ """
128
+ Predict class for a single image.
129
+
130
+ Returns:
131
+ {
132
+ "class_id": int,
133
+ "class_name": str,
134
+ "probabilities": {class_name: prob_float},
135
+ "top_k": [
136
+ {"class_id": int, "class_name": str, "probability": float},
137
+ ...
138
+ ]
139
+ }
140
+ """
141
+ feats_np = self._extract_features(img) # (1, D)
142
+
143
+ # LR has predict_proba, use that directly
144
+ if hasattr(self.lr_head, "predict_proba"):
145
+ probs = self.lr_head.predict_proba(feats_np)[0] # (C,)
146
+ else:
147
+ # Fallback: use decision_function and softmax
148
+ scores = self.lr_head.decision_function(feats_np)
149
+ if scores.ndim == 1:
150
+ scores = scores[np.newaxis, :]
151
+ probs = self._to_probabilities_from_logits(scores[0])
152
+
153
+ pred_id = int(np.argmax(probs))
154
+ pred_name = self.id_to_name[pred_id]
155
+
156
+ # Full distribution
157
+ prob_dict: Dict[str, float] = {
158
+ self.id_to_name[i]: float(p)
159
+ for i, p in enumerate(probs)
160
+ }
161
+
162
+ # Top-k sorted
163
+ sorted_indices = np.argsort(probs)[::-1]
164
+ top_k = min(top_k, len(sorted_indices))
165
+ top_k_list: List[Dict[str, Any]] = []
166
+ for i in range(top_k):
167
+ cid = int(sorted_indices[i])
168
+ top_k_list.append({
169
+ "class_id": cid,
170
+ "class_name": self.id_to_name[cid],
171
+ "probability": float(probs[cid]),
172
+ })
173
+
174
+ return {
175
+ "class_id": pred_id,
176
+ "class_name": pred_name,
177
+ "probabilities": prob_dict,
178
+ "top_k": top_k_list,
179
+ }
src/inference/resnet_pt_svm_model.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference/resnet_pt_svm_model.py
2
+
3
+ import os
4
+ import json
5
+ from typing import Dict, Any, List, Optional
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ import torch
11
+ from torchvision.models import resnet18, ResNet18_Weights
12
+ import joblib
13
+
14
+
15
+ class ResNetPTSVMModel:
16
+ """
17
+ ResNet18 (pretrained, frozen) + Linear SVM head.
18
+
19
+ Pipeline:
20
+ - PIL image
21
+ - ImageNet transforms
22
+ - ResNet18 backbone (fc -> Identity) -> feature vector
23
+ - Linear SVM decision_function
24
+ - Softmax over scores to get probabilities
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib",
30
+ labels_path: str = "configs/labels.json",
31
+ device: Optional[str] = None,
32
+ ):
33
+ assert os.path.exists(ckpt_path), f"ResNet PT + SVM checkpoint not found: {ckpt_path}"
34
+ assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
35
+
36
+ # Device
37
+ if device is None:
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ else:
40
+ self.device = torch.device(device)
41
+
42
+ print(f"[ResNetPTSVMModel] Using device: {self.device}")
43
+
44
+ # --- Load SVM head ---
45
+ print(f"[ResNetPTSVMModel] Loading SVM head from {ckpt_path} ...")
46
+ payload = joblib.load(ckpt_path)
47
+
48
+ if isinstance(payload, dict) and "model" in payload:
49
+ self.svm_head = payload["model"]
50
+ self.feature_dim = int(payload.get("feature_dim", 512))
51
+ self.backbone_name = payload.get("backbone", "resnet18_imagenet")
52
+ self.saved_labels_path = payload.get("labels_path", labels_path)
53
+ else:
54
+ self.svm_head = payload
55
+ self.feature_dim = None
56
+ self.backbone_name = "resnet18_imagenet"
57
+ self.saved_labels_path = labels_path
58
+
59
+ # --- Load labels mapping ---
60
+ labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
61
+ print(f"[ResNetPTSVMModel] Loading labels from {labels_file} ...")
62
+ with open(labels_file, "r") as f:
63
+ id_to_name = json.load(f)
64
+
65
+ # ensure keys are ints
66
+ self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
67
+
68
+ # --- Build ResNet18 backbone + preprocess ---
69
+ print("[ResNetPTSVMModel] Building ResNet18 backbone ...")
70
+ weights = ResNet18_Weights.DEFAULT
71
+ model = resnet18(weights=weights)
72
+
73
+ import torch.nn as nn
74
+ model.fc = nn.Identity()
75
+
76
+ model.to(self.device)
77
+ model.eval()
78
+
79
+ self.backbone = model
80
+ self.preprocess_tf = weights.transforms()
81
+
82
+ # Optional: sanity check feature_dim
83
+ if self.feature_dim is not None:
84
+ try:
85
+ test_input = torch.zeros(1, 3, 224, 224).to(self.device)
86
+ with torch.no_grad():
87
+ out = self.backbone(test_input)
88
+ actual_dim = out.shape[1]
89
+ if actual_dim != self.feature_dim:
90
+ print(
91
+ f"[ResNetPTSVMModel][WARN] feature_dim mismatch: "
92
+ f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
93
+ )
94
+ except Exception as e:
95
+ print(f"[ResNetPTSVMModel][WARN] could not verify feature_dim: {e}")
96
+
97
+ def preprocess(self, img: Image.Image) -> torch.Tensor:
98
+ """
99
+ Apply the ImageNet-style transforms and return (1, 3, H, W) tensor on device.
100
+ """
101
+ t = self.preprocess_tf(img) # (3, H, W)
102
+ if t.ndim == 3:
103
+ t = t.unsqueeze(0)
104
+ return t.to(self.device)
105
+
106
+ @staticmethod
107
+ def _softmax(scores: np.ndarray) -> np.ndarray:
108
+ scores = scores - np.max(scores)
109
+ exp = np.exp(scores)
110
+ return exp / np.sum(exp)
111
+
112
+ def _extract_features(self, img: Image.Image) -> np.ndarray:
113
+ """
114
+ Run image through ResNet backbone to get (1, D) feature vector.
115
+ """
116
+ x = self.preprocess(img)
117
+ with torch.no_grad():
118
+ feats = self.backbone(x) # (1, D)
119
+ return feats.cpu().numpy() # (1, D)
120
+
121
+ def predict(
122
+ self,
123
+ img: Image.Image,
124
+ top_k: int = 5,
125
+ ) -> Dict[str, Any]:
126
+ """
127
+ Predict class for a single image.
128
+
129
+ Returns:
130
+ {
131
+ "class_id": int,
132
+ "class_name": str,
133
+ "probabilities": {class_name: prob_float},
134
+ "top_k": [
135
+ {"class_id": int, "class_name": str, "probability": float},
136
+ ...
137
+ ]
138
+ }
139
+ """
140
+ feats_np = self._extract_features(img) # (1, D)
141
+
142
+ # LinearSVC has no predict_proba -> use decision_function
143
+ scores = self.svm_head.decision_function(feats_np)
144
+ if scores.ndim == 1:
145
+ scores = scores[np.newaxis, :]
146
+ scores = scores[0] # (C,)
147
+
148
+ probs = self._softmax(scores) # (C,)
149
+
150
+ pred_id = int(np.argmax(probs))
151
+ pred_name = self.id_to_name[pred_id]
152
+
153
+ prob_dict: Dict[str, float] = {
154
+ self.id_to_name[i]: float(p)
155
+ for i, p in enumerate(probs)
156
+ }
157
+
158
+ sorted_indices = np.argsort(probs)[::-1]
159
+ top_k = min(top_k, len(sorted_indices))
160
+ top_k_list: List[Dict[str, Any]] = []
161
+ for i in range(top_k):
162
+ cid = int(sorted_indices[i])
163
+ top_k_list.append({
164
+ "class_id": cid,
165
+ "class_name": self.id_to_name[cid],
166
+ "probability": float(probs[cid]),
167
+ })
168
+
169
+ return {
170
+ "class_id": pred_id,
171
+ "class_name": pred_name,
172
+ "probabilities": prob_dict,
173
+ "top_k": top_k_list,
174
+ }
src/inference/svm_model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference/svm_model.py
2
+
3
+ import os
4
+ import json
5
+ from typing import Dict, Any, List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ import joblib
11
+
12
+
13
+ class SVMModel:
14
+ """
15
+ Inference wrapper for the Linear SVM trained on raw 64x64 grayscale pixels.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ ckpt_path: str = "checkpoints/svm_model.joblib",
21
+ labels_path: str = "configs/labels.json",
22
+ ):
23
+ assert os.path.exists(ckpt_path), f"SVM checkpoint not found: {ckpt_path}"
24
+ assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
25
+
26
+ print(f"[SVMModel] Loading checkpoint from {ckpt_path} ...")
27
+ payload = joblib.load(ckpt_path)
28
+
29
+ # You might have saved a dict with more keys, so handle both cases.
30
+ if isinstance(payload, dict) and "model" in payload:
31
+ self.model = payload["model"]
32
+ else:
33
+ self.model = payload
34
+
35
+ print(f"[SVMModel] Loading labels from {labels_path} ...")
36
+ with open(labels_path, "r") as f:
37
+ self.id_to_name = json.load(f)
38
+
39
+ # Ensure keys are integers
40
+ self.id_to_name = {int(k): v for k, v in self.id_to_name.items()}
41
+
42
+ self.preprocess_tf = transforms.Compose([
43
+ transforms.Resize((64, 64)),
44
+ transforms.Grayscale(num_output_channels=1),
45
+ transforms.ToTensor(), # (1, 64, 64) in [0, 1]
46
+ ])
47
+
48
+ def preprocess(self, img: Image.Image) -> np.ndarray:
49
+ """
50
+ Convert PIL image to flattened grayscale vector (1, 4096).
51
+ """
52
+ t = self.preprocess_tf(img) # (1, 64, 64) tensor
53
+ arr = t.view(-1).numpy() # (4096,)
54
+ return arr[np.newaxis, :] # (1, 4096)
55
+
56
+ @staticmethod
57
+ def _softmax(scores: np.ndarray) -> np.ndarray:
58
+ # scores: (C,)
59
+ scores = scores - np.max(scores) # for numerical stability
60
+ exp = np.exp(scores)
61
+ return exp / np.sum(exp)
62
+
63
+ def predict(
64
+ self,
65
+ img: Image.Image,
66
+ top_k: int = 5,
67
+ ) -> Dict[str, Any]:
68
+ """
69
+ Predict the class of a single image.
70
+
71
+ Returns:
72
+ {
73
+ "class_id": int,
74
+ "class_name": str,
75
+ "probabilities": {class_name: prob_float} # full distribution
76
+ "top_k": List[{"class_id": int, "class_name": str, "probability": float}]
77
+ }
78
+ """
79
+ x = self.preprocess(img) # (1, 4096)
80
+
81
+ # LinearSVC doesn't have predict_proba, but decision_function gives scores
82
+ scores = self.model.decision_function(x) # (1, C) or (C,) if binary
83
+ if scores.ndim == 1:
84
+ scores = scores[np.newaxis, :]
85
+ scores = scores[0] # (C,)
86
+
87
+ probs = self._softmax(scores) # (C,)
88
+
89
+ pred_id = int(np.argmax(probs))
90
+ pred_name = self.id_to_name[pred_id]
91
+
92
+ # Build dict of {class_name: prob}
93
+ prob_dict = {
94
+ self.id_to_name[i]: float(p)
95
+ for i, p in enumerate(probs)
96
+ }
97
+
98
+ # Build sorted top-k
99
+ sorted_indices = np.argsort(probs)[::-1]
100
+ top_k = min(top_k, len(sorted_indices))
101
+ top_k_list: List[Dict[str, Any]] = []
102
+ for i in range(top_k):
103
+ cid = int(sorted_indices[i])
104
+ top_k_list.append({
105
+ "class_id": cid,
106
+ "class_name": self.id_to_name[cid],
107
+ "probability": float(probs[cid]),
108
+ })
109
+
110
+ return {
111
+ "class_id": pred_id,
112
+ "class_name": pred_name,
113
+ "probabilities": prob_dict,
114
+ "top_k": top_k_list,
115
+ }
src/inference/test_resnet_pt_lr.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference/test_resnet_pt_lr.py
2
+
3
+ import os
4
+ import argparse
5
+ import random
6
+
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torchvision import datasets
11
+
12
+ from src.inference.resnet_pt_lr_model import ResNetPTLRModel
13
+
14
+
15
+ def test_single_image(
16
+ image_path: str,
17
+ ckpt_path: str,
18
+ labels_path: str,
19
+ device: str = None,
20
+ top_k: int = 5,
21
+ ):
22
+ assert os.path.exists(image_path), f"Image not found: {image_path}"
23
+ img = Image.open(image_path).convert("RGB")
24
+
25
+ model = ResNetPTLRModel(
26
+ ckpt_path=ckpt_path,
27
+ labels_path=labels_path,
28
+ device=device,
29
+ )
30
+
31
+ out = model.predict(img, top_k=top_k)
32
+
33
+ print(f"Input image: {image_path}")
34
+ print(f"Predicted class_id : {out['class_id']}")
35
+ print(f"Predicted class_name: {out['class_name']}")
36
+ print("Top-k predictions:")
37
+ for i, item in enumerate(out["top_k"], start=1):
38
+ print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
39
+
40
+
41
+ def test_random_dataset_sample(
42
+ data_root: str,
43
+ ckpt_path: str,
44
+ labels_path: str,
45
+ device: str = None,
46
+ top_k: int = 5,
47
+ ):
48
+ """
49
+ Pick a random sample from the Oxford-IIIT Pet test split and run inference.
50
+ """
51
+ print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")
52
+
53
+ # transform=None -> returns PIL.Image
54
+ test_ds = datasets.OxfordIIITPet(
55
+ root=data_root,
56
+ split="test",
57
+ target_types="category",
58
+ transform=None,
59
+ download=True,
60
+ )
61
+
62
+ model = ResNetPTLRModel(
63
+ ckpt_path=ckpt_path,
64
+ labels_path=labels_path,
65
+ device=device,
66
+ )
67
+
68
+ idx = random.randint(0, len(test_ds) - 1)
69
+ img, target = test_ds[idx]
70
+ assert isinstance(img, Image.Image)
71
+
72
+ # dataset has .categories giving names
73
+ gt_name = test_ds.categories[target]
74
+
75
+ print(f"[+] Random sample idx={idx}")
76
+ print(f" Ground truth: id={target}, name={gt_name}")
77
+
78
+ out = model.predict(img, top_k=top_k)
79
+
80
+ print(f" Predicted class_id : {out['class_id']}")
81
+ print(f" Predicted class_name: {out['class_name']}")
82
+ print(" Top-k predictions:")
83
+ for i, item in enumerate(out["top_k"], start=1):
84
+ print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
85
+
86
+
87
+ def parse_args():
88
+ parser = argparse.ArgumentParser(
89
+ description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet."
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--ckpt-path",
94
+ type=str,
95
+ default="checkpoints/resnet_pt_lr_head.joblib",
96
+ help="Path to ResNet PT + LR checkpoint.",
97
+ )
98
+ parser.add_argument(
99
+ "--labels-path",
100
+ type=str,
101
+ default="configs/labels.json",
102
+ help="Path to labels mapping JSON.",
103
+ )
104
+ parser.add_argument(
105
+ "--data-root",
106
+ type=str,
107
+ default="data/oxford-iiit-pet",
108
+ help="Root directory for Oxford-IIIT Pet dataset.",
109
+ )
110
+ parser.add_argument(
111
+ "--image-path",
112
+ type=str,
113
+ default=None,
114
+ help="If provided, run inference on this image instead of a random test sample.",
115
+ )
116
+ parser.add_argument(
117
+ "--device",
118
+ type=str,
119
+ default=None,
120
+ help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.",
121
+ )
122
+ parser.add_argument(
123
+ "--top-k",
124
+ type=int,
125
+ default=5,
126
+ help="Number of top classes to print.",
127
+ )
128
+
129
+ return parser.parse_args()
130
+
131
+
132
+ if __name__ == "__main__":
133
+ args = parse_args()
134
+
135
+ if args.image_path is not None:
136
+ test_single_image(
137
+ image_path=args.image_path,
138
+ ckpt_path=args.ckpt_path,
139
+ labels_path=args.labels_path,
140
+ device=args.device,
141
+ top_k=args.top_k,
142
+ )
143
+ else:
144
+ test_random_dataset_sample(
145
+ data_root=args.data_root,
146
+ ckpt_path=args.ckpt_path,
147
+ labels_path=args.labels_path,
148
+ device=args.device,
149
+ top_k=args.top_k,
150
+ )
src/inference/test_resnet_pt_svm.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference/test_resnet_pt_svm.py
2
+
3
+ import os
4
+ import argparse
5
+ import random
6
+
7
+ from PIL import Image
8
+ from torchvision import datasets
9
+
10
+ from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
11
+
12
+
13
+ def test_single_image(
14
+ image_path: str,
15
+ ckpt_path: str,
16
+ labels_path: str,
17
+ device: str = None,
18
+ top_k: int = 5,
19
+ ):
20
+ assert os.path.exists(image_path), f"Image not found: {image_path}"
21
+ img = Image.open(image_path).convert("RGB")
22
+
23
+ model = ResNetPTSVMModel(
24
+ ckpt_path=ckpt_path,
25
+ labels_path=labels_path,
26
+ device=device,
27
+ )
28
+
29
+ out = model.predict(img, top_k=top_k)
30
+
31
+ print(f"Input image: {image_path}")
32
+ print(f"Predicted class_id : {out['class_id']}")
33
+ print(f"Predicted class_name: {out['class_name']}")
34
+ print("Top-k predictions:")
35
+ for i, item in enumerate(out["top_k"], start=1):
36
+ print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
37
+
38
+
39
+ def test_random_dataset_sample(
40
+ data_root: str,
41
+ ckpt_path: str,
42
+ labels_path: str,
43
+ device: str = None,
44
+ top_k: int = 5,
45
+ ):
46
+ print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")
47
+
48
+ test_ds = datasets.OxfordIIITPet(
49
+ root=data_root,
50
+ split="test",
51
+ target_types="category",
52
+ transform=None, # return PIL.Image
53
+ download=True,
54
+ )
55
+
56
+ model = ResNetPTSVMModel(
57
+ ckpt_path=ckpt_path,
58
+ labels_path=labels_path,
59
+ device=device,
60
+ )
61
+
62
+ idx = random.randint(0, len(test_ds) - 1)
63
+ img, target = test_ds[idx]
64
+ assert isinstance(img, Image.Image)
65
+
66
+ gt_name = test_ds.categories[target]
67
+
68
+ print(f"[+] Random sample idx={idx}")
69
+ print(f" Ground truth: id={target}, name={gt_name}")
70
+
71
+ out = model.predict(img, top_k=top_k)
72
+
73
+ print(f" Predicted class_id : {out['class_id']}")
74
+ print(f" Predicted class_name: {out['class_name']}")
75
+ print(" Top-k predictions:")
76
+ for i, item in enumerate(out["top_k"], start=1):
77
+ print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
78
+
79
+
80
+ def parse_args():
81
+ parser = argparse.ArgumentParser(
82
+ description="Test ResNet(PT) + SVM inference on Oxford-IIIT Pet."
83
+ )
84
+
85
+ parser.add_argument(
86
+ "--ckpt-path",
87
+ type=str,
88
+ default="checkpoints/resnet_pt_svm_head.joblib",
89
+ help="Path to ResNet PT + SVM checkpoint.",
90
+ )
91
+ parser.add_argument(
92
+ "--labels-path",
93
+ type=str,
94
+ default="configs/labels.json",
95
+ help="Path to labels mapping JSON.",
96
+ )
97
+ parser.add_argument(
98
+ "--data-root",
99
+ type=str,
100
+ default="data/oxford-iiit-pet",
101
+ help="Root directory for Oxford-IIIT Pet dataset.",
102
+ )
103
+ parser.add_argument(
104
+ "--image-path",
105
+ type=str,
106
+ default=None,
107
+ help="If provided, run on this image instead of random test sample.",
108
+ )
109
+ parser.add_argument(
110
+ "--device",
111
+ type=str,
112
+ default=None,
113
+ help="Device to use (e.g. 'cpu', 'cuda'). If None, auto-select.",
114
+ )
115
+ parser.add_argument(
116
+ "--top-k",
117
+ type=int,
118
+ default=5,
119
+ help="Number of top classes to print.",
120
+ )
121
+
122
+ return parser.parse_args()
123
+
124
+
125
+ if __name__ == "__main__":
126
+ args = parse_args()
127
+
128
+ if args.image_path is not None:
129
+ test_single_image(
130
+ image_path=args.image_path,
131
+ ckpt_path=args.ckpt_path,
132
+ labels_path=args.labels_path,
133
+ device=args.device,
134
+ top_k=args.top_k,
135
+ )
136
+ else:
137
+ test_random_dataset_sample(
138
+ data_root=args.data_root,
139
+ ckpt_path=args.ckpt_path,
140
+ labels_path=args.labels_path,
141
+ device=args.device,
142
+ top_k=args.top_k,
143
+ )
src/registry.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/registry.py
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Callable, Dict, Any, Optional
5
+
6
+
7
+ @dataclass
8
+ class RegisteredModel:
9
+ """Metadata + lazy loader for a single model."""
10
+ id: str
11
+ display_name: str
12
+ loader: Callable[[], Any]
13
+ _instance: Optional[Any] = field(default=None, init=False, repr=False)
14
+
15
+ def get(self) -> Any:
16
+ """Instantiate on first call, then cache."""
17
+ if self._instance is None:
18
+ self._instance = self.loader()
19
+ return self._instance
20
+
21
+
22
+ def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]:
23
+ """
24
+ Central place to register all models.
25
+ Returns a dict: model_id -> RegisteredModel.
26
+ """
27
+
28
+ def make_lr_raw():
29
+ from src.inference.lr_model import LRModel
30
+ return LRModel(
31
+ ckpt_path="checkpoints/lr_model.joblib",
32
+ labels_path="configs/labels.json",
33
+ device=device,
34
+ )
35
+
36
+ def make_svm_raw():
37
+ from src.inference.svm_model import SVMModel
38
+ return SVMModel(
39
+ ckpt_path="checkpoints/svm_model.joblib",
40
+ labels_path="configs/labels.json",
41
+ device=device,
42
+ )
43
+
44
+ def make_resnet_pt_lr():
45
+ from src.inference.resnet_pt_lr_model import ResNetPTLRModel
46
+ return ResNetPTLRModel(
47
+ ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
48
+ labels_path="configs/labels.json",
49
+ device=device,
50
+ )
51
+
52
+ def make_resnet_pt_svm():
53
+ from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
54
+ return ResNetPTSVMModel(
55
+ ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
56
+ labels_path="configs/labels.json",
57
+ device=device,
58
+ )
59
+
60
+ return {
61
+ "lr_raw": RegisteredModel(
62
+ id="lr_raw",
63
+ display_name="LR (raw 64×64 grayscale)",
64
+ loader=make_lr_raw,
65
+ ),
66
+ "svm_raw": RegisteredModel(
67
+ id="svm_raw",
68
+ display_name="SVM (raw 64×64 grayscale)",
69
+ loader=make_svm_raw,
70
+ ),
71
+ "resnet_pt_lr": RegisteredModel(
72
+ id="resnet_pt_lr",
73
+ display_name="ResNet(PT) + LR",
74
+ loader=make_resnet_pt_lr,
75
+ ),
76
+ "resnet_pt_svm": RegisteredModel(
77
+ id="resnet_pt_svm",
78
+ display_name="ResNet(PT) + SVM",
79
+ loader=make_resnet_pt_svm,
80
+ ),
81
+ }
82
+
83
+
84
+ # Build once at import; models themselves are loaded lazily.
85
+ _REGISTRY: Dict[str, RegisteredModel] = _build_registry()
86
+
87
+
88
+ def get_registry() -> Dict[str, RegisteredModel]:
89
+ """Return the full registry (id -> RegisteredModel)."""
90
+ return _REGISTRY
91
+
92
+
93
+ def get_models() -> Dict[str, Any]:
94
+ """
95
+ Eagerly instantiate all models and return id -> model_instance.
96
+ Useful for simple scripts or for initializing everything at UI startup.
97
+ """
98
+ return {mid: entry.get() for mid, entry in _REGISTRY.items()}
99
+
100
+
101
+ def get_model(model_id: str) -> Any:
102
+ """Get a single model instance by id (instantiates on first use)."""
103
+ return _REGISTRY[model_id].get()
104
+
105
+
106
+ def get_model_display_names() -> Dict[str, str]:
107
+ """Return mapping id -> human-readable name (for dropdown choices)."""
108
+ return {mid: entry.display_name for mid, entry in _REGISTRY.items()}
src/training/__pycache__/extract_resnet_features.cpython-313.pyc ADDED
Binary file (6.8 kB). View file
 
src/training/__pycache__/train_resnet_pt_lr.cpython-313.pyc ADDED
Binary file (5.58 kB). View file
 
src/training/__pycache__/train_resnet_pt_svm.cpython-313.pyc ADDED
Binary file (5.51 kB). View file
 
src/training/__pycache__/train_svm.cpython-313.pyc ADDED
Binary file (6.6 kB). View file
 
src/training/extract_resnet_features.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/training/extract_resnet_features.py
2
+
3
+ import os
4
+ import argparse
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from torchvision import datasets
10
+ from torchvision.models import resnet18, ResNet18_Weights
11
+
12
+
13
+ def build_datasets(data_root: str, preprocess):
14
+ """
15
+ Build Oxford-IIIT Pet train/test datasets with ResNet preprocessing.
16
+ """
17
+ train_ds = datasets.OxfordIIITPet(
18
+ root=data_root,
19
+ split="trainval",
20
+ target_types="category",
21
+ transform=preprocess,
22
+ download=True,
23
+ )
24
+
25
+ test_ds = datasets.OxfordIIITPet(
26
+ root=data_root,
27
+ split="test",
28
+ target_types="category",
29
+ transform=preprocess,
30
+ download=True,
31
+ )
32
+
33
+ return train_ds, test_ds
34
+
35
+
36
+ def build_dataloaders(train_ds, test_ds, batch_size: int = 64, num_workers: int = 2):
37
+ train_loader = DataLoader(
38
+ train_ds,
39
+ batch_size=batch_size,
40
+ shuffle=False, # don't shuffle, we just want deterministic feature arrays
41
+ num_workers=num_workers,
42
+ )
43
+
44
+ test_loader = DataLoader(
45
+ test_ds,
46
+ batch_size=batch_size,
47
+ shuffle=False,
48
+ num_workers=num_workers,
49
+ )
50
+
51
+ return train_loader, test_loader
52
+
53
+
54
+ def build_resnet18_backbone(device: torch.device):
55
+ """
56
+ Load ResNet18 pretrained on ImageNet, replace final fc with Identity.
57
+ Returns:
58
+ model (nn.Module), feature_dim (int), preprocess (transform)
59
+ """
60
+ weights = ResNet18_Weights.DEFAULT
61
+ model = resnet18(weights=weights)
62
+ feature_dim = model.fc.in_features # 512
63
+
64
+ # Replace final classifier with identity to get penultimate features
65
+ import torch.nn as nn
66
+ model.fc = nn.Identity()
67
+
68
+ model.to(device)
69
+ model.eval()
70
+
71
+ # Official preprocessing pipeline for these weights (resize + crop + norm)
72
+ preprocess = weights.transforms()
73
+
74
+ return model, feature_dim, preprocess
75
+
76
+
77
+ def extract_features(model, loader, device: torch.device):
78
+ """
79
+ Run images through the model and collect features + labels.
80
+ Returns:
81
+ X: (N, feature_dim) numpy array
82
+ y: (N,) numpy array
83
+ """
84
+ features_list = []
85
+ labels_list = []
86
+
87
+ with torch.no_grad():
88
+ for images, targets in loader:
89
+ images = images.to(device)
90
+ outputs = model(images) # (B, feature_dim)
91
+ features_list.append(outputs.cpu().numpy())
92
+ labels_list.append(targets.numpy())
93
+
94
+ X = np.concatenate(features_list, axis=0)
95
+ y = np.concatenate(labels_list, axis=0)
96
+ return X, y
97
+
98
+
99
+ def main(
100
+ data_root: str = "data/oxford-iiit-pet",
101
+ out_dir: str = "data/resnet18_features",
102
+ batch_size: int = 64,
103
+ num_workers: int = 2,
104
+ ):
105
+ os.makedirs(out_dir, exist_ok=True)
106
+
107
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+ print(f"[+] Using device: {device}")
109
+
110
+ print("[+] Building ResNet18 backbone and preprocessing ...")
111
+ model, feature_dim, preprocess = build_resnet18_backbone(device)
112
+ print(f"[+] Feature dimension: {feature_dim}")
113
+
114
+ print(f"[+] Loading Oxford-IIIT Pet from {data_root} ...")
115
+ train_ds, test_ds = build_datasets(data_root, preprocess)
116
+
117
+ print("[+] Building dataloaders ...")
118
+ train_loader, test_loader = build_dataloaders(
119
+ train_ds, test_ds, batch_size=batch_size, num_workers=num_workers
120
+ )
121
+
122
+ print("[+] Extracting train features ...")
123
+ X_train, y_train = extract_features(model, train_loader, device)
124
+ print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
125
+
126
+ print("[+] Extracting test features ...")
127
+ X_test, y_test = extract_features(model, test_loader, device)
128
+ print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
129
+
130
+ # Save to .npy
131
+ x_train_path = os.path.join(out_dir, "X_train_resnet18.npy")
132
+ y_train_path = os.path.join(out_dir, "y_train.npy")
133
+ x_test_path = os.path.join(out_dir, "X_test_resnet18.npy")
134
+ y_test_path = os.path.join(out_dir, "y_test.npy")
135
+
136
+ print(f"[+] Saving features to {out_dir} ...")
137
+ np.save(x_train_path, X_train)
138
+ np.save(y_train_path, y_train)
139
+ np.save(x_test_path, X_test)
140
+ np.save(y_test_path, y_test)
141
+
142
+ print("[+] Done extracting ResNet18 features.")
143
+
144
+
145
+ def parse_args():
146
+ parser = argparse.ArgumentParser(
147
+ description="Extract ResNet18 (pretrained) features for Oxford-IIIT Pet."
148
+ )
149
+ parser.add_argument(
150
+ "--data-root",
151
+ type=str,
152
+ default="data/oxford-iiit-pet",
153
+ help="Root directory for Oxford-IIIT Pet dataset.",
154
+ )
155
+ parser.add_argument(
156
+ "--out-dir",
157
+ type=str,
158
+ default="data/resnet18_features",
159
+ help="Directory to save .npy feature files.",
160
+ )
161
+ parser.add_argument(
162
+ "--batch-size",
163
+ type=int,
164
+ default=64,
165
+ help="Batch size for feature extraction.",
166
+ )
167
+ parser.add_argument(
168
+ "--num-workers",
169
+ type=int,
170
+ default=2,
171
+ help="Num workers for dataloader.",
172
+ )
173
+ return parser.parse_args()
174
+
175
+
176
+ if __name__ == "__main__":
177
+ args = parse_args()
178
+ main(
179
+ data_root=args.data_root,
180
+ out_dir=args.out_dir,
181
+ batch_size=args.batch_size,
182
+ num_workers=args.num_workers,
183
+ )
src/training/train_lr.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+ from torchvision import datasets, transforms
10
+
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.metrics import accuracy_score, classification_report
13
+ import joblib
14
+
15
+
16
+ def get_datasets(data_root: str, image_size: int = 64) -> Tuple[torch.utils.data.Dataset,
17
+ torch.utils.data.Dataset,
18
+ dict]:
19
+ """
20
+ Load Oxford-IIIT Pet train/test splits with simple transforms.
21
+
22
+ Returns:
23
+ train_dataset, test_dataset, class_to_idx
24
+ """
25
+ # Simple transform: resize -> grayscale -> tensor in [0,1]
26
+ transform = transforms.Compose([
27
+ transforms.Resize((image_size, image_size)),
28
+ transforms.Grayscale(num_output_channels=1),
29
+ transforms.ToTensor(), # (1, H, W), float32 in [0,1]
30
+ ])
31
+
32
+ train_dataset = datasets.OxfordIIITPet(
33
+ root=data_root,
34
+ split="trainval",
35
+ target_types="category",
36
+ transform=transform,
37
+ download=True, # downloads to root/oxford-iiit-pet if not present
38
+ )
39
+
40
+ test_dataset = datasets.OxfordIIITPet(
41
+ root=data_root,
42
+ split="test",
43
+ target_types="category",
44
+ transform=transform,
45
+ download=True,
46
+ )
47
+
48
+ # class_to_idx mapping
49
+ # Many torchvision datasets expose this attribute
50
+ class_to_idx = train_dataset.class_to_idx
51
+
52
+ return train_dataset, test_dataset, class_to_idx
53
+
54
+
55
+ def dataset_to_numpy(dataset: torch.utils.data.Dataset) -> Tuple[np.ndarray, np.ndarray]:
56
+ """
57
+ Convert a torchvision dataset (with tensor images) to numpy arrays
58
+ suitable for scikit-learn.
59
+
60
+ X: (N, D) flattened grayscale pixels
61
+ y: (N,) int labels
62
+ """
63
+ X_list = []
64
+ y_list = []
65
+
66
+ for img, label in tqdm(dataset, desc="Converting to numpy"):
67
+ # img: torch.Tensor, shape (1, H, W)
68
+ arr = img.numpy() # (1, H, W)
69
+ arr = arr.reshape(-1) # flatten to (D,)
70
+ X_list.append(arr)
71
+ y_list.append(label)
72
+
73
+ X = np.stack(X_list, axis=0).astype(np.float32) # (N, D)
74
+ y = np.array(y_list, dtype=np.int64) # (N,)
75
+
76
+ return X, y
77
+
78
+
79
+ def save_labels(class_to_idx: dict, labels_path: str):
80
+ """
81
+ Save labels as id -> class_name in a JSON file for inference/UI.
82
+ """
83
+ # Invert mapping: idx -> class_name
84
+ idx_to_class = {idx: cls_name for cls_name, idx in class_to_idx.items()}
85
+
86
+ os.makedirs(os.path.dirname(labels_path), exist_ok=True)
87
+ with open(labels_path, "w") as f:
88
+ json.dump(idx_to_class, f, indent=2)
89
+ print(f"[INFO] Saved labels to {labels_path}")
90
+
91
+
92
+ def train_logistic_regression(X_train: np.ndarray, y_train: np.ndarray) -> LogisticRegression:
93
+ """
94
+ Train multinomial Logistic Regression on given features.
95
+
96
+ We use 'saga' because it supports multinomial loss and L1/L2,
97
+ and works decently with high-dimensional sparse-ish data.
98
+ """
99
+ num_classes = len(np.unique(y_train))
100
+ print(f"[INFO] Training Logistic Regression on {X_train.shape[0]} samples, "
101
+ f"{X_train.shape[1]} features, {num_classes} classes")
102
+
103
+ clf = LogisticRegression(
104
+ penalty="l2",
105
+ C=1.0,
106
+ solver="saga",
107
+ multi_class="multinomial",
108
+ max_iter=1000,
109
+ n_jobs=-1,
110
+ verbose=1,
111
+ )
112
+ clf.fit(X_train, y_train)
113
+ return clf
114
+
115
+
116
+ def evaluate_model(clf: LogisticRegression, X: np.ndarray, y: np.ndarray, split_name: str):
117
+ """
118
+ Print accuracy and basic classification report for a given split.
119
+ """
120
+ y_pred = clf.predict(X)
121
+ acc = accuracy_score(y, y_pred)
122
+ print(f"\n[{split_name}] Accuracy: {acc * 100:.2f}%")
123
+ print(f"[{split_name}] Classification report (macro avg at bottom):")
124
+ print(classification_report(y, y_pred, digits=3))
125
+
126
+
127
+ def main():
128
+ # -------- configs (tweak paths as needed) --------
129
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
130
+ data_root = os.path.join(project_root, "data")
131
+ checkpoints_dir = os.path.join(project_root, "checkpoints")
132
+ configs_dir = os.path.join(project_root, "configs")
133
+
134
+ os.makedirs(checkpoints_dir, exist_ok=True)
135
+ os.makedirs(configs_dir, exist_ok=True)
136
+
137
+ labels_path = os.path.join(configs_dir, "labels.json")
138
+ model_path = os.path.join(checkpoints_dir, "lr_model.joblib")
139
+
140
+ image_size = 64 # 64x64 grayscale baseline
141
+ # ------------------------------------------------
142
+
143
+ print("[INFO] Loading datasets...")
144
+ train_dataset, test_dataset, class_to_idx = get_datasets(data_root, image_size=image_size)
145
+
146
+ print(f"[INFO] Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
147
+ print(f"[INFO] Number of classes: {len(class_to_idx)}")
148
+
149
+ print("[INFO] Converting train split to numpy...")
150
+ X_train, y_train = dataset_to_numpy(train_dataset)
151
+
152
+ print("[INFO] Converting test split to numpy...")
153
+ X_test, y_test = dataset_to_numpy(test_dataset)
154
+
155
+ # Save label mapping for later inference
156
+ save_labels(class_to_idx, labels_path)
157
+
158
+ # Train LR
159
+ clf = train_logistic_regression(X_train, y_train)
160
+
161
+ # Evaluate
162
+ evaluate_model(clf, X_train, y_train, split_name="Train")
163
+ evaluate_model(clf, X_test, y_test, split_name="Test")
164
+
165
+ # Save model
166
+ joblib.dump(clf, model_path)
167
+ print(f"[INFO] Saved Logistic Regression model to {model_path}")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()
src/training/train_resnet_pt_lr.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/training/train_resnet_pt_lr.py
2
+
3
+ import os
4
+ import argparse
5
+ import json
6
+
7
+ import numpy as np
8
+ from sklearn.linear_model import LogisticRegression
9
+ from sklearn.metrics import accuracy_score
10
+ import joblib
11
+
12
+
13
+ def load_features(features_dir: str):
14
+ x_train_path = os.path.join(features_dir, "X_train_resnet18.npy")
15
+ y_train_path = os.path.join(features_dir, "y_train.npy")
16
+ x_test_path = os.path.join(features_dir, "X_test_resnet18.npy")
17
+ y_test_path = os.path.join(features_dir, "y_test.npy")
18
+
19
+ assert os.path.exists(x_train_path), f"Missing: {x_train_path}"
20
+ assert os.path.exists(y_train_path), f"Missing: {y_train_path}"
21
+ assert os.path.exists(x_test_path), f"Missing: {x_test_path}"
22
+ assert os.path.exists(y_test_path), f"Missing: {y_test_path}"
23
+
24
+ X_train = np.load(x_train_path)
25
+ y_train = np.load(y_train_path)
26
+ X_test = np.load(x_test_path)
27
+ y_test = np.load(y_test_path)
28
+
29
+ return X_train, y_train, X_test, y_test
30
+
31
+
32
+ def main(
33
+ features_dir: str = "data/resnet18_features",
34
+ ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib",
35
+ labels_path: str = "configs/labels.json",
36
+ ):
37
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
38
+
39
+ print(f"[+] Loading features from {features_dir} ...")
40
+ X_train, y_train, X_test, y_test = load_features(features_dir)
41
+
42
+ print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
43
+ print(f" X_test shape : {X_test.shape}, y_test shape : {y_test.shape}")
44
+
45
+ num_features = X_train.shape[1]
46
+ print(f"[+] Feature dimension: {num_features}")
47
+
48
+ # Labels mapping is not strictly needed for training, but we keep the path
49
+ # around for inference later.
50
+ if os.path.exists(labels_path):
51
+ with open(labels_path, "r") as f:
52
+ labels = json.load(f)
53
+ num_classes = len(labels)
54
+ print(f"[+] Loaded labels from {labels_path}, num_classes={num_classes}")
55
+ else:
56
+ print(f"[!] Warning: {labels_path} not found. Inference will need this later.")
57
+ labels = None
58
+
59
+ print("[+] Training Logistic Regression on ResNet18 features ...")
60
+ clf = LogisticRegression(
61
+ penalty="l2",
62
+ C=1.0,
63
+ solver="saga",
64
+ multi_class="multinomial",
65
+ max_iter=1000,
66
+ n_jobs=-1,
67
+ verbose=1,
68
+ )
69
+
70
+ clf.fit(X_train, y_train)
71
+
72
+ print("[+] Evaluating ...")
73
+ y_pred_train = clf.predict(X_train)
74
+ y_pred_test = clf.predict(X_test)
75
+
76
+ train_acc = accuracy_score(y_train, y_pred_train)
77
+ test_acc = accuracy_score(y_test, y_pred_test)
78
+
79
+ print(f" Train accuracy: {train_acc:.4f}")
80
+ print(f" Test accuracy : {test_acc:.4f}")
81
+
82
+ print(f"[+] Saving LR head to {ckpt_path} ...")
83
+ payload = {
84
+ "model": clf,
85
+ "backbone": "resnet18_imagenet",
86
+ "feature_dim": int(num_features),
87
+ "labels_path": labels_path,
88
+ "train_acc": float(train_acc),
89
+ "test_acc": float(test_acc),
90
+ }
91
+ joblib.dump(payload, ckpt_path)
92
+
93
+ print("[+] Done training ResNet PT + LR.")
94
+
95
+
96
+ def parse_args():
97
+ parser = argparse.ArgumentParser(
98
+ description="Train Logistic Regression head on ResNet18 (pretrained) features."
99
+ )
100
+ parser.add_argument(
101
+ "--features-dir",
102
+ type=str,
103
+ default="data/resnet18_features",
104
+ help="Directory containing X_train_resnet18.npy etc.",
105
+ )
106
+ parser.add_argument(
107
+ "--ckpt-path",
108
+ type=str,
109
+ default="checkpoints/resnet_pt_lr_head.joblib",
110
+ help="Where to save LR head checkpoint.",
111
+ )
112
+ parser.add_argument(
113
+ "--labels-path",
114
+ type=str,
115
+ default="configs/labels.json",
116
+ help="Path to labels mapping JSON.",
117
+ )
118
+
119
+ return parser.parse_args()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ args = parse_args()
124
+ main(
125
+ features_dir=args.features_dir,
126
+ ckpt_path=args.ckpt_path,
127
+ labels_path=args.labels_path,
128
+ )
src/training/train_resnet_pt_svm.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/training/train_resnet_pt_svm.py
2
+
3
+ import os
4
+ import argparse
5
+ import json
6
+
7
+ import numpy as np
8
+ from sklearn.svm import LinearSVC
9
+ from sklearn.metrics import accuracy_score
10
+ import joblib
11
+
12
+
13
+ def load_features(features_dir: str):
14
+ x_train_path = os.path.join(features_dir, "X_train_resnet18.npy")
15
+ y_train_path = os.path.join(features_dir, "y_train.npy")
16
+ x_test_path = os.path.join(features_dir, "X_test_resnet18.npy")
17
+ y_test_path = os.path.join(features_dir, "y_test.npy")
18
+
19
+ assert os.path.exists(x_train_path), f"Missing: {x_train_path}"
20
+ assert os.path.exists(y_train_path), f"Missing: {y_train_path}"
21
+ assert os.path.exists(x_test_path), f"Missing: {x_test_path}"
22
+ assert os.path.exists(y_test_path), f"Missing: {y_test_path}"
23
+
24
+ X_train = np.load(x_train_path)
25
+ y_train = np.load(y_train_path)
26
+ X_test = np.load(x_test_path)
27
+ y_test = np.load(y_test_path)
28
+
29
+ return X_train, y_train, X_test, y_test
30
+
31
+
32
+ def main(
33
+ features_dir: str = "data/resnet18_features",
34
+ ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib",
35
+ labels_path: str = "configs/labels.json",
36
+ ):
37
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
38
+
39
+ print(f"[+] Loading ResNet18 features from {features_dir} ...")
40
+ X_train, y_train, X_test, y_test = load_features(features_dir)
41
+
42
+ print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
43
+ print(f" X_test shape : {X_test.shape}, y_test shape : {y_test.shape}")
44
+
45
+ num_features = X_train.shape[1]
46
+ print(f"[+] Feature dimension: {num_features}")
47
+
48
+ # Labels mapping for logging / sanity
49
+ if os.path.exists(labels_path):
50
+ with open(labels_path, "r") as f:
51
+ labels = json.load(f)
52
+ num_classes = len(labels)
53
+ print(f"[+] Loaded labels from {labels_path}, num_classes={num_classes}")
54
+ else:
55
+ print(f"[!] Warning: {labels_path} not found. Inference will need this later.")
56
+ labels = None
57
+
58
+ print("[+] Training Linear SVM on ResNet18 features ...")
59
+ svm = LinearSVC(
60
+ C=1.0,
61
+ penalty="l2",
62
+ loss="squared_hinge",
63
+ max_iter=5000, # give it some room
64
+ )
65
+
66
+ svm.fit(X_train, y_train)
67
+
68
+ print("[+] Evaluating ...")
69
+ y_pred_train = svm.predict(X_train)
70
+ y_pred_test = svm.predict(X_test)
71
+
72
+ train_acc = accuracy_score(y_train, y_pred_train)
73
+ test_acc = accuracy_score(y_test, y_pred_test)
74
+
75
+ print(f" Train accuracy: {train_acc:.4f}")
76
+ print(f" Test accuracy : {test_acc:.4f}")
77
+
78
+ print(f"[+] Saving ResNet PT + SVM head to {ckpt_path} ...")
79
+ payload = {
80
+ "model": svm,
81
+ "backbone": "resnet18_imagenet",
82
+ "feature_dim": int(num_features),
83
+ "labels_path": labels_path,
84
+ "train_acc": float(train_acc),
85
+ "test_acc": float(test_acc),
86
+ }
87
+ joblib.dump(payload, ckpt_path)
88
+
89
+ print("[+] Done training ResNet PT + SVM.")
90
+
91
+
92
+ def parse_args():
93
+ parser = argparse.ArgumentParser(
94
+ description="Train Linear SVM head on ResNet18 (pretrained) features."
95
+ )
96
+ parser.add_argument(
97
+ "--features-dir",
98
+ type=str,
99
+ default="data/resnet18_features",
100
+ help="Directory containing X_train_resnet18.npy etc.",
101
+ )
102
+ parser.add_argument(
103
+ "--ckpt-path",
104
+ type=str,
105
+ default="checkpoints/resnet_pt_svm_head.joblib",
106
+ help="Where to save SVM head checkpoint.",
107
+ )
108
+ parser.add_argument(
109
+ "--labels-path",
110
+ type=str,
111
+ default="configs/labels.json",
112
+ help="Path to labels mapping JSON.",
113
+ )
114
+
115
+ return parser.parse_args()
116
+
117
+
118
+ if __name__ == "__main__":
119
+ args = parse_args()
120
+ main(
121
+ features_dir=args.features_dir,
122
+ ckpt_path=args.ckpt_path,
123
+ labels_path=args.labels_path,
124
+ )
src/training/train_svm.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/training/train_svm.py
2
+
3
+ import os
4
+ import json
5
+ import argparse
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from torchvision import transforms, datasets
10
+
11
+ import numpy as np
12
+ from sklearn.svm import LinearSVC
13
+ from sklearn.metrics import accuracy_score
14
+ import joblib
15
+
16
+
17
+ def get_transforms():
18
+ return transforms.Compose([
19
+ transforms.Resize((64, 64)),
20
+ transforms.Grayscale(num_output_channels=1),
21
+ transforms.ToTensor(), # (1, 64, 64) in [0, 1]
22
+ ])
23
+
24
+
25
+ def build_datasets(data_root: str):
26
+ tx = get_transforms()
27
+
28
+ train_ds = datasets.OxfordIIITPet(
29
+ root=data_root,
30
+ split="trainval",
31
+ target_types="category",
32
+ transform=tx,
33
+ download=True,
34
+ )
35
+
36
+ test_ds = datasets.OxfordIIITPet(
37
+ root=data_root,
38
+ split="test",
39
+ target_types="category",
40
+ transform=tx,
41
+ download=True,
42
+ )
43
+
44
+ return train_ds, test_ds
45
+
46
+
47
+ def dataset_to_numpy(dataset):
48
+ """
49
+ Convert a torchvision dataset to (X, y) numpy arrays.
50
+ X: (N, 4096) flattened grayscale pixels
51
+ y: (N,) integer labels
52
+ """
53
+ loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)
54
+
55
+ xs = []
56
+ ys = []
57
+ for images, targets in loader:
58
+ # images: (B, 1, 64, 64)
59
+ b = images.shape[0]
60
+ images = images.view(b, -1) # (B, 4096)
61
+ xs.append(images.numpy())
62
+ ys.append(targets.numpy())
63
+
64
+ X = np.concatenate(xs, axis=0)
65
+ y = np.concatenate(ys, axis=0)
66
+ return X, y
67
+
68
+
69
+ def ensure_labels_json(train_ds, labels_path: str):
70
+ os.makedirs(os.path.dirname(labels_path), exist_ok=True)
71
+
72
+ if os.path.exists(labels_path):
73
+ with open(labels_path, "r") as f:
74
+ labels = json.load(f)
75
+ # sanity: if it already exists, just return
76
+ return labels
77
+
78
+ # OxfordIIITPet: category targets are indices into .categories
79
+ id_to_name = {i: name for i, name in enumerate(train_ds.categories)}
80
+
81
+ with open(labels_path, "w") as f:
82
+ json.dump(id_to_name, f, indent=2)
83
+
84
+ return id_to_name
85
+
86
+
87
+ def train_svm(
88
+ data_root: str = "data/oxford-iiit-pet",
89
+ ckpt_path: str = "checkpoints/svm_model.joblib",
90
+ labels_path: str = "configs/labels.json",
91
+ ):
92
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
93
+
94
+ print(f"[+] Loading datasets from {data_root} ...")
95
+ train_ds, test_ds = build_datasets(data_root)
96
+
97
+ print("[+] Building labels.json (if missing) ...")
98
+ labels = ensure_labels_json(train_ds, labels_path)
99
+ num_classes = len(labels)
100
+ print(f"[+] Num classes (from labels.json): {num_classes}")
101
+
102
+ print("[+] Converting train dataset to numpy features ...")
103
+ X_train, y_train = dataset_to_numpy(train_ds)
104
+ print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
105
+
106
+ print("[+] Converting test dataset to numpy features ...")
107
+ X_test, y_test = dataset_to_numpy(test_ds)
108
+ print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
109
+
110
+ print("[+] Training Linear SVM on raw pixels ...")
111
+ svm = LinearSVC(
112
+ C=1.0,
113
+ penalty="l2",
114
+ loss="squared_hinge",
115
+ max_iter=2000,
116
+ # dual=True (default) is fine when n_samples > n_features,
117
+ # which is the case here.
118
+ )
119
+
120
+ svm.fit(X_train, y_train)
121
+
122
+ print("[+] Evaluating on train and test ...")
123
+ y_pred_train = svm.predict(X_train)
124
+ y_pred_test = svm.predict(X_test)
125
+
126
+ train_acc = accuracy_score(y_train, y_pred_train)
127
+ test_acc = accuracy_score(y_test, y_pred_test)
128
+
129
+ print(f" Train accuracy: {train_acc:.4f}")
130
+ print(f" Test accuracy : {test_acc:.4f}")
131
+
132
+ print(f"[+] Saving SVM model to {ckpt_path} ...")
133
+ joblib.dump(
134
+ {
135
+ "model": svm,
136
+ "labels_path": labels_path,
137
+ "train_acc": float(train_acc),
138
+ "test_acc": float(test_acc),
139
+ },
140
+ ckpt_path,
141
+ )
142
+
143
+ print("[+] Done.")
144
+
145
+
146
+ def parse_args():
147
+ parser = argparse.ArgumentParser(description="Train Linear SVM on raw pixel features.")
148
+
149
+ parser.add_argument(
150
+ "--data-root",
151
+ type=str,
152
+ default="data/oxford-iiit-pet",
153
+ help="Root directory for Oxford-IIIT Pet dataset.",
154
+ )
155
+ parser.add_argument(
156
+ "--ckpt-path",
157
+ type=str,
158
+ default="checkpoints/svm_model.joblib",
159
+ help="Where to save the trained SVM model.",
160
+ )
161
+ parser.add_argument(
162
+ "--labels-path",
163
+ type=str,
164
+ default="configs/labels.json",
165
+ help="Path to labels mapping JSON.",
166
+ )
167
+
168
+ return parser.parse_args()
169
+
170
+
171
+ if __name__ == "__main__":
172
+ args = parse_args()
173
+ train_svm(
174
+ data_root=args.data_root,
175
+ ckpt_path=args.ckpt_path,
176
+ labels_path=args.labels_path,
177
+ )