ct-heart-segmentation / scripts /run_pipeline.py
kbressem's picture
Upload scripts/run_pipeline.py with huggingface_hub
0900543 verified
#!/usr/bin/env python3
"""End-to-end coronary segmentation: CT image → binary mask → segmental labels.
Usage:
python scripts/run_pipeline.py --input /path/to/dicoms --output /path/to/results
docker run --gpus all -v /data/in:/input -v /data/out:/output ct-heart-seg
"""
import argparse
import logging
import os
import shutil
import subprocess
import sys
import tempfile
import time
from pathlib import Path
logger = logging.getLogger(__name__)
BINARY_BUNDLE = "ct_binary_coronary_segmentation"
SEGMENTAL_BUNDLE = "ct_segmental_coronary_segmentation"
def find_bundle_dir(root: Path, name: str) -> Path:
for prefix in ["", "bundle"]:
d = root / prefix / name if prefix else root / name
if (d / "configs").is_dir():
return d
raise FileNotFoundError(f"Cannot find {name}/ in {root} or {root / 'bundle'}")
def run_inference(bundle_dir: Path, extra_args: list[str], gpu: int, label: str):
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = [
sys.executable, "-m", "monai.bundle", "run", "inference",
"--config_file", "configs/ensemble_inference.yaml",
*extra_args,
]
logger.info("[%s] %s (cwd=%s)", label, " ".join(cmd), bundle_dir)
t0 = time.time()
result = subprocess.run(cmd, cwd=str(bundle_dir), env=env, capture_output=True, text=True)
elapsed = time.time() - t0
if result.stdout:
for line in result.stdout.strip().split("\n"):
logger.info("[%s] %s", label, line)
if result.stderr:
for line in result.stderr.strip().split("\n"):
logger.info("[%s] %s", label, line)
if result.returncode != 0:
logger.error("[%s] FAILED (exit %d) after %.1fs", label, result.returncode, elapsed)
sys.exit(result.returncode)
logger.info("[%s] Done in %.1fs", label, elapsed)
def main():
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--input", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--gpu", type=int, default=0)
args = parser.parse_args()
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(), logging.FileHandler(output_dir / "pipeline.log")],
force=True,
)
repo_root = Path(__file__).resolve().parent.parent
binary_dir = find_bundle_dir(repo_root, BINARY_BUNDLE)
segmental_dir = find_bundle_dir(repo_root, SEGMENTAL_BUNDLE)
logger.info("Input: %s | Output: %s | GPU: %d", args.input, output_dir, args.gpu)
t_start = time.time()
with tempfile.TemporaryDirectory(prefix="binary_output_") as tmp_binary:
run_inference(binary_dir, ["--dataset_dir", str(Path(args.input).resolve()), "--output_dir", tmp_binary], args.gpu, "binary")
binary_out = output_dir / "binary"
binary_out.mkdir(parents=True, exist_ok=True)
for f in Path(tmp_binary).glob("*"):
shutil.copy2(f, binary_out / f.name)
logger.info("Binary: %d files → %s", len(list(binary_out.glob("*.nii.gz"))), binary_out)
segmental_out = output_dir / "segmental"
segmental_out.mkdir(parents=True, exist_ok=True)
run_inference(segmental_dir, ["--binary_label_dir", tmp_binary, "--output_dir", str(segmental_out)], args.gpu, "segmental")
logger.info("Segmental: %d files → %s", len(list(segmental_out.glob("*.nii.gz"))), segmental_out)
logger.info("Pipeline complete in %.1fs", time.time() - t_start)
if __name__ == "__main__":
main()