vehicle-damage-classifier / src /step8_evaluate.py
efnanaladagg's picture
Clean push
6f6eb85
import json
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
from datasets import load_from_disk
SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
CKPT_PATH = Path("artifacts/model.pt")
LABELS_PATH = Path("artifacts/label_names.json")
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 0
def confusion_matrix_torch(y_true, y_pred, num_classes):
cm = torch.zeros((num_classes, num_classes), dtype=torch.int64)
for t, p in zip(y_true, y_pred):
cm[t, p] += 1
return cm
def precision_recall_f1(cm):
# cm rows: true, cols: pred
num_classes = cm.size(0)
metrics = []
for i in range(num_classes):
tp = cm[i, i].item()
fp = cm[:, i].sum().item() - tp
fn = cm[i, :].sum().item() - tp
prec = tp / (tp + fp) if (tp + fp) else 0.0
rec = tp / (tp + fn) if (tp + fn) else 0.0
f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0
metrics.append((prec, rec, f1))
return metrics
def main():
# Load label names (source of truth for readable reporting)
label_names = json.loads(LABELS_PATH.read_text(encoding="utf-8"))
num_classes = len(label_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("Classes:", label_names)
# Load val split from disk
splits = load_from_disk(SPLIT_DIR)
val_ds = splits["val"]
# Deterministic val transforms
val_tf = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def collate_val(batch):
imgs = [val_tf(row["image"].convert("RGB")) for row in batch]
labels = torch.tensor([row["label"] for row in batch], dtype=torch.long)
return {"pixel_values": torch.stack(imgs), "labels": labels}
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, collate_fn=collate_val)
# Rebuild model architecture and load checkpoint weights
model = models.resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
y_true_all = []
y_pred_all = []
with torch.no_grad():
for batch in val_loader:
x = batch["pixel_values"].to(device)
y = batch["labels"].to(device)
logits = model(x)
preds = logits.argmax(dim=1)
y_true_all.append(y.cpu())
y_pred_all.append(preds.cpu())
y_true = torch.cat(y_true_all)
y_pred = torch.cat(y_pred_all)
acc = (y_true == y_pred).float().mean().item()
print(f"\nVAL Accuracy: {acc:.4f}")
cm = confusion_matrix_torch(y_true, y_pred, num_classes)
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm)
metrics = precision_recall_f1(cm)
print("\nPer-class metrics:")
for i, (prec, rec, f1) in enumerate(metrics):
print(f"- {label_names[i]:<10} | P {prec:.3f} | R {rec:.3f} | F1 {f1:.3f}")
# Save CM for later reporting
out_path = Path("artifacts/confusion_matrix.pt")
torch.save({"confusion_matrix": cm, "label_names": label_names, "val_acc": acc}, out_path)
print(f"\nSaved confusion matrix to: {out_path}")
if __name__ == "__main__":
main()
# This script evaluates a trained ResNet18 model on the validation split of the
# "comprehensive-car-damage" dataset, computes accuracy, confusion matrix,
# precision, recall, and F1-score for each class, and saves the confusion matrix to disk.