zephyrie's picture
Initial commit: NV-Generate Gradio showcase
ab1db83
from __future__ import annotations
import time
from .base import GenerationRequest, GenerationResult
from . import _runner
def generate(req: GenerationRequest) -> GenerationResult:
"""CT generation. With masks: paired pipeline. Without: image-only."""
t0 = time.time()
if req.generate_masks:
image_path, mask_path = _runner.run_paired_ct(
output_size=req.output_size,
spacing=req.spacing,
body_region=req.body_region or [],
anatomy_list=req.anatomy_list or ["liver"],
seed=req.seed,
num_inference_steps=req.num_steps,
)
used: dict[int, str] = {}
if mask_path is not None:
from viewer.colormaps import CT_LABELS_BY_ID
present = _runner.labels_present(mask_path)
used = {lbl: CT_LABELS_BY_ID.get(lbl, f"label {lbl}") for lbl in sorted(present)}
return GenerationResult(
volume_path=str(image_path),
mask_path=str(mask_path) if mask_path else None,
used_anatomy_labels=used,
runtime_seconds=time.time() - t0,
seed=req.seed,
modality=1,
)
else:
path = _runner.run_image_only(
version="rflow-ct",
output_size=req.output_size,
spacing=req.spacing,
modality=1,
seed=req.seed,
num_inference_steps=req.num_steps,
)
return GenerationResult(
volume_path=str(path),
runtime_seconds=time.time() - t0,
seed=req.seed,
modality=1,
)