|
|
""" |
|
|
Evaluation script for CASWiT model. |
|
|
|
|
|
Evaluates a trained model on test/validation sets and computes metrics. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import yaml |
|
|
import logging |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from model.build_model import build_model |
|
|
from dataset.definition_dataset import build_transforms |
|
|
from dataset.factory import build_eval_dataset |
|
|
from utils.metrics import compute_metrics_from_confusion, BoundaryIoUMeter |
|
|
from train.train import load_config, TrainConfig |
|
|
|
|
|
def evaluate_model(cfg: TrainConfig, checkpoint_path: str, split: str = "test"): |
|
|
""" |
|
|
Evaluate model on specified split. |
|
|
|
|
|
Args: |
|
|
cfg: Training configuration |
|
|
checkpoint_path: Path to model checkpoint |
|
|
split: Dataset split to evaluate ('test' or 'val') |
|
|
""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
checkpoint_path_obj = Path(checkpoint_path) |
|
|
if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file(): |
|
|
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") |
|
|
|
|
|
|
|
|
model = build_model(cfg).to(device) |
|
|
|
|
|
|
|
|
print(f"Loading checkpoint from: {checkpoint_path}") |
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
print(f"Successfully loaded checkpoint from: {checkpoint_path}") |
|
|
if len(missing) > 0: |
|
|
print(f" Missing keys: {len(missing)}") |
|
|
if len(unexpected) > 0: |
|
|
print(f" Unexpected keys: {len(unexpected)}") |
|
|
if len(missing) == 0 and len(unexpected) == 0: |
|
|
print(f" Perfect match! All weights loaded successfully.") |
|
|
model.eval() |
|
|
dataset_name = cfg.dataset_name |
|
|
|
|
|
t = build_transforms() |
|
|
ds = build_eval_dataset(cfg, split=split, transform=t) |
|
|
|
|
|
dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True) |
|
|
|
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=cfg.ignore_index) |
|
|
running_loss = 0.0 |
|
|
full_confmat = torch.zeros((cfg.num_classes, cfg.num_classes), dtype=torch.long, device=device) |
|
|
mbiou_meter = BoundaryIoUMeter(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index, radius=10) |
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for batch in tqdm(dl, desc=f"Evaluating {split}"): |
|
|
|
|
|
if len(batch) == 5: |
|
|
images_hr, masks_hr, images_lr, masks_lr, _ = batch |
|
|
else: |
|
|
images_hr, masks_hr, images_lr, masks_lr = batch |
|
|
images_hr = images_hr.to(device, non_blocking=True) |
|
|
masks_hr = masks_hr.to(device, non_blocking=True) |
|
|
images_lr = images_lr.to(device, non_blocking=True) |
|
|
masks_lr = masks_lr.to(device, non_blocking=True) |
|
|
|
|
|
with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()): |
|
|
out = model(images_hr, images_lr) |
|
|
logits_hr = out["logits_hr"] |
|
|
logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False) |
|
|
loss = criterion(logits_hr, masks_hr) |
|
|
|
|
|
running_loss += float(loss.item()) |
|
|
|
|
|
preds = torch.argmax(logits_hr, dim=1) |
|
|
mbiou_meter.update(preds, masks_hr) |
|
|
valid = (masks_hr >= 0) & (masks_hr < cfg.num_classes) |
|
|
t = masks_hr[valid] |
|
|
p = preds[valid] |
|
|
|
|
|
cm = torch.bincount( |
|
|
(t * cfg.num_classes + p).view(-1), |
|
|
minlength=cfg.num_classes * cfg.num_classes |
|
|
).reshape(cfg.num_classes, cfg.num_classes) |
|
|
full_confmat += cm |
|
|
|
|
|
|
|
|
avg_loss = running_loss / len(dl) |
|
|
confmat_np = full_confmat.cpu().numpy() |
|
|
metrics = compute_metrics_from_confusion(confmat_np) |
|
|
mbiou, bious_per_class = mbiou_meter.compute() |
|
|
metrics["mBIoU"] = mbiou |
|
|
metrics["BIoUs"] = bious_per_class.numpy() |
|
|
|
|
|
print(f"\n{split.upper()} Results:") |
|
|
print(f" Loss: {avg_loss:.4f}") |
|
|
print(f" mIoU: {metrics['mIoU']:.4f}") |
|
|
print(f" mF1: {metrics['mF1']:.4f}") |
|
|
print(f" Per-class IoU: {metrics['IoUs']}") |
|
|
print(f" mBIoU: {metrics['mBIoU']:.4f}") |
|
|
print(f" Per-class BIoU: {metrics['BIoUs']}") |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main evaluation function.""" |
|
|
import sys |
|
|
if len(sys.argv) < 3: |
|
|
print("Usage: python eval.py <config_path> <checkpoint_path> [split]") |
|
|
sys.exit(1) |
|
|
|
|
|
cfg_path = sys.argv[1] |
|
|
checkpoint_path = sys.argv[2] |
|
|
split = sys.argv[3] if len(sys.argv) > 3 else "test" |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
cfg = load_config(cfg_path) |
|
|
evaluate_model(cfg, checkpoint_path, split) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|