File size: 4,620 Bytes
fa5bb00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
nnU-Net prediction script that loads the trained model checkpoint
and generates predictions compatible with evaluate.py.
Usage:
python baselines/predict_nnunet.py \
--nnunet_results /workspace/data/nnUNet_results \
--test_dir data/flat_test \
--output_dir results/nnunet \
--num_samples 16
"""
import argparse
import os
import sys
import glob
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
def load_nnunet_model(nnunet_results, fold=0):
"""Load nnU-Net model from checkpoint."""
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
# Find the trainer folder
dataset_dir = os.path.join(nnunet_results, "Dataset001_LIDC")
trainer_dirs = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))]
trainer_name = trainer_dirs[0] # e.g., nnUNetTrainer200epochs__nnUNetPlans__2d
model_folder = os.path.join(dataset_dir, trainer_name)
print(f"Loading nnU-Net model from: {model_folder}")
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=False,
)
predictor.initialize_from_trained_model_folder(
model_folder,
use_folds=(fold,),
checkpoint_name='checkpoint_best.pth',
)
return predictor
def predict_single_image(predictor, image_path):
"""Run nnU-Net prediction on a single image and return binary mask."""
# Load image as numpy array
img = np.array(Image.open(image_path).convert('L')) # (H, W)
# nnU-Net expects (C, Z, Y, X) for 3D or is handled internally
# For 2D nnU-Net, we need to pass the image properly
# Use the predictor's predict_single_npy_array method
img_input = img[np.newaxis, np.newaxis].astype(np.float32) # (1, 1, H, W)
# nnU-Net properties dict for 2D data
props = {
'spacing': [999.0, 1.0, 1.0],
'shape_after_cropping': img_input.shape[1:],
'bbox_used_for_cropping': [[0, s] for s in img_input.shape[1:]],
'shape_before_cropping': img_input.shape[1:],
'class_locations': {},
}
# Run prediction
prediction = predictor.predict_single_npy_array(
input_image=img_input[0], # (1, H, W) - single channel
image_properties=props,
segmentation_previous_stage=None,
output_file_or_dir=None,
save_or_return_probabilities=False,
)
# prediction is a numpy array with segmentation labels
binary_mask = (prediction > 0).astype(np.uint8)
# Remove z dimension if present
if binary_mask.ndim == 3:
binary_mask = binary_mask[0]
return binary_mask
def main():
parser = argparse.ArgumentParser(description="nnU-Net Prediction")
parser.add_argument("--nnunet_results", type=str, required=True)
parser.add_argument("--test_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--num_samples", type=int, default=16)
parser.add_argument("--fold", type=int, default=0)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Load model
predictor = load_nnunet_model(args.nnunet_results, args.fold)
# Get test images
image_files = sorted(glob.glob(os.path.join(args.test_dir, "*.png")))
# Filter to only images (not masks)
image_files = [f for f in image_files if "_mask" not in os.path.basename(f)]
print(f"Test dataset: {len(image_files)} images from {args.test_dir}")
total_saved = 0
for img_path in tqdm(image_files, desc="Predicting"):
basename = os.path.splitext(os.path.basename(img_path))[0]
# Get prediction
mask = predict_single_image(predictor, img_path)
# Save N replicated samples (deterministic model = same prediction)
for s in range(args.num_samples):
out_path = os.path.join(args.output_dir, f"{basename}_sample{s:02d}.png")
Image.fromarray(mask * 255).save(out_path)
total_saved += 1
print(f"\nSaved {len(image_files)} predictions × {args.num_samples} samples = {total_saved} files")
print(f"\nPredictions saved to {args.output_dir}")
print(f"Ready for evaluation:")
print(f" python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/nnunet_eval.csv")
if __name__ == "__main__":
main()
|