nv-generate / ui /workspace_ct.py
zephyrie's picture
Move Quick Presets to the top of the controls sidebar
7fb7cc3
"""CT workspace — paired image+mask generation with anatomy controls."""
from __future__ import annotations
from typing import Any
import gradio as gr
from pipelines import GenerationRequest
from pipelines.ct import generate as generate_ct
from utils.windowing import CT_PRESETS
from viewer.niivue_embed import empty_html, render_viewer
from viewer.colormaps import legend_html
from .presets import CT_ANATOMY_CHOICES, CT_BODY_REGIONS, CT_SAMPLES, XY_CHOICES, Z_CHOICES
def _spacing_default() -> tuple[float, float, float]:
return (1.5, 1.5, 1.5)
def _on_preset(preset_name: str) -> tuple[float, float] | None:
preset = CT_PRESETS.get(preset_name)
if preset is None:
return None
lo, hi = preset.to_min_max()
return lo, hi
def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]:
"""Returns the hidden workspace Group and the back-to-home button."""
with gr.Group(visible=False, elem_classes=["workspace"]) as group:
with gr.Row(elem_classes=["workspace-header"]):
back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0)
gr.HTML(
'<div class="workspace-title">'
'<span class="ws-dot" style="background:var(--ct);color:var(--ct)"></span>'
'<span class="ws-crumb">NV-Generate</span>'
'<span class="ws-crumb-sep">/</span>'
'<span class="ws-active">CT</span>'
'</div>'
)
gr.HTML(
"""
<div class="ws-intro ws-intro-ct">
<div class="ws-intro-left">
<h2 class="ws-intro-title">NV-Generate · CT</h2>
<p class="ws-intro-desc">
Whole-body synthetic CT volumes with paired 132-class anatomy masks.
Generate balanced training data for segmentation models, augment rare
pathologies with controllable organ and tumor size, or share
privacy-preserving samples for research.
</p>
</div>
<div class="ws-intro-facts">
<div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div>
<div class="ws-fact"><span class="ws-fact-k">Body regions</span><span class="ws-fact-v">head · chest · thorax · abdomen · pelvis · lower</span></div>
<div class="ws-fact"><span class="ws-fact-k">Segmentation</span><span class="ws-fact-v">132 classes · paired</span></div>
<div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 rectified-flow steps</span></div>
<div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 768 vox</span></div>
</div>
</div>
"""
)
with gr.Row(elem_classes=["workspace-row"]):
with gr.Column(scale=4, min_width=320, elem_classes=["controls"]):
gr.Markdown("##### Quick presets")
with gr.Row():
sample_btns = [gr.Button(s["label"], size="sm") for s in CT_SAMPLES]
gr.Markdown("##### Conditioning")
generate_masks = gr.Checkbox(
label="Paired anatomy mask · 132 classes",
value=True,
info="Off → image-only generation (faster).",
)
body_region = gr.CheckboxGroup(
choices=CT_BODY_REGIONS,
value=["abdomen"],
label="Body region",
)
anatomy_list = gr.Dropdown(
choices=CT_ANATOMY_CHOICES,
value=["liver"],
multiselect=True,
label="Target anatomies",
)
gr.Markdown("##### Geometry")
dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)")
dim_z = gr.Radio(choices=Z_CHOICES, value=256, label="Z (voxels)")
with gr.Row(equal_height=True):
sp_x = gr.Slider(1.0, 5.0, value=1.5, step=0.05, label="Spacing X (mm)")
sp_y = gr.Slider(1.0, 5.0, value=1.5, step=0.05, label="Spacing Y (mm)")
sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)")
gr.HTML('<div class="hint">Field of view needs at least 256 mm: increase X voxels or X spacing if shorter.</div>')
gr.Markdown("##### Diffusion")
with gr.Row(equal_height=True):
seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"])
steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps")
generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"])
status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"])
with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]):
gr.HTML(
'<div class="viewer-strip">'
'<span class="viewer-strip-left">Viewport · Multiplanar</span>'
'<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>'
'</div>'
)
viewer = gr.HTML(empty_html(), elem_classes=["viewer"])
with gr.Row(elem_classes=["preset-row"]):
gr.HTML('<div class="preset-label">Window / level preset</div>')
preset = gr.Radio(
choices=[p.name for p in CT_PRESETS.values()],
value="Soft Tissue",
show_label=False,
container=False,
elem_classes=["preset-radio"],
)
legend = gr.HTML("", elem_classes=["legend-host"], visible=False)
download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"])
# State holding the most recent generation, so the W/L preset radio can
# re-render the viewer without re-running inference.
last_result = gr.State(None)
_PRESET_KEY = {"Soft Tissue": "soft_tissue", "Lung": "lung", "Bone": "bone", "Brain": "brain"}
# ---- handlers ----
def _generate(generate_masks, body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, preset):
req = GenerationRequest(
model="ct",
output_size=(int(dim_xy), int(dim_xy), int(dim_z)),
spacing=(float(sp_x), float(sp_y), float(sp_z)),
seed=int(seed),
num_steps=int(steps),
body_region=list(body_region) if generate_masks else None,
anatomy_list=list(anatomy_list) if anatomy_list else ["liver"],
generate_masks=bool(generate_masks),
)
try:
result = generate_ct(req)
except Exception as e:
return (
empty_html(f"Generation failed: {e}"),
gr.update(visible=False, value=""),
gr.update(visible=False, value=None),
f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>',
None,
)
wm = CT_PRESETS.get(_PRESET_KEY.get(preset, "soft_tissue"))
window_min, window_max = wm.to_min_max() if wm else (None, None)
html = render_viewer(
volume_path=result.volume_path,
mask_path=result.mask_path,
colormap="gray", # CT base is grayscale; W/L preset drives windowing via cal_min/max
used_label_ids=list(result.used_anatomy_labels.keys()),
window_min=window_min,
window_max=window_max,
)
legend_str = legend_html(result.used_anatomy_labels) if result.mask_path else ""
files = [result.volume_path] + ([result.mask_path] if result.mask_path else [])
stat = (
'<div class="stat-line">'
'<span class="stat-mark"></span>'
'<span class="stat-label">Generated</span>'
f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>'
f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>'
f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>'
f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>'
'</div>'
)
return (
html,
gr.update(visible=bool(legend_str), value=legend_str),
gr.update(visible=True, value=files),
stat,
result,
)
def _reapply_preset(preset, result):
if not result or not getattr(result, "volume_path", None):
return gr.update()
wm = CT_PRESETS.get(_PRESET_KEY.get(preset, "soft_tissue"))
window_min, window_max = wm.to_min_max() if wm else (None, None)
return render_viewer(
volume_path=result.volume_path,
mask_path=result.mask_path,
colormap="gray",
used_label_ids=list(result.used_anatomy_labels.keys()) if result.used_anatomy_labels else [],
window_min=window_min,
window_max=window_max,
)
decorated = spaces_gpu(_generate) if spaces_gpu else _generate
(
generate_btn.click(
lambda: gr.update(value="Generating volume…", interactive=False),
outputs=[generate_btn],
)
.then(
decorated,
inputs=[generate_masks, body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, preset],
outputs=[viewer, legend, download, status, last_result],
show_progress="full",
)
.then(
lambda: gr.update(value="Generate volume", interactive=True),
outputs=[generate_btn],
)
)
preset.change(_reapply_preset, inputs=[preset, last_result], outputs=[viewer])
for btn, sample in zip(sample_btns, CT_SAMPLES):
def _apply(s=sample):
return (
s["body_region"],
s["anatomy_list"],
s["xy"], s["z"],
s["spacing"][0], s["spacing"][1], s["spacing"][2],
)
btn.click(_apply, outputs=[body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z])
return group, back_btn