File size: 4,620 Bytes
fa5bb00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
nnU-Net prediction script that loads the trained model checkpoint
and generates predictions compatible with evaluate.py.

Usage:
    python baselines/predict_nnunet.py \
        --nnunet_results /workspace/data/nnUNet_results \
        --test_dir data/flat_test \
        --output_dir results/nnunet \
        --num_samples 16
"""
import argparse
import os
import sys
import glob
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm


def load_nnunet_model(nnunet_results, fold=0):
    """Load nnU-Net model from checkpoint."""
    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

    # Find the trainer folder
    dataset_dir = os.path.join(nnunet_results, "Dataset001_LIDC")
    trainer_dirs = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))]
    trainer_name = trainer_dirs[0]  # e.g., nnUNetTrainer200epochs__nnUNetPlans__2d
    model_folder = os.path.join(dataset_dir, trainer_name)

    print(f"Loading nnU-Net model from: {model_folder}")

    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=False,
    )
    predictor.initialize_from_trained_model_folder(
        model_folder,
        use_folds=(fold,),
        checkpoint_name='checkpoint_best.pth',
    )
    return predictor


def predict_single_image(predictor, image_path):
    """Run nnU-Net prediction on a single image and return binary mask."""
    # Load image as numpy array
    img = np.array(Image.open(image_path).convert('L'))  # (H, W)
    # nnU-Net expects (C, Z, Y, X) for 3D or is handled internally
    # For 2D nnU-Net, we need to pass the image properly
    # Use the predictor's predict_single_npy_array method
    img_input = img[np.newaxis, np.newaxis].astype(np.float32)  # (1, 1, H, W)

    # nnU-Net properties dict for 2D data
    props = {
        'spacing': [999.0, 1.0, 1.0],
        'shape_after_cropping': img_input.shape[1:],
        'bbox_used_for_cropping': [[0, s] for s in img_input.shape[1:]],
        'shape_before_cropping': img_input.shape[1:],
        'class_locations': {},
    }

    # Run prediction
    prediction = predictor.predict_single_npy_array(
        input_image=img_input[0],  # (1, H, W) - single channel
        image_properties=props,
        segmentation_previous_stage=None,
        output_file_or_dir=None,
        save_or_return_probabilities=False,
    )

    # prediction is a numpy array with segmentation labels
    binary_mask = (prediction > 0).astype(np.uint8)

    # Remove z dimension if present
    if binary_mask.ndim == 3:
        binary_mask = binary_mask[0]

    return binary_mask


def main():
    parser = argparse.ArgumentParser(description="nnU-Net Prediction")
    parser.add_argument("--nnunet_results", type=str, required=True)
    parser.add_argument("--test_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--num_samples", type=int, default=16)
    parser.add_argument("--fold", type=int, default=0)
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # Load model
    predictor = load_nnunet_model(args.nnunet_results, args.fold)

    # Get test images
    image_files = sorted(glob.glob(os.path.join(args.test_dir, "*.png")))
    # Filter to only images (not masks)
    image_files = [f for f in image_files if "_mask" not in os.path.basename(f)]
    print(f"Test dataset: {len(image_files)} images from {args.test_dir}")

    total_saved = 0
    for img_path in tqdm(image_files, desc="Predicting"):
        basename = os.path.splitext(os.path.basename(img_path))[0]

        # Get prediction
        mask = predict_single_image(predictor, img_path)

        # Save N replicated samples (deterministic model = same prediction)
        for s in range(args.num_samples):
            out_path = os.path.join(args.output_dir, f"{basename}_sample{s:02d}.png")
            Image.fromarray(mask * 255).save(out_path)
            total_saved += 1

    print(f"\nSaved {len(image_files)} predictions × {args.num_samples} samples = {total_saved} files")
    print(f"\nPredictions saved to {args.output_dir}")
    print(f"Ready for evaluation:")
    print(f"  python evaluate.py --samples_dir {args.output_dir} --gt_dir data/testing --results_file results/nnunet_eval.csv")


if __name__ == "__main__":
    main()