Image Segmentation
English
File size: 5,266 Bytes
36b4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a8126c
 
 
 
36b4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a8126c
36b4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a8126c
36b4539
1a8126c
 
36b4539
 
 
 
 
 
1a8126c
 
 
36b4539
 
 
 
 
 
 
 
 
 
 
 
1a8126c
36b4539
 
 
 
 
 
 
 
1a8126c
36b4539
 
 
 
 
 
 
 
 
1a8126c
36b4539
 
 
 
1a8126c
 
 
36b4539
 
 
 
 
 
1a8126c
 
36b4539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Evaluation script for CASWiT model.

Evaluates a trained model on test/validation sets and computes metrics.
"""

import sys
import yaml
import logging
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

from model.build_model import build_model
from dataset.definition_dataset import build_transforms
from dataset.factory import build_eval_dataset
from utils.metrics import compute_metrics_from_confusion, BoundaryIoUMeter
from train.train import load_config, TrainConfig

def evaluate_model(cfg: TrainConfig, checkpoint_path: str, split: str = "test"):
    """
    Evaluate model on specified split.
    
    Args:
        cfg: Training configuration
        checkpoint_path: Path to model checkpoint
        split: Dataset split to evaluate ('test' or 'val')
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Validate checkpoint path
    checkpoint_path_obj = Path(checkpoint_path)
    if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
    
    # Load model
    model = build_model(cfg).to(device)
    
    # Load checkpoint
    print(f"Loading checkpoint from: {checkpoint_path}")
    state_dict = torch.load(checkpoint_path, map_location=device)
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    print(f"Successfully loaded checkpoint from: {checkpoint_path}")
    if len(missing) > 0:
        print(f"   Missing keys: {len(missing)}")
    if len(unexpected) > 0:
        print(f"   Unexpected keys: {len(unexpected)}")
    if len(missing) == 0 and len(unexpected) == 0:
        print(f"   Perfect match! All weights loaded successfully.")
    model.eval()
    dataset_name = cfg.dataset_name

    t = build_transforms()
    ds = build_eval_dataset(cfg, split=split, transform=t)

    dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
    
    # Evaluate
    criterion = torch.nn.CrossEntropyLoss(ignore_index=cfg.ignore_index)
    running_loss = 0.0
    full_confmat = torch.zeros((cfg.num_classes, cfg.num_classes), dtype=torch.long, device=device)
    mbiou_meter = BoundaryIoUMeter(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index, radius=10)
    device_type = "cuda" if torch.cuda.is_available() else "cpu"

    with torch.inference_mode():
        for batch in tqdm(dl, desc=f"Evaluating {split}"):
            # Handle datasets that return meta dict (URURHRLRDataset returns 5 values)
            if len(batch) == 5:
                images_hr, masks_hr, images_lr, masks_lr, _ = batch
            else:
                images_hr, masks_hr, images_lr, masks_lr = batch
            images_hr = images_hr.to(device, non_blocking=True)
            masks_hr = masks_hr.to(device, non_blocking=True)
            images_lr = images_lr.to(device, non_blocking=True)
            masks_lr = masks_lr.to(device, non_blocking=True)
            
            with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                out = model(images_hr, images_lr)
                logits_hr = out["logits_hr"]
                logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
                loss = criterion(logits_hr, masks_hr)
            
            running_loss += float(loss.item())
            
            preds = torch.argmax(logits_hr, dim=1)
            mbiou_meter.update(preds, masks_hr)
            valid = (masks_hr >= 0) & (masks_hr < cfg.num_classes)
            t = masks_hr[valid]
            p = preds[valid]
            
            cm = torch.bincount(
                (t * cfg.num_classes + p).view(-1),
                minlength=cfg.num_classes * cfg.num_classes
            ).reshape(cfg.num_classes, cfg.num_classes)
            full_confmat += cm
            
    
    avg_loss = running_loss / len(dl)
    confmat_np = full_confmat.cpu().numpy()
    metrics = compute_metrics_from_confusion(confmat_np)
    mbiou, bious_per_class = mbiou_meter.compute()
    metrics["mBIoU"] = mbiou
    metrics["BIoUs"] = bious_per_class.numpy()
    
    print(f"\n{split.upper()} Results:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  mIoU: {metrics['mIoU']:.4f}")
    print(f"  mF1: {metrics['mF1']:.4f}")
    print(f"  Per-class IoU: {metrics['IoUs']}")
    print(f"  mBIoU: {metrics['mBIoU']:.4f}")
    print(f"  Per-class BIoU: {metrics['BIoUs']}")
    
    return metrics


def main():
    """Main evaluation function."""
    import sys
    if len(sys.argv) < 3:
        print("Usage: python eval.py <config_path> <checkpoint_path> [split]")
        sys.exit(1)
    
    cfg_path = sys.argv[1]
    checkpoint_path = sys.argv[2]
    split = sys.argv[3] if len(sys.argv) > 3 else "test"
    
    logging.basicConfig(level=logging.INFO)
    cfg = load_config(cfg_path)
    evaluate_model(cfg, checkpoint_path, split)


if __name__ == "__main__":
    main()