Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/evaluation/eval_confusion.py
import argparse
from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
# Reuse the same dataset + model loading logic as eval_accuracy.py
from src.evaluation.eval_accuracy import load_test_dataset, load_model_direct
def load_class_names(labels_path: str = "configs/labels.json"):
"""
Try to load class names from labels.json.
This is written to be robust to a few likely formats:
- List: ["Abyssinian", "American Bulldog", ...]
- Dict with string keys: {"0": "Abyssinian", "1": "American Bulldog", ...}
- Dict with 'id_to_label': {"id_to_label": {"0": "Abyssinian", ...}}
If anything goes wrong, returns None and we’ll just use numeric class IDs on the axes.
"""
try:
with open(labels_path, "r") as f:
data = json.load(f)
except FileNotFoundError:
print(f"[WARN] labels file not found at {labels_path}, using numeric IDs.")
return None
except json.JSONDecodeError:
print(f"[WARN] Could not parse {labels_path}, using numeric IDs.")
return None
# Case 1: simple list
if isinstance(data, list):
return data
# Case 2: dict with 'id_to_label'
if isinstance(data, dict) and "id_to_label" in data:
id_to_label = data["id_to_label"]
# sort by integer key
keys = sorted(id_to_label.keys(), key=lambda k: int(k))
return [id_to_label[k] for k in keys]
# Case 3: dict mapping "0" -> "Abyssinian"
if isinstance(data, dict):
try:
keys = sorted(data.keys(), key=lambda k: int(k))
return [data[k] for k in keys]
except Exception:
pass
print(f"[WARN] Unrecognized labels.json format, using numeric IDs.")
return None
def collect_predictions(model_id: str, data_root: str):
"""
Run the given model across the Oxford-IIIT Pet test split and collect:
- y_true: ground-truth integer class indices
- y_pred: top-1 predicted class indices
Uses the same model API as eval_accuracy.py: model.predict(PIL, top_k=5)
"""
print(f"\n=== Collecting predictions for model: {model_id} ===")
dataset = load_test_dataset(data_root)
model = load_model_direct(model_id)
y_true = []
y_pred = []
for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"):
img, target = dataset[idx] # img: PIL.Image, target: int
# Same predict logic as eval_accuracy (support with/without top_k)
try:
result = model.predict(img, top_k=5)
except TypeError:
result = model.predict(img)
pred_id = int(result.get("class_id"))
y_true.append(int(target))
y_pred.append(pred_id)
y_true = np.array(y_true)
y_pred = np.array(y_pred)
print(f" Collected {len(y_true)} predictions.")
return y_true, y_pred
def plot_confusion_matrix(
cm: np.ndarray,
class_names,
title: str,
save_path: Path,
normalize: bool = True,
):
"""
Plot and save a confusion matrix.
If normalize=True, each row (true class) is normalized to sum to 1.
If class_names is None, we just use numeric indices on axes.
"""
if normalize:
cm = cm.astype("float")
row_sums = cm.sum(axis=1, keepdims=True)
cm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0)
num_classes = cm.shape[0]
plt.figure(figsize=(12, 10))
im = plt.imshow(cm, interpolation="nearest", cmap="viridis")
plt.title(title)
plt.colorbar(im, fraction=0.046, pad=0.04)
if class_names is not None and len(class_names) == num_classes:
tick_labels = class_names
else:
tick_labels = list(range(num_classes))
plt.xticks(
ticks=np.arange(num_classes),
labels=tick_labels,
rotation=90,
fontsize=6,
)
plt.yticks(
ticks=np.arange(num_classes),
labels=tick_labels,
fontsize=6,
)
plt.xlabel("Predicted class")
plt.ylabel("True class")
plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.close()
print(f" Saved confusion matrix plot to: {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory of Oxford-IIIT Pet dataset.",
)
parser.add_argument(
"--labels-path",
type=str,
default="configs/labels.json",
help="Path to labels.json (for axis names).",
)
parser.add_argument(
"--out-dir",
type=str,
default="outputs/confusion_matrices",
help="Directory to save confusion matrices and plots.",
)
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Same set of models as eval_accuracy
model_ids = [
"lr_raw",
"svm_raw",
"resnet_pt_lr",
"resnet_pt_svm",
]
class_names = load_class_names(args.labels_path)
# y_true is identical for all models (same test split, same indexing),
# but for clarity we recompute per model; confusion_matrix only needs
# consistent labels (0..36) which we enforce below.
for model_id in model_ids:
y_true, y_pred = collect_predictions(model_id, args.data_root)
# Define a fixed label ordering (0..max) to get 37x37
num_classes = int(y_true.max()) + 1
labels = list(range(num_classes))
cm = confusion_matrix(y_true, y_pred, labels=labels)
# Save raw matrix for future analysis
npy_path = out_dir / f"cm_{model_id}.npy"
np.save(npy_path, cm)
print(f" Saved raw confusion matrix to: {npy_path}")
# Save a normalized plot
png_path = out_dir / f"cm_{model_id}.png"
title = f"Confusion Matrix ({model_id})"
plot_confusion_matrix(cm, class_names, title, png_path, normalize=True)
if __name__ == "__main__":
main()