""" Shared driver for the upstream NV-Generate-CTMR scripts. Strategy: the upstream code is structured around argparse + global filesystem layout (reads relative-path configs, writes to a relative output_dir). Rather than refactor its internals, we treat it as an external tool: chdir into the upstream root, write modified config copies that override the user-controlled fields, ensure weights are downloaded, then call the upstream entry-point function. We then return the path of the most recently produced NIfTI from the configured output dir. """ from __future__ import annotations import contextlib import importlib import json import os import sys import time import uuid from pathlib import Path from typing import Iterable, Optional ROOT = Path(__file__).resolve().parent.parent UPSTREAM = ROOT / "repos" / "NV-Generate-CTMR" GENERATED_OUTPUT = UPSTREAM / "output" @contextlib.contextmanager def upstream_context(): """Temporarily add the upstream repo to sys.path and switch CWD to it.""" if not UPSTREAM.exists(): raise RuntimeError( f"Upstream repo not found at {UPSTREAM}. Run `bash pre-build.sh` first." ) prev_cwd = os.getcwd() added = False try: upstream_str = str(UPSTREAM) if upstream_str not in sys.path: sys.path.insert(0, upstream_str) added = True os.chdir(upstream_str) yield UPSTREAM finally: os.chdir(prev_cwd) if added and upstream_str in sys.path: sys.path.remove(upstream_str) def ensure_weights(version: str) -> None: """Download model weights from HF Hub if not already on disk.""" with upstream_context(): download_mod = importlib.import_module("scripts.download_model_data") # download_model_data is idempotent — it skips files that already exist. download_mod.download_model_data(version, "./", model_only=False if version in ("rflow-ct", "ddpm-ct") else True) def _list_outputs_before(output_dir: Path) -> set[str]: if not output_dir.exists(): return set() return {p.name for p in output_dir.glob("*.nii.gz")} def _newest_outputs(output_dir: Path, before: set[str]) -> list[Path]: if not output_dir.exists(): return [] new = [p for p in output_dir.glob("*.nii.gz") if p.name not in before] new.sort(key=lambda p: p.stat().st_mtime) return new def _write_temp_configs( base_env_config: Path, base_model_config: Path, overrides: dict, tag: str, ) -> tuple[Path, Path]: """ Write modified copies of the env + model configs into a per-call temp dir under UPSTREAM / configs / _temp /. Returns (env_path, model_path). """ temp_dir = UPSTREAM / "configs" / "_temp" temp_dir.mkdir(parents=True, exist_ok=True) env = json.loads(base_env_config.read_text()) model = json.loads(base_model_config.read_text()) if "env" in overrides: env.update(overrides["env"]) if "diffusion_unet_inference" in overrides: model.setdefault("diffusion_unet_inference", {}).update(overrides["diffusion_unet_inference"]) suffix = f"{tag}_{uuid.uuid4().hex[:8]}" env_path = temp_dir / f"env_{suffix}.json" model_path = temp_dir / f"model_{suffix}.json" env_path.write_text(json.dumps(env, indent=2)) model_path.write_text(json.dumps(model, indent=2)) return env_path, model_path def run_image_only( *, version: str, output_size: tuple[int, int, int], spacing: tuple[float, float, float], modality: int, seed: int, num_inference_steps: int = 30, cfg_guidance_scale: Optional[float] = None, ) -> Path: """ Run the image-only diffusion pipeline (`scripts.diff_model_infer`) for the given version (rflow-ct / rflow-mr / rflow-mr-brain). Returns path to generated NIfTI. """ ensure_weights(version) base_env = UPSTREAM / "configs" / f"environment_maisi_diff_model_{version}.json" base_model = UPSTREAM / "configs" / f"config_maisi_diff_model_{version}.json" network_def = UPSTREAM / "configs" / "config_network_rflow.json" inference_overrides = { "dim": list(output_size), "spacing": list(spacing), "modality": modality, "random_seed": seed, "num_inference_steps": num_inference_steps, } if cfg_guidance_scale is not None: inference_overrides["cfg_guidance_scale"] = cfg_guidance_scale with upstream_context(): env_path, model_path = _write_temp_configs( base_env_config=base_env, base_model_config=base_model, overrides={"diffusion_unet_inference": inference_overrides}, tag=version, ) # Read env to determine output_dir (relative to upstream root) env_data = json.loads(env_path.read_text()) output_dir = (UPSTREAM / env_data["output_dir"]).resolve() existing = _list_outputs_before(output_dir) diff_mod = importlib.import_module("scripts.diff_model_infer") t0 = time.time() diff_mod.diff_model_infer( env_config_path=str(env_path.relative_to(UPSTREAM)), model_config_path=str(model_path.relative_to(UPSTREAM)), model_def_path=str(network_def.relative_to(UPSTREAM)), num_gpus=1, ) runtime = time.time() - t0 new_files = _newest_outputs(output_dir, existing) if not new_files: raise RuntimeError(f"No new NIfTI produced in {output_dir}") latest = new_files[-1] # Cleanup temp configs (don't fail if cleanup errors) for p in (env_path, model_path): try: p.unlink() except OSError: pass return latest def run_paired_ct( *, output_size: tuple[int, int, int], spacing: tuple[float, float, float], body_region: list[str], anatomy_list: list[str], seed: int, num_inference_steps: int = 30, num_output_samples: int = 1, ) -> tuple[Path, Optional[Path]]: """ Run the paired CT image+mask pipeline (`scripts.inference`). Returns (image_path, mask_path). Mask is the corresponding label volume. """ version = "rflow-ct" ensure_weights(version) base_env = UPSTREAM / "configs" / f"environment_{version}.json" base_infer = UPSTREAM / "configs" / "config_infer.json" # Build a custom config_infer with overrides infer_data = json.loads(base_infer.read_text()) infer_data["output_size"] = list(output_size) infer_data["spacing"] = list(spacing) infer_data["body_region"] = list(body_region) infer_data["anatomy_list"] = list(anatomy_list) infer_data["num_inference_steps"] = num_inference_steps infer_data["num_output_samples"] = num_output_samples temp_dir = UPSTREAM / "configs" / "_temp" temp_dir.mkdir(parents=True, exist_ok=True) suffix = uuid.uuid4().hex[:8] infer_path = temp_dir / f"config_infer_{version}_{suffix}.json" infer_path.write_text(json.dumps(infer_data, indent=2)) env_data = json.loads(base_env.read_text()) output_dir = (UPSTREAM / env_data["output_dir"]).resolve() with upstream_context(): existing = _list_outputs_before(output_dir) inference_mod = importlib.import_module("scripts.inference") # The upstream `main()` parses argv directly. Patch sys.argv around the call. old_argv = sys.argv sys.argv = [ "scripts.inference", "-t", "./configs/config_network_rflow.json", "-i", str(infer_path.relative_to(UPSTREAM)), "-e", str(base_env.relative_to(UPSTREAM)), "--random-seed", str(seed), "--version", version, ] os.environ.setdefault("MONAI_DATA_DIRECTORY", str(UPSTREAM / "temp_work_dir")) try: inference_mod.main() finally: sys.argv = old_argv new_files = _newest_outputs(output_dir, existing) try: infer_path.unlink() except OSError: pass # Paired pipeline writes both image and label NIfTIs. Convention: filenames # contain "image" / "label" or are emitted as adjacent files. image_path: Optional[Path] = None mask_path: Optional[Path] = None for p in new_files: name = p.name.lower() if "label" in name or "_mask" in name or "seg" in name: mask_path = p elif "image" in name or "img" in name: image_path = p # Fallback: if naming is ambiguous, treat the smaller-modality-time file as image if image_path is None and new_files: image_path = new_files[0] if mask_path is None and len(new_files) > 1: mask_path = new_files[-1] if image_path is None: raise RuntimeError(f"No NIfTI produced in {output_dir}") return image_path, mask_path def labels_present(mask_path: Path) -> set[int]: """Return the set of unique non-zero label IDs present in the mask volume.""" import nibabel as nib import numpy as np img = nib.load(str(mask_path)) data = np.asarray(img.dataobj) uniq = np.unique(data).astype(int).tolist() return {int(u) for u in uniq if u != 0}