#!/usr/bin/env python3 import os import argparse import numpy as np import torch from pathlib import Path from batchgenerators.utilities.file_and_folder_operations import join, isdir, maybe_mkdir_p from nnInteractive.inference.inference_session import nnInteractiveInferenceSession def run_inference( model_dir: str, input_image: np.ndarray, device: str = "cuda", fold: int = None, checkpoint_name: str = "checkpoint_final.pth", use_torch_compile: bool = False, do_autozoom: bool = True, verbose: bool = False, output_file: str = None ): """ Run inference with the nnInteractiveInferenceSession Args: model_dir: Path to the trained model directory input_image: 4D image data [c, x, y, z] device: Device to run inference on fold: Fold to use for inference checkpoint_name: Name of the checkpoint file to use use_torch_compile: Whether to use torch.compile do_autozoom: Whether to use auto-zooming verbose: Whether to print verbose output output_file: Path to save the output segmentation Returns: The inference session object that can be used for interactive segmentation """ # Create inference session inference_session = nnInteractiveInferenceSession( device=torch.device(device), use_torch_compile=use_torch_compile, verbose=verbose, do_autozoom=do_autozoom ) # Initialize from model inference_session.initialize_from_trained_model_folder( model_dir, use_fold=fold, checkpoint_name=checkpoint_name ) # Set the input image inference_session.set_image(input_image) # Initialize target buffer to store the result target_shape = input_image.shape[1:] target_buffer = np.zeros(target_shape, dtype=np.uint8) inference_session.set_target_buffer(target_buffer) print(f"Initialized inference session from {model_dir}") print(f"Input image shape: {input_image.shape}") # Return the initialized session for interactive use return inference_session def demonstrate_interaction(inference_session, input_image): """ Demonstrate the different types of interactions Args: inference_session: The initialized inference session input_image: The input image """ shape = input_image.shape[1:] # Create a basic bounding box around the center of the image center = [s // 2 for s in shape] box_size = [s // 4 for s in shape] bbox_coords = [[c - s, c + s] for c, s in zip(center, box_size)] print("Adding bounding box interaction...") inference_session.add_bbox_interaction(bbox_coords, include_interaction=True) # Add a point interaction near the center point_coords = [c + np.random.randint(-5, 5) for c in center] print("Adding point interaction...") inference_session.add_point_interaction(point_coords, include_interaction=True) # Create a simple scribble scribble_image = np.zeros(shape, dtype=np.float32) scribble_start = [c - s // 2 for c, s in zip(center, box_size)] scribble_end = [c + s // 2 for c, s in zip(center, box_size)] # Create a line along one dimension for i in range(scribble_start[0], scribble_end[0]): coords = [i, center[1], center[2]] scribble_image[coords[0], coords[1], coords[2]] = 1 print("Adding scribble interaction...") inference_session.add_scribble_interaction(scribble_image, include_interaction=True) print("Inference with interactions complete!") return inference_session.target_buffer def main(): parser = argparse.ArgumentParser(description="Run interactive inference with trained model") parser.add_argument("model_dir", type=str, help="Path to trained model directory") parser.add_argument("--input_file", type=str, help="Path to input image file (numpy array)") parser.add_argument("--fold", type=int, default=None, help="Fold to use") parser.add_argument("--checkpoint", type=str, default="checkpoint_final.pth", help="Checkpoint name") parser.add_argument("--device", type=str, default="cuda", help="Device to use") parser.add_argument("--no_autozoom", action="store_false", dest="do_autozoom", help="Disable auto-zooming") parser.add_argument("--verbose", action="store_true", help="Enable verbose output") parser.add_argument("--output_file", type=str, help="Path to save output segmentation") parser.add_argument("--demo", action="store_true", help="Run a demo with sample interactions") args = parser.parse_args() # Check if model directory exists if not isdir(args.model_dir): raise ValueError(f"Model directory {args.model_dir} does not exist") # Create dummy input data if no input file is provided if args.input_file is None: print("No input file provided, creating dummy data...") # Create a 4D dummy image with a sphere in the center shape = (1, 128, 128, 128) input_image = np.zeros(shape, dtype=np.float32) # Create a sphere center = np.array([64, 64, 64]) radius = 20 x, y, z = np.ogrid[:128, :128, :128] dist = np.sqrt((x - center[0])**2 + (y - center[1])**2 + (z - center[2])**2) input_image[0, dist <= radius] = 1.0 else: # Load input image input_image = np.load(args.input_file) # Ensure 4D input if input_image.ndim == 3: input_image = input_image[np.newaxis] assert input_image.ndim == 4, f"Input image must be 4D, got shape {input_image.shape}" # Run inference inference_session = run_inference( model_dir=args.model_dir, input_image=input_image, device=args.device, fold=args.fold, checkpoint_name=args.checkpoint, do_autozoom=args.do_autozoom, verbose=args.verbose, output_file=args.output_file ) # Run demo if requested if args.demo: segmentation = demonstrate_interaction(inference_session, input_image) # Save output if requested if args.output_file: output_dir = os.path.dirname(args.output_file) if output_dir: maybe_mkdir_p(output_dir) np.save(args.output_file, segmentation) print(f"Saved segmentation to {args.output_file}") else: print("Inference session initialized and ready for interaction") print("You can now add interactions using the inference_session object") return inference_session if __name__ == "__main__": inference_session = main()