|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms, models |
|
|
from datasets import load_from_disk |
|
|
|
|
|
SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2" |
|
|
CKPT_PATH = Path("artifacts/model.pt") |
|
|
LABELS_PATH = Path("artifacts/label_names.json") |
|
|
|
|
|
IMG_SIZE = 224 |
|
|
BATCH_SIZE = 32 |
|
|
NUM_WORKERS = 0 |
|
|
|
|
|
def confusion_matrix_torch(y_true, y_pred, num_classes): |
|
|
cm = torch.zeros((num_classes, num_classes), dtype=torch.int64) |
|
|
for t, p in zip(y_true, y_pred): |
|
|
cm[t, p] += 1 |
|
|
return cm |
|
|
|
|
|
def precision_recall_f1(cm): |
|
|
|
|
|
num_classes = cm.size(0) |
|
|
metrics = [] |
|
|
for i in range(num_classes): |
|
|
tp = cm[i, i].item() |
|
|
fp = cm[:, i].sum().item() - tp |
|
|
fn = cm[i, :].sum().item() - tp |
|
|
|
|
|
prec = tp / (tp + fp) if (tp + fp) else 0.0 |
|
|
rec = tp / (tp + fn) if (tp + fn) else 0.0 |
|
|
f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) else 0.0 |
|
|
metrics.append((prec, rec, f1)) |
|
|
return metrics |
|
|
|
|
|
def main(): |
|
|
|
|
|
label_names = json.loads(LABELS_PATH.read_text(encoding="utf-8")) |
|
|
num_classes = len(label_names) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print("Device:", device) |
|
|
print("Classes:", label_names) |
|
|
|
|
|
|
|
|
splits = load_from_disk(SPLIT_DIR) |
|
|
val_ds = splits["val"] |
|
|
|
|
|
|
|
|
val_tf = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def collate_val(batch): |
|
|
imgs = [val_tf(row["image"].convert("RGB")) for row in batch] |
|
|
labels = torch.tensor([row["label"] for row in batch], dtype=torch.long) |
|
|
return {"pixel_values": torch.stack(imgs), "labels": labels} |
|
|
|
|
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, |
|
|
num_workers=NUM_WORKERS, collate_fn=collate_val) |
|
|
|
|
|
|
|
|
model = models.resnet18(weights=None) |
|
|
in_features = model.fc.in_features |
|
|
model.fc = nn.Linear(in_features, num_classes) |
|
|
|
|
|
ckpt = torch.load(CKPT_PATH, map_location="cpu") |
|
|
model.load_state_dict(ckpt["model_state_dict"]) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
y_true_all = [] |
|
|
y_pred_all = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
x = batch["pixel_values"].to(device) |
|
|
y = batch["labels"].to(device) |
|
|
|
|
|
logits = model(x) |
|
|
preds = logits.argmax(dim=1) |
|
|
|
|
|
y_true_all.append(y.cpu()) |
|
|
y_pred_all.append(preds.cpu()) |
|
|
|
|
|
y_true = torch.cat(y_true_all) |
|
|
y_pred = torch.cat(y_pred_all) |
|
|
|
|
|
acc = (y_true == y_pred).float().mean().item() |
|
|
print(f"\nVAL Accuracy: {acc:.4f}") |
|
|
|
|
|
cm = confusion_matrix_torch(y_true, y_pred, num_classes) |
|
|
print("\nConfusion Matrix (rows=true, cols=pred):") |
|
|
print(cm) |
|
|
|
|
|
metrics = precision_recall_f1(cm) |
|
|
print("\nPer-class metrics:") |
|
|
for i, (prec, rec, f1) in enumerate(metrics): |
|
|
print(f"- {label_names[i]:<10} | P {prec:.3f} | R {rec:.3f} | F1 {f1:.3f}") |
|
|
|
|
|
|
|
|
out_path = Path("artifacts/confusion_matrix.pt") |
|
|
torch.save({"confusion_matrix": cm, "label_names": label_names, "val_acc": acc}, out_path) |
|
|
print(f"\nSaved confusion matrix to: {out_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|