""" Evaluation script for CASWiT model. Evaluates a trained model on test/validation sets and computes metrics. """ import sys import yaml import logging from pathlib import Path import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm # Add project root to Python path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from model.build_model import build_model from dataset.definition_dataset import build_transforms from dataset.factory import build_eval_dataset from utils.metrics import compute_metrics_from_confusion, BoundaryIoUMeter from train.train import load_config, TrainConfig def evaluate_model(cfg: TrainConfig, checkpoint_path: str, split: str = "test"): """ Evaluate model on specified split. Args: cfg: Training configuration checkpoint_path: Path to model checkpoint split: Dataset split to evaluate ('test' or 'val') """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Validate checkpoint path checkpoint_path_obj = Path(checkpoint_path) if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file(): raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") # Load model model = build_model(cfg).to(device) # Load checkpoint print(f"Loading checkpoint from: {checkpoint_path}") state_dict = torch.load(checkpoint_path, map_location=device) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"Successfully loaded checkpoint from: {checkpoint_path}") if len(missing) > 0: print(f" Missing keys: {len(missing)}") if len(unexpected) > 0: print(f" Unexpected keys: {len(unexpected)}") if len(missing) == 0 and len(unexpected) == 0: print(f" Perfect match! All weights loaded successfully.") model.eval() dataset_name = cfg.dataset_name t = build_transforms() ds = build_eval_dataset(cfg, split=split, transform=t) dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True) # Evaluate criterion = torch.nn.CrossEntropyLoss(ignore_index=cfg.ignore_index) running_loss = 0.0 full_confmat = torch.zeros((cfg.num_classes, cfg.num_classes), dtype=torch.long, device=device) mbiou_meter = BoundaryIoUMeter(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index, radius=10) device_type = "cuda" if torch.cuda.is_available() else "cpu" with torch.inference_mode(): for batch in tqdm(dl, desc=f"Evaluating {split}"): # Handle datasets that return meta dict (URURHRLRDataset returns 5 values) if len(batch) == 5: images_hr, masks_hr, images_lr, masks_lr, _ = batch else: images_hr, masks_hr, images_lr, masks_lr = batch images_hr = images_hr.to(device, non_blocking=True) masks_hr = masks_hr.to(device, non_blocking=True) images_lr = images_lr.to(device, non_blocking=True) masks_lr = masks_lr.to(device, non_blocking=True) with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()): out = model(images_hr, images_lr) logits_hr = out["logits_hr"] logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False) loss = criterion(logits_hr, masks_hr) running_loss += float(loss.item()) preds = torch.argmax(logits_hr, dim=1) mbiou_meter.update(preds, masks_hr) valid = (masks_hr >= 0) & (masks_hr < cfg.num_classes) t = masks_hr[valid] p = preds[valid] cm = torch.bincount( (t * cfg.num_classes + p).view(-1), minlength=cfg.num_classes * cfg.num_classes ).reshape(cfg.num_classes, cfg.num_classes) full_confmat += cm avg_loss = running_loss / len(dl) confmat_np = full_confmat.cpu().numpy() metrics = compute_metrics_from_confusion(confmat_np) mbiou, bious_per_class = mbiou_meter.compute() metrics["mBIoU"] = mbiou metrics["BIoUs"] = bious_per_class.numpy() print(f"\n{split.upper()} Results:") print(f" Loss: {avg_loss:.4f}") print(f" mIoU: {metrics['mIoU']:.4f}") print(f" mF1: {metrics['mF1']:.4f}") print(f" Per-class IoU: {metrics['IoUs']}") print(f" mBIoU: {metrics['mBIoU']:.4f}") print(f" Per-class BIoU: {metrics['BIoUs']}") return metrics def main(): """Main evaluation function.""" import sys if len(sys.argv) < 3: print("Usage: python eval.py [split]") sys.exit(1) cfg_path = sys.argv[1] checkpoint_path = sys.argv[2] split = sys.argv[3] if len(sys.argv) > 3 else "test" logging.basicConfig(level=logging.INFO) cfg = load_config(cfg_path) evaluate_model(cfg, checkpoint_path, split) if __name__ == "__main__": main()