"""CT workspace — paired image+mask generation with anatomy controls.""" from __future__ import annotations from typing import Any import gradio as gr from pipelines import GenerationRequest from pipelines.ct import generate as generate_ct from utils.windowing import CT_PRESETS from viewer.niivue_embed import empty_html, render_viewer from viewer.colormaps import legend_html from .presets import CT_ANATOMY_CHOICES, CT_BODY_REGIONS, CT_SAMPLES, XY_CHOICES, Z_CHOICES def _spacing_default() -> tuple[float, float, float]: return (1.5, 1.5, 1.5) def _on_preset(preset_name: str) -> tuple[float, float] | None: preset = CT_PRESETS.get(preset_name) if preset is None: return None lo, hi = preset.to_min_max() return lo, hi def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]: """Returns the hidden workspace Group and the back-to-home button.""" with gr.Group(visible=False, elem_classes=["workspace"]) as group: with gr.Row(elem_classes=["workspace-header"]): back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0) gr.HTML( '
' '' 'NV-Generate' '/' 'CT' '
' ) gr.HTML( """

NV-Generate · CT

Whole-body synthetic CT volumes with paired 132-class anatomy masks. Generate balanced training data for segmentation models, augment rare pathologies with controllable organ and tumor size, or share privacy-preserving samples for research.

ArchitectureMAISI-v2 · Rectified Flow
Body regionshead · chest · thorax · abdomen · pelvis · lower
Segmentation132 classes · paired
Inference30 rectified-flow steps
Max volume512 × 512 × 768 vox
""" ) with gr.Row(elem_classes=["workspace-row"]): with gr.Column(scale=4, min_width=320, elem_classes=["controls"]): gr.Markdown("##### Quick presets") with gr.Row(): sample_btns = [gr.Button(s["label"], size="sm") for s in CT_SAMPLES] gr.Markdown("##### Conditioning") generate_masks = gr.Checkbox( label="Paired anatomy mask · 132 classes", value=True, info="Off → image-only generation (faster).", ) body_region = gr.CheckboxGroup( choices=CT_BODY_REGIONS, value=["abdomen"], label="Body region", ) anatomy_list = gr.Dropdown( choices=CT_ANATOMY_CHOICES, value=["liver"], multiselect=True, label="Target anatomies", ) gr.Markdown("##### Geometry") dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)") dim_z = gr.Radio(choices=Z_CHOICES, value=256, label="Z (voxels)") with gr.Row(equal_height=True): sp_x = gr.Slider(1.0, 5.0, value=1.5, step=0.05, label="Spacing X (mm)") sp_y = gr.Slider(1.0, 5.0, value=1.5, step=0.05, label="Spacing Y (mm)") sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)") gr.HTML('
Field of view needs at least 256 mm: increase X voxels or X spacing if shorter.
') gr.Markdown("##### Diffusion") with gr.Row(equal_height=True): seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"]) steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps") generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"]) status = gr.HTML('
Idle. Configure parameters and click Generate.
', elem_classes=["status"]) with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]): gr.HTML( '
' 'Viewport · Multiplanar' 'Axial · Coronal · Sagittal · 3D' '
' ) viewer = gr.HTML(empty_html(), elem_classes=["viewer"]) with gr.Row(elem_classes=["preset-row"]): gr.HTML('
Window / level preset
') preset = gr.Radio( choices=[p.name for p in CT_PRESETS.values()], value="Soft Tissue", show_label=False, container=False, elem_classes=["preset-radio"], ) legend = gr.HTML("", elem_classes=["legend-host"], visible=False) download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"]) # State holding the most recent generation, so the W/L preset radio can # re-render the viewer without re-running inference. last_result = gr.State(None) _PRESET_KEY = {"Soft Tissue": "soft_tissue", "Lung": "lung", "Bone": "bone", "Brain": "brain"} # ---- handlers ---- def _generate(generate_masks, body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, preset): req = GenerationRequest( model="ct", output_size=(int(dim_xy), int(dim_xy), int(dim_z)), spacing=(float(sp_x), float(sp_y), float(sp_z)), seed=int(seed), num_steps=int(steps), body_region=list(body_region) if generate_masks else None, anatomy_list=list(anatomy_list) if anatomy_list else ["liver"], generate_masks=bool(generate_masks), ) try: result = generate_ct(req) except Exception as e: return ( empty_html(f"Generation failed: {e}"), gr.update(visible=False, value=""), gr.update(visible=False, value=None), f'
✕ Generation failed ERR{e}
', None, ) wm = CT_PRESETS.get(_PRESET_KEY.get(preset, "soft_tissue")) window_min, window_max = wm.to_min_max() if wm else (None, None) html = render_viewer( volume_path=result.volume_path, mask_path=result.mask_path, colormap="gray", # CT base is grayscale; W/L preset drives windowing via cal_min/max used_label_ids=list(result.used_anatomy_labels.keys()), window_min=window_min, window_max=window_max, ) legend_str = legend_html(result.used_anatomy_labels) if result.mask_path else "" files = [result.volume_path] + ([result.mask_path] if result.mask_path else []) stat = ( '
' '' 'Generated' f'runtime{result.runtime_seconds:.1f}s' f'seed{result.seed}' f'steps{req.num_steps}' f'size{req.output_size[0]}³' '
' ) return ( html, gr.update(visible=bool(legend_str), value=legend_str), gr.update(visible=True, value=files), stat, result, ) def _reapply_preset(preset, result): if not result or not getattr(result, "volume_path", None): return gr.update() wm = CT_PRESETS.get(_PRESET_KEY.get(preset, "soft_tissue")) window_min, window_max = wm.to_min_max() if wm else (None, None) return render_viewer( volume_path=result.volume_path, mask_path=result.mask_path, colormap="gray", used_label_ids=list(result.used_anatomy_labels.keys()) if result.used_anatomy_labels else [], window_min=window_min, window_max=window_max, ) decorated = spaces_gpu(_generate) if spaces_gpu else _generate ( generate_btn.click( lambda: gr.update(value="Generating volume…", interactive=False), outputs=[generate_btn], ) .then( decorated, inputs=[generate_masks, body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, preset], outputs=[viewer, legend, download, status, last_result], show_progress="full", ) .then( lambda: gr.update(value="Generate volume", interactive=True), outputs=[generate_btn], ) ) preset.change(_reapply_preset, inputs=[preset, last_result], outputs=[viewer]) for btn, sample in zip(sample_btns, CT_SAMPLES): def _apply(s=sample): return ( s["body_region"], s["anatomy_list"], s["xy"], s["z"], s["spacing"][0], s["spacing"][1], s["spacing"][2], ) btn.click(_apply, outputs=[body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z]) return group, back_btn