Image Segmentation
English
CASWiT / train /eval.py
antoine.carreaud67
Adding vrt inference/update
1a8126c
"""
Evaluation script for CASWiT model.
Evaluates a trained model on test/validation sets and computes metrics.
"""
import sys
import yaml
import logging
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from model.build_model import build_model
from dataset.definition_dataset import build_transforms
from dataset.factory import build_eval_dataset
from utils.metrics import compute_metrics_from_confusion, BoundaryIoUMeter
from train.train import load_config, TrainConfig
def evaluate_model(cfg: TrainConfig, checkpoint_path: str, split: str = "test"):
"""
Evaluate model on specified split.
Args:
cfg: Training configuration
checkpoint_path: Path to model checkpoint
split: Dataset split to evaluate ('test' or 'val')
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Validate checkpoint path
checkpoint_path_obj = Path(checkpoint_path)
if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
# Load model
model = build_model(cfg).to(device)
# Load checkpoint
print(f"Loading checkpoint from: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Successfully loaded checkpoint from: {checkpoint_path}")
if len(missing) > 0:
print(f" Missing keys: {len(missing)}")
if len(unexpected) > 0:
print(f" Unexpected keys: {len(unexpected)}")
if len(missing) == 0 and len(unexpected) == 0:
print(f" Perfect match! All weights loaded successfully.")
model.eval()
dataset_name = cfg.dataset_name
t = build_transforms()
ds = build_eval_dataset(cfg, split=split, transform=t)
dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
# Evaluate
criterion = torch.nn.CrossEntropyLoss(ignore_index=cfg.ignore_index)
running_loss = 0.0
full_confmat = torch.zeros((cfg.num_classes, cfg.num_classes), dtype=torch.long, device=device)
mbiou_meter = BoundaryIoUMeter(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index, radius=10)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
with torch.inference_mode():
for batch in tqdm(dl, desc=f"Evaluating {split}"):
# Handle datasets that return meta dict (URURHRLRDataset returns 5 values)
if len(batch) == 5:
images_hr, masks_hr, images_lr, masks_lr, _ = batch
else:
images_hr, masks_hr, images_lr, masks_lr = batch
images_hr = images_hr.to(device, non_blocking=True)
masks_hr = masks_hr.to(device, non_blocking=True)
images_lr = images_lr.to(device, non_blocking=True)
masks_lr = masks_lr.to(device, non_blocking=True)
with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
out = model(images_hr, images_lr)
logits_hr = out["logits_hr"]
logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
loss = criterion(logits_hr, masks_hr)
running_loss += float(loss.item())
preds = torch.argmax(logits_hr, dim=1)
mbiou_meter.update(preds, masks_hr)
valid = (masks_hr >= 0) & (masks_hr < cfg.num_classes)
t = masks_hr[valid]
p = preds[valid]
cm = torch.bincount(
(t * cfg.num_classes + p).view(-1),
minlength=cfg.num_classes * cfg.num_classes
).reshape(cfg.num_classes, cfg.num_classes)
full_confmat += cm
avg_loss = running_loss / len(dl)
confmat_np = full_confmat.cpu().numpy()
metrics = compute_metrics_from_confusion(confmat_np)
mbiou, bious_per_class = mbiou_meter.compute()
metrics["mBIoU"] = mbiou
metrics["BIoUs"] = bious_per_class.numpy()
print(f"\n{split.upper()} Results:")
print(f" Loss: {avg_loss:.4f}")
print(f" mIoU: {metrics['mIoU']:.4f}")
print(f" mF1: {metrics['mF1']:.4f}")
print(f" Per-class IoU: {metrics['IoUs']}")
print(f" mBIoU: {metrics['mBIoU']:.4f}")
print(f" Per-class BIoU: {metrics['BIoUs']}")
return metrics
def main():
"""Main evaluation function."""
import sys
if len(sys.argv) < 3:
print("Usage: python eval.py <config_path> <checkpoint_path> [split]")
sys.exit(1)
cfg_path = sys.argv[1]
checkpoint_path = sys.argv[2]
split = sys.argv[3] if len(sys.argv) > 3 else "test"
logging.basicConfig(level=logging.INFO)
cfg = load_config(cfg_path)
evaluate_model(cfg, checkpoint_path, split)
if __name__ == "__main__":
main()