File size: 3,888 Bytes
6f6eb85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
from datasets import load_from_disk

SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
CKPT_PATH = Path("artifacts/model.pt")
LABELS_PATH = Path("artifacts/label_names.json")

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 0

def confusion_matrix_torch(y_true, y_pred, num_classes):
    cm = torch.zeros((num_classes, num_classes), dtype=torch.int64)
    for t, p in zip(y_true, y_pred):
        cm[t, p] += 1
    return cm

def precision_recall_f1(cm):
    # cm rows: true, cols: pred
    num_classes = cm.size(0)
    metrics = []
    for i in range(num_classes):
        tp = cm[i, i].item()
        fp = cm[:, i].sum().item() - tp
        fn = cm[i, :].sum().item() - tp

        prec = tp / (tp + fp) if (tp + fp) else 0.0
        rec  = tp / (tp + fn) if (tp + fn) else 0.0
        f1   = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0
        metrics.append((prec, rec, f1))
    return metrics

def main():
    # Load label names (source of truth for readable reporting)
    label_names = json.loads(LABELS_PATH.read_text(encoding="utf-8"))
    num_classes = len(label_names)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)
    print("Classes:", label_names)

    # Load val split from disk
    splits = load_from_disk(SPLIT_DIR)
    val_ds = splits["val"]

    # Deterministic val transforms
    val_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    def collate_val(batch):
        imgs = [val_tf(row["image"].convert("RGB")) for row in batch]
        labels = torch.tensor([row["label"] for row in batch], dtype=torch.long)
        return {"pixel_values": torch.stack(imgs), "labels": labels}

    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, collate_fn=collate_val)

    # Rebuild model architecture and load checkpoint weights
    model = models.resnet18(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)

    ckpt = torch.load(CKPT_PATH, map_location="cpu")
    model.load_state_dict(ckpt["model_state_dict"])
    model = model.to(device)
    model.eval()

    y_true_all = []
    y_pred_all = []

    with torch.no_grad():
        for batch in val_loader:
            x = batch["pixel_values"].to(device)
            y = batch["labels"].to(device)

            logits = model(x)
            preds = logits.argmax(dim=1)

            y_true_all.append(y.cpu())
            y_pred_all.append(preds.cpu())

    y_true = torch.cat(y_true_all)
    y_pred = torch.cat(y_pred_all)

    acc = (y_true == y_pred).float().mean().item()
    print(f"\nVAL Accuracy: {acc:.4f}")

    cm = confusion_matrix_torch(y_true, y_pred, num_classes)
    print("\nConfusion Matrix (rows=true, cols=pred):")
    print(cm)

    metrics = precision_recall_f1(cm)
    print("\nPer-class metrics:")
    for i, (prec, rec, f1) in enumerate(metrics):
        print(f"- {label_names[i]:<10} | P {prec:.3f} | R {rec:.3f} | F1 {f1:.3f}")

    # Save CM for later reporting
    out_path = Path("artifacts/confusion_matrix.pt")
    torch.save({"confusion_matrix": cm, "label_names": label_names, "val_acc": acc}, out_path)
    print(f"\nSaved confusion matrix to: {out_path}")

if __name__ == "__main__":
    main()
# This script evaluates a trained ResNet18 model on the validation split of the
# "comprehensive-car-damage" dataset, computes accuracy, confusion matrix,
# precision, recall, and F1-score for each class, and saves the confusion matrix to disk.