Spaces:
Sleeping
Sleeping
Upload 37 files
Browse files- app.py +152 -0
- checkpoints/lr_model.joblib +3 -0
- checkpoints/resnet_pt_lr_head.joblib +3 -0
- checkpoints/resnet_pt_svm_head.joblib +3 -0
- checkpoints/svm_model.joblib +3 -0
- configs/labels.json +39 -0
- requirements.txt +7 -0
- src/__pycache__/registry.cpython-313.pyc +0 -0
- src/evaluation/__pycache__/eval_accuracy.cpython-313.pyc +0 -0
- src/evaluation/__pycache__/eval_confusion.cpython-313.pyc +0 -0
- src/evaluation/__pycache__/eval_tsne_umap.cpython-313.pyc +0 -0
- src/evaluation/eval_accuracy.py +184 -0
- src/evaluation/eval_confusion.py +206 -0
- src/evaluation/eval_tsne_umap.py +283 -0
- src/inference/__pycache__/lr_model.cpython-313.pyc +0 -0
- src/inference/__pycache__/resnet_pt_lr_model.cpython-313.pyc +0 -0
- src/inference/__pycache__/resnet_pt_svm_model.cpython-313.pyc +0 -0
- src/inference/__pycache__/svm_model.cpython-313.pyc +0 -0
- src/inference/__pycache__/test_resnet_pt_lr.cpython-313.pyc +0 -0
- src/inference/__pycache__/test_resnet_pt_svm.cpython-313.pyc +0 -0
- src/inference/base_model.py +30 -0
- src/inference/lr_model.py +63 -0
- src/inference/resnet_pt_lr_model.py +179 -0
- src/inference/resnet_pt_svm_model.py +174 -0
- src/inference/svm_model.py +115 -0
- src/inference/test_resnet_pt_lr.py +150 -0
- src/inference/test_resnet_pt_svm.py +143 -0
- src/registry.py +108 -0
- src/training/__pycache__/extract_resnet_features.cpython-313.pyc +0 -0
- src/training/__pycache__/train_resnet_pt_lr.cpython-313.pyc +0 -0
- src/training/__pycache__/train_resnet_pt_svm.cpython-313.pyc +0 -0
- src/training/__pycache__/train_svm.cpython-313.pyc +0 -0
- src/training/extract_resnet_features.py +183 -0
- src/training/train_lr.py +171 -0
- src/training/train_resnet_pt_lr.py +128 -0
- src/training/train_resnet_pt_svm.py +124 -0
- src/training/train_svm.py +177 -0
app.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ui/app.py
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
from src.registry import get_model_display_names, get_model
|
| 7 |
+
|
| 8 |
+
APP_TITLE = "PetRecog – Oxford-IIIT Pet Identification"
|
| 9 |
+
APP_DESC = (
|
| 10 |
+
"Upload a pet image, choose a model, and compare predictions across "
|
| 11 |
+
"classical (LR, SVM) and deep-feature (ResNet) models."
|
| 12 |
+
)
|
| 13 |
+
TOP_K_DEFAULT = 5
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]:
|
| 17 |
+
"""
|
| 18 |
+
Convert the model's top_k list of dicts into a 2D list suitable for gr.Dataframe.
|
| 19 |
+
|
| 20 |
+
Expected each entry in top_k to look like:
|
| 21 |
+
{ 'class_id': int, 'class_name': str, 'probability': float }
|
| 22 |
+
"""
|
| 23 |
+
rows = []
|
| 24 |
+
for rank, entry in enumerate(top_k, start=1):
|
| 25 |
+
class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}")
|
| 26 |
+
prob = entry.get("probability", 0.0)
|
| 27 |
+
rows.append([rank, class_name, round(float(prob) * 100.0, 2)])
|
| 28 |
+
return rows
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_inference(model_id: str, image) -> Dict[str, Any]:
|
| 32 |
+
"""
|
| 33 |
+
Wrapper called by Gradio.
|
| 34 |
+
|
| 35 |
+
Inputs:
|
| 36 |
+
- model_id: key from the registry
|
| 37 |
+
- image: PIL image object from gr.Image (type='pil')
|
| 38 |
+
|
| 39 |
+
Outputs (as a dict mapped to Gradio components in the UI):
|
| 40 |
+
- main_text: formatted prediction string
|
| 41 |
+
- topk_table: 2D list for gr.Dataframe
|
| 42 |
+
"""
|
| 43 |
+
if image is None:
|
| 44 |
+
return {
|
| 45 |
+
"main_text": "⚠️ Please upload an image first.",
|
| 46 |
+
"topk_table": [],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# Get the model instance (lazy-loaded via registry)
|
| 50 |
+
model = get_model(model_id)
|
| 51 |
+
|
| 52 |
+
# All models follow the shared predict API:
|
| 53 |
+
# predict(PIL.Image, top_k=TOP_K_DEFAULT) -> {
|
| 54 |
+
# 'class_id', 'class_name', 'probabilities', 'top_k'
|
| 55 |
+
# }
|
| 56 |
+
result = model.predict(image, top_k=TOP_K_DEFAULT)
|
| 57 |
+
|
| 58 |
+
class_name = result.get("class_name", "Unknown")
|
| 59 |
+
class_id = result.get("class_id", "N/A")
|
| 60 |
+
top_k = result.get("top_k", [])
|
| 61 |
+
|
| 62 |
+
main_text = f"**Predicted Class:** {class_name} \n" f"**Class ID:** {class_id}"
|
| 63 |
+
|
| 64 |
+
table = format_topk_for_table(top_k)
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"main_text": main_text,
|
| 68 |
+
"topk_table": table,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_demo() -> gr.Blocks:
|
| 73 |
+
model_display_names = get_model_display_names()
|
| 74 |
+
# Gradio dropdown will show pretty display_name, but we need to map back to ids.
|
| 75 |
+
id_to_name = model_display_names
|
| 76 |
+
name_to_id = {v: k for k, v in id_to_name.items()}
|
| 77 |
+
|
| 78 |
+
default_display_name = next(iter(name_to_id.keys())) if name_to_id else None
|
| 79 |
+
|
| 80 |
+
with gr.Blocks(css="""
|
| 81 |
+
body { background: #fbead8; }
|
| 82 |
+
.noble-header { text-align: center; margin-bottom: 1.0rem; }
|
| 83 |
+
.noble-title { font-size: 2.0rem; font-weight: 800; color: #5b3b27; }
|
| 84 |
+
.noble-subtitle { font-size: 0.95rem; color: #7a5b45; }
|
| 85 |
+
""") as demo:
|
| 86 |
+
# Header
|
| 87 |
+
with gr.Row(elem_classes="noble-header"):
|
| 88 |
+
gr.Markdown(
|
| 89 |
+
f"### {APP_TITLE}\n{APP_DESC}",
|
| 90 |
+
elem_classes="noble-title"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
with gr.Row():
|
| 94 |
+
# Left column: controls
|
| 95 |
+
with gr.Column(scale=1):
|
| 96 |
+
gr.Markdown("#### 1️⃣ Select Model & Upload Image")
|
| 97 |
+
|
| 98 |
+
model_dropdown = gr.Dropdown(
|
| 99 |
+
choices=list(name_to_id.keys()),
|
| 100 |
+
value=default_display_name,
|
| 101 |
+
label="Select Model",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
image_input = gr.Image(
|
| 105 |
+
type="pil",
|
| 106 |
+
label="Upload your pet image (JPEG/PNG)",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
run_button = gr.Button("Run Identification")
|
| 110 |
+
|
| 111 |
+
# Right column: output
|
| 112 |
+
with gr.Column(scale=1):
|
| 113 |
+
gr.Markdown("#### 2️⃣ Model Prediction")
|
| 114 |
+
|
| 115 |
+
main_output = gr.Markdown(
|
| 116 |
+
value="Prediction will appear here.",
|
| 117 |
+
label="Prediction",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
topk_output = gr.Dataframe(
|
| 121 |
+
headers=["Rank", "Class Name", "Probability (%)"],
|
| 122 |
+
datatype=["number", "str", "number"],
|
| 123 |
+
col_count=(3, "fixed"),
|
| 124 |
+
label=f"Top-{TOP_K_DEFAULT} Predictions",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Wiring: button click -> inference
|
| 128 |
+
def _gradio_infer(selected_display_name, img):
|
| 129 |
+
if selected_display_name is None:
|
| 130 |
+
return {
|
| 131 |
+
main_output: "⚠️ Please select a model.",
|
| 132 |
+
topk_output: [],
|
| 133 |
+
}
|
| 134 |
+
model_id = name_to_id[selected_display_name]
|
| 135 |
+
result = run_inference(model_id, img)
|
| 136 |
+
return {
|
| 137 |
+
main_output: result["main_text"],
|
| 138 |
+
topk_output: result["topk_table"],
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
run_button.click(
|
| 142 |
+
fn=_gradio_infer,
|
| 143 |
+
inputs=[model_dropdown, image_input],
|
| 144 |
+
outputs=[main_output, topk_output],
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return demo
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
demo = build_demo()
|
| 152 |
+
demo.launch()
|
checkpoints/lr_model.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4aa8d4f7586fbbaf893e0ff269fdde75123a12eb66ee4175beb0e4cc26d5e8a
|
| 3 |
+
size 607515
|
checkpoints/resnet_pt_lr_head.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a91416a9e7522ac0626b96f0ba1903482ddd8fe3181accdcb8833a3494c55d94
|
| 3 |
+
size 77209
|
checkpoints/resnet_pt_svm_head.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9f950c784e7196d427c1165cfc18a3520c3212d2a52d5d8b96006591f80da6c
|
| 3 |
+
size 153001
|
checkpoints/svm_model.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f6e50cb4316ab0d9e6634d150c18fdb1302d25d88067c1d52c7a5deffe17ec6
|
| 3 |
+
size 1213818
|
configs/labels.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0": "Abyssinian",
|
| 3 |
+
"1": "American Bulldog",
|
| 4 |
+
"2": "American Pit Bull Terrier",
|
| 5 |
+
"3": "Basset Hound",
|
| 6 |
+
"4": "Beagle",
|
| 7 |
+
"5": "Bengal",
|
| 8 |
+
"6": "Birman",
|
| 9 |
+
"7": "Bombay",
|
| 10 |
+
"8": "Boxer",
|
| 11 |
+
"9": "British Shorthair",
|
| 12 |
+
"10": "Chihuahua",
|
| 13 |
+
"11": "Egyptian Mau",
|
| 14 |
+
"12": "English Cocker Spaniel",
|
| 15 |
+
"13": "English Setter",
|
| 16 |
+
"14": "German Shorthaired",
|
| 17 |
+
"15": "Great Pyrenees",
|
| 18 |
+
"16": "Havanese",
|
| 19 |
+
"17": "Japanese Chin",
|
| 20 |
+
"18": "Keeshond",
|
| 21 |
+
"19": "Leonberger",
|
| 22 |
+
"20": "Maine Coon",
|
| 23 |
+
"21": "Miniature Pinscher",
|
| 24 |
+
"22": "Newfoundland",
|
| 25 |
+
"23": "Persian",
|
| 26 |
+
"24": "Pomeranian",
|
| 27 |
+
"25": "Pug",
|
| 28 |
+
"26": "Ragdoll",
|
| 29 |
+
"27": "Russian Blue",
|
| 30 |
+
"28": "Saint Bernard",
|
| 31 |
+
"29": "Samoyed",
|
| 32 |
+
"30": "Scottish Terrier",
|
| 33 |
+
"31": "Shiba Inu",
|
| 34 |
+
"32": "Siamese",
|
| 35 |
+
"33": "Sphynx",
|
| 36 |
+
"34": "Staffordshire Bull Terrier",
|
| 37 |
+
"35": "Wheaten Terrier",
|
| 38 |
+
"36": "Yorkshire Terrier"
|
| 39 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0
|
| 2 |
+
torch>=2.0
|
| 3 |
+
torchvision>=0.15
|
| 4 |
+
numpy
|
| 5 |
+
scikit-learn
|
| 6 |
+
joblib
|
| 7 |
+
Pillow
|
src/__pycache__/registry.cpython-313.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
src/evaluation/__pycache__/eval_accuracy.cpython-313.pyc
ADDED
|
Binary file (6.02 kB). View file
|
|
|
src/evaluation/__pycache__/eval_confusion.cpython-313.pyc
ADDED
|
Binary file (8.15 kB). View file
|
|
|
src/evaluation/__pycache__/eval_tsne_umap.cpython-313.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
src/evaluation/eval_accuracy.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/evaluation/eval_accuracy.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 9 |
+
|
| 10 |
+
from torchvision.datasets import OxfordIIITPet
|
| 11 |
+
|
| 12 |
+
from src.registry import get_model
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_test_dataset(data_root: str):
|
| 17 |
+
"""
|
| 18 |
+
Load Oxford-IIIT Pet test split without transforms, so we get PIL images.
|
| 19 |
+
Targets will be integer class indices (0..36).
|
| 20 |
+
"""
|
| 21 |
+
dataset = OxfordIIITPet(
|
| 22 |
+
root=data_root,
|
| 23 |
+
split="test",
|
| 24 |
+
target_types="category",
|
| 25 |
+
transform=None, # we want raw PIL here
|
| 26 |
+
)
|
| 27 |
+
return dataset
|
| 28 |
+
|
| 29 |
+
def load_model_direct(model_id: str):
|
| 30 |
+
"""
|
| 31 |
+
Workaround loader that bypasses registry and constructs models
|
| 32 |
+
using their actual existing constructor signatures.
|
| 33 |
+
Modify only the paths here if needed.
|
| 34 |
+
"""
|
| 35 |
+
if model_id == "lr_raw":
|
| 36 |
+
from src.inference.lr_model import LRModel
|
| 37 |
+
# Adjust to match your actual LRModel __init__
|
| 38 |
+
return LRModel("checkpoints/lr_model.joblib", "configs/labels.json")
|
| 39 |
+
|
| 40 |
+
elif model_id == "svm_raw":
|
| 41 |
+
from src.inference.svm_model import SVMModel
|
| 42 |
+
return SVMModel("checkpoints/svm_model.joblib", "configs/labels.json")
|
| 43 |
+
|
| 44 |
+
elif model_id == "resnet_pt_lr":
|
| 45 |
+
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
|
| 46 |
+
# If these require device or not, match your working constructor
|
| 47 |
+
return ResNetPTLRModel(
|
| 48 |
+
ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
|
| 49 |
+
labels_path="configs/labels.json",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
elif model_id == "resnet_pt_svm":
|
| 53 |
+
from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
|
| 54 |
+
return ResNetPTSVMModel(
|
| 55 |
+
ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
|
| 56 |
+
labels_path="configs/labels.json",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unsupported model_id: {model_id}")
|
| 61 |
+
|
| 62 |
+
def evaluate_model_on_dataset(model_id: str, data_root: str):
|
| 63 |
+
"""
|
| 64 |
+
Evaluate a single model (by id from registry) on the Oxford-IIIT Pet test split.
|
| 65 |
+
Uses model.predict(PIL.Image, top_k=5) API.
|
| 66 |
+
|
| 67 |
+
Returns a dict with:
|
| 68 |
+
- top1_acc
|
| 69 |
+
- top5_acc
|
| 70 |
+
- report_dict (per-class and aggregate metrics)
|
| 71 |
+
"""
|
| 72 |
+
print(f"\n=== Evaluating model: {model_id} ===")
|
| 73 |
+
|
| 74 |
+
dataset = load_test_dataset(data_root)
|
| 75 |
+
model = load_model_direct(model_id)
|
| 76 |
+
|
| 77 |
+
y_true = []
|
| 78 |
+
y_pred_top1 = []
|
| 79 |
+
top5_correct = 0
|
| 80 |
+
|
| 81 |
+
for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
|
| 82 |
+
img, target = dataset[idx] # img: PIL.Image, target: int
|
| 83 |
+
|
| 84 |
+
# Try to call with top_k; if the model doesn't support it, fall back gracefully
|
| 85 |
+
try:
|
| 86 |
+
result = model.predict(img, top_k=5)
|
| 87 |
+
except TypeError:
|
| 88 |
+
# Older / simpler API: predict(img) without top_k
|
| 89 |
+
result = model.predict(img)
|
| 90 |
+
|
| 91 |
+
# Top-1 prediction (must exist)
|
| 92 |
+
pred_id = int(result.get("class_id"))
|
| 93 |
+
y_true.append(int(target))
|
| 94 |
+
y_pred_top1.append(pred_id)
|
| 95 |
+
|
| 96 |
+
# Try to get top_k list; if not present, create a synthetic one using only top-1
|
| 97 |
+
top_k = result.get("top_k")
|
| 98 |
+
if not top_k:
|
| 99 |
+
# Fallback: just treat the top-1 prediction as the only candidate.
|
| 100 |
+
# This means Top-5 == Top-1 for such models, which is acceptable as a workaround.
|
| 101 |
+
cname = result.get("class_name", "")
|
| 102 |
+
top_k = [{
|
| 103 |
+
"class_id": pred_id,
|
| 104 |
+
"class_name": cname,
|
| 105 |
+
"probability": 1.0
|
| 106 |
+
}]
|
| 107 |
+
|
| 108 |
+
# Top-5 correct? (GT in top_k list)
|
| 109 |
+
if any(int(entry.get("class_id")) == int(target) for entry in top_k):
|
| 110 |
+
top5_correct += 1
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
y_true = np.array(y_true)
|
| 114 |
+
y_pred_top1 = np.array(y_pred_top1)
|
| 115 |
+
n = len(y_true)
|
| 116 |
+
|
| 117 |
+
# Overall Top-1 accuracy
|
| 118 |
+
top1_acc = accuracy_score(y_true, y_pred_top1)
|
| 119 |
+
|
| 120 |
+
# Overall Top-5 accuracy
|
| 121 |
+
top5_acc = top5_correct / float(n)
|
| 122 |
+
|
| 123 |
+
# Detailed precision/recall/F1 per class + aggregate
|
| 124 |
+
report = classification_report(
|
| 125 |
+
y_true,
|
| 126 |
+
y_pred_top1,
|
| 127 |
+
digits=4,
|
| 128 |
+
output_dict=True # gives a nice dict we can log/inspect
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
print(f"Top-1 accuracy ({model_id}): {top1_acc:.4f}")
|
| 132 |
+
print(f"Top-5 accuracy ({model_id}): {top5_acc:.4f}")
|
| 133 |
+
print("\nMacro avg (from classification_report):")
|
| 134 |
+
print(report["macro avg"])
|
| 135 |
+
print("\nWeighted avg (from classification_report):")
|
| 136 |
+
print(report["weighted avg"])
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"model_id": model_id,
|
| 140 |
+
"top1_acc": top1_acc,
|
| 141 |
+
"top5_acc": top5_acc,
|
| 142 |
+
"report": report,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
parser = argparse.ArgumentParser()
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--data-root",
|
| 150 |
+
type=str,
|
| 151 |
+
default="data/oxford-iiit-pet",
|
| 152 |
+
help="Root directory of Oxford-IIIT Pet dataset.",
|
| 153 |
+
)
|
| 154 |
+
args = parser.parse_args()
|
| 155 |
+
|
| 156 |
+
# List all models you want to evaluate
|
| 157 |
+
model_ids = [
|
| 158 |
+
"lr_raw",
|
| 159 |
+
"svm_raw",
|
| 160 |
+
"resnet_pt_lr",
|
| 161 |
+
"resnet_pt_svm",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
all_results = []
|
| 165 |
+
|
| 166 |
+
for mid in model_ids:
|
| 167 |
+
res = evaluate_model_on_dataset(mid, args.data_root)
|
| 168 |
+
all_results.append(res)
|
| 169 |
+
|
| 170 |
+
# Print a compact summary table at the end
|
| 171 |
+
print("\n===== Summary (Top-1 & Top-5) =====")
|
| 172 |
+
print(f"{'Model':25s} {'Top-1':>8s} {'Top-5':>8s}")
|
| 173 |
+
print("-" * 50)
|
| 174 |
+
for res in all_results:
|
| 175 |
+
name = res["model_id"]
|
| 176 |
+
t1 = res["top1_acc"]
|
| 177 |
+
t5 = res["top5_acc"]
|
| 178 |
+
print(f"{name:25s} {t1:8.4f} {t5:8.4f}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
# Make sure torch doesn't spawn too many threads on some systems
|
| 183 |
+
torch.set_num_threads(4)
|
| 184 |
+
main()
|
src/evaluation/eval_confusion.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/evaluation/eval_confusion.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from sklearn.metrics import confusion_matrix
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
# Reuse the same dataset + model loading logic as eval_accuracy.py
|
| 14 |
+
from src.evaluation.eval_accuracy import load_test_dataset, load_model_direct
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_class_names(labels_path: str = "configs/labels.json"):
|
| 18 |
+
"""
|
| 19 |
+
Try to load class names from labels.json.
|
| 20 |
+
|
| 21 |
+
This is written to be robust to a few likely formats:
|
| 22 |
+
- List: ["Abyssinian", "American Bulldog", ...]
|
| 23 |
+
- Dict with string keys: {"0": "Abyssinian", "1": "American Bulldog", ...}
|
| 24 |
+
- Dict with 'id_to_label': {"id_to_label": {"0": "Abyssinian", ...}}
|
| 25 |
+
|
| 26 |
+
If anything goes wrong, returns None and we’ll just use numeric class IDs on the axes.
|
| 27 |
+
"""
|
| 28 |
+
try:
|
| 29 |
+
with open(labels_path, "r") as f:
|
| 30 |
+
data = json.load(f)
|
| 31 |
+
except FileNotFoundError:
|
| 32 |
+
print(f"[WARN] labels file not found at {labels_path}, using numeric IDs.")
|
| 33 |
+
return None
|
| 34 |
+
except json.JSONDecodeError:
|
| 35 |
+
print(f"[WARN] Could not parse {labels_path}, using numeric IDs.")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
# Case 1: simple list
|
| 39 |
+
if isinstance(data, list):
|
| 40 |
+
return data
|
| 41 |
+
|
| 42 |
+
# Case 2: dict with 'id_to_label'
|
| 43 |
+
if isinstance(data, dict) and "id_to_label" in data:
|
| 44 |
+
id_to_label = data["id_to_label"]
|
| 45 |
+
# sort by integer key
|
| 46 |
+
keys = sorted(id_to_label.keys(), key=lambda k: int(k))
|
| 47 |
+
return [id_to_label[k] for k in keys]
|
| 48 |
+
|
| 49 |
+
# Case 3: dict mapping "0" -> "Abyssinian"
|
| 50 |
+
if isinstance(data, dict):
|
| 51 |
+
try:
|
| 52 |
+
keys = sorted(data.keys(), key=lambda k: int(k))
|
| 53 |
+
return [data[k] for k in keys]
|
| 54 |
+
except Exception:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
print(f"[WARN] Unrecognized labels.json format, using numeric IDs.")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def collect_predictions(model_id: str, data_root: str):
|
| 62 |
+
"""
|
| 63 |
+
Run the given model across the Oxford-IIIT Pet test split and collect:
|
| 64 |
+
- y_true: ground-truth integer class indices
|
| 65 |
+
- y_pred: top-1 predicted class indices
|
| 66 |
+
|
| 67 |
+
Uses the same model API as eval_accuracy.py: model.predict(PIL, top_k=5)
|
| 68 |
+
"""
|
| 69 |
+
print(f"\n=== Collecting predictions for model: {model_id} ===")
|
| 70 |
+
|
| 71 |
+
dataset = load_test_dataset(data_root)
|
| 72 |
+
model = load_model_direct(model_id)
|
| 73 |
+
|
| 74 |
+
y_true = []
|
| 75 |
+
y_pred = []
|
| 76 |
+
|
| 77 |
+
for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
|
| 78 |
+
img, target = dataset[idx] # img: PIL.Image, target: int
|
| 79 |
+
|
| 80 |
+
# Same predict logic as eval_accuracy (support with/without top_k)
|
| 81 |
+
try:
|
| 82 |
+
result = model.predict(img, top_k=5)
|
| 83 |
+
except TypeError:
|
| 84 |
+
result = model.predict(img)
|
| 85 |
+
|
| 86 |
+
pred_id = int(result.get("class_id"))
|
| 87 |
+
y_true.append(int(target))
|
| 88 |
+
y_pred.append(pred_id)
|
| 89 |
+
|
| 90 |
+
y_true = np.array(y_true)
|
| 91 |
+
y_pred = np.array(y_pred)
|
| 92 |
+
|
| 93 |
+
print(f" Collected {len(y_true)} predictions.")
|
| 94 |
+
return y_true, y_pred
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def plot_confusion_matrix(
|
| 98 |
+
cm: np.ndarray,
|
| 99 |
+
class_names,
|
| 100 |
+
title: str,
|
| 101 |
+
save_path: Path,
|
| 102 |
+
normalize: bool = True,
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Plot and save a confusion matrix.
|
| 106 |
+
|
| 107 |
+
If normalize=True, each row (true class) is normalized to sum to 1.
|
| 108 |
+
If class_names is None, we just use numeric indices on axes.
|
| 109 |
+
"""
|
| 110 |
+
if normalize:
|
| 111 |
+
cm = cm.astype("float")
|
| 112 |
+
row_sums = cm.sum(axis=1, keepdims=True)
|
| 113 |
+
cm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0)
|
| 114 |
+
|
| 115 |
+
num_classes = cm.shape[0]
|
| 116 |
+
|
| 117 |
+
plt.figure(figsize=(12, 10))
|
| 118 |
+
im = plt.imshow(cm, interpolation="nearest", cmap="viridis")
|
| 119 |
+
plt.title(title)
|
| 120 |
+
plt.colorbar(im, fraction=0.046, pad=0.04)
|
| 121 |
+
|
| 122 |
+
if class_names is not None and len(class_names) == num_classes:
|
| 123 |
+
tick_labels = class_names
|
| 124 |
+
else:
|
| 125 |
+
tick_labels = list(range(num_classes))
|
| 126 |
+
|
| 127 |
+
plt.xticks(
|
| 128 |
+
ticks=np.arange(num_classes),
|
| 129 |
+
labels=tick_labels,
|
| 130 |
+
rotation=90,
|
| 131 |
+
fontsize=6,
|
| 132 |
+
)
|
| 133 |
+
plt.yticks(
|
| 134 |
+
ticks=np.arange(num_classes),
|
| 135 |
+
labels=tick_labels,
|
| 136 |
+
fontsize=6,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
plt.xlabel("Predicted class")
|
| 140 |
+
plt.ylabel("True class")
|
| 141 |
+
plt.tight_layout()
|
| 142 |
+
plt.savefig(save_path, dpi=300)
|
| 143 |
+
plt.close()
|
| 144 |
+
print(f" Saved confusion matrix plot to: {save_path}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
parser = argparse.ArgumentParser()
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--data-root",
|
| 151 |
+
type=str,
|
| 152 |
+
default="data/oxford-iiit-pet",
|
| 153 |
+
help="Root directory of Oxford-IIIT Pet dataset.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--labels-path",
|
| 157 |
+
type=str,
|
| 158 |
+
default="configs/labels.json",
|
| 159 |
+
help="Path to labels.json (for axis names).",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--out-dir",
|
| 163 |
+
type=str,
|
| 164 |
+
default="outputs/confusion_matrices",
|
| 165 |
+
help="Directory to save confusion matrices and plots.",
|
| 166 |
+
)
|
| 167 |
+
args = parser.parse_args()
|
| 168 |
+
|
| 169 |
+
out_dir = Path(args.out_dir)
|
| 170 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 171 |
+
|
| 172 |
+
# Same set of models as eval_accuracy
|
| 173 |
+
model_ids = [
|
| 174 |
+
"lr_raw",
|
| 175 |
+
"svm_raw",
|
| 176 |
+
"resnet_pt_lr",
|
| 177 |
+
"resnet_pt_svm",
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
class_names = load_class_names(args.labels_path)
|
| 181 |
+
|
| 182 |
+
# y_true is identical for all models (same test split, same indexing),
|
| 183 |
+
# but for clarity we recompute per model; confusion_matrix only needs
|
| 184 |
+
# consistent labels (0..36) which we enforce below.
|
| 185 |
+
for model_id in model_ids:
|
| 186 |
+
y_true, y_pred = collect_predictions(model_id, args.data_root)
|
| 187 |
+
|
| 188 |
+
# Define a fixed label ordering (0..max) to get 37x37
|
| 189 |
+
num_classes = int(y_true.max()) + 1
|
| 190 |
+
labels = list(range(num_classes))
|
| 191 |
+
|
| 192 |
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
| 193 |
+
|
| 194 |
+
# Save raw matrix for future analysis
|
| 195 |
+
npy_path = out_dir / f"cm_{model_id}.npy"
|
| 196 |
+
np.save(npy_path, cm)
|
| 197 |
+
print(f" Saved raw confusion matrix to: {npy_path}")
|
| 198 |
+
|
| 199 |
+
# Save a normalized plot
|
| 200 |
+
png_path = out_dir / f"cm_{model_id}.png"
|
| 201 |
+
title = f"Confusion Matrix ({model_id})"
|
| 202 |
+
plot_confusion_matrix(cm, class_names, title, png_path, normalize=True)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
main()
|
src/evaluation/eval_tsne_umap.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/evaluation/eval_tsne_umap.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torchvision import transforms as T, models
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from sklearn.manifold import TSNE
|
| 14 |
+
|
| 15 |
+
# Reuse your test dataset loader from eval_accuracy
|
| 16 |
+
from src.evaluation.eval_accuracy import load_test_dataset
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Optional UMAP support
|
| 20 |
+
try:
|
| 21 |
+
import umap
|
| 22 |
+
HAS_UMAP = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_UMAP = False
|
| 25 |
+
print("[INFO] umap-learn not installed; will skip UMAP and only run t-SNE.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ResNetFeatureExtractor(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Wraps a torchvision ResNet18 pretrained on ImageNet and
|
| 31 |
+
exposes a 512-d feature vector for each image.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, device="cuda"):
|
| 35 |
+
super().__init__()
|
| 36 |
+
# Use the modern weights API
|
| 37 |
+
backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
| 38 |
+
# Remove the final FC layer: keep everything up to avgpool
|
| 39 |
+
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
|
| 40 |
+
self.feature_extractor.to(device)
|
| 41 |
+
self.feature_extractor.eval()
|
| 42 |
+
self.device = device
|
| 43 |
+
|
| 44 |
+
# Standard ImageNet normalization
|
| 45 |
+
self.transform = T.Compose([
|
| 46 |
+
T.Resize((224, 224)),
|
| 47 |
+
T.ToTensor(),
|
| 48 |
+
T.Normalize(
|
| 49 |
+
mean=[0.485, 0.456, 0.406],
|
| 50 |
+
std=[0.229, 0.224, 0.225],
|
| 51 |
+
),
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def forward(self, pil_img):
|
| 56 |
+
"""
|
| 57 |
+
pil_img: a single PIL.Image
|
| 58 |
+
returns: numpy array of shape (512,)
|
| 59 |
+
"""
|
| 60 |
+
x = self.transform(pil_img).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
|
| 61 |
+
feat = self.feature_extractor(x) # (1, 512, 1, 1)
|
| 62 |
+
feat = feat.view(1, -1) # (1, 512)
|
| 63 |
+
return feat.squeeze(0).cpu().numpy()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def extract_features(data_root: str, max_samples: int = 2000, seed: int = 42):
|
| 67 |
+
"""
|
| 68 |
+
Extract:
|
| 69 |
+
- Raw 64x64 grayscale flattened features (for LR/SVM-style space)
|
| 70 |
+
- ResNet18 pretrained 512-d features
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
X_raw : (N, 4096)
|
| 74 |
+
X_resnet: (N, 512)
|
| 75 |
+
y : (N,)
|
| 76 |
+
"""
|
| 77 |
+
print(f"[INFO] Loading test dataset from {data_root}")
|
| 78 |
+
dataset = load_test_dataset(data_root)
|
| 79 |
+
total = len(dataset)
|
| 80 |
+
|
| 81 |
+
# Optional subsampling for t-SNE / UMAP visualization
|
| 82 |
+
rng = np.random.default_rng(seed)
|
| 83 |
+
if max_samples is not None and max_samples < total:
|
| 84 |
+
indices = rng.choice(total, size=max_samples, replace=False)
|
| 85 |
+
indices = sorted(indices.tolist())
|
| 86 |
+
print(f"[INFO] Subsampling {len(indices)} / {total} test samples for visualization.")
|
| 87 |
+
else:
|
| 88 |
+
indices = list(range(total))
|
| 89 |
+
print(f"[INFO] Using all {total} test samples for visualization.")
|
| 90 |
+
|
| 91 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 92 |
+
print(f"[INFO] Using device: {device}")
|
| 93 |
+
|
| 94 |
+
# Raw feature pipeline: 64x64 grayscale + flatten
|
| 95 |
+
raw_transform = T.Compose([
|
| 96 |
+
T.Resize((64, 64)),
|
| 97 |
+
T.Grayscale(num_output_channels=1),
|
| 98 |
+
T.ToTensor(), # (1, 64, 64), values in [0,1]
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
resnet_extractor = ResNetFeatureExtractor(device=device)
|
| 102 |
+
|
| 103 |
+
X_raw_list = []
|
| 104 |
+
X_resnet_list = []
|
| 105 |
+
y_list = []
|
| 106 |
+
|
| 107 |
+
for idx in tqdm(indices, desc="Extracting features"):
|
| 108 |
+
img, target = dataset[idx] # img: PIL.Image, target: int
|
| 109 |
+
y_list.append(int(target))
|
| 110 |
+
|
| 111 |
+
# Raw features
|
| 112 |
+
raw_tensor = raw_transform(img) # (1, 64, 64)
|
| 113 |
+
X_raw_list.append(raw_tensor.view(-1).numpy()) # (4096,)
|
| 114 |
+
|
| 115 |
+
# ResNet features
|
| 116 |
+
resnet_feat = resnet_extractor(img) # (512,)
|
| 117 |
+
X_resnet_list.append(resnet_feat)
|
| 118 |
+
|
| 119 |
+
X_raw = np.stack(X_raw_list, axis=0) # (N, 4096)
|
| 120 |
+
X_resnet = np.stack(X_resnet_list, axis=0) # (N, 512)
|
| 121 |
+
y = np.array(y_list, dtype=int)
|
| 122 |
+
|
| 123 |
+
print(f"[INFO] X_raw shape: {X_raw.shape}")
|
| 124 |
+
print(f"[INFO] X_resnet shape: {X_resnet.shape}")
|
| 125 |
+
print(f"[INFO] y shape: {y.shape}")
|
| 126 |
+
|
| 127 |
+
return X_raw, X_resnet, y
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def run_tsne(X, y, out_path: Path, title: str, num_classes_to_label: int = 10):
|
| 131 |
+
"""
|
| 132 |
+
Run t-SNE on feature matrix X and save a 2D scatter plot.
|
| 133 |
+
Points are colored by class label.
|
| 134 |
+
"""
|
| 135 |
+
print(f"[INFO] Running t-SNE for {title} with shape {X.shape}")
|
| 136 |
+
tsne = TSNE(
|
| 137 |
+
n_components=2,
|
| 138 |
+
perplexity=30,
|
| 139 |
+
learning_rate="auto",
|
| 140 |
+
init="pca",
|
| 141 |
+
random_state=42,
|
| 142 |
+
)
|
| 143 |
+
X_2d = tsne.fit_transform(X)
|
| 144 |
+
|
| 145 |
+
# Plot
|
| 146 |
+
plt.figure(figsize=(10, 8))
|
| 147 |
+
scatter = plt.scatter(
|
| 148 |
+
X_2d[:, 0],
|
| 149 |
+
X_2d[:, 1],
|
| 150 |
+
c=y,
|
| 151 |
+
s=8,
|
| 152 |
+
alpha=0.7,
|
| 153 |
+
cmap="tab20",
|
| 154 |
+
)
|
| 155 |
+
plt.title(title)
|
| 156 |
+
plt.xticks([])
|
| 157 |
+
plt.yticks([])
|
| 158 |
+
|
| 159 |
+
# Optionally build a legend with a subset of classes to avoid clutter
|
| 160 |
+
unique_classes = np.unique(y)
|
| 161 |
+
if len(unique_classes) > num_classes_to_label:
|
| 162 |
+
chosen = unique_classes[:num_classes_to_label]
|
| 163 |
+
else:
|
| 164 |
+
chosen = unique_classes
|
| 165 |
+
|
| 166 |
+
# Create proxy artists for legend
|
| 167 |
+
handles = []
|
| 168 |
+
labels = []
|
| 169 |
+
for cls in chosen:
|
| 170 |
+
handles.append(plt.Line2D([], [], marker="o", linestyle="",
|
| 171 |
+
color=scatter.cmap(scatter.norm(cls))))
|
| 172 |
+
labels.append(f"Class {cls}")
|
| 173 |
+
plt.legend(handles, labels, title="Example classes", fontsize=8, loc="best")
|
| 174 |
+
|
| 175 |
+
plt.tight_layout()
|
| 176 |
+
plt.savefig(out_path, dpi=300)
|
| 177 |
+
plt.close()
|
| 178 |
+
print(f"[INFO] Saved t-SNE plot to {out_path}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def run_umap(X, y, out_path: Path, title: str, num_classes_to_label: int = 10):
|
| 182 |
+
"""
|
| 183 |
+
Run UMAP on feature matrix X and save a 2D scatter plot.
|
| 184 |
+
Only runs if umap-learn is installed.
|
| 185 |
+
"""
|
| 186 |
+
if not HAS_UMAP:
|
| 187 |
+
print(f"[WARN] UMAP not available; skipping {title}")
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
print(f"[INFO] Running UMAP for {title} with shape {X.shape}")
|
| 191 |
+
reducer = umap.UMAP(
|
| 192 |
+
n_components=2,
|
| 193 |
+
n_neighbors=15,
|
| 194 |
+
min_dist=0.1,
|
| 195 |
+
random_state=42,
|
| 196 |
+
)
|
| 197 |
+
X_2d = reducer.fit_transform(X)
|
| 198 |
+
|
| 199 |
+
plt.figure(figsize=(10, 8))
|
| 200 |
+
scatter = plt.scatter(
|
| 201 |
+
X_2d[:, 0],
|
| 202 |
+
X_2d[:, 1],
|
| 203 |
+
c=y,
|
| 204 |
+
s=8,
|
| 205 |
+
alpha=0.7,
|
| 206 |
+
cmap="tab20",
|
| 207 |
+
)
|
| 208 |
+
plt.title(title)
|
| 209 |
+
plt.xticks([])
|
| 210 |
+
plt.yticks([])
|
| 211 |
+
|
| 212 |
+
unique_classes = np.unique(y)
|
| 213 |
+
if len(unique_classes) > num_classes_to_label:
|
| 214 |
+
chosen = unique_classes[:num_classes_to_label]
|
| 215 |
+
else:
|
| 216 |
+
chosen = unique_classes
|
| 217 |
+
|
| 218 |
+
handles = []
|
| 219 |
+
labels = []
|
| 220 |
+
for cls in chosen:
|
| 221 |
+
handles.append(plt.Line2D([], [], marker="o", linestyle="",
|
| 222 |
+
color=scatter.cmap(scatter.norm(cls))))
|
| 223 |
+
labels.append(f"Class {cls}")
|
| 224 |
+
plt.legend(handles, labels, title="Example classes", fontsize=8, loc="best")
|
| 225 |
+
|
| 226 |
+
plt.tight_layout()
|
| 227 |
+
plt.savefig(out_path, dpi=300)
|
| 228 |
+
plt.close()
|
| 229 |
+
print(f"[INFO] Saved UMAP plot to {out_path}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def main():
|
| 233 |
+
parser = argparse.ArgumentParser()
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--data-root",
|
| 236 |
+
type=str,
|
| 237 |
+
default="data/oxford-iiit-pet",
|
| 238 |
+
help="Root directory of Oxford-IIIT Pet dataset.",
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--out-dir",
|
| 242 |
+
type=str,
|
| 243 |
+
default="outputs/feature_viz",
|
| 244 |
+
help="Directory to save t-SNE/UMAP plots.",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--max-samples",
|
| 248 |
+
type=int,
|
| 249 |
+
default=2000,
|
| 250 |
+
help="Max number of test samples to subsample for visualization (None = all).",
|
| 251 |
+
)
|
| 252 |
+
args = parser.parse_args()
|
| 253 |
+
|
| 254 |
+
out_dir = Path(args.out_dir)
|
| 255 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 256 |
+
|
| 257 |
+
# 1) Extract features
|
| 258 |
+
X_raw, X_resnet, y = extract_features(
|
| 259 |
+
data_root=args.data_root,
|
| 260 |
+
max_samples=args.max_samples,
|
| 261 |
+
seed=42,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# 2) t-SNE on raw features
|
| 265 |
+
tsne_raw_path = out_dir / "tsne_raw.png"
|
| 266 |
+
run_tsne(X_raw, y, tsne_raw_path, title="t-SNE: Raw 64x64 Grayscale Features")
|
| 267 |
+
|
| 268 |
+
# 3) t-SNE on ResNet features
|
| 269 |
+
tsne_resnet_path = out_dir / "tsne_resnet.png"
|
| 270 |
+
run_tsne(X_resnet, y, tsne_resnet_path, title="t-SNE: ResNet18 Pretrained Features")
|
| 271 |
+
|
| 272 |
+
# 4) Optional UMAP (if available)
|
| 273 |
+
umap_raw_path = out_dir / "umap_raw.png"
|
| 274 |
+
run_umap(X_raw, y, umap_raw_path, title="UMAP: Raw 64x64 Grayscale Features")
|
| 275 |
+
|
| 276 |
+
umap_resnet_path = out_dir / "umap_resnet.png"
|
| 277 |
+
run_umap(X_resnet, y, umap_resnet_path, title="UMAP: ResNet18 Pretrained Features")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
# Keep torch threads manageable
|
| 282 |
+
torch.set_num_threads(4)
|
| 283 |
+
main()
|
src/inference/__pycache__/lr_model.cpython-313.pyc
ADDED
|
Binary file (3.17 kB). View file
|
|
|
src/inference/__pycache__/resnet_pt_lr_model.cpython-313.pyc
ADDED
|
Binary file (8.63 kB). View file
|
|
|
src/inference/__pycache__/resnet_pt_svm_model.cpython-313.pyc
ADDED
|
Binary file (8.41 kB). View file
|
|
|
src/inference/__pycache__/svm_model.cpython-313.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
src/inference/__pycache__/test_resnet_pt_lr.cpython-313.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
src/inference/__pycache__/test_resnet_pt_svm.cpython-313.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
src/inference/base_model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/inference/base_model.py
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
class BaseModel(ABC):
|
| 7 |
+
"""Common interface for all pet recognition models."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, name: str, labels: Dict[int, str]):
|
| 10 |
+
self.name = name
|
| 11 |
+
self.labels = labels
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def preprocess(self, image: Image.Image) -> Any:
|
| 15 |
+
"""Convert PIL image → model input (tensor / numpy / feature vector)."""
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def predict(self, image: Image.Image) -> Dict[str, Any]:
|
| 20 |
+
"""
|
| 21 |
+
Run full pipeline: preprocess → forward pass → postprocess.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
{
|
| 25 |
+
"class_id": int,
|
| 26 |
+
"class_name": str,
|
| 27 |
+
"probs": Dict[str, float], # optional, top-k
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
pass
|
src/inference/lr_model.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import joblib
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LRModel:
|
| 8 |
+
"""
|
| 9 |
+
Inference pipeline for Logistic Regression model
|
| 10 |
+
trained on 64x64 grayscale flattened images.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, model_path: str, labels_path: str, image_size: int = 64):
|
| 14 |
+
self.model = joblib.load(model_path)
|
| 15 |
+
self.labels = self._load_labels(labels_path)
|
| 16 |
+
self.image_size = image_size
|
| 17 |
+
|
| 18 |
+
def _load_labels(self, labels_path):
|
| 19 |
+
with open(labels_path, "r") as f:
|
| 20 |
+
label_dict = json.load(f)
|
| 21 |
+
|
| 22 |
+
# Ensure keys are integer indices, not strings
|
| 23 |
+
label_dict = {int(k): v for k, v in label_dict.items()}
|
| 24 |
+
return label_dict
|
| 25 |
+
|
| 26 |
+
def preprocess(self, image: Image.Image) -> np.ndarray:
|
| 27 |
+
"""
|
| 28 |
+
Preprocessing matching training:
|
| 29 |
+
- Resize to 64x64
|
| 30 |
+
- Grayscale
|
| 31 |
+
- Normalize to [0,1]
|
| 32 |
+
- Flatten to (1, D)
|
| 33 |
+
"""
|
| 34 |
+
img = image.resize((self.image_size, self.image_size))
|
| 35 |
+
img = img.convert("L") # grayscale
|
| 36 |
+
arr = np.array(img, dtype=np.float32) / 255.0
|
| 37 |
+
arr = arr.reshape(1, -1) # shape: (1, D)
|
| 38 |
+
return arr
|
| 39 |
+
|
| 40 |
+
def predict(self, image: Image.Image):
|
| 41 |
+
"""
|
| 42 |
+
Returns:
|
| 43 |
+
{
|
| 44 |
+
"class_id": int,
|
| 45 |
+
"class_name": str,
|
| 46 |
+
"probabilities": {class_name: prob, ...}
|
| 47 |
+
}
|
| 48 |
+
"""
|
| 49 |
+
x = self.preprocess(image)
|
| 50 |
+
probs = self.model.predict_proba(x)[0]
|
| 51 |
+
class_id = int(np.argmax(probs))
|
| 52 |
+
class_name = self.labels[class_id]
|
| 53 |
+
|
| 54 |
+
# Build probability dict (optional)
|
| 55 |
+
prob_dict = {
|
| 56 |
+
self.labels[i]: float(probs[i]) for i in range(len(probs))
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"class_id": class_id,
|
| 61 |
+
"class_name": class_name,
|
| 62 |
+
"probabilities": prob_dict
|
| 63 |
+
}
|
src/inference/resnet_pt_lr_model.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/inference/resnet_pt_lr_model.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
| 12 |
+
import joblib
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResNetPTLRModel:
|
| 16 |
+
"""
|
| 17 |
+
End-to-end inference wrapper:
|
| 18 |
+
- ResNet18 (pretrained on ImageNet) as frozen backbone
|
| 19 |
+
- Logistic Regression head trained on extracted features
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib",
|
| 25 |
+
labels_path: str = "configs/labels.json",
|
| 26 |
+
device: Optional[str] = None,
|
| 27 |
+
):
|
| 28 |
+
assert os.path.exists(ckpt_path), f"ResNet PT + LR checkpoint not found: {ckpt_path}"
|
| 29 |
+
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
|
| 30 |
+
|
| 31 |
+
# Decide device
|
| 32 |
+
if device is None:
|
| 33 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
else:
|
| 35 |
+
self.device = torch.device(device)
|
| 36 |
+
|
| 37 |
+
print(f"[ResNetPTLRModel] Using device: {self.device}")
|
| 38 |
+
|
| 39 |
+
# --- Load LR head ---
|
| 40 |
+
print(f"[ResNetPTLRModel] Loading LR head from {ckpt_path} ...")
|
| 41 |
+
payload = joblib.load(ckpt_path)
|
| 42 |
+
|
| 43 |
+
# payload was saved as dict in train_resnet_pt_lr.py
|
| 44 |
+
if isinstance(payload, dict) and "model" in payload:
|
| 45 |
+
self.lr_head = payload["model"]
|
| 46 |
+
self.feature_dim = int(payload.get("feature_dim", 512))
|
| 47 |
+
self.backbone_name = payload.get("backbone", "resnet18_imagenet")
|
| 48 |
+
self.saved_labels_path = payload.get("labels_path", labels_path)
|
| 49 |
+
else:
|
| 50 |
+
# Fallback if someone saved the raw model
|
| 51 |
+
self.lr_head = payload
|
| 52 |
+
self.feature_dim = None
|
| 53 |
+
self.backbone_name = "resnet18_imagenet"
|
| 54 |
+
self.saved_labels_path = labels_path
|
| 55 |
+
|
| 56 |
+
# --- Load labels mapping ---
|
| 57 |
+
labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
|
| 58 |
+
print(f"[ResNetPTLRModel] Loading labels from {labels_file} ...")
|
| 59 |
+
with open(labels_file, "r") as f:
|
| 60 |
+
id_to_name = json.load(f)
|
| 61 |
+
|
| 62 |
+
# ensure keys are ints
|
| 63 |
+
self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
|
| 64 |
+
|
| 65 |
+
# --- Build ResNet18 backbone + preprocess (same as in feature extraction) ---
|
| 66 |
+
print("[ResNetPTLRModel] Building ResNet18 backbone ...")
|
| 67 |
+
weights = ResNet18_Weights.DEFAULT
|
| 68 |
+
model = resnet18(weights=weights)
|
| 69 |
+
|
| 70 |
+
import torch.nn as nn
|
| 71 |
+
model.fc = nn.Identity()
|
| 72 |
+
|
| 73 |
+
model.to(self.device)
|
| 74 |
+
model.eval()
|
| 75 |
+
|
| 76 |
+
self.backbone = model
|
| 77 |
+
self.preprocess_tf = weights.transforms()
|
| 78 |
+
|
| 79 |
+
# Optional: check feature_dim consistency if available
|
| 80 |
+
if self.feature_dim is not None:
|
| 81 |
+
try:
|
| 82 |
+
test_input = torch.zeros(1, 3, 224, 224).to(self.device)
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
out = self.backbone(test_input)
|
| 85 |
+
actual_dim = out.shape[1]
|
| 86 |
+
if actual_dim != self.feature_dim:
|
| 87 |
+
print(
|
| 88 |
+
f"[ResNetPTLRModel][WARN] feature_dim mismatch: "
|
| 89 |
+
f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
|
| 90 |
+
)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"[ResNetPTLRModel][WARN] could not verify feature_dim: {e}")
|
| 93 |
+
|
| 94 |
+
def preprocess(self, img: Image.Image) -> torch.Tensor:
|
| 95 |
+
"""
|
| 96 |
+
Apply ImageNet-style transforms and return a (1, 3, H, W) tensor on device.
|
| 97 |
+
"""
|
| 98 |
+
t = self.preprocess_tf(img) # (3, H, W)
|
| 99 |
+
if t.ndim == 3:
|
| 100 |
+
t = t.unsqueeze(0) # (1, 3, H, W)
|
| 101 |
+
return t.to(self.device)
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def _to_probabilities_from_logits(logits: np.ndarray) -> np.ndarray:
|
| 105 |
+
"""
|
| 106 |
+
Convert raw scores/logits to probabilities using softmax.
|
| 107 |
+
"""
|
| 108 |
+
logits = logits - np.max(logits)
|
| 109 |
+
exp = np.exp(logits)
|
| 110 |
+
return exp / np.sum(exp)
|
| 111 |
+
|
| 112 |
+
def _extract_features(self, img: Image.Image) -> np.ndarray:
|
| 113 |
+
"""
|
| 114 |
+
Run a PIL image through the backbone and get a (1, D) numpy feature vector.
|
| 115 |
+
"""
|
| 116 |
+
x = self.preprocess(img) # (1, 3, H, W)
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
feats = self.backbone(x) # (1, D)
|
| 119 |
+
feats_np = feats.cpu().numpy()
|
| 120 |
+
return feats_np # (1, D)
|
| 121 |
+
|
| 122 |
+
def predict(
|
| 123 |
+
self,
|
| 124 |
+
img: Image.Image,
|
| 125 |
+
top_k: int = 5,
|
| 126 |
+
) -> Dict[str, Any]:
|
| 127 |
+
"""
|
| 128 |
+
Predict class for a single image.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
{
|
| 132 |
+
"class_id": int,
|
| 133 |
+
"class_name": str,
|
| 134 |
+
"probabilities": {class_name: prob_float},
|
| 135 |
+
"top_k": [
|
| 136 |
+
{"class_id": int, "class_name": str, "probability": float},
|
| 137 |
+
...
|
| 138 |
+
]
|
| 139 |
+
}
|
| 140 |
+
"""
|
| 141 |
+
feats_np = self._extract_features(img) # (1, D)
|
| 142 |
+
|
| 143 |
+
# LR has predict_proba, use that directly
|
| 144 |
+
if hasattr(self.lr_head, "predict_proba"):
|
| 145 |
+
probs = self.lr_head.predict_proba(feats_np)[0] # (C,)
|
| 146 |
+
else:
|
| 147 |
+
# Fallback: use decision_function and softmax
|
| 148 |
+
scores = self.lr_head.decision_function(feats_np)
|
| 149 |
+
if scores.ndim == 1:
|
| 150 |
+
scores = scores[np.newaxis, :]
|
| 151 |
+
probs = self._to_probabilities_from_logits(scores[0])
|
| 152 |
+
|
| 153 |
+
pred_id = int(np.argmax(probs))
|
| 154 |
+
pred_name = self.id_to_name[pred_id]
|
| 155 |
+
|
| 156 |
+
# Full distribution
|
| 157 |
+
prob_dict: Dict[str, float] = {
|
| 158 |
+
self.id_to_name[i]: float(p)
|
| 159 |
+
for i, p in enumerate(probs)
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Top-k sorted
|
| 163 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 164 |
+
top_k = min(top_k, len(sorted_indices))
|
| 165 |
+
top_k_list: List[Dict[str, Any]] = []
|
| 166 |
+
for i in range(top_k):
|
| 167 |
+
cid = int(sorted_indices[i])
|
| 168 |
+
top_k_list.append({
|
| 169 |
+
"class_id": cid,
|
| 170 |
+
"class_name": self.id_to_name[cid],
|
| 171 |
+
"probability": float(probs[cid]),
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"class_id": pred_id,
|
| 176 |
+
"class_name": pred_name,
|
| 177 |
+
"probabilities": prob_dict,
|
| 178 |
+
"top_k": top_k_list,
|
| 179 |
+
}
|
src/inference/resnet_pt_svm_model.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/inference/resnet_pt_svm_model.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
| 12 |
+
import joblib
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResNetPTSVMModel:
|
| 16 |
+
"""
|
| 17 |
+
ResNet18 (pretrained, frozen) + Linear SVM head.
|
| 18 |
+
|
| 19 |
+
Pipeline:
|
| 20 |
+
- PIL image
|
| 21 |
+
- ImageNet transforms
|
| 22 |
+
- ResNet18 backbone (fc -> Identity) -> feature vector
|
| 23 |
+
- Linear SVM decision_function
|
| 24 |
+
- Softmax over scores to get probabilities
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib",
|
| 30 |
+
labels_path: str = "configs/labels.json",
|
| 31 |
+
device: Optional[str] = None,
|
| 32 |
+
):
|
| 33 |
+
assert os.path.exists(ckpt_path), f"ResNet PT + SVM checkpoint not found: {ckpt_path}"
|
| 34 |
+
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
|
| 35 |
+
|
| 36 |
+
# Device
|
| 37 |
+
if device is None:
|
| 38 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
else:
|
| 40 |
+
self.device = torch.device(device)
|
| 41 |
+
|
| 42 |
+
print(f"[ResNetPTSVMModel] Using device: {self.device}")
|
| 43 |
+
|
| 44 |
+
# --- Load SVM head ---
|
| 45 |
+
print(f"[ResNetPTSVMModel] Loading SVM head from {ckpt_path} ...")
|
| 46 |
+
payload = joblib.load(ckpt_path)
|
| 47 |
+
|
| 48 |
+
if isinstance(payload, dict) and "model" in payload:
|
| 49 |
+
self.svm_head = payload["model"]
|
| 50 |
+
self.feature_dim = int(payload.get("feature_dim", 512))
|
| 51 |
+
self.backbone_name = payload.get("backbone", "resnet18_imagenet")
|
| 52 |
+
self.saved_labels_path = payload.get("labels_path", labels_path)
|
| 53 |
+
else:
|
| 54 |
+
self.svm_head = payload
|
| 55 |
+
self.feature_dim = None
|
| 56 |
+
self.backbone_name = "resnet18_imagenet"
|
| 57 |
+
self.saved_labels_path = labels_path
|
| 58 |
+
|
| 59 |
+
# --- Load labels mapping ---
|
| 60 |
+
labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
|
| 61 |
+
print(f"[ResNetPTSVMModel] Loading labels from {labels_file} ...")
|
| 62 |
+
with open(labels_file, "r") as f:
|
| 63 |
+
id_to_name = json.load(f)
|
| 64 |
+
|
| 65 |
+
# ensure keys are ints
|
| 66 |
+
self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
|
| 67 |
+
|
| 68 |
+
# --- Build ResNet18 backbone + preprocess ---
|
| 69 |
+
print("[ResNetPTSVMModel] Building ResNet18 backbone ...")
|
| 70 |
+
weights = ResNet18_Weights.DEFAULT
|
| 71 |
+
model = resnet18(weights=weights)
|
| 72 |
+
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
model.fc = nn.Identity()
|
| 75 |
+
|
| 76 |
+
model.to(self.device)
|
| 77 |
+
model.eval()
|
| 78 |
+
|
| 79 |
+
self.backbone = model
|
| 80 |
+
self.preprocess_tf = weights.transforms()
|
| 81 |
+
|
| 82 |
+
# Optional: sanity check feature_dim
|
| 83 |
+
if self.feature_dim is not None:
|
| 84 |
+
try:
|
| 85 |
+
test_input = torch.zeros(1, 3, 224, 224).to(self.device)
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
out = self.backbone(test_input)
|
| 88 |
+
actual_dim = out.shape[1]
|
| 89 |
+
if actual_dim != self.feature_dim:
|
| 90 |
+
print(
|
| 91 |
+
f"[ResNetPTSVMModel][WARN] feature_dim mismatch: "
|
| 92 |
+
f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
|
| 93 |
+
)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"[ResNetPTSVMModel][WARN] could not verify feature_dim: {e}")
|
| 96 |
+
|
| 97 |
+
def preprocess(self, img: Image.Image) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
Apply the ImageNet-style transforms and return (1, 3, H, W) tensor on device.
|
| 100 |
+
"""
|
| 101 |
+
t = self.preprocess_tf(img) # (3, H, W)
|
| 102 |
+
if t.ndim == 3:
|
| 103 |
+
t = t.unsqueeze(0)
|
| 104 |
+
return t.to(self.device)
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def _softmax(scores: np.ndarray) -> np.ndarray:
|
| 108 |
+
scores = scores - np.max(scores)
|
| 109 |
+
exp = np.exp(scores)
|
| 110 |
+
return exp / np.sum(exp)
|
| 111 |
+
|
| 112 |
+
def _extract_features(self, img: Image.Image) -> np.ndarray:
|
| 113 |
+
"""
|
| 114 |
+
Run image through ResNet backbone to get (1, D) feature vector.
|
| 115 |
+
"""
|
| 116 |
+
x = self.preprocess(img)
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
feats = self.backbone(x) # (1, D)
|
| 119 |
+
return feats.cpu().numpy() # (1, D)
|
| 120 |
+
|
| 121 |
+
def predict(
|
| 122 |
+
self,
|
| 123 |
+
img: Image.Image,
|
| 124 |
+
top_k: int = 5,
|
| 125 |
+
) -> Dict[str, Any]:
|
| 126 |
+
"""
|
| 127 |
+
Predict class for a single image.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
{
|
| 131 |
+
"class_id": int,
|
| 132 |
+
"class_name": str,
|
| 133 |
+
"probabilities": {class_name: prob_float},
|
| 134 |
+
"top_k": [
|
| 135 |
+
{"class_id": int, "class_name": str, "probability": float},
|
| 136 |
+
...
|
| 137 |
+
]
|
| 138 |
+
}
|
| 139 |
+
"""
|
| 140 |
+
feats_np = self._extract_features(img) # (1, D)
|
| 141 |
+
|
| 142 |
+
# LinearSVC has no predict_proba -> use decision_function
|
| 143 |
+
scores = self.svm_head.decision_function(feats_np)
|
| 144 |
+
if scores.ndim == 1:
|
| 145 |
+
scores = scores[np.newaxis, :]
|
| 146 |
+
scores = scores[0] # (C,)
|
| 147 |
+
|
| 148 |
+
probs = self._softmax(scores) # (C,)
|
| 149 |
+
|
| 150 |
+
pred_id = int(np.argmax(probs))
|
| 151 |
+
pred_name = self.id_to_name[pred_id]
|
| 152 |
+
|
| 153 |
+
prob_dict: Dict[str, float] = {
|
| 154 |
+
self.id_to_name[i]: float(p)
|
| 155 |
+
for i, p in enumerate(probs)
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 159 |
+
top_k = min(top_k, len(sorted_indices))
|
| 160 |
+
top_k_list: List[Dict[str, Any]] = []
|
| 161 |
+
for i in range(top_k):
|
| 162 |
+
cid = int(sorted_indices[i])
|
| 163 |
+
top_k_list.append({
|
| 164 |
+
"class_id": cid,
|
| 165 |
+
"class_name": self.id_to_name[cid],
|
| 166 |
+
"probability": float(probs[cid]),
|
| 167 |
+
})
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"class_id": pred_id,
|
| 171 |
+
"class_name": pred_name,
|
| 172 |
+
"probabilities": prob_dict,
|
| 173 |
+
"top_k": top_k_list,
|
| 174 |
+
}
|
src/inference/svm_model.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/inference/svm_model.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
import joblib
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SVMModel:
|
| 14 |
+
"""
|
| 15 |
+
Inference wrapper for the Linear SVM trained on raw 64x64 grayscale pixels.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
ckpt_path: str = "checkpoints/svm_model.joblib",
|
| 21 |
+
labels_path: str = "configs/labels.json",
|
| 22 |
+
):
|
| 23 |
+
assert os.path.exists(ckpt_path), f"SVM checkpoint not found: {ckpt_path}"
|
| 24 |
+
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
|
| 25 |
+
|
| 26 |
+
print(f"[SVMModel] Loading checkpoint from {ckpt_path} ...")
|
| 27 |
+
payload = joblib.load(ckpt_path)
|
| 28 |
+
|
| 29 |
+
# You might have saved a dict with more keys, so handle both cases.
|
| 30 |
+
if isinstance(payload, dict) and "model" in payload:
|
| 31 |
+
self.model = payload["model"]
|
| 32 |
+
else:
|
| 33 |
+
self.model = payload
|
| 34 |
+
|
| 35 |
+
print(f"[SVMModel] Loading labels from {labels_path} ...")
|
| 36 |
+
with open(labels_path, "r") as f:
|
| 37 |
+
self.id_to_name = json.load(f)
|
| 38 |
+
|
| 39 |
+
# Ensure keys are integers
|
| 40 |
+
self.id_to_name = {int(k): v for k, v in self.id_to_name.items()}
|
| 41 |
+
|
| 42 |
+
self.preprocess_tf = transforms.Compose([
|
| 43 |
+
transforms.Resize((64, 64)),
|
| 44 |
+
transforms.Grayscale(num_output_channels=1),
|
| 45 |
+
transforms.ToTensor(), # (1, 64, 64) in [0, 1]
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
def preprocess(self, img: Image.Image) -> np.ndarray:
|
| 49 |
+
"""
|
| 50 |
+
Convert PIL image to flattened grayscale vector (1, 4096).
|
| 51 |
+
"""
|
| 52 |
+
t = self.preprocess_tf(img) # (1, 64, 64) tensor
|
| 53 |
+
arr = t.view(-1).numpy() # (4096,)
|
| 54 |
+
return arr[np.newaxis, :] # (1, 4096)
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def _softmax(scores: np.ndarray) -> np.ndarray:
|
| 58 |
+
# scores: (C,)
|
| 59 |
+
scores = scores - np.max(scores) # for numerical stability
|
| 60 |
+
exp = np.exp(scores)
|
| 61 |
+
return exp / np.sum(exp)
|
| 62 |
+
|
| 63 |
+
def predict(
|
| 64 |
+
self,
|
| 65 |
+
img: Image.Image,
|
| 66 |
+
top_k: int = 5,
|
| 67 |
+
) -> Dict[str, Any]:
|
| 68 |
+
"""
|
| 69 |
+
Predict the class of a single image.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
{
|
| 73 |
+
"class_id": int,
|
| 74 |
+
"class_name": str,
|
| 75 |
+
"probabilities": {class_name: prob_float} # full distribution
|
| 76 |
+
"top_k": List[{"class_id": int, "class_name": str, "probability": float}]
|
| 77 |
+
}
|
| 78 |
+
"""
|
| 79 |
+
x = self.preprocess(img) # (1, 4096)
|
| 80 |
+
|
| 81 |
+
# LinearSVC doesn't have predict_proba, but decision_function gives scores
|
| 82 |
+
scores = self.model.decision_function(x) # (1, C) or (C,) if binary
|
| 83 |
+
if scores.ndim == 1:
|
| 84 |
+
scores = scores[np.newaxis, :]
|
| 85 |
+
scores = scores[0] # (C,)
|
| 86 |
+
|
| 87 |
+
probs = self._softmax(scores) # (C,)
|
| 88 |
+
|
| 89 |
+
pred_id = int(np.argmax(probs))
|
| 90 |
+
pred_name = self.id_to_name[pred_id]
|
| 91 |
+
|
| 92 |
+
# Build dict of {class_name: prob}
|
| 93 |
+
prob_dict = {
|
| 94 |
+
self.id_to_name[i]: float(p)
|
| 95 |
+
for i, p in enumerate(probs)
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Build sorted top-k
|
| 99 |
+
sorted_indices = np.argsort(probs)[::-1]
|
| 100 |
+
top_k = min(top_k, len(sorted_indices))
|
| 101 |
+
top_k_list: List[Dict[str, Any]] = []
|
| 102 |
+
for i in range(top_k):
|
| 103 |
+
cid = int(sorted_indices[i])
|
| 104 |
+
top_k_list.append({
|
| 105 |
+
"class_id": cid,
|
| 106 |
+
"class_name": self.id_to_name[cid],
|
| 107 |
+
"probability": float(probs[cid]),
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"class_id": pred_id,
|
| 112 |
+
"class_name": pred_name,
|
| 113 |
+
"probabilities": prob_dict,
|
| 114 |
+
"top_k": top_k_list,
|
| 115 |
+
}
|
src/inference/test_resnet_pt_lr.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/inference/test_resnet_pt_lr.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torchvision import datasets
|
| 11 |
+
|
| 12 |
+
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_single_image(
|
| 16 |
+
image_path: str,
|
| 17 |
+
ckpt_path: str,
|
| 18 |
+
labels_path: str,
|
| 19 |
+
device: str = None,
|
| 20 |
+
top_k: int = 5,
|
| 21 |
+
):
|
| 22 |
+
assert os.path.exists(image_path), f"Image not found: {image_path}"
|
| 23 |
+
img = Image.open(image_path).convert("RGB")
|
| 24 |
+
|
| 25 |
+
model = ResNetPTLRModel(
|
| 26 |
+
ckpt_path=ckpt_path,
|
| 27 |
+
labels_path=labels_path,
|
| 28 |
+
device=device,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
out = model.predict(img, top_k=top_k)
|
| 32 |
+
|
| 33 |
+
print(f"Input image: {image_path}")
|
| 34 |
+
print(f"Predicted class_id : {out['class_id']}")
|
| 35 |
+
print(f"Predicted class_name: {out['class_name']}")
|
| 36 |
+
print("Top-k predictions:")
|
| 37 |
+
for i, item in enumerate(out["top_k"], start=1):
|
| 38 |
+
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_random_dataset_sample(
|
| 42 |
+
data_root: str,
|
| 43 |
+
ckpt_path: str,
|
| 44 |
+
labels_path: str,
|
| 45 |
+
device: str = None,
|
| 46 |
+
top_k: int = 5,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Pick a random sample from the Oxford-IIIT Pet test split and run inference.
|
| 50 |
+
"""
|
| 51 |
+
print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")
|
| 52 |
+
|
| 53 |
+
# transform=None -> returns PIL.Image
|
| 54 |
+
test_ds = datasets.OxfordIIITPet(
|
| 55 |
+
root=data_root,
|
| 56 |
+
split="test",
|
| 57 |
+
target_types="category",
|
| 58 |
+
transform=None,
|
| 59 |
+
download=True,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
model = ResNetPTLRModel(
|
| 63 |
+
ckpt_path=ckpt_path,
|
| 64 |
+
labels_path=labels_path,
|
| 65 |
+
device=device,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
idx = random.randint(0, len(test_ds) - 1)
|
| 69 |
+
img, target = test_ds[idx]
|
| 70 |
+
assert isinstance(img, Image.Image)
|
| 71 |
+
|
| 72 |
+
# dataset has .categories giving names
|
| 73 |
+
gt_name = test_ds.categories[target]
|
| 74 |
+
|
| 75 |
+
print(f"[+] Random sample idx={idx}")
|
| 76 |
+
print(f" Ground truth: id={target}, name={gt_name}")
|
| 77 |
+
|
| 78 |
+
out = model.predict(img, top_k=top_k)
|
| 79 |
+
|
| 80 |
+
print(f" Predicted class_id : {out['class_id']}")
|
| 81 |
+
print(f" Predicted class_name: {out['class_name']}")
|
| 82 |
+
print(" Top-k predictions:")
|
| 83 |
+
for i, item in enumerate(out["top_k"], start=1):
|
| 84 |
+
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def parse_args():
|
| 88 |
+
parser = argparse.ArgumentParser(
|
| 89 |
+
description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--ckpt-path",
|
| 94 |
+
type=str,
|
| 95 |
+
default="checkpoints/resnet_pt_lr_head.joblib",
|
| 96 |
+
help="Path to ResNet PT + LR checkpoint.",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--labels-path",
|
| 100 |
+
type=str,
|
| 101 |
+
default="configs/labels.json",
|
| 102 |
+
help="Path to labels mapping JSON.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--data-root",
|
| 106 |
+
type=str,
|
| 107 |
+
default="data/oxford-iiit-pet",
|
| 108 |
+
help="Root directory for Oxford-IIIT Pet dataset.",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--image-path",
|
| 112 |
+
type=str,
|
| 113 |
+
default=None,
|
| 114 |
+
help="If provided, run inference on this image instead of a random test sample.",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--device",
|
| 118 |
+
type=str,
|
| 119 |
+
default=None,
|
| 120 |
+
help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--top-k",
|
| 124 |
+
type=int,
|
| 125 |
+
default=5,
|
| 126 |
+
help="Number of top classes to print.",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return parser.parse_args()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
args = parse_args()
|
| 134 |
+
|
| 135 |
+
if args.image_path is not None:
|
| 136 |
+
test_single_image(
|
| 137 |
+
image_path=args.image_path,
|
| 138 |
+
ckpt_path=args.ckpt_path,
|
| 139 |
+
labels_path=args.labels_path,
|
| 140 |
+
device=args.device,
|
| 141 |
+
top_k=args.top_k,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
test_random_dataset_sample(
|
| 145 |
+
data_root=args.data_root,
|
| 146 |
+
ckpt_path=args.ckpt_path,
|
| 147 |
+
labels_path=args.labels_path,
|
| 148 |
+
device=args.device,
|
| 149 |
+
top_k=args.top_k,
|
| 150 |
+
)
|
src/inference/test_resnet_pt_svm.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/inference/test_resnet_pt_svm.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torchvision import datasets
|
| 9 |
+
|
| 10 |
+
from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_single_image(
|
| 14 |
+
image_path: str,
|
| 15 |
+
ckpt_path: str,
|
| 16 |
+
labels_path: str,
|
| 17 |
+
device: str = None,
|
| 18 |
+
top_k: int = 5,
|
| 19 |
+
):
|
| 20 |
+
assert os.path.exists(image_path), f"Image not found: {image_path}"
|
| 21 |
+
img = Image.open(image_path).convert("RGB")
|
| 22 |
+
|
| 23 |
+
model = ResNetPTSVMModel(
|
| 24 |
+
ckpt_path=ckpt_path,
|
| 25 |
+
labels_path=labels_path,
|
| 26 |
+
device=device,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
out = model.predict(img, top_k=top_k)
|
| 30 |
+
|
| 31 |
+
print(f"Input image: {image_path}")
|
| 32 |
+
print(f"Predicted class_id : {out['class_id']}")
|
| 33 |
+
print(f"Predicted class_name: {out['class_name']}")
|
| 34 |
+
print("Top-k predictions:")
|
| 35 |
+
for i, item in enumerate(out["top_k"], start=1):
|
| 36 |
+
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_random_dataset_sample(
|
| 40 |
+
data_root: str,
|
| 41 |
+
ckpt_path: str,
|
| 42 |
+
labels_path: str,
|
| 43 |
+
device: str = None,
|
| 44 |
+
top_k: int = 5,
|
| 45 |
+
):
|
| 46 |
+
print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")
|
| 47 |
+
|
| 48 |
+
test_ds = datasets.OxfordIIITPet(
|
| 49 |
+
root=data_root,
|
| 50 |
+
split="test",
|
| 51 |
+
target_types="category",
|
| 52 |
+
transform=None, # return PIL.Image
|
| 53 |
+
download=True,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
model = ResNetPTSVMModel(
|
| 57 |
+
ckpt_path=ckpt_path,
|
| 58 |
+
labels_path=labels_path,
|
| 59 |
+
device=device,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
idx = random.randint(0, len(test_ds) - 1)
|
| 63 |
+
img, target = test_ds[idx]
|
| 64 |
+
assert isinstance(img, Image.Image)
|
| 65 |
+
|
| 66 |
+
gt_name = test_ds.categories[target]
|
| 67 |
+
|
| 68 |
+
print(f"[+] Random sample idx={idx}")
|
| 69 |
+
print(f" Ground truth: id={target}, name={gt_name}")
|
| 70 |
+
|
| 71 |
+
out = model.predict(img, top_k=top_k)
|
| 72 |
+
|
| 73 |
+
print(f" Predicted class_id : {out['class_id']}")
|
| 74 |
+
print(f" Predicted class_name: {out['class_name']}")
|
| 75 |
+
print(" Top-k predictions:")
|
| 76 |
+
for i, item in enumerate(out["top_k"], start=1):
|
| 77 |
+
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def parse_args():
|
| 81 |
+
parser = argparse.ArgumentParser(
|
| 82 |
+
description="Test ResNet(PT) + SVM inference on Oxford-IIIT Pet."
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--ckpt-path",
|
| 87 |
+
type=str,
|
| 88 |
+
default="checkpoints/resnet_pt_svm_head.joblib",
|
| 89 |
+
help="Path to ResNet PT + SVM checkpoint.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--labels-path",
|
| 93 |
+
type=str,
|
| 94 |
+
default="configs/labels.json",
|
| 95 |
+
help="Path to labels mapping JSON.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--data-root",
|
| 99 |
+
type=str,
|
| 100 |
+
default="data/oxford-iiit-pet",
|
| 101 |
+
help="Root directory for Oxford-IIIT Pet dataset.",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--image-path",
|
| 105 |
+
type=str,
|
| 106 |
+
default=None,
|
| 107 |
+
help="If provided, run on this image instead of random test sample.",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--device",
|
| 111 |
+
type=str,
|
| 112 |
+
default=None,
|
| 113 |
+
help="Device to use (e.g. 'cpu', 'cuda'). If None, auto-select.",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--top-k",
|
| 117 |
+
type=int,
|
| 118 |
+
default=5,
|
| 119 |
+
help="Number of top classes to print.",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return parser.parse_args()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
args = parse_args()
|
| 127 |
+
|
| 128 |
+
if args.image_path is not None:
|
| 129 |
+
test_single_image(
|
| 130 |
+
image_path=args.image_path,
|
| 131 |
+
ckpt_path=args.ckpt_path,
|
| 132 |
+
labels_path=args.labels_path,
|
| 133 |
+
device=args.device,
|
| 134 |
+
top_k=args.top_k,
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
test_random_dataset_sample(
|
| 138 |
+
data_root=args.data_root,
|
| 139 |
+
ckpt_path=args.ckpt_path,
|
| 140 |
+
labels_path=args.labels_path,
|
| 141 |
+
device=args.device,
|
| 142 |
+
top_k=args.top_k,
|
| 143 |
+
)
|
src/registry.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/registry.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Callable, Dict, Any, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class RegisteredModel:
|
| 9 |
+
"""Metadata + lazy loader for a single model."""
|
| 10 |
+
id: str
|
| 11 |
+
display_name: str
|
| 12 |
+
loader: Callable[[], Any]
|
| 13 |
+
_instance: Optional[Any] = field(default=None, init=False, repr=False)
|
| 14 |
+
|
| 15 |
+
def get(self) -> Any:
|
| 16 |
+
"""Instantiate on first call, then cache."""
|
| 17 |
+
if self._instance is None:
|
| 18 |
+
self._instance = self.loader()
|
| 19 |
+
return self._instance
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]:
|
| 23 |
+
"""
|
| 24 |
+
Central place to register all models.
|
| 25 |
+
Returns a dict: model_id -> RegisteredModel.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def make_lr_raw():
|
| 29 |
+
from src.inference.lr_model import LRModel
|
| 30 |
+
return LRModel(
|
| 31 |
+
ckpt_path="checkpoints/lr_model.joblib",
|
| 32 |
+
labels_path="configs/labels.json",
|
| 33 |
+
device=device,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def make_svm_raw():
|
| 37 |
+
from src.inference.svm_model import SVMModel
|
| 38 |
+
return SVMModel(
|
| 39 |
+
ckpt_path="checkpoints/svm_model.joblib",
|
| 40 |
+
labels_path="configs/labels.json",
|
| 41 |
+
device=device,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def make_resnet_pt_lr():
|
| 45 |
+
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
|
| 46 |
+
return ResNetPTLRModel(
|
| 47 |
+
ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
|
| 48 |
+
labels_path="configs/labels.json",
|
| 49 |
+
device=device,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def make_resnet_pt_svm():
|
| 53 |
+
from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
|
| 54 |
+
return ResNetPTSVMModel(
|
| 55 |
+
ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
|
| 56 |
+
labels_path="configs/labels.json",
|
| 57 |
+
device=device,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
"lr_raw": RegisteredModel(
|
| 62 |
+
id="lr_raw",
|
| 63 |
+
display_name="LR (raw 64×64 grayscale)",
|
| 64 |
+
loader=make_lr_raw,
|
| 65 |
+
),
|
| 66 |
+
"svm_raw": RegisteredModel(
|
| 67 |
+
id="svm_raw",
|
| 68 |
+
display_name="SVM (raw 64×64 grayscale)",
|
| 69 |
+
loader=make_svm_raw,
|
| 70 |
+
),
|
| 71 |
+
"resnet_pt_lr": RegisteredModel(
|
| 72 |
+
id="resnet_pt_lr",
|
| 73 |
+
display_name="ResNet(PT) + LR",
|
| 74 |
+
loader=make_resnet_pt_lr,
|
| 75 |
+
),
|
| 76 |
+
"resnet_pt_svm": RegisteredModel(
|
| 77 |
+
id="resnet_pt_svm",
|
| 78 |
+
display_name="ResNet(PT) + SVM",
|
| 79 |
+
loader=make_resnet_pt_svm,
|
| 80 |
+
),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Build once at import; models themselves are loaded lazily.
|
| 85 |
+
_REGISTRY: Dict[str, RegisteredModel] = _build_registry()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_registry() -> Dict[str, RegisteredModel]:
|
| 89 |
+
"""Return the full registry (id -> RegisteredModel)."""
|
| 90 |
+
return _REGISTRY
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_models() -> Dict[str, Any]:
|
| 94 |
+
"""
|
| 95 |
+
Eagerly instantiate all models and return id -> model_instance.
|
| 96 |
+
Useful for simple scripts or for initializing everything at UI startup.
|
| 97 |
+
"""
|
| 98 |
+
return {mid: entry.get() for mid, entry in _REGISTRY.items()}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_model(model_id: str) -> Any:
|
| 102 |
+
"""Get a single model instance by id (instantiates on first use)."""
|
| 103 |
+
return _REGISTRY[model_id].get()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_model_display_names() -> Dict[str, str]:
|
| 107 |
+
"""Return mapping id -> human-readable name (for dropdown choices)."""
|
| 108 |
+
return {mid: entry.display_name for mid, entry in _REGISTRY.items()}
|
src/training/__pycache__/extract_resnet_features.cpython-313.pyc
ADDED
|
Binary file (6.8 kB). View file
|
|
|
src/training/__pycache__/train_resnet_pt_lr.cpython-313.pyc
ADDED
|
Binary file (5.58 kB). View file
|
|
|
src/training/__pycache__/train_resnet_pt_svm.cpython-313.pyc
ADDED
|
Binary file (5.51 kB). View file
|
|
|
src/training/__pycache__/train_svm.cpython-313.pyc
ADDED
|
Binary file (6.6 kB). View file
|
|
|
src/training/extract_resnet_features.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/training/extract_resnet_features.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from torchvision import datasets
|
| 10 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_datasets(data_root: str, preprocess):
|
| 14 |
+
"""
|
| 15 |
+
Build Oxford-IIIT Pet train/test datasets with ResNet preprocessing.
|
| 16 |
+
"""
|
| 17 |
+
train_ds = datasets.OxfordIIITPet(
|
| 18 |
+
root=data_root,
|
| 19 |
+
split="trainval",
|
| 20 |
+
target_types="category",
|
| 21 |
+
transform=preprocess,
|
| 22 |
+
download=True,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
test_ds = datasets.OxfordIIITPet(
|
| 26 |
+
root=data_root,
|
| 27 |
+
split="test",
|
| 28 |
+
target_types="category",
|
| 29 |
+
transform=preprocess,
|
| 30 |
+
download=True,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return train_ds, test_ds
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_dataloaders(train_ds, test_ds, batch_size: int = 64, num_workers: int = 2):
|
| 37 |
+
train_loader = DataLoader(
|
| 38 |
+
train_ds,
|
| 39 |
+
batch_size=batch_size,
|
| 40 |
+
shuffle=False, # don't shuffle, we just want deterministic feature arrays
|
| 41 |
+
num_workers=num_workers,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
test_loader = DataLoader(
|
| 45 |
+
test_ds,
|
| 46 |
+
batch_size=batch_size,
|
| 47 |
+
shuffle=False,
|
| 48 |
+
num_workers=num_workers,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return train_loader, test_loader
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def build_resnet18_backbone(device: torch.device):
|
| 55 |
+
"""
|
| 56 |
+
Load ResNet18 pretrained on ImageNet, replace final fc with Identity.
|
| 57 |
+
Returns:
|
| 58 |
+
model (nn.Module), feature_dim (int), preprocess (transform)
|
| 59 |
+
"""
|
| 60 |
+
weights = ResNet18_Weights.DEFAULT
|
| 61 |
+
model = resnet18(weights=weights)
|
| 62 |
+
feature_dim = model.fc.in_features # 512
|
| 63 |
+
|
| 64 |
+
# Replace final classifier with identity to get penultimate features
|
| 65 |
+
import torch.nn as nn
|
| 66 |
+
model.fc = nn.Identity()
|
| 67 |
+
|
| 68 |
+
model.to(device)
|
| 69 |
+
model.eval()
|
| 70 |
+
|
| 71 |
+
# Official preprocessing pipeline for these weights (resize + crop + norm)
|
| 72 |
+
preprocess = weights.transforms()
|
| 73 |
+
|
| 74 |
+
return model, feature_dim, preprocess
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def extract_features(model, loader, device: torch.device):
|
| 78 |
+
"""
|
| 79 |
+
Run images through the model and collect features + labels.
|
| 80 |
+
Returns:
|
| 81 |
+
X: (N, feature_dim) numpy array
|
| 82 |
+
y: (N,) numpy array
|
| 83 |
+
"""
|
| 84 |
+
features_list = []
|
| 85 |
+
labels_list = []
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for images, targets in loader:
|
| 89 |
+
images = images.to(device)
|
| 90 |
+
outputs = model(images) # (B, feature_dim)
|
| 91 |
+
features_list.append(outputs.cpu().numpy())
|
| 92 |
+
labels_list.append(targets.numpy())
|
| 93 |
+
|
| 94 |
+
X = np.concatenate(features_list, axis=0)
|
| 95 |
+
y = np.concatenate(labels_list, axis=0)
|
| 96 |
+
return X, y
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main(
|
| 100 |
+
data_root: str = "data/oxford-iiit-pet",
|
| 101 |
+
out_dir: str = "data/resnet18_features",
|
| 102 |
+
batch_size: int = 64,
|
| 103 |
+
num_workers: int = 2,
|
| 104 |
+
):
|
| 105 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 106 |
+
|
| 107 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 108 |
+
print(f"[+] Using device: {device}")
|
| 109 |
+
|
| 110 |
+
print("[+] Building ResNet18 backbone and preprocessing ...")
|
| 111 |
+
model, feature_dim, preprocess = build_resnet18_backbone(device)
|
| 112 |
+
print(f"[+] Feature dimension: {feature_dim}")
|
| 113 |
+
|
| 114 |
+
print(f"[+] Loading Oxford-IIIT Pet from {data_root} ...")
|
| 115 |
+
train_ds, test_ds = build_datasets(data_root, preprocess)
|
| 116 |
+
|
| 117 |
+
print("[+] Building dataloaders ...")
|
| 118 |
+
train_loader, test_loader = build_dataloaders(
|
| 119 |
+
train_ds, test_ds, batch_size=batch_size, num_workers=num_workers
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
print("[+] Extracting train features ...")
|
| 123 |
+
X_train, y_train = extract_features(model, train_loader, device)
|
| 124 |
+
print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
|
| 125 |
+
|
| 126 |
+
print("[+] Extracting test features ...")
|
| 127 |
+
X_test, y_test = extract_features(model, test_loader, device)
|
| 128 |
+
print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
|
| 129 |
+
|
| 130 |
+
# Save to .npy
|
| 131 |
+
x_train_path = os.path.join(out_dir, "X_train_resnet18.npy")
|
| 132 |
+
y_train_path = os.path.join(out_dir, "y_train.npy")
|
| 133 |
+
x_test_path = os.path.join(out_dir, "X_test_resnet18.npy")
|
| 134 |
+
y_test_path = os.path.join(out_dir, "y_test.npy")
|
| 135 |
+
|
| 136 |
+
print(f"[+] Saving features to {out_dir} ...")
|
| 137 |
+
np.save(x_train_path, X_train)
|
| 138 |
+
np.save(y_train_path, y_train)
|
| 139 |
+
np.save(x_test_path, X_test)
|
| 140 |
+
np.save(y_test_path, y_test)
|
| 141 |
+
|
| 142 |
+
print("[+] Done extracting ResNet18 features.")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def parse_args():
|
| 146 |
+
parser = argparse.ArgumentParser(
|
| 147 |
+
description="Extract ResNet18 (pretrained) features for Oxford-IIIT Pet."
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--data-root",
|
| 151 |
+
type=str,
|
| 152 |
+
default="data/oxford-iiit-pet",
|
| 153 |
+
help="Root directory for Oxford-IIIT Pet dataset.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--out-dir",
|
| 157 |
+
type=str,
|
| 158 |
+
default="data/resnet18_features",
|
| 159 |
+
help="Directory to save .npy feature files.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--batch-size",
|
| 163 |
+
type=int,
|
| 164 |
+
default=64,
|
| 165 |
+
help="Batch size for feature extraction.",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--num-workers",
|
| 169 |
+
type=int,
|
| 170 |
+
default=2,
|
| 171 |
+
help="Num workers for dataloader.",
|
| 172 |
+
)
|
| 173 |
+
return parser.parse_args()
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
args = parse_args()
|
| 178 |
+
main(
|
| 179 |
+
data_root=args.data_root,
|
| 180 |
+
out_dir=args.out_dir,
|
| 181 |
+
batch_size=args.batch_size,
|
| 182 |
+
num_workers=args.num_workers,
|
| 183 |
+
)
|
src/training/train_lr.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision import datasets, transforms
|
| 10 |
+
|
| 11 |
+
from sklearn.linear_model import LogisticRegression
|
| 12 |
+
from sklearn.metrics import accuracy_score, classification_report
|
| 13 |
+
import joblib
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_datasets(data_root: str, image_size: int = 64) -> Tuple[torch.utils.data.Dataset,
|
| 17 |
+
torch.utils.data.Dataset,
|
| 18 |
+
dict]:
|
| 19 |
+
"""
|
| 20 |
+
Load Oxford-IIIT Pet train/test splits with simple transforms.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
train_dataset, test_dataset, class_to_idx
|
| 24 |
+
"""
|
| 25 |
+
# Simple transform: resize -> grayscale -> tensor in [0,1]
|
| 26 |
+
transform = transforms.Compose([
|
| 27 |
+
transforms.Resize((image_size, image_size)),
|
| 28 |
+
transforms.Grayscale(num_output_channels=1),
|
| 29 |
+
transforms.ToTensor(), # (1, H, W), float32 in [0,1]
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
train_dataset = datasets.OxfordIIITPet(
|
| 33 |
+
root=data_root,
|
| 34 |
+
split="trainval",
|
| 35 |
+
target_types="category",
|
| 36 |
+
transform=transform,
|
| 37 |
+
download=True, # downloads to root/oxford-iiit-pet if not present
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
test_dataset = datasets.OxfordIIITPet(
|
| 41 |
+
root=data_root,
|
| 42 |
+
split="test",
|
| 43 |
+
target_types="category",
|
| 44 |
+
transform=transform,
|
| 45 |
+
download=True,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# class_to_idx mapping
|
| 49 |
+
# Many torchvision datasets expose this attribute
|
| 50 |
+
class_to_idx = train_dataset.class_to_idx
|
| 51 |
+
|
| 52 |
+
return train_dataset, test_dataset, class_to_idx
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def dataset_to_numpy(dataset: torch.utils.data.Dataset) -> Tuple[np.ndarray, np.ndarray]:
|
| 56 |
+
"""
|
| 57 |
+
Convert a torchvision dataset (with tensor images) to numpy arrays
|
| 58 |
+
suitable for scikit-learn.
|
| 59 |
+
|
| 60 |
+
X: (N, D) flattened grayscale pixels
|
| 61 |
+
y: (N,) int labels
|
| 62 |
+
"""
|
| 63 |
+
X_list = []
|
| 64 |
+
y_list = []
|
| 65 |
+
|
| 66 |
+
for img, label in tqdm(dataset, desc="Converting to numpy"):
|
| 67 |
+
# img: torch.Tensor, shape (1, H, W)
|
| 68 |
+
arr = img.numpy() # (1, H, W)
|
| 69 |
+
arr = arr.reshape(-1) # flatten to (D,)
|
| 70 |
+
X_list.append(arr)
|
| 71 |
+
y_list.append(label)
|
| 72 |
+
|
| 73 |
+
X = np.stack(X_list, axis=0).astype(np.float32) # (N, D)
|
| 74 |
+
y = np.array(y_list, dtype=np.int64) # (N,)
|
| 75 |
+
|
| 76 |
+
return X, y
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_labels(class_to_idx: dict, labels_path: str):
|
| 80 |
+
"""
|
| 81 |
+
Save labels as id -> class_name in a JSON file for inference/UI.
|
| 82 |
+
"""
|
| 83 |
+
# Invert mapping: idx -> class_name
|
| 84 |
+
idx_to_class = {idx: cls_name for cls_name, idx in class_to_idx.items()}
|
| 85 |
+
|
| 86 |
+
os.makedirs(os.path.dirname(labels_path), exist_ok=True)
|
| 87 |
+
with open(labels_path, "w") as f:
|
| 88 |
+
json.dump(idx_to_class, f, indent=2)
|
| 89 |
+
print(f"[INFO] Saved labels to {labels_path}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def train_logistic_regression(X_train: np.ndarray, y_train: np.ndarray) -> LogisticRegression:
|
| 93 |
+
"""
|
| 94 |
+
Train multinomial Logistic Regression on given features.
|
| 95 |
+
|
| 96 |
+
We use 'saga' because it supports multinomial loss and L1/L2,
|
| 97 |
+
and works decently with high-dimensional sparse-ish data.
|
| 98 |
+
"""
|
| 99 |
+
num_classes = len(np.unique(y_train))
|
| 100 |
+
print(f"[INFO] Training Logistic Regression on {X_train.shape[0]} samples, "
|
| 101 |
+
f"{X_train.shape[1]} features, {num_classes} classes")
|
| 102 |
+
|
| 103 |
+
clf = LogisticRegression(
|
| 104 |
+
penalty="l2",
|
| 105 |
+
C=1.0,
|
| 106 |
+
solver="saga",
|
| 107 |
+
multi_class="multinomial",
|
| 108 |
+
max_iter=1000,
|
| 109 |
+
n_jobs=-1,
|
| 110 |
+
verbose=1,
|
| 111 |
+
)
|
| 112 |
+
clf.fit(X_train, y_train)
|
| 113 |
+
return clf
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def evaluate_model(clf: LogisticRegression, X: np.ndarray, y: np.ndarray, split_name: str):
|
| 117 |
+
"""
|
| 118 |
+
Print accuracy and basic classification report for a given split.
|
| 119 |
+
"""
|
| 120 |
+
y_pred = clf.predict(X)
|
| 121 |
+
acc = accuracy_score(y, y_pred)
|
| 122 |
+
print(f"\n[{split_name}] Accuracy: {acc * 100:.2f}%")
|
| 123 |
+
print(f"[{split_name}] Classification report (macro avg at bottom):")
|
| 124 |
+
print(classification_report(y, y_pred, digits=3))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def main():
|
| 128 |
+
# -------- configs (tweak paths as needed) --------
|
| 129 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 130 |
+
data_root = os.path.join(project_root, "data")
|
| 131 |
+
checkpoints_dir = os.path.join(project_root, "checkpoints")
|
| 132 |
+
configs_dir = os.path.join(project_root, "configs")
|
| 133 |
+
|
| 134 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
| 135 |
+
os.makedirs(configs_dir, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
labels_path = os.path.join(configs_dir, "labels.json")
|
| 138 |
+
model_path = os.path.join(checkpoints_dir, "lr_model.joblib")
|
| 139 |
+
|
| 140 |
+
image_size = 64 # 64x64 grayscale baseline
|
| 141 |
+
# ------------------------------------------------
|
| 142 |
+
|
| 143 |
+
print("[INFO] Loading datasets...")
|
| 144 |
+
train_dataset, test_dataset, class_to_idx = get_datasets(data_root, image_size=image_size)
|
| 145 |
+
|
| 146 |
+
print(f"[INFO] Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
|
| 147 |
+
print(f"[INFO] Number of classes: {len(class_to_idx)}")
|
| 148 |
+
|
| 149 |
+
print("[INFO] Converting train split to numpy...")
|
| 150 |
+
X_train, y_train = dataset_to_numpy(train_dataset)
|
| 151 |
+
|
| 152 |
+
print("[INFO] Converting test split to numpy...")
|
| 153 |
+
X_test, y_test = dataset_to_numpy(test_dataset)
|
| 154 |
+
|
| 155 |
+
# Save label mapping for later inference
|
| 156 |
+
save_labels(class_to_idx, labels_path)
|
| 157 |
+
|
| 158 |
+
# Train LR
|
| 159 |
+
clf = train_logistic_regression(X_train, y_train)
|
| 160 |
+
|
| 161 |
+
# Evaluate
|
| 162 |
+
evaluate_model(clf, X_train, y_train, split_name="Train")
|
| 163 |
+
evaluate_model(clf, X_test, y_test, split_name="Test")
|
| 164 |
+
|
| 165 |
+
# Save model
|
| 166 |
+
joblib.dump(clf, model_path)
|
| 167 |
+
print(f"[INFO] Saved Logistic Regression model to {model_path}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
main()
|
src/training/train_resnet_pt_lr.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/training/train_resnet_pt_lr.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.linear_model import LogisticRegression
|
| 9 |
+
from sklearn.metrics import accuracy_score
|
| 10 |
+
import joblib
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_features(features_dir: str):
|
| 14 |
+
x_train_path = os.path.join(features_dir, "X_train_resnet18.npy")
|
| 15 |
+
y_train_path = os.path.join(features_dir, "y_train.npy")
|
| 16 |
+
x_test_path = os.path.join(features_dir, "X_test_resnet18.npy")
|
| 17 |
+
y_test_path = os.path.join(features_dir, "y_test.npy")
|
| 18 |
+
|
| 19 |
+
assert os.path.exists(x_train_path), f"Missing: {x_train_path}"
|
| 20 |
+
assert os.path.exists(y_train_path), f"Missing: {y_train_path}"
|
| 21 |
+
assert os.path.exists(x_test_path), f"Missing: {x_test_path}"
|
| 22 |
+
assert os.path.exists(y_test_path), f"Missing: {y_test_path}"
|
| 23 |
+
|
| 24 |
+
X_train = np.load(x_train_path)
|
| 25 |
+
y_train = np.load(y_train_path)
|
| 26 |
+
X_test = np.load(x_test_path)
|
| 27 |
+
y_test = np.load(y_test_path)
|
| 28 |
+
|
| 29 |
+
return X_train, y_train, X_test, y_test
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def main(
|
| 33 |
+
features_dir: str = "data/resnet18_features",
|
| 34 |
+
ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib",
|
| 35 |
+
labels_path: str = "configs/labels.json",
|
| 36 |
+
):
|
| 37 |
+
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
| 38 |
+
|
| 39 |
+
print(f"[+] Loading features from {features_dir} ...")
|
| 40 |
+
X_train, y_train, X_test, y_test = load_features(features_dir)
|
| 41 |
+
|
| 42 |
+
print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
|
| 43 |
+
print(f" X_test shape : {X_test.shape}, y_test shape : {y_test.shape}")
|
| 44 |
+
|
| 45 |
+
num_features = X_train.shape[1]
|
| 46 |
+
print(f"[+] Feature dimension: {num_features}")
|
| 47 |
+
|
| 48 |
+
# Labels mapping is not strictly needed for training, but we keep the path
|
| 49 |
+
# around for inference later.
|
| 50 |
+
if os.path.exists(labels_path):
|
| 51 |
+
with open(labels_path, "r") as f:
|
| 52 |
+
labels = json.load(f)
|
| 53 |
+
num_classes = len(labels)
|
| 54 |
+
print(f"[+] Loaded labels from {labels_path}, num_classes={num_classes}")
|
| 55 |
+
else:
|
| 56 |
+
print(f"[!] Warning: {labels_path} not found. Inference will need this later.")
|
| 57 |
+
labels = None
|
| 58 |
+
|
| 59 |
+
print("[+] Training Logistic Regression on ResNet18 features ...")
|
| 60 |
+
clf = LogisticRegression(
|
| 61 |
+
penalty="l2",
|
| 62 |
+
C=1.0,
|
| 63 |
+
solver="saga",
|
| 64 |
+
multi_class="multinomial",
|
| 65 |
+
max_iter=1000,
|
| 66 |
+
n_jobs=-1,
|
| 67 |
+
verbose=1,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
clf.fit(X_train, y_train)
|
| 71 |
+
|
| 72 |
+
print("[+] Evaluating ...")
|
| 73 |
+
y_pred_train = clf.predict(X_train)
|
| 74 |
+
y_pred_test = clf.predict(X_test)
|
| 75 |
+
|
| 76 |
+
train_acc = accuracy_score(y_train, y_pred_train)
|
| 77 |
+
test_acc = accuracy_score(y_test, y_pred_test)
|
| 78 |
+
|
| 79 |
+
print(f" Train accuracy: {train_acc:.4f}")
|
| 80 |
+
print(f" Test accuracy : {test_acc:.4f}")
|
| 81 |
+
|
| 82 |
+
print(f"[+] Saving LR head to {ckpt_path} ...")
|
| 83 |
+
payload = {
|
| 84 |
+
"model": clf,
|
| 85 |
+
"backbone": "resnet18_imagenet",
|
| 86 |
+
"feature_dim": int(num_features),
|
| 87 |
+
"labels_path": labels_path,
|
| 88 |
+
"train_acc": float(train_acc),
|
| 89 |
+
"test_acc": float(test_acc),
|
| 90 |
+
}
|
| 91 |
+
joblib.dump(payload, ckpt_path)
|
| 92 |
+
|
| 93 |
+
print("[+] Done training ResNet PT + LR.")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_args():
|
| 97 |
+
parser = argparse.ArgumentParser(
|
| 98 |
+
description="Train Logistic Regression head on ResNet18 (pretrained) features."
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--features-dir",
|
| 102 |
+
type=str,
|
| 103 |
+
default="data/resnet18_features",
|
| 104 |
+
help="Directory containing X_train_resnet18.npy etc.",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--ckpt-path",
|
| 108 |
+
type=str,
|
| 109 |
+
default="checkpoints/resnet_pt_lr_head.joblib",
|
| 110 |
+
help="Where to save LR head checkpoint.",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--labels-path",
|
| 114 |
+
type=str,
|
| 115 |
+
default="configs/labels.json",
|
| 116 |
+
help="Path to labels mapping JSON.",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return parser.parse_args()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
args = parse_args()
|
| 124 |
+
main(
|
| 125 |
+
features_dir=args.features_dir,
|
| 126 |
+
ckpt_path=args.ckpt_path,
|
| 127 |
+
labels_path=args.labels_path,
|
| 128 |
+
)
|
src/training/train_resnet_pt_svm.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/training/train_resnet_pt_svm.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.svm import LinearSVC
|
| 9 |
+
from sklearn.metrics import accuracy_score
|
| 10 |
+
import joblib
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_features(features_dir: str):
|
| 14 |
+
x_train_path = os.path.join(features_dir, "X_train_resnet18.npy")
|
| 15 |
+
y_train_path = os.path.join(features_dir, "y_train.npy")
|
| 16 |
+
x_test_path = os.path.join(features_dir, "X_test_resnet18.npy")
|
| 17 |
+
y_test_path = os.path.join(features_dir, "y_test.npy")
|
| 18 |
+
|
| 19 |
+
assert os.path.exists(x_train_path), f"Missing: {x_train_path}"
|
| 20 |
+
assert os.path.exists(y_train_path), f"Missing: {y_train_path}"
|
| 21 |
+
assert os.path.exists(x_test_path), f"Missing: {x_test_path}"
|
| 22 |
+
assert os.path.exists(y_test_path), f"Missing: {y_test_path}"
|
| 23 |
+
|
| 24 |
+
X_train = np.load(x_train_path)
|
| 25 |
+
y_train = np.load(y_train_path)
|
| 26 |
+
X_test = np.load(x_test_path)
|
| 27 |
+
y_test = np.load(y_test_path)
|
| 28 |
+
|
| 29 |
+
return X_train, y_train, X_test, y_test
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def main(
|
| 33 |
+
features_dir: str = "data/resnet18_features",
|
| 34 |
+
ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib",
|
| 35 |
+
labels_path: str = "configs/labels.json",
|
| 36 |
+
):
|
| 37 |
+
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
| 38 |
+
|
| 39 |
+
print(f"[+] Loading ResNet18 features from {features_dir} ...")
|
| 40 |
+
X_train, y_train, X_test, y_test = load_features(features_dir)
|
| 41 |
+
|
| 42 |
+
print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
|
| 43 |
+
print(f" X_test shape : {X_test.shape}, y_test shape : {y_test.shape}")
|
| 44 |
+
|
| 45 |
+
num_features = X_train.shape[1]
|
| 46 |
+
print(f"[+] Feature dimension: {num_features}")
|
| 47 |
+
|
| 48 |
+
# Labels mapping for logging / sanity
|
| 49 |
+
if os.path.exists(labels_path):
|
| 50 |
+
with open(labels_path, "r") as f:
|
| 51 |
+
labels = json.load(f)
|
| 52 |
+
num_classes = len(labels)
|
| 53 |
+
print(f"[+] Loaded labels from {labels_path}, num_classes={num_classes}")
|
| 54 |
+
else:
|
| 55 |
+
print(f"[!] Warning: {labels_path} not found. Inference will need this later.")
|
| 56 |
+
labels = None
|
| 57 |
+
|
| 58 |
+
print("[+] Training Linear SVM on ResNet18 features ...")
|
| 59 |
+
svm = LinearSVC(
|
| 60 |
+
C=1.0,
|
| 61 |
+
penalty="l2",
|
| 62 |
+
loss="squared_hinge",
|
| 63 |
+
max_iter=5000, # give it some room
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
svm.fit(X_train, y_train)
|
| 67 |
+
|
| 68 |
+
print("[+] Evaluating ...")
|
| 69 |
+
y_pred_train = svm.predict(X_train)
|
| 70 |
+
y_pred_test = svm.predict(X_test)
|
| 71 |
+
|
| 72 |
+
train_acc = accuracy_score(y_train, y_pred_train)
|
| 73 |
+
test_acc = accuracy_score(y_test, y_pred_test)
|
| 74 |
+
|
| 75 |
+
print(f" Train accuracy: {train_acc:.4f}")
|
| 76 |
+
print(f" Test accuracy : {test_acc:.4f}")
|
| 77 |
+
|
| 78 |
+
print(f"[+] Saving ResNet PT + SVM head to {ckpt_path} ...")
|
| 79 |
+
payload = {
|
| 80 |
+
"model": svm,
|
| 81 |
+
"backbone": "resnet18_imagenet",
|
| 82 |
+
"feature_dim": int(num_features),
|
| 83 |
+
"labels_path": labels_path,
|
| 84 |
+
"train_acc": float(train_acc),
|
| 85 |
+
"test_acc": float(test_acc),
|
| 86 |
+
}
|
| 87 |
+
joblib.dump(payload, ckpt_path)
|
| 88 |
+
|
| 89 |
+
print("[+] Done training ResNet PT + SVM.")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def parse_args():
|
| 93 |
+
parser = argparse.ArgumentParser(
|
| 94 |
+
description="Train Linear SVM head on ResNet18 (pretrained) features."
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--features-dir",
|
| 98 |
+
type=str,
|
| 99 |
+
default="data/resnet18_features",
|
| 100 |
+
help="Directory containing X_train_resnet18.npy etc.",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--ckpt-path",
|
| 104 |
+
type=str,
|
| 105 |
+
default="checkpoints/resnet_pt_svm_head.joblib",
|
| 106 |
+
help="Where to save SVM head checkpoint.",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--labels-path",
|
| 110 |
+
type=str,
|
| 111 |
+
default="configs/labels.json",
|
| 112 |
+
help="Path to labels mapping JSON.",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return parser.parse_args()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
args = parse_args()
|
| 120 |
+
main(
|
| 121 |
+
features_dir=args.features_dir,
|
| 122 |
+
ckpt_path=args.ckpt_path,
|
| 123 |
+
labels_path=args.labels_path,
|
| 124 |
+
)
|
src/training/train_svm.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/training/train_svm.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from torchvision import transforms, datasets
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from sklearn.svm import LinearSVC
|
| 13 |
+
from sklearn.metrics import accuracy_score
|
| 14 |
+
import joblib
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_transforms():
|
| 18 |
+
return transforms.Compose([
|
| 19 |
+
transforms.Resize((64, 64)),
|
| 20 |
+
transforms.Grayscale(num_output_channels=1),
|
| 21 |
+
transforms.ToTensor(), # (1, 64, 64) in [0, 1]
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_datasets(data_root: str):
|
| 26 |
+
tx = get_transforms()
|
| 27 |
+
|
| 28 |
+
train_ds = datasets.OxfordIIITPet(
|
| 29 |
+
root=data_root,
|
| 30 |
+
split="trainval",
|
| 31 |
+
target_types="category",
|
| 32 |
+
transform=tx,
|
| 33 |
+
download=True,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
test_ds = datasets.OxfordIIITPet(
|
| 37 |
+
root=data_root,
|
| 38 |
+
split="test",
|
| 39 |
+
target_types="category",
|
| 40 |
+
transform=tx,
|
| 41 |
+
download=True,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
return train_ds, test_ds
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def dataset_to_numpy(dataset):
|
| 48 |
+
"""
|
| 49 |
+
Convert a torchvision dataset to (X, y) numpy arrays.
|
| 50 |
+
X: (N, 4096) flattened grayscale pixels
|
| 51 |
+
y: (N,) integer labels
|
| 52 |
+
"""
|
| 53 |
+
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)
|
| 54 |
+
|
| 55 |
+
xs = []
|
| 56 |
+
ys = []
|
| 57 |
+
for images, targets in loader:
|
| 58 |
+
# images: (B, 1, 64, 64)
|
| 59 |
+
b = images.shape[0]
|
| 60 |
+
images = images.view(b, -1) # (B, 4096)
|
| 61 |
+
xs.append(images.numpy())
|
| 62 |
+
ys.append(targets.numpy())
|
| 63 |
+
|
| 64 |
+
X = np.concatenate(xs, axis=0)
|
| 65 |
+
y = np.concatenate(ys, axis=0)
|
| 66 |
+
return X, y
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def ensure_labels_json(train_ds, labels_path: str):
|
| 70 |
+
os.makedirs(os.path.dirname(labels_path), exist_ok=True)
|
| 71 |
+
|
| 72 |
+
if os.path.exists(labels_path):
|
| 73 |
+
with open(labels_path, "r") as f:
|
| 74 |
+
labels = json.load(f)
|
| 75 |
+
# sanity: if it already exists, just return
|
| 76 |
+
return labels
|
| 77 |
+
|
| 78 |
+
# OxfordIIITPet: category targets are indices into .categories
|
| 79 |
+
id_to_name = {i: name for i, name in enumerate(train_ds.categories)}
|
| 80 |
+
|
| 81 |
+
with open(labels_path, "w") as f:
|
| 82 |
+
json.dump(id_to_name, f, indent=2)
|
| 83 |
+
|
| 84 |
+
return id_to_name
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def train_svm(
|
| 88 |
+
data_root: str = "data/oxford-iiit-pet",
|
| 89 |
+
ckpt_path: str = "checkpoints/svm_model.joblib",
|
| 90 |
+
labels_path: str = "configs/labels.json",
|
| 91 |
+
):
|
| 92 |
+
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
| 93 |
+
|
| 94 |
+
print(f"[+] Loading datasets from {data_root} ...")
|
| 95 |
+
train_ds, test_ds = build_datasets(data_root)
|
| 96 |
+
|
| 97 |
+
print("[+] Building labels.json (if missing) ...")
|
| 98 |
+
labels = ensure_labels_json(train_ds, labels_path)
|
| 99 |
+
num_classes = len(labels)
|
| 100 |
+
print(f"[+] Num classes (from labels.json): {num_classes}")
|
| 101 |
+
|
| 102 |
+
print("[+] Converting train dataset to numpy features ...")
|
| 103 |
+
X_train, y_train = dataset_to_numpy(train_ds)
|
| 104 |
+
print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
|
| 105 |
+
|
| 106 |
+
print("[+] Converting test dataset to numpy features ...")
|
| 107 |
+
X_test, y_test = dataset_to_numpy(test_ds)
|
| 108 |
+
print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
|
| 109 |
+
|
| 110 |
+
print("[+] Training Linear SVM on raw pixels ...")
|
| 111 |
+
svm = LinearSVC(
|
| 112 |
+
C=1.0,
|
| 113 |
+
penalty="l2",
|
| 114 |
+
loss="squared_hinge",
|
| 115 |
+
max_iter=2000,
|
| 116 |
+
# dual=True (default) is fine when n_samples > n_features,
|
| 117 |
+
# which is the case here.
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
svm.fit(X_train, y_train)
|
| 121 |
+
|
| 122 |
+
print("[+] Evaluating on train and test ...")
|
| 123 |
+
y_pred_train = svm.predict(X_train)
|
| 124 |
+
y_pred_test = svm.predict(X_test)
|
| 125 |
+
|
| 126 |
+
train_acc = accuracy_score(y_train, y_pred_train)
|
| 127 |
+
test_acc = accuracy_score(y_test, y_pred_test)
|
| 128 |
+
|
| 129 |
+
print(f" Train accuracy: {train_acc:.4f}")
|
| 130 |
+
print(f" Test accuracy : {test_acc:.4f}")
|
| 131 |
+
|
| 132 |
+
print(f"[+] Saving SVM model to {ckpt_path} ...")
|
| 133 |
+
joblib.dump(
|
| 134 |
+
{
|
| 135 |
+
"model": svm,
|
| 136 |
+
"labels_path": labels_path,
|
| 137 |
+
"train_acc": float(train_acc),
|
| 138 |
+
"test_acc": float(test_acc),
|
| 139 |
+
},
|
| 140 |
+
ckpt_path,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
print("[+] Done.")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def parse_args():
|
| 147 |
+
parser = argparse.ArgumentParser(description="Train Linear SVM on raw pixel features.")
|
| 148 |
+
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--data-root",
|
| 151 |
+
type=str,
|
| 152 |
+
default="data/oxford-iiit-pet",
|
| 153 |
+
help="Root directory for Oxford-IIIT Pet dataset.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--ckpt-path",
|
| 157 |
+
type=str,
|
| 158 |
+
default="checkpoints/svm_model.joblib",
|
| 159 |
+
help="Where to save the trained SVM model.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--labels-path",
|
| 163 |
+
type=str,
|
| 164 |
+
default="configs/labels.json",
|
| 165 |
+
help="Path to labels mapping JSON.",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return parser.parse_args()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
args = parse_args()
|
| 173 |
+
train_svm(
|
| 174 |
+
data_root=args.data_root,
|
| 175 |
+
ckpt_path=args.ckpt_path,
|
| 176 |
+
labels_path=args.labels_path,
|
| 177 |
+
)
|