File size: 5,266 Bytes
36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 1a8126c 36b4539 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """
Evaluation script for CASWiT model.
Evaluates a trained model on test/validation sets and computes metrics.
"""
import sys
import yaml
import logging
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from model.build_model import build_model
from dataset.definition_dataset import build_transforms
from dataset.factory import build_eval_dataset
from utils.metrics import compute_metrics_from_confusion, BoundaryIoUMeter
from train.train import load_config, TrainConfig
def evaluate_model(cfg: TrainConfig, checkpoint_path: str, split: str = "test"):
"""
Evaluate model on specified split.
Args:
cfg: Training configuration
checkpoint_path: Path to model checkpoint
split: Dataset split to evaluate ('test' or 'val')
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Validate checkpoint path
checkpoint_path_obj = Path(checkpoint_path)
if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
# Load model
model = build_model(cfg).to(device)
# Load checkpoint
print(f"Loading checkpoint from: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Successfully loaded checkpoint from: {checkpoint_path}")
if len(missing) > 0:
print(f" Missing keys: {len(missing)}")
if len(unexpected) > 0:
print(f" Unexpected keys: {len(unexpected)}")
if len(missing) == 0 and len(unexpected) == 0:
print(f" Perfect match! All weights loaded successfully.")
model.eval()
dataset_name = cfg.dataset_name
t = build_transforms()
ds = build_eval_dataset(cfg, split=split, transform=t)
dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
# Evaluate
criterion = torch.nn.CrossEntropyLoss(ignore_index=cfg.ignore_index)
running_loss = 0.0
full_confmat = torch.zeros((cfg.num_classes, cfg.num_classes), dtype=torch.long, device=device)
mbiou_meter = BoundaryIoUMeter(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index, radius=10)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
with torch.inference_mode():
for batch in tqdm(dl, desc=f"Evaluating {split}"):
# Handle datasets that return meta dict (URURHRLRDataset returns 5 values)
if len(batch) == 5:
images_hr, masks_hr, images_lr, masks_lr, _ = batch
else:
images_hr, masks_hr, images_lr, masks_lr = batch
images_hr = images_hr.to(device, non_blocking=True)
masks_hr = masks_hr.to(device, non_blocking=True)
images_lr = images_lr.to(device, non_blocking=True)
masks_lr = masks_lr.to(device, non_blocking=True)
with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
out = model(images_hr, images_lr)
logits_hr = out["logits_hr"]
logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
loss = criterion(logits_hr, masks_hr)
running_loss += float(loss.item())
preds = torch.argmax(logits_hr, dim=1)
mbiou_meter.update(preds, masks_hr)
valid = (masks_hr >= 0) & (masks_hr < cfg.num_classes)
t = masks_hr[valid]
p = preds[valid]
cm = torch.bincount(
(t * cfg.num_classes + p).view(-1),
minlength=cfg.num_classes * cfg.num_classes
).reshape(cfg.num_classes, cfg.num_classes)
full_confmat += cm
avg_loss = running_loss / len(dl)
confmat_np = full_confmat.cpu().numpy()
metrics = compute_metrics_from_confusion(confmat_np)
mbiou, bious_per_class = mbiou_meter.compute()
metrics["mBIoU"] = mbiou
metrics["BIoUs"] = bious_per_class.numpy()
print(f"\n{split.upper()} Results:")
print(f" Loss: {avg_loss:.4f}")
print(f" mIoU: {metrics['mIoU']:.4f}")
print(f" mF1: {metrics['mF1']:.4f}")
print(f" Per-class IoU: {metrics['IoUs']}")
print(f" mBIoU: {metrics['mBIoU']:.4f}")
print(f" Per-class BIoU: {metrics['BIoUs']}")
return metrics
def main():
"""Main evaluation function."""
import sys
if len(sys.argv) < 3:
print("Usage: python eval.py <config_path> <checkpoint_path> [split]")
sys.exit(1)
cfg_path = sys.argv[1]
checkpoint_path = sys.argv[2]
split = sys.argv[3] if len(sys.argv) > 3 else "test"
logging.basicConfig(level=logging.INFO)
cfg = load_config(cfg_path)
evaluate_model(cfg, checkpoint_path, split)
if __name__ == "__main__":
main()
|