ISDNet-pytorch / inference.py
Antoine1091's picture
Upload folder using huggingface_hub
49d2955 verified
#!/usr/bin/env python3
"""
ISDNet Inference Script for FLAIR Dataset
Evaluates a trained model on the test set with per-class IoU metrics.
Usage:
python inference.py --checkpoint isdnet_flair_best.pth
python inference.py --checkpoint isdnet_flair_best.pth --split valid
"""
import argparse
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from isdnet import ISDNet, FLAIRDataset, NUM_CLASSES, CROP_SIZE, DOWN_RATIO, IGNORE_INDEX
from isdnet.config import DATA_ROOT
# FLAIR class names
CLASS_NAMES = [
'building', 'pervious', 'impervious', 'bare_soil', 'water',
'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous',
'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse'
]
def parse_args():
parser = argparse.ArgumentParser(description='ISDNet Inference on FLAIR')
parser.add_argument('--checkpoint', type=str, default='isdnet_flair_best.pth',
help='Path to model checkpoint')
parser.add_argument('--data-root', type=str, default=DATA_ROOT,
help='Path to FLAIR dataset')
parser.add_argument('--split', type=str, default='test', choices=['valid', 'test'],
help='Dataset split to evaluate')
parser.add_argument('--batch-size', type=int, default=16,
help='Batch size for inference')
parser.add_argument('--num-workers', type=int, default=4,
help='Number of data loading workers')
parser.add_argument('--output', type=str, default='',
help='Output JSON file for results (default: auto-generated)')
return parser.parse_args()
@torch.no_grad()
def evaluate(model, loader, num_classes, device):
"""Evaluate model and compute per-class IoU."""
model.eval()
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
total_loss = 0
intersection = torch.zeros(num_classes, device=device)
union = torch.zeros(num_classes, device=device)
total_correct = 0
total_pixels = 0
for imgs, masks in tqdm(loader, desc="Evaluating"):
imgs = imgs.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True)
outputs = model(imgs, return_loss=False)
total_loss += criterion(outputs, masks).item()
preds = outputs.argmax(dim=1)
valid_mask = (masks != IGNORE_INDEX)
# Pixel accuracy
total_correct += ((preds == masks) & valid_mask).sum().item()
total_pixels += valid_mask.sum().item()
# Per-class IoU
for cls in range(num_classes):
cls_pred = (preds == cls) & valid_mask
cls_gt = (masks == cls)
intersection[cls] += (cls_pred & cls_gt).sum()
union[cls] += (cls_pred | cls_gt).sum()
# Compute metrics
avg_loss = total_loss / len(loader)
pixel_acc = total_correct / total_pixels if total_pixels > 0 else 0
class_iou = intersection / (union + 1e-6)
miou = class_iou.mean().item()
return {
'loss': avg_loss,
'miou': miou,
'pixel_accuracy': pixel_acc,
'class_iou': {CLASS_NAMES[i]: class_iou[i].item() for i in range(num_classes)}
}
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Load model
print(f"Loading checkpoint: {args.checkpoint}")
model = ISDNet(
num_classes=NUM_CLASSES,
backbone='resnet18',
ch=128,
down_ratio=DOWN_RATIO,
pretrained=False
).to(device)
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint.get('epoch', '?')} "
f"with val mIoU {checkpoint.get('miou', 0):.4f}")
# Dataset
print(f"\nLoading {args.split} dataset from {args.data_root}")
dataset = FLAIRDataset(
args.data_root,
split=args.split,
crop_size=CROP_SIZE,
augment=False
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True
)
print(f"Samples: {len(dataset)}, Batches: {len(loader)}")
# Evaluate
print(f"\nEvaluating on {args.split} set...")
results = evaluate(model, loader, NUM_CLASSES, device)
# Print results
print(f"\n{'='*50}")
print(f"Results on {args.split} set:")
print(f"{'='*50}")
print(f"Loss: {results['loss']:.4f}")
print(f"mIoU: {results['miou']*100:.2f}%")
print(f"Pixel Accuracy: {results['pixel_accuracy']*100:.2f}%")
print(f"\nPer-class IoU:")
for name, iou in sorted(results['class_iou'].items(), key=lambda x: -x[1]):
print(f" {name:15s}: {iou*100:5.1f}%")
# Save results
output_file = args.output or f"isdnet_flair_{args.split}_results.json"
with open(output_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_file}")
if __name__ == "__main__":
main()