Spaces:
Sleeping
Sleeping
File size: 6,339 Bytes
52dd1ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# src/evaluation/eval_confusion.py
import argparse
from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
# Reuse the same dataset + model loading logic as eval_accuracy.py
from src.evaluation.eval_accuracy import load_test_dataset, load_model_direct
def load_class_names(labels_path: str = "configs/labels.json"):
"""
Try to load class names from labels.json.
This is written to be robust to a few likely formats:
- List: ["Abyssinian", "American Bulldog", ...]
- Dict with string keys: {"0": "Abyssinian", "1": "American Bulldog", ...}
- Dict with 'id_to_label': {"id_to_label": {"0": "Abyssinian", ...}}
If anything goes wrong, returns None and we’ll just use numeric class IDs on the axes.
"""
try:
with open(labels_path, "r") as f:
data = json.load(f)
except FileNotFoundError:
print(f"[WARN] labels file not found at {labels_path}, using numeric IDs.")
return None
except json.JSONDecodeError:
print(f"[WARN] Could not parse {labels_path}, using numeric IDs.")
return None
# Case 1: simple list
if isinstance(data, list):
return data
# Case 2: dict with 'id_to_label'
if isinstance(data, dict) and "id_to_label" in data:
id_to_label = data["id_to_label"]
# sort by integer key
keys = sorted(id_to_label.keys(), key=lambda k: int(k))
return [id_to_label[k] for k in keys]
# Case 3: dict mapping "0" -> "Abyssinian"
if isinstance(data, dict):
try:
keys = sorted(data.keys(), key=lambda k: int(k))
return [data[k] for k in keys]
except Exception:
pass
print(f"[WARN] Unrecognized labels.json format, using numeric IDs.")
return None
def collect_predictions(model_id: str, data_root: str):
"""
Run the given model across the Oxford-IIIT Pet test split and collect:
- y_true: ground-truth integer class indices
- y_pred: top-1 predicted class indices
Uses the same model API as eval_accuracy.py: model.predict(PIL, top_k=5)
"""
print(f"\n=== Collecting predictions for model: {model_id} ===")
dataset = load_test_dataset(data_root)
model = load_model_direct(model_id)
y_true = []
y_pred = []
for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
img, target = dataset[idx] # img: PIL.Image, target: int
# Same predict logic as eval_accuracy (support with/without top_k)
try:
result = model.predict(img, top_k=5)
except TypeError:
result = model.predict(img)
pred_id = int(result.get("class_id"))
y_true.append(int(target))
y_pred.append(pred_id)
y_true = np.array(y_true)
y_pred = np.array(y_pred)
print(f" Collected {len(y_true)} predictions.")
return y_true, y_pred
def plot_confusion_matrix(
cm: np.ndarray,
class_names,
title: str,
save_path: Path,
normalize: bool = True,
):
"""
Plot and save a confusion matrix.
If normalize=True, each row (true class) is normalized to sum to 1.
If class_names is None, we just use numeric indices on axes.
"""
if normalize:
cm = cm.astype("float")
row_sums = cm.sum(axis=1, keepdims=True)
cm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0)
num_classes = cm.shape[0]
plt.figure(figsize=(12, 10))
im = plt.imshow(cm, interpolation="nearest", cmap="viridis")
plt.title(title)
plt.colorbar(im, fraction=0.046, pad=0.04)
if class_names is not None and len(class_names) == num_classes:
tick_labels = class_names
else:
tick_labels = list(range(num_classes))
plt.xticks(
ticks=np.arange(num_classes),
labels=tick_labels,
rotation=90,
fontsize=6,
)
plt.yticks(
ticks=np.arange(num_classes),
labels=tick_labels,
fontsize=6,
)
plt.xlabel("Predicted class")
plt.ylabel("True class")
plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.close()
print(f" Saved confusion matrix plot to: {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory of Oxford-IIIT Pet dataset.",
)
parser.add_argument(
"--labels-path",
type=str,
default="configs/labels.json",
help="Path to labels.json (for axis names).",
)
parser.add_argument(
"--out-dir",
type=str,
default="outputs/confusion_matrices",
help="Directory to save confusion matrices and plots.",
)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Same set of models as eval_accuracy
model_ids = [
"lr_raw",
"svm_raw",
"resnet_pt_lr",
"resnet_pt_svm",
]
class_names = load_class_names(args.labels_path)
# y_true is identical for all models (same test split, same indexing),
# but for clarity we recompute per model; confusion_matrix only needs
# consistent labels (0..36) which we enforce below.
for model_id in model_ids:
y_true, y_pred = collect_predictions(model_id, args.data_root)
# Define a fixed label ordering (0..max) to get 37x37
num_classes = int(y_true.max()) + 1
labels = list(range(num_classes))
cm = confusion_matrix(y_true, y_pred, labels=labels)
# Save raw matrix for future analysis
npy_path = out_dir / f"cm_{model_id}.npy"
np.save(npy_path, cm)
print(f" Saved raw confusion matrix to: {npy_path}")
# Save a normalized plot
png_path = out_dir / f"cm_{model_id}.png"
title = f"Confusion Matrix ({model_id})"
plot_confusion_matrix(cm, class_names, title, png_path, normalize=True)
if __name__ == "__main__":
main()
|