nv-generate / pipelines /_runner.py
zephyrie's picture
Initial commit: NV-Generate Gradio showcase
ab1db83
"""
Shared driver for the upstream NV-Generate-CTMR scripts.
Strategy: the upstream code is structured around argparse + global filesystem layout
(reads relative-path configs, writes to a relative output_dir). Rather than refactor
its internals, we treat it as an external tool: chdir into the upstream root, write
modified config copies that override the user-controlled fields, ensure weights are
downloaded, then call the upstream entry-point function. We then return the path of
the most recently produced NIfTI from the configured output dir.
"""
from __future__ import annotations
import contextlib
import importlib
import json
import os
import sys
import time
import uuid
from pathlib import Path
from typing import Iterable, Optional
ROOT = Path(__file__).resolve().parent.parent
UPSTREAM = ROOT / "repos" / "NV-Generate-CTMR"
GENERATED_OUTPUT = UPSTREAM / "output"
@contextlib.contextmanager
def upstream_context():
"""Temporarily add the upstream repo to sys.path and switch CWD to it."""
if not UPSTREAM.exists():
raise RuntimeError(
f"Upstream repo not found at {UPSTREAM}. Run `bash pre-build.sh` first."
)
prev_cwd = os.getcwd()
added = False
try:
upstream_str = str(UPSTREAM)
if upstream_str not in sys.path:
sys.path.insert(0, upstream_str)
added = True
os.chdir(upstream_str)
yield UPSTREAM
finally:
os.chdir(prev_cwd)
if added and upstream_str in sys.path:
sys.path.remove(upstream_str)
def ensure_weights(version: str) -> None:
"""Download model weights from HF Hub if not already on disk."""
with upstream_context():
download_mod = importlib.import_module("scripts.download_model_data")
# download_model_data is idempotent — it skips files that already exist.
download_mod.download_model_data(version, "./", model_only=False if version in ("rflow-ct", "ddpm-ct") else True)
def _list_outputs_before(output_dir: Path) -> set[str]:
if not output_dir.exists():
return set()
return {p.name for p in output_dir.glob("*.nii.gz")}
def _newest_outputs(output_dir: Path, before: set[str]) -> list[Path]:
if not output_dir.exists():
return []
new = [p for p in output_dir.glob("*.nii.gz") if p.name not in before]
new.sort(key=lambda p: p.stat().st_mtime)
return new
def _write_temp_configs(
base_env_config: Path,
base_model_config: Path,
overrides: dict,
tag: str,
) -> tuple[Path, Path]:
"""
Write modified copies of the env + model configs into a per-call temp dir under
UPSTREAM / configs / _temp /. Returns (env_path, model_path).
"""
temp_dir = UPSTREAM / "configs" / "_temp"
temp_dir.mkdir(parents=True, exist_ok=True)
env = json.loads(base_env_config.read_text())
model = json.loads(base_model_config.read_text())
if "env" in overrides:
env.update(overrides["env"])
if "diffusion_unet_inference" in overrides:
model.setdefault("diffusion_unet_inference", {}).update(overrides["diffusion_unet_inference"])
suffix = f"{tag}_{uuid.uuid4().hex[:8]}"
env_path = temp_dir / f"env_{suffix}.json"
model_path = temp_dir / f"model_{suffix}.json"
env_path.write_text(json.dumps(env, indent=2))
model_path.write_text(json.dumps(model, indent=2))
return env_path, model_path
def run_image_only(
*,
version: str,
output_size: tuple[int, int, int],
spacing: tuple[float, float, float],
modality: int,
seed: int,
num_inference_steps: int = 30,
cfg_guidance_scale: Optional[float] = None,
) -> Path:
"""
Run the image-only diffusion pipeline (`scripts.diff_model_infer`) for the given
version (rflow-ct / rflow-mr / rflow-mr-brain). Returns path to generated NIfTI.
"""
ensure_weights(version)
base_env = UPSTREAM / "configs" / f"environment_maisi_diff_model_{version}.json"
base_model = UPSTREAM / "configs" / f"config_maisi_diff_model_{version}.json"
network_def = UPSTREAM / "configs" / "config_network_rflow.json"
inference_overrides = {
"dim": list(output_size),
"spacing": list(spacing),
"modality": modality,
"random_seed": seed,
"num_inference_steps": num_inference_steps,
}
if cfg_guidance_scale is not None:
inference_overrides["cfg_guidance_scale"] = cfg_guidance_scale
with upstream_context():
env_path, model_path = _write_temp_configs(
base_env_config=base_env,
base_model_config=base_model,
overrides={"diffusion_unet_inference": inference_overrides},
tag=version,
)
# Read env to determine output_dir (relative to upstream root)
env_data = json.loads(env_path.read_text())
output_dir = (UPSTREAM / env_data["output_dir"]).resolve()
existing = _list_outputs_before(output_dir)
diff_mod = importlib.import_module("scripts.diff_model_infer")
t0 = time.time()
diff_mod.diff_model_infer(
env_config_path=str(env_path.relative_to(UPSTREAM)),
model_config_path=str(model_path.relative_to(UPSTREAM)),
model_def_path=str(network_def.relative_to(UPSTREAM)),
num_gpus=1,
)
runtime = time.time() - t0
new_files = _newest_outputs(output_dir, existing)
if not new_files:
raise RuntimeError(f"No new NIfTI produced in {output_dir}")
latest = new_files[-1]
# Cleanup temp configs (don't fail if cleanup errors)
for p in (env_path, model_path):
try:
p.unlink()
except OSError:
pass
return latest
def run_paired_ct(
*,
output_size: tuple[int, int, int],
spacing: tuple[float, float, float],
body_region: list[str],
anatomy_list: list[str],
seed: int,
num_inference_steps: int = 30,
num_output_samples: int = 1,
) -> tuple[Path, Optional[Path]]:
"""
Run the paired CT image+mask pipeline (`scripts.inference`). Returns
(image_path, mask_path). Mask is the corresponding label volume.
"""
version = "rflow-ct"
ensure_weights(version)
base_env = UPSTREAM / "configs" / f"environment_{version}.json"
base_infer = UPSTREAM / "configs" / "config_infer.json"
# Build a custom config_infer with overrides
infer_data = json.loads(base_infer.read_text())
infer_data["output_size"] = list(output_size)
infer_data["spacing"] = list(spacing)
infer_data["body_region"] = list(body_region)
infer_data["anatomy_list"] = list(anatomy_list)
infer_data["num_inference_steps"] = num_inference_steps
infer_data["num_output_samples"] = num_output_samples
temp_dir = UPSTREAM / "configs" / "_temp"
temp_dir.mkdir(parents=True, exist_ok=True)
suffix = uuid.uuid4().hex[:8]
infer_path = temp_dir / f"config_infer_{version}_{suffix}.json"
infer_path.write_text(json.dumps(infer_data, indent=2))
env_data = json.loads(base_env.read_text())
output_dir = (UPSTREAM / env_data["output_dir"]).resolve()
with upstream_context():
existing = _list_outputs_before(output_dir)
inference_mod = importlib.import_module("scripts.inference")
# The upstream `main()` parses argv directly. Patch sys.argv around the call.
old_argv = sys.argv
sys.argv = [
"scripts.inference",
"-t", "./configs/config_network_rflow.json",
"-i", str(infer_path.relative_to(UPSTREAM)),
"-e", str(base_env.relative_to(UPSTREAM)),
"--random-seed", str(seed),
"--version", version,
]
os.environ.setdefault("MONAI_DATA_DIRECTORY", str(UPSTREAM / "temp_work_dir"))
try:
inference_mod.main()
finally:
sys.argv = old_argv
new_files = _newest_outputs(output_dir, existing)
try:
infer_path.unlink()
except OSError:
pass
# Paired pipeline writes both image and label NIfTIs. Convention: filenames
# contain "image" / "label" or are emitted as adjacent files.
image_path: Optional[Path] = None
mask_path: Optional[Path] = None
for p in new_files:
name = p.name.lower()
if "label" in name or "_mask" in name or "seg" in name:
mask_path = p
elif "image" in name or "img" in name:
image_path = p
# Fallback: if naming is ambiguous, treat the smaller-modality-time file as image
if image_path is None and new_files:
image_path = new_files[0]
if mask_path is None and len(new_files) > 1:
mask_path = new_files[-1]
if image_path is None:
raise RuntimeError(f"No NIfTI produced in {output_dir}")
return image_path, mask_path
def labels_present(mask_path: Path) -> set[int]:
"""Return the set of unique non-zero label IDs present in the mask volume."""
import nibabel as nib
import numpy as np
img = nib.load(str(mask_path))
data = np.asarray(img.dataobj)
uniq = np.unique(data).astype(int).tolist()
return {int(u) for u in uniq if u != 0}