import json from pathlib import Path import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import transforms, models from datasets import load_from_disk SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2" CKPT_PATH = Path("artifacts/model.pt") LABELS_PATH = Path("artifacts/label_names.json") IMG_SIZE = 224 BATCH_SIZE = 32 NUM_WORKERS = 0 def confusion_matrix_torch(y_true, y_pred, num_classes): cm = torch.zeros((num_classes, num_classes), dtype=torch.int64) for t, p in zip(y_true, y_pred): cm[t, p] += 1 return cm def precision_recall_f1(cm): # cm rows: true, cols: pred num_classes = cm.size(0) metrics = [] for i in range(num_classes): tp = cm[i, i].item() fp = cm[:, i].sum().item() - tp fn = cm[i, :].sum().item() - tp prec = tp / (tp + fp) if (tp + fp) else 0.0 rec = tp / (tp + fn) if (tp + fn) else 0.0 f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0 metrics.append((prec, rec, f1)) return metrics def main(): # Load label names (source of truth for readable reporting) label_names = json.loads(LABELS_PATH.read_text(encoding="utf-8")) num_classes = len(label_names) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device:", device) print("Classes:", label_names) # Load val split from disk splits = load_from_disk(SPLIT_DIR) val_ds = splits["val"] # Deterministic val transforms val_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def collate_val(batch): imgs = [val_tf(row["image"].convert("RGB")) for row in batch] labels = torch.tensor([row["label"] for row in batch], dtype=torch.long) return {"pixel_values": torch.stack(imgs), "labels": labels} val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_val) # Rebuild model architecture and load checkpoint weights model = models.resnet18(weights=None) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) ckpt = torch.load(CKPT_PATH, map_location="cpu") model.load_state_dict(ckpt["model_state_dict"]) model = model.to(device) model.eval() y_true_all = [] y_pred_all = [] with torch.no_grad(): for batch in val_loader: x = batch["pixel_values"].to(device) y = batch["labels"].to(device) logits = model(x) preds = logits.argmax(dim=1) y_true_all.append(y.cpu()) y_pred_all.append(preds.cpu()) y_true = torch.cat(y_true_all) y_pred = torch.cat(y_pred_all) acc = (y_true == y_pred).float().mean().item() print(f"\nVAL Accuracy: {acc:.4f}") cm = confusion_matrix_torch(y_true, y_pred, num_classes) print("\nConfusion Matrix (rows=true, cols=pred):") print(cm) metrics = precision_recall_f1(cm) print("\nPer-class metrics:") for i, (prec, rec, f1) in enumerate(metrics): print(f"- {label_names[i]:<10} | P {prec:.3f} | R {rec:.3f} | F1 {f1:.3f}") # Save CM for later reporting out_path = Path("artifacts/confusion_matrix.pt") torch.save({"confusion_matrix": cm, "label_names": label_names, "val_acc": acc}, out_path) print(f"\nSaved confusion matrix to: {out_path}") if __name__ == "__main__": main() # This script evaluates a trained ResNet18 model on the validation split of the # "comprehensive-car-damage" dataset, computes accuracy, confusion matrix, # precision, recall, and F1-score for each class, and saves the confusion matrix to disk.