Spaces:
Running on Zero
Running on Zero
| """CT workspace — paired image+mask generation with anatomy controls.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| import gradio as gr | |
| from pipelines import GenerationRequest | |
| from pipelines.ct import generate as generate_ct | |
| from utils.windowing import CT_PRESETS | |
| from viewer.niivue_embed import empty_html, render_viewer | |
| from viewer.colormaps import legend_html | |
| from .presets import CT_ANATOMY_CHOICES, CT_BODY_REGIONS, CT_SAMPLES, XY_CHOICES, Z_CHOICES | |
| def _spacing_default() -> tuple[float, float, float]: | |
| return (1.5, 1.5, 1.5) | |
| def _on_preset(preset_name: str) -> tuple[float, float] | None: | |
| preset = CT_PRESETS.get(preset_name) | |
| if preset is None: | |
| return None | |
| lo, hi = preset.to_min_max() | |
| return lo, hi | |
| def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]: | |
| """Returns the hidden workspace Group and the back-to-home button.""" | |
| with gr.Group(visible=False, elem_classes=["workspace"]) as group: | |
| with gr.Row(elem_classes=["workspace-header"]): | |
| back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0) | |
| gr.HTML( | |
| '<div class="workspace-title">' | |
| '<span class="ws-dot" style="background:var(--ct);color:var(--ct)"></span>' | |
| '<span class="ws-crumb">NV-Generate</span>' | |
| '<span class="ws-crumb-sep">/</span>' | |
| '<span class="ws-active">CT</span>' | |
| '</div>' | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="ws-intro ws-intro-ct"> | |
| <div class="ws-intro-left"> | |
| <h2 class="ws-intro-title">NV-Generate · CT</h2> | |
| <p class="ws-intro-desc"> | |
| Whole-body synthetic CT volumes with paired 132-class anatomy masks. | |
| Generate balanced training data for segmentation models, augment rare | |
| pathologies with controllable organ and tumor size, or share | |
| privacy-preserving samples for research. | |
| </p> | |
| </div> | |
| <div class="ws-intro-facts"> | |
| <div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Body regions</span><span class="ws-fact-v">head · chest · thorax · abdomen · pelvis · lower</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Segmentation</span><span class="ws-fact-v">132 classes · paired</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 rectified-flow steps</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 768 vox</span></div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(elem_classes=["workspace-row"]): | |
| with gr.Column(scale=4, min_width=320, elem_classes=["controls"]): | |
| gr.Markdown("##### Quick presets") | |
| with gr.Row(): | |
| sample_btns = [gr.Button(s["label"], size="sm") for s in CT_SAMPLES] | |
| gr.Markdown("##### Conditioning") | |
| generate_masks = gr.Checkbox( | |
| label="Paired anatomy mask · 132 classes", | |
| value=True, | |
| info="Off → image-only generation (faster).", | |
| ) | |
| body_region = gr.CheckboxGroup( | |
| choices=CT_BODY_REGIONS, | |
| value=["abdomen"], | |
| label="Body region", | |
| ) | |
| anatomy_list = gr.Dropdown( | |
| choices=CT_ANATOMY_CHOICES, | |
| value=["liver"], | |
| multiselect=True, | |
| label="Target anatomies", | |
| ) | |
| gr.Markdown("##### Geometry") | |
| dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)") | |
| dim_z = gr.Radio(choices=Z_CHOICES, value=256, label="Z (voxels)") | |
| with gr.Row(equal_height=True): | |
| sp_x = gr.Slider(1.0, 5.0, value=1.5, step=0.05, label="Spacing X (mm)") | |
| sp_y = gr.Slider(1.0, 5.0, value=1.5, step=0.05, label="Spacing Y (mm)") | |
| sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)") | |
| gr.HTML('<div class="hint">Field of view needs at least 256 mm: increase X voxels or X spacing if shorter.</div>') | |
| gr.Markdown("##### Diffusion") | |
| with gr.Row(equal_height=True): | |
| seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"]) | |
| steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps") | |
| generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"]) | |
| status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"]) | |
| with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]): | |
| gr.HTML( | |
| '<div class="viewer-strip">' | |
| '<span class="viewer-strip-left">Viewport · Multiplanar</span>' | |
| '<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>' | |
| '</div>' | |
| ) | |
| viewer = gr.HTML(empty_html(), elem_classes=["viewer"]) | |
| with gr.Row(elem_classes=["preset-row"]): | |
| gr.HTML('<div class="preset-label">Window / level preset</div>') | |
| preset = gr.Radio( | |
| choices=[p.name for p in CT_PRESETS.values()], | |
| value="Soft Tissue", | |
| show_label=False, | |
| container=False, | |
| elem_classes=["preset-radio"], | |
| ) | |
| legend = gr.HTML("", elem_classes=["legend-host"], visible=False) | |
| download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"]) | |
| # State holding the most recent generation, so the W/L preset radio can | |
| # re-render the viewer without re-running inference. | |
| last_result = gr.State(None) | |
| _PRESET_KEY = {"Soft Tissue": "soft_tissue", "Lung": "lung", "Bone": "bone", "Brain": "brain"} | |
| # ---- handlers ---- | |
| def _generate(generate_masks, body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, preset): | |
| req = GenerationRequest( | |
| model="ct", | |
| output_size=(int(dim_xy), int(dim_xy), int(dim_z)), | |
| spacing=(float(sp_x), float(sp_y), float(sp_z)), | |
| seed=int(seed), | |
| num_steps=int(steps), | |
| body_region=list(body_region) if generate_masks else None, | |
| anatomy_list=list(anatomy_list) if anatomy_list else ["liver"], | |
| generate_masks=bool(generate_masks), | |
| ) | |
| try: | |
| result = generate_ct(req) | |
| except Exception as e: | |
| return ( | |
| empty_html(f"Generation failed: {e}"), | |
| gr.update(visible=False, value=""), | |
| gr.update(visible=False, value=None), | |
| f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>', | |
| None, | |
| ) | |
| wm = CT_PRESETS.get(_PRESET_KEY.get(preset, "soft_tissue")) | |
| window_min, window_max = wm.to_min_max() if wm else (None, None) | |
| html = render_viewer( | |
| volume_path=result.volume_path, | |
| mask_path=result.mask_path, | |
| colormap="gray", # CT base is grayscale; W/L preset drives windowing via cal_min/max | |
| used_label_ids=list(result.used_anatomy_labels.keys()), | |
| window_min=window_min, | |
| window_max=window_max, | |
| ) | |
| legend_str = legend_html(result.used_anatomy_labels) if result.mask_path else "" | |
| files = [result.volume_path] + ([result.mask_path] if result.mask_path else []) | |
| stat = ( | |
| '<div class="stat-line">' | |
| '<span class="stat-mark"></span>' | |
| '<span class="stat-label">Generated</span>' | |
| f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>' | |
| '</div>' | |
| ) | |
| return ( | |
| html, | |
| gr.update(visible=bool(legend_str), value=legend_str), | |
| gr.update(visible=True, value=files), | |
| stat, | |
| result, | |
| ) | |
| def _reapply_preset(preset, result): | |
| if not result or not getattr(result, "volume_path", None): | |
| return gr.update() | |
| wm = CT_PRESETS.get(_PRESET_KEY.get(preset, "soft_tissue")) | |
| window_min, window_max = wm.to_min_max() if wm else (None, None) | |
| return render_viewer( | |
| volume_path=result.volume_path, | |
| mask_path=result.mask_path, | |
| colormap="gray", | |
| used_label_ids=list(result.used_anatomy_labels.keys()) if result.used_anatomy_labels else [], | |
| window_min=window_min, | |
| window_max=window_max, | |
| ) | |
| decorated = spaces_gpu(_generate) if spaces_gpu else _generate | |
| ( | |
| generate_btn.click( | |
| lambda: gr.update(value="Generating volume…", interactive=False), | |
| outputs=[generate_btn], | |
| ) | |
| .then( | |
| decorated, | |
| inputs=[generate_masks, body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, preset], | |
| outputs=[viewer, legend, download, status, last_result], | |
| show_progress="full", | |
| ) | |
| .then( | |
| lambda: gr.update(value="Generate volume", interactive=True), | |
| outputs=[generate_btn], | |
| ) | |
| ) | |
| preset.change(_reapply_preset, inputs=[preset, last_result], outputs=[viewer]) | |
| for btn, sample in zip(sample_btns, CT_SAMPLES): | |
| def _apply(s=sample): | |
| return ( | |
| s["body_region"], | |
| s["anatomy_list"], | |
| s["xy"], s["z"], | |
| s["spacing"][0], s["spacing"][1], s["spacing"][2], | |
| ) | |
| btn.click(_apply, outputs=[body_region, anatomy_list, dim_xy, dim_z, sp_x, sp_y, sp_z]) | |
| return group, back_btn | |