| | |
| | """ |
| | ISDNet Inference Script for FLAIR Dataset |
| | |
| | Evaluates a trained model on the test set with per-class IoU metrics. |
| | |
| | Usage: |
| | python inference.py --checkpoint isdnet_flair_best.pth |
| | python inference.py --checkpoint isdnet_flair_best.pth --split valid |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| |
|
| | from isdnet import ISDNet, FLAIRDataset, NUM_CLASSES, CROP_SIZE, DOWN_RATIO, IGNORE_INDEX |
| | from isdnet.config import DATA_ROOT |
| |
|
| |
|
| | |
| | CLASS_NAMES = [ |
| | 'building', 'pervious', 'impervious', 'bare_soil', 'water', |
| | 'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous', |
| | 'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse' |
| | ] |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='ISDNet Inference on FLAIR') |
| | parser.add_argument('--checkpoint', type=str, default='isdnet_flair_best.pth', |
| | help='Path to model checkpoint') |
| | parser.add_argument('--data-root', type=str, default=DATA_ROOT, |
| | help='Path to FLAIR dataset') |
| | parser.add_argument('--split', type=str, default='test', choices=['valid', 'test'], |
| | help='Dataset split to evaluate') |
| | parser.add_argument('--batch-size', type=int, default=16, |
| | help='Batch size for inference') |
| | parser.add_argument('--num-workers', type=int, default=4, |
| | help='Number of data loading workers') |
| | parser.add_argument('--output', type=str, default='', |
| | help='Output JSON file for results (default: auto-generated)') |
| | return parser.parse_args() |
| |
|
| |
|
| | @torch.no_grad() |
| | def evaluate(model, loader, num_classes, device): |
| | """Evaluate model and compute per-class IoU.""" |
| | model.eval() |
| | criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
| |
|
| | total_loss = 0 |
| | intersection = torch.zeros(num_classes, device=device) |
| | union = torch.zeros(num_classes, device=device) |
| | total_correct = 0 |
| | total_pixels = 0 |
| |
|
| | for imgs, masks in tqdm(loader, desc="Evaluating"): |
| | imgs = imgs.to(device, non_blocking=True) |
| | masks = masks.to(device, non_blocking=True) |
| |
|
| | outputs = model(imgs, return_loss=False) |
| | total_loss += criterion(outputs, masks).item() |
| |
|
| | preds = outputs.argmax(dim=1) |
| | valid_mask = (masks != IGNORE_INDEX) |
| |
|
| | |
| | total_correct += ((preds == masks) & valid_mask).sum().item() |
| | total_pixels += valid_mask.sum().item() |
| |
|
| | |
| | for cls in range(num_classes): |
| | cls_pred = (preds == cls) & valid_mask |
| | cls_gt = (masks == cls) |
| | intersection[cls] += (cls_pred & cls_gt).sum() |
| | union[cls] += (cls_pred | cls_gt).sum() |
| |
|
| | |
| | avg_loss = total_loss / len(loader) |
| | pixel_acc = total_correct / total_pixels if total_pixels > 0 else 0 |
| | class_iou = intersection / (union + 1e-6) |
| | miou = class_iou.mean().item() |
| |
|
| | return { |
| | 'loss': avg_loss, |
| | 'miou': miou, |
| | 'pixel_accuracy': pixel_acc, |
| | 'class_iou': {CLASS_NAMES[i]: class_iou[i].item() for i in range(num_classes)} |
| | } |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Device: {device}") |
| |
|
| | |
| | print(f"Loading checkpoint: {args.checkpoint}") |
| | model = ISDNet( |
| | num_classes=NUM_CLASSES, |
| | backbone='resnet18', |
| | ch=128, |
| | down_ratio=DOWN_RATIO, |
| | pretrained=False |
| | ).to(device) |
| |
|
| | checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | print(f"Loaded model from epoch {checkpoint.get('epoch', '?')} " |
| | f"with val mIoU {checkpoint.get('miou', 0):.4f}") |
| |
|
| | |
| | print(f"\nLoading {args.split} dataset from {args.data_root}") |
| | dataset = FLAIRDataset( |
| | args.data_root, |
| | split=args.split, |
| | crop_size=CROP_SIZE, |
| | augment=False |
| | ) |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=args.num_workers, |
| | pin_memory=True |
| | ) |
| | print(f"Samples: {len(dataset)}, Batches: {len(loader)}") |
| |
|
| | |
| | print(f"\nEvaluating on {args.split} set...") |
| | results = evaluate(model, loader, NUM_CLASSES, device) |
| |
|
| | |
| | print(f"\n{'='*50}") |
| | print(f"Results on {args.split} set:") |
| | print(f"{'='*50}") |
| | print(f"Loss: {results['loss']:.4f}") |
| | print(f"mIoU: {results['miou']*100:.2f}%") |
| | print(f"Pixel Accuracy: {results['pixel_accuracy']*100:.2f}%") |
| | print(f"\nPer-class IoU:") |
| | for name, iou in sorted(results['class_iou'].items(), key=lambda x: -x[1]): |
| | print(f" {name:15s}: {iou*100:5.1f}%") |
| |
|
| | |
| | output_file = args.output or f"isdnet_flair_{args.split}_results.json" |
| | with open(output_file, 'w') as f: |
| | json.dump(results, f, indent=2) |
| | print(f"\nResults saved to {output_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|