Spaces:
Running on Zero
Running on Zero
| """ | |
| Shared driver for the upstream NV-Generate-CTMR scripts. | |
| Strategy: the upstream code is structured around argparse + global filesystem layout | |
| (reads relative-path configs, writes to a relative output_dir). Rather than refactor | |
| its internals, we treat it as an external tool: chdir into the upstream root, write | |
| modified config copies that override the user-controlled fields, ensure weights are | |
| downloaded, then call the upstream entry-point function. We then return the path of | |
| the most recently produced NIfTI from the configured output dir. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import importlib | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| ROOT = Path(__file__).resolve().parent.parent | |
| UPSTREAM = ROOT / "repos" / "NV-Generate-CTMR" | |
| GENERATED_OUTPUT = UPSTREAM / "output" | |
| def upstream_context(): | |
| """Temporarily add the upstream repo to sys.path and switch CWD to it.""" | |
| if not UPSTREAM.exists(): | |
| raise RuntimeError( | |
| f"Upstream repo not found at {UPSTREAM}. Run `bash pre-build.sh` first." | |
| ) | |
| prev_cwd = os.getcwd() | |
| added = False | |
| try: | |
| upstream_str = str(UPSTREAM) | |
| if upstream_str not in sys.path: | |
| sys.path.insert(0, upstream_str) | |
| added = True | |
| os.chdir(upstream_str) | |
| yield UPSTREAM | |
| finally: | |
| os.chdir(prev_cwd) | |
| if added and upstream_str in sys.path: | |
| sys.path.remove(upstream_str) | |
| def ensure_weights(version: str) -> None: | |
| """Download model weights from HF Hub if not already on disk.""" | |
| with upstream_context(): | |
| download_mod = importlib.import_module("scripts.download_model_data") | |
| # download_model_data is idempotent — it skips files that already exist. | |
| download_mod.download_model_data(version, "./", model_only=False if version in ("rflow-ct", "ddpm-ct") else True) | |
| def _list_outputs_before(output_dir: Path) -> set[str]: | |
| if not output_dir.exists(): | |
| return set() | |
| return {p.name for p in output_dir.glob("*.nii.gz")} | |
| def _newest_outputs(output_dir: Path, before: set[str]) -> list[Path]: | |
| if not output_dir.exists(): | |
| return [] | |
| new = [p for p in output_dir.glob("*.nii.gz") if p.name not in before] | |
| new.sort(key=lambda p: p.stat().st_mtime) | |
| return new | |
| def _write_temp_configs( | |
| base_env_config: Path, | |
| base_model_config: Path, | |
| overrides: dict, | |
| tag: str, | |
| ) -> tuple[Path, Path]: | |
| """ | |
| Write modified copies of the env + model configs into a per-call temp dir under | |
| UPSTREAM / configs / _temp /. Returns (env_path, model_path). | |
| """ | |
| temp_dir = UPSTREAM / "configs" / "_temp" | |
| temp_dir.mkdir(parents=True, exist_ok=True) | |
| env = json.loads(base_env_config.read_text()) | |
| model = json.loads(base_model_config.read_text()) | |
| if "env" in overrides: | |
| env.update(overrides["env"]) | |
| if "diffusion_unet_inference" in overrides: | |
| model.setdefault("diffusion_unet_inference", {}).update(overrides["diffusion_unet_inference"]) | |
| suffix = f"{tag}_{uuid.uuid4().hex[:8]}" | |
| env_path = temp_dir / f"env_{suffix}.json" | |
| model_path = temp_dir / f"model_{suffix}.json" | |
| env_path.write_text(json.dumps(env, indent=2)) | |
| model_path.write_text(json.dumps(model, indent=2)) | |
| return env_path, model_path | |
| def run_image_only( | |
| *, | |
| version: str, | |
| output_size: tuple[int, int, int], | |
| spacing: tuple[float, float, float], | |
| modality: int, | |
| seed: int, | |
| num_inference_steps: int = 30, | |
| cfg_guidance_scale: Optional[float] = None, | |
| ) -> Path: | |
| """ | |
| Run the image-only diffusion pipeline (`scripts.diff_model_infer`) for the given | |
| version (rflow-ct / rflow-mr / rflow-mr-brain). Returns path to generated NIfTI. | |
| """ | |
| ensure_weights(version) | |
| base_env = UPSTREAM / "configs" / f"environment_maisi_diff_model_{version}.json" | |
| base_model = UPSTREAM / "configs" / f"config_maisi_diff_model_{version}.json" | |
| network_def = UPSTREAM / "configs" / "config_network_rflow.json" | |
| inference_overrides = { | |
| "dim": list(output_size), | |
| "spacing": list(spacing), | |
| "modality": modality, | |
| "random_seed": seed, | |
| "num_inference_steps": num_inference_steps, | |
| } | |
| if cfg_guidance_scale is not None: | |
| inference_overrides["cfg_guidance_scale"] = cfg_guidance_scale | |
| with upstream_context(): | |
| env_path, model_path = _write_temp_configs( | |
| base_env_config=base_env, | |
| base_model_config=base_model, | |
| overrides={"diffusion_unet_inference": inference_overrides}, | |
| tag=version, | |
| ) | |
| # Read env to determine output_dir (relative to upstream root) | |
| env_data = json.loads(env_path.read_text()) | |
| output_dir = (UPSTREAM / env_data["output_dir"]).resolve() | |
| existing = _list_outputs_before(output_dir) | |
| diff_mod = importlib.import_module("scripts.diff_model_infer") | |
| t0 = time.time() | |
| diff_mod.diff_model_infer( | |
| env_config_path=str(env_path.relative_to(UPSTREAM)), | |
| model_config_path=str(model_path.relative_to(UPSTREAM)), | |
| model_def_path=str(network_def.relative_to(UPSTREAM)), | |
| num_gpus=1, | |
| ) | |
| runtime = time.time() - t0 | |
| new_files = _newest_outputs(output_dir, existing) | |
| if not new_files: | |
| raise RuntimeError(f"No new NIfTI produced in {output_dir}") | |
| latest = new_files[-1] | |
| # Cleanup temp configs (don't fail if cleanup errors) | |
| for p in (env_path, model_path): | |
| try: | |
| p.unlink() | |
| except OSError: | |
| pass | |
| return latest | |
| def run_paired_ct( | |
| *, | |
| output_size: tuple[int, int, int], | |
| spacing: tuple[float, float, float], | |
| body_region: list[str], | |
| anatomy_list: list[str], | |
| seed: int, | |
| num_inference_steps: int = 30, | |
| num_output_samples: int = 1, | |
| ) -> tuple[Path, Optional[Path]]: | |
| """ | |
| Run the paired CT image+mask pipeline (`scripts.inference`). Returns | |
| (image_path, mask_path). Mask is the corresponding label volume. | |
| """ | |
| version = "rflow-ct" | |
| ensure_weights(version) | |
| base_env = UPSTREAM / "configs" / f"environment_{version}.json" | |
| base_infer = UPSTREAM / "configs" / "config_infer.json" | |
| # Build a custom config_infer with overrides | |
| infer_data = json.loads(base_infer.read_text()) | |
| infer_data["output_size"] = list(output_size) | |
| infer_data["spacing"] = list(spacing) | |
| infer_data["body_region"] = list(body_region) | |
| infer_data["anatomy_list"] = list(anatomy_list) | |
| infer_data["num_inference_steps"] = num_inference_steps | |
| infer_data["num_output_samples"] = num_output_samples | |
| temp_dir = UPSTREAM / "configs" / "_temp" | |
| temp_dir.mkdir(parents=True, exist_ok=True) | |
| suffix = uuid.uuid4().hex[:8] | |
| infer_path = temp_dir / f"config_infer_{version}_{suffix}.json" | |
| infer_path.write_text(json.dumps(infer_data, indent=2)) | |
| env_data = json.loads(base_env.read_text()) | |
| output_dir = (UPSTREAM / env_data["output_dir"]).resolve() | |
| with upstream_context(): | |
| existing = _list_outputs_before(output_dir) | |
| inference_mod = importlib.import_module("scripts.inference") | |
| # The upstream `main()` parses argv directly. Patch sys.argv around the call. | |
| old_argv = sys.argv | |
| sys.argv = [ | |
| "scripts.inference", | |
| "-t", "./configs/config_network_rflow.json", | |
| "-i", str(infer_path.relative_to(UPSTREAM)), | |
| "-e", str(base_env.relative_to(UPSTREAM)), | |
| "--random-seed", str(seed), | |
| "--version", version, | |
| ] | |
| os.environ.setdefault("MONAI_DATA_DIRECTORY", str(UPSTREAM / "temp_work_dir")) | |
| try: | |
| inference_mod.main() | |
| finally: | |
| sys.argv = old_argv | |
| new_files = _newest_outputs(output_dir, existing) | |
| try: | |
| infer_path.unlink() | |
| except OSError: | |
| pass | |
| # Paired pipeline writes both image and label NIfTIs. Convention: filenames | |
| # contain "image" / "label" or are emitted as adjacent files. | |
| image_path: Optional[Path] = None | |
| mask_path: Optional[Path] = None | |
| for p in new_files: | |
| name = p.name.lower() | |
| if "label" in name or "_mask" in name or "seg" in name: | |
| mask_path = p | |
| elif "image" in name or "img" in name: | |
| image_path = p | |
| # Fallback: if naming is ambiguous, treat the smaller-modality-time file as image | |
| if image_path is None and new_files: | |
| image_path = new_files[0] | |
| if mask_path is None and len(new_files) > 1: | |
| mask_path = new_files[-1] | |
| if image_path is None: | |
| raise RuntimeError(f"No NIfTI produced in {output_dir}") | |
| return image_path, mask_path | |
| def labels_present(mask_path: Path) -> set[int]: | |
| """Return the set of unique non-zero label IDs present in the mask volume.""" | |
| import nibabel as nib | |
| import numpy as np | |
| img = nib.load(str(mask_path)) | |
| data = np.asarray(img.dataobj) | |
| uniq = np.unique(data).astype(int).tolist() | |
| return {int(u) for u in uniq if u != 0} | |