#!/usr/bin/env python3 """ ISDNet Inference Script for FLAIR Dataset Evaluates a trained model on the test set with per-class IoU metrics. Usage: python inference.py --checkpoint isdnet_flair_best.pth python inference.py --checkpoint isdnet_flair_best.pth --split valid """ import argparse import json import torch import torch.nn as nn from torch.utils.data import DataLoader from tqdm import tqdm from isdnet import ISDNet, FLAIRDataset, NUM_CLASSES, CROP_SIZE, DOWN_RATIO, IGNORE_INDEX from isdnet.config import DATA_ROOT # FLAIR class names CLASS_NAMES = [ 'building', 'pervious', 'impervious', 'bare_soil', 'water', 'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous', 'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse' ] def parse_args(): parser = argparse.ArgumentParser(description='ISDNet Inference on FLAIR') parser.add_argument('--checkpoint', type=str, default='isdnet_flair_best.pth', help='Path to model checkpoint') parser.add_argument('--data-root', type=str, default=DATA_ROOT, help='Path to FLAIR dataset') parser.add_argument('--split', type=str, default='test', choices=['valid', 'test'], help='Dataset split to evaluate') parser.add_argument('--batch-size', type=int, default=16, help='Batch size for inference') parser.add_argument('--num-workers', type=int, default=4, help='Number of data loading workers') parser.add_argument('--output', type=str, default='', help='Output JSON file for results (default: auto-generated)') return parser.parse_args() @torch.no_grad() def evaluate(model, loader, num_classes, device): """Evaluate model and compute per-class IoU.""" model.eval() criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) total_loss = 0 intersection = torch.zeros(num_classes, device=device) union = torch.zeros(num_classes, device=device) total_correct = 0 total_pixels = 0 for imgs, masks in tqdm(loader, desc="Evaluating"): imgs = imgs.to(device, non_blocking=True) masks = masks.to(device, non_blocking=True) outputs = model(imgs, return_loss=False) total_loss += criterion(outputs, masks).item() preds = outputs.argmax(dim=1) valid_mask = (masks != IGNORE_INDEX) # Pixel accuracy total_correct += ((preds == masks) & valid_mask).sum().item() total_pixels += valid_mask.sum().item() # Per-class IoU for cls in range(num_classes): cls_pred = (preds == cls) & valid_mask cls_gt = (masks == cls) intersection[cls] += (cls_pred & cls_gt).sum() union[cls] += (cls_pred | cls_gt).sum() # Compute metrics avg_loss = total_loss / len(loader) pixel_acc = total_correct / total_pixels if total_pixels > 0 else 0 class_iou = intersection / (union + 1e-6) miou = class_iou.mean().item() return { 'loss': avg_loss, 'miou': miou, 'pixel_accuracy': pixel_acc, 'class_iou': {CLASS_NAMES[i]: class_iou[i].item() for i in range(num_classes)} } def main(): args = parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Load model print(f"Loading checkpoint: {args.checkpoint}") model = ISDNet( num_classes=NUM_CLASSES, backbone='resnet18', ch=128, down_ratio=DOWN_RATIO, pretrained=False ).to(device) checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from epoch {checkpoint.get('epoch', '?')} " f"with val mIoU {checkpoint.get('miou', 0):.4f}") # Dataset print(f"\nLoading {args.split} dataset from {args.data_root}") dataset = FLAIRDataset( args.data_root, split=args.split, crop_size=CROP_SIZE, augment=False ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) print(f"Samples: {len(dataset)}, Batches: {len(loader)}") # Evaluate print(f"\nEvaluating on {args.split} set...") results = evaluate(model, loader, NUM_CLASSES, device) # Print results print(f"\n{'='*50}") print(f"Results on {args.split} set:") print(f"{'='*50}") print(f"Loss: {results['loss']:.4f}") print(f"mIoU: {results['miou']*100:.2f}%") print(f"Pixel Accuracy: {results['pixel_accuracy']*100:.2f}%") print(f"\nPer-class IoU:") for name, iou in sorted(results['class_iou'].items(), key=lambda x: -x[1]): print(f" {name:15s}: {iou*100:5.1f}%") # Save results output_file = args.output or f"isdnet_flair_{args.split}_results.json" with open(output_file, 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved to {output_file}") if __name__ == "__main__": main()