Spaces:
Sleeping
Sleeping
| """Run single-case PanCancerSeg nnUNet CT inference and visualization.""" | |
| import argparse | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import SimpleITK as sitk | |
| import torch | |
| from visualize import generate_outputs | |
| CANCER_CONFIGS = { | |
| "kidney_cancer": { | |
| "dataset_id": 102, | |
| "dataset_name": "Dataset102_Kidney", | |
| "display_name": "Kidney cancer", | |
| "wl": 40, | |
| "ww": 400, | |
| "color": (255, 0, 0), | |
| }, | |
| "liver_cancer": { | |
| "dataset_id": 103, | |
| "dataset_name": "Dataset103_Liver", | |
| "display_name": "Liver cancer", | |
| "wl": 40, | |
| "ww": 400, | |
| "color": (255, 0, 0), | |
| }, | |
| "pancreatic_cancer": { | |
| "dataset_id": 104, | |
| "dataset_name": "Dataset104_Pancreas", | |
| "display_name": "Pancreatic cancer", | |
| "wl": 40, | |
| "ww": 400, | |
| "color": (255, 0, 0), | |
| }, | |
| "lung_cancer": { | |
| "dataset_id": 105, | |
| "dataset_name": "Dataset105_Lung", | |
| "display_name": "Lung cancer", | |
| "wl": -600, | |
| "ww": 1500, | |
| "color": (255, 0, 0), | |
| }, | |
| } | |
| CANCER_TYPE_ALIASES = { | |
| "kidney": "kidney_cancer", | |
| "liver": "liver_cancer", | |
| "pancreas": "pancreatic_cancer", | |
| "lung": "lung_cancer", | |
| } | |
| TRAINER_NAME = "nnUNetTrainerWandb2000" | |
| PLANS_NAME = "nnUNetResEncUNetMPlans" | |
| CONFIGURATION = "3d_fullres" | |
| CHECKPOINT_NAME = "checkpoint_best.pth" | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Run one PanCancerSeg cancer-specific nnUNet model on a single NIfTI image." | |
| ) | |
| parser.add_argument("--input", required=True, help="Path to a single .nii.gz image") | |
| parser.add_argument( | |
| "--cancer_type", | |
| required=True, | |
| help=( | |
| "Cancer-specific model to use. " | |
| f"Canonical values: {', '.join(sorted(CANCER_CONFIGS))}. " | |
| f"Legacy aliases still accepted: {', '.join(sorted(CANCER_TYPE_ALIASES))}." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--model_dir", | |
| required=True, | |
| help="Path to nnUNet results directory containing DatasetXXX_* folders", | |
| ) | |
| parser.add_argument("--output_dir", default="./output", help="Where to save results") | |
| parser.add_argument("--fps", type=int, default=10, help="Video frames per second") | |
| parser.add_argument("--device", choices=["cuda", "cpu"], default="cuda") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| args.cancer_type = normalize_cancer_type(args.cancer_type) | |
| input_path = Path(args.input).expanduser().resolve() | |
| model_dir = Path(args.model_dir).expanduser().resolve() | |
| output_dir = Path(args.output_dir).expanduser().resolve() | |
| if not input_path.exists(): | |
| raise FileNotFoundError(f"Input image not found: {input_path}") | |
| if input_path.name.startswith("._") or not input_path.name.endswith(".nii.gz"): | |
| raise ValueError(f"Expected a .nii.gz image, got: {input_path.name}") | |
| if not model_dir.exists(): | |
| raise FileNotFoundError(f"Model directory not found: {model_dir}") | |
| if args.device == "cuda" and not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "CUDA was requested but torch.cuda.is_available() is False. " | |
| "Use --device cpu or install CUDA-enabled PyTorch." | |
| ) | |
| if args.fps <= 0: | |
| raise ValueError("--fps must be a positive integer") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| config = CANCER_CONFIGS[args.cancer_type] | |
| case_id = resolve_case_id(input_path) | |
| install_custom_trainer() | |
| model_folder = resolve_model_folder(model_dir, config["dataset_name"]) | |
| with tempfile.TemporaryDirectory(prefix="pancancerseg_") as tmp: | |
| tmp_path = Path(tmp) | |
| tmp_input_dir = tmp_path / "input" | |
| tmp_output_dir = tmp_path / "prediction" | |
| tmp_input_dir.mkdir() | |
| tmp_output_dir.mkdir() | |
| nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz" | |
| symlink_or_copy(input_path, nnunet_input) | |
| run_nnunet_prediction( | |
| model_folder=model_folder, | |
| input_dir=tmp_input_dir, | |
| output_dir=tmp_output_dir, | |
| device=args.device, | |
| ) | |
| raw_seg = tmp_output_dir / f"{case_id}.nii.gz" | |
| if not raw_seg.exists(): | |
| produced = sorted(tmp_output_dir.glob("*.nii.gz")) | |
| raise FileNotFoundError( | |
| f"nnUNet did not write the expected segmentation {raw_seg}. " | |
| f"Found: {[p.name for p in produced]}" | |
| ) | |
| seg_path = output_dir / f"{case_id}_seg.nii.gz" | |
| shutil.copy2(raw_seg, seg_path) | |
| viz_outputs = generate_outputs( | |
| image_path=input_path, | |
| mask_path=seg_path, | |
| output_dir=output_dir, | |
| case_name=case_id, | |
| cancer_type=config["display_name"], | |
| wl=config["wl"], | |
| ww=config["ww"], | |
| color=config["color"], | |
| alpha=0.5, | |
| fps=args.fps, | |
| ) | |
| positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path) | |
| print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml) | |
| def resolve_case_id(input_path): | |
| name = input_path.name | |
| if not name.endswith(".nii.gz"): | |
| raise ValueError(f"Expected a .nii.gz image, got: {name}") | |
| case_id = name[: -len(".nii.gz")] | |
| if case_id.endswith("_0000"): | |
| case_id = case_id[: -len("_0000")] | |
| if not case_id: | |
| raise ValueError(f"Could not resolve a case ID from: {input_path}") | |
| return case_id | |
| def normalize_cancer_type(cancer_type): | |
| cancer_type = cancer_type.strip().lower() | |
| normalized = CANCER_TYPE_ALIASES.get(cancer_type, cancer_type) | |
| if normalized not in CANCER_CONFIGS: | |
| valid = sorted(list(CANCER_CONFIGS) + list(CANCER_TYPE_ALIASES)) | |
| raise ValueError( | |
| f"Unsupported --cancer_type '{cancer_type}'. Valid values: {', '.join(valid)}" | |
| ) | |
| return normalized | |
| def install_custom_trainer(): | |
| import nnunetv2 | |
| src = Path(__file__).resolve().parent / "trainers" / f"{TRAINER_NAME}.py" | |
| if not src.exists(): | |
| raise FileNotFoundError(f"Custom trainer file is missing: {src}") | |
| variants_dir = Path(nnunetv2.__path__[0]) / "training" / "nnUNetTrainer" / "variants" | |
| variants_dir.mkdir(parents=True, exist_ok=True) | |
| dst = variants_dir / src.name | |
| if dst.exists() or dst.is_symlink(): | |
| try: | |
| if dst.resolve() == src.resolve(): | |
| return dst | |
| except OSError: | |
| pass | |
| dst.unlink() | |
| try: | |
| dst.symlink_to(src.resolve()) | |
| except (OSError, NotImplementedError): | |
| shutil.copy2(src, dst) | |
| print(f"Installed custom trainer: {dst}") | |
| return dst | |
| def resolve_model_folder(model_dir, dataset_name): | |
| model_folder = ( | |
| model_dir | |
| / dataset_name | |
| / f"{TRAINER_NAME}__{PLANS_NAME}__{CONFIGURATION}" | |
| ) | |
| checkpoint = model_folder / "fold_0" / CHECKPOINT_NAME | |
| if not checkpoint.exists(): | |
| raise FileNotFoundError( | |
| f"Expected checkpoint not found: {checkpoint}. " | |
| "Check --model_dir and make sure the trained weights are downloaded." | |
| ) | |
| return model_folder | |
| def symlink_or_copy(src, dst): | |
| try: | |
| dst.symlink_to(src.resolve()) | |
| except (OSError, NotImplementedError): | |
| shutil.copy2(src, dst) | |
| def run_nnunet_prediction(model_folder, input_dir, output_dir, device): | |
| from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor | |
| predictor = nnUNetPredictor( | |
| tile_step_size=0.5, | |
| use_gaussian=True, | |
| use_mirroring=False, | |
| perform_everything_on_device=(device == "cuda"), | |
| device=torch.device(device), | |
| verbose=False, | |
| verbose_preprocessing=False, | |
| allow_tqdm=True, | |
| ) | |
| predictor.initialize_from_trained_model_folder( | |
| str(model_folder), | |
| use_folds=(0,), | |
| checkpoint_name=CHECKPOINT_NAME, | |
| ) | |
| predictor.predict_from_files( | |
| str(input_dir), | |
| str(output_dir), | |
| save_probabilities=False, | |
| overwrite=True, | |
| num_processes_preprocessing=1, | |
| num_processes_segmentation_export=1, | |
| folder_with_segs_from_prev_stage=None, | |
| num_parts=1, | |
| part_id=0, | |
| ) | |
| def run_nnunet_prediction_single(model_folder, input_file, output_file, device): | |
| """Single-case nnUNet inference without any multiprocessing. | |
| Uses nnUNet's in-process API so it works inside restricted environments | |
| (e.g. ZeroGPU daemon workers) where spawning child processes is forbidden. | |
| """ | |
| from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor | |
| from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO | |
| predictor = nnUNetPredictor( | |
| tile_step_size=0.5, | |
| use_gaussian=True, | |
| use_mirroring=False, | |
| perform_everything_on_device=(device == "cuda"), | |
| device=torch.device(device), | |
| verbose=False, | |
| verbose_preprocessing=False, | |
| allow_tqdm=True, | |
| ) | |
| predictor.initialize_from_trained_model_folder( | |
| str(model_folder), | |
| use_folds=(0,), | |
| checkpoint_name=CHECKPOINT_NAME, | |
| ) | |
| io = SimpleITKIO() | |
| images, properties = io.read_images([str(input_file)]) | |
| segmentation = predictor.predict_single_npy_array( | |
| images, properties, None, None, False | |
| ) | |
| io.write_seg(segmentation, str(output_file), properties) | |
| def summarize_segmentation(seg_path): | |
| seg = sitk.ReadImage(str(seg_path)) | |
| seg_arr = sitk.GetArrayFromImage(seg) | |
| positive_voxels = int(np.count_nonzero(seg_arr)) | |
| spacing_x, spacing_y, spacing_z = seg.GetSpacing() | |
| tumor_volume_ml = positive_voxels * spacing_x * spacing_y * spacing_z / 1000.0 | |
| return positive_voxels, tumor_volume_ml | |
| def print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml): | |
| print("\nPanCancerSeg inference complete") | |
| print(f"Segmentation mask : {seg_path}") | |
| print("Slice PNGs :") | |
| for label, path in viz_outputs["slices"].items(): | |
| print(f" {label:9s} : {path}") | |
| print(f"Overlay video : {viz_outputs['video']}") | |
| print(f"Positive voxels : {positive_voxels}") | |
| print(f"Tumor volume : {tumor_volume_ml:.3f} mL") | |
| if __name__ == "__main__": | |
| main() | |