File size: 9,177 Bytes
ab1db83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Shared driver for the upstream NV-Generate-CTMR scripts.

Strategy: the upstream code is structured around argparse + global filesystem layout
(reads relative-path configs, writes to a relative output_dir). Rather than refactor
its internals, we treat it as an external tool: chdir into the upstream root, write
modified config copies that override the user-controlled fields, ensure weights are
downloaded, then call the upstream entry-point function. We then return the path of
the most recently produced NIfTI from the configured output dir.
"""
from __future__ import annotations

import contextlib
import importlib
import json
import os
import sys
import time
import uuid
from pathlib import Path
from typing import Iterable, Optional

ROOT = Path(__file__).resolve().parent.parent
UPSTREAM = ROOT / "repos" / "NV-Generate-CTMR"
GENERATED_OUTPUT = UPSTREAM / "output"


@contextlib.contextmanager
def upstream_context():
    """Temporarily add the upstream repo to sys.path and switch CWD to it."""
    if not UPSTREAM.exists():
        raise RuntimeError(
            f"Upstream repo not found at {UPSTREAM}. Run `bash pre-build.sh` first."
        )
    prev_cwd = os.getcwd()
    added = False
    try:
        upstream_str = str(UPSTREAM)
        if upstream_str not in sys.path:
            sys.path.insert(0, upstream_str)
            added = True
        os.chdir(upstream_str)
        yield UPSTREAM
    finally:
        os.chdir(prev_cwd)
        if added and upstream_str in sys.path:
            sys.path.remove(upstream_str)


def ensure_weights(version: str) -> None:
    """Download model weights from HF Hub if not already on disk."""
    with upstream_context():
        download_mod = importlib.import_module("scripts.download_model_data")
        # download_model_data is idempotent — it skips files that already exist.
        download_mod.download_model_data(version, "./", model_only=False if version in ("rflow-ct", "ddpm-ct") else True)


def _list_outputs_before(output_dir: Path) -> set[str]:
    if not output_dir.exists():
        return set()
    return {p.name for p in output_dir.glob("*.nii.gz")}


def _newest_outputs(output_dir: Path, before: set[str]) -> list[Path]:
    if not output_dir.exists():
        return []
    new = [p for p in output_dir.glob("*.nii.gz") if p.name not in before]
    new.sort(key=lambda p: p.stat().st_mtime)
    return new


def _write_temp_configs(
    base_env_config: Path,
    base_model_config: Path,
    overrides: dict,
    tag: str,
) -> tuple[Path, Path]:
    """
    Write modified copies of the env + model configs into a per-call temp dir under
    UPSTREAM / configs / _temp /.  Returns (env_path, model_path).
    """
    temp_dir = UPSTREAM / "configs" / "_temp"
    temp_dir.mkdir(parents=True, exist_ok=True)

    env = json.loads(base_env_config.read_text())
    model = json.loads(base_model_config.read_text())
    if "env" in overrides:
        env.update(overrides["env"])
    if "diffusion_unet_inference" in overrides:
        model.setdefault("diffusion_unet_inference", {}).update(overrides["diffusion_unet_inference"])

    suffix = f"{tag}_{uuid.uuid4().hex[:8]}"
    env_path = temp_dir / f"env_{suffix}.json"
    model_path = temp_dir / f"model_{suffix}.json"
    env_path.write_text(json.dumps(env, indent=2))
    model_path.write_text(json.dumps(model, indent=2))
    return env_path, model_path


def run_image_only(
    *,
    version: str,
    output_size: tuple[int, int, int],
    spacing: tuple[float, float, float],
    modality: int,
    seed: int,
    num_inference_steps: int = 30,
    cfg_guidance_scale: Optional[float] = None,
) -> Path:
    """
    Run the image-only diffusion pipeline (`scripts.diff_model_infer`) for the given
    version (rflow-ct / rflow-mr / rflow-mr-brain). Returns path to generated NIfTI.
    """
    ensure_weights(version)

    base_env = UPSTREAM / "configs" / f"environment_maisi_diff_model_{version}.json"
    base_model = UPSTREAM / "configs" / f"config_maisi_diff_model_{version}.json"
    network_def = UPSTREAM / "configs" / "config_network_rflow.json"

    inference_overrides = {
        "dim": list(output_size),
        "spacing": list(spacing),
        "modality": modality,
        "random_seed": seed,
        "num_inference_steps": num_inference_steps,
    }
    if cfg_guidance_scale is not None:
        inference_overrides["cfg_guidance_scale"] = cfg_guidance_scale

    with upstream_context():
        env_path, model_path = _write_temp_configs(
            base_env_config=base_env,
            base_model_config=base_model,
            overrides={"diffusion_unet_inference": inference_overrides},
            tag=version,
        )

        # Read env to determine output_dir (relative to upstream root)
        env_data = json.loads(env_path.read_text())
        output_dir = (UPSTREAM / env_data["output_dir"]).resolve()
        existing = _list_outputs_before(output_dir)

        diff_mod = importlib.import_module("scripts.diff_model_infer")

        t0 = time.time()
        diff_mod.diff_model_infer(
            env_config_path=str(env_path.relative_to(UPSTREAM)),
            model_config_path=str(model_path.relative_to(UPSTREAM)),
            model_def_path=str(network_def.relative_to(UPSTREAM)),
            num_gpus=1,
        )
        runtime = time.time() - t0

        new_files = _newest_outputs(output_dir, existing)
        if not new_files:
            raise RuntimeError(f"No new NIfTI produced in {output_dir}")
        latest = new_files[-1]

    # Cleanup temp configs (don't fail if cleanup errors)
    for p in (env_path, model_path):
        try:
            p.unlink()
        except OSError:
            pass

    return latest


def run_paired_ct(
    *,
    output_size: tuple[int, int, int],
    spacing: tuple[float, float, float],
    body_region: list[str],
    anatomy_list: list[str],
    seed: int,
    num_inference_steps: int = 30,
    num_output_samples: int = 1,
) -> tuple[Path, Optional[Path]]:
    """
    Run the paired CT image+mask pipeline (`scripts.inference`). Returns
    (image_path, mask_path). Mask is the corresponding label volume.
    """
    version = "rflow-ct"
    ensure_weights(version)

    base_env = UPSTREAM / "configs" / f"environment_{version}.json"
    base_infer = UPSTREAM / "configs" / "config_infer.json"

    # Build a custom config_infer with overrides
    infer_data = json.loads(base_infer.read_text())
    infer_data["output_size"] = list(output_size)
    infer_data["spacing"] = list(spacing)
    infer_data["body_region"] = list(body_region)
    infer_data["anatomy_list"] = list(anatomy_list)
    infer_data["num_inference_steps"] = num_inference_steps
    infer_data["num_output_samples"] = num_output_samples

    temp_dir = UPSTREAM / "configs" / "_temp"
    temp_dir.mkdir(parents=True, exist_ok=True)
    suffix = uuid.uuid4().hex[:8]
    infer_path = temp_dir / f"config_infer_{version}_{suffix}.json"
    infer_path.write_text(json.dumps(infer_data, indent=2))

    env_data = json.loads(base_env.read_text())
    output_dir = (UPSTREAM / env_data["output_dir"]).resolve()

    with upstream_context():
        existing = _list_outputs_before(output_dir)

        inference_mod = importlib.import_module("scripts.inference")
        # The upstream `main()` parses argv directly. Patch sys.argv around the call.
        old_argv = sys.argv
        sys.argv = [
            "scripts.inference",
            "-t", "./configs/config_network_rflow.json",
            "-i", str(infer_path.relative_to(UPSTREAM)),
            "-e", str(base_env.relative_to(UPSTREAM)),
            "--random-seed", str(seed),
            "--version", version,
        ]
        os.environ.setdefault("MONAI_DATA_DIRECTORY", str(UPSTREAM / "temp_work_dir"))

        try:
            inference_mod.main()
        finally:
            sys.argv = old_argv

        new_files = _newest_outputs(output_dir, existing)

    try:
        infer_path.unlink()
    except OSError:
        pass

    # Paired pipeline writes both image and label NIfTIs. Convention: filenames
    # contain "image" / "label" or are emitted as adjacent files.
    image_path: Optional[Path] = None
    mask_path: Optional[Path] = None
    for p in new_files:
        name = p.name.lower()
        if "label" in name or "_mask" in name or "seg" in name:
            mask_path = p
        elif "image" in name or "img" in name:
            image_path = p

    # Fallback: if naming is ambiguous, treat the smaller-modality-time file as image
    if image_path is None and new_files:
        image_path = new_files[0]
    if mask_path is None and len(new_files) > 1:
        mask_path = new_files[-1]

    if image_path is None:
        raise RuntimeError(f"No NIfTI produced in {output_dir}")
    return image_path, mask_path


def labels_present(mask_path: Path) -> set[int]:
    """Return the set of unique non-zero label IDs present in the mask volume."""
    import nibabel as nib
    import numpy as np

    img = nib.load(str(mask_path))
    data = np.asarray(img.dataobj)
    uniq = np.unique(data).astype(int).tolist()
    return {int(u) for u in uniq if u != 0}