Spaces:
Running on Zero
Running on Zero
| """MR workspace — image-only multi-contrast MR generation.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| import gradio as gr | |
| from pipelines import GenerationRequest | |
| from pipelines.mr import generate as generate_mr | |
| from viewer.niivue_embed import empty_html, render_viewer | |
| from .presets import XY_CHOICES, Z_CHOICES, MR_CONTRAST_CHOICES, MR_SAMPLES | |
| CONTRAST_LABELS = [c[0] for c in MR_CONTRAST_CHOICES] | |
| LABEL_TO_MODALITY = {c[0]: c[1] for c in MR_CONTRAST_CHOICES} | |
| def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]: | |
| with gr.Group(visible=False, elem_classes=["workspace"]) as group: | |
| with gr.Row(elem_classes=["workspace-header"]): | |
| back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0) | |
| gr.HTML( | |
| '<div class="workspace-title">' | |
| '<span class="ws-dot" style="background:var(--mr);color:var(--mr)"></span>' | |
| '<span class="ws-crumb">NV-Generate</span>' | |
| '<span class="ws-crumb-sep">/</span>' | |
| '<span class="ws-active">MR</span>' | |
| '</div>' | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="ws-intro ws-intro-mr"> | |
| <div class="ws-intro-left"> | |
| <h2 class="ws-intro-title">NV-Generate · MR</h2> | |
| <p class="ws-intro-desc"> | |
| Multi-contrast MRI across brain, prostate, breast, and abdominal anatomy. | |
| Drive contrast through a modality embedding — T1, T2, FLAIR — at | |
| variable resolution and voxel spacing. Fine-tune on your own MRI data | |
| to extend to new modalities and regions. | |
| </p> | |
| </div> | |
| <div class="ws-intro-facts"> | |
| <div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Contrasts</span><span class="ws-fact-v">T1 · T2 · FLAIR</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Regions</span><span class="ws-fact-v">brain · prostate · breast · abdomen</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 steps</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 128 vox</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">License</span><span class="ws-fact-v ws-fact-warn">NVIDIA Non-Commercial</span></div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| '<div class="license-banner">' | |
| '<strong>Non-commercial license.</strong> ' | |
| 'NV-Generate-MR weights are released under the <a href="https://developer.download.nvidia.com/licenses/NVIDIA-OneWay-Noncommercial-License-22Mar2022.pdf" target="_blank">NVIDIA OneWay Non-Commercial License</a>. ' | |
| 'Use is permitted for academic research only.' | |
| '</div>' | |
| ) | |
| with gr.Row(elem_classes=["workspace-row"]): | |
| with gr.Column(scale=4, min_width=320, elem_classes=["controls"]): | |
| gr.Markdown("##### Quick presets") | |
| with gr.Row(): | |
| sample_btns = [gr.Button(s["label"], size="sm") for s in MR_SAMPLES] | |
| gr.Markdown("##### Conditioning") | |
| contrast = gr.Dropdown( | |
| choices=CONTRAST_LABELS, | |
| value="T2 prostate", | |
| label="Contrast & anatomy", | |
| info="Drives the modality embedding.", | |
| ) | |
| gr.Markdown("##### Geometry") | |
| dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)") | |
| dim_z = gr.Radio(choices=Z_CHOICES, value=128, label="Z (voxels)") | |
| with gr.Row(equal_height=True): | |
| sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)") | |
| sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)") | |
| sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)") | |
| gr.Markdown("##### Diffusion") | |
| with gr.Row(equal_height=True): | |
| seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"]) | |
| steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps") | |
| cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance") | |
| generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"]) | |
| status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"]) | |
| with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]): | |
| gr.HTML( | |
| '<div class="viewer-strip">' | |
| '<span class="viewer-strip-left">Viewport · Multiplanar</span>' | |
| '<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>' | |
| '</div>' | |
| ) | |
| viewer = gr.HTML(empty_html(), elem_classes=["viewer"]) | |
| download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"]) | |
| # MR has no mask, but keep a legend slot so all workspaces share structure | |
| legend = gr.HTML("", elem_classes=["legend-host"], visible=False) | |
| def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg): | |
| req = GenerationRequest( | |
| model="mr", | |
| output_size=(int(dim_xy), int(dim_xy), int(dim_z)), | |
| spacing=(float(sp_x), float(sp_y), float(sp_z)), | |
| seed=int(seed), | |
| num_steps=int(steps), | |
| cfg_guidance_scale=float(cfg), | |
| modality_class=LABEL_TO_MODALITY.get(contrast, 9), | |
| ) | |
| try: | |
| result = generate_mr(req) | |
| except Exception as e: | |
| return ( | |
| empty_html(f"Generation failed: {e}"), | |
| gr.update(visible=False, value=None), | |
| f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>', | |
| ) | |
| html = render_viewer(volume_path=result.volume_path, colormap="gray") | |
| stat = ( | |
| '<div class="stat-line">' | |
| '<span class="stat-mark"></span>' | |
| '<span class="stat-label">Generated</span>' | |
| f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>' | |
| '</div>' | |
| ) | |
| return html, gr.update(visible=True, value=result.volume_path), stat | |
| decorated = spaces_gpu(_generate) if spaces_gpu else _generate | |
| ( | |
| generate_btn.click( | |
| lambda: gr.update(value="Generating volume…", interactive=False), | |
| outputs=[generate_btn], | |
| ) | |
| .then( | |
| decorated, | |
| inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg], | |
| outputs=[viewer, download, status], | |
| show_progress="full", | |
| ) | |
| .then( | |
| lambda: gr.update(value="Generate volume", interactive=True), | |
| outputs=[generate_btn], | |
| ) | |
| ) | |
| for btn, sample in zip(sample_btns, MR_SAMPLES): | |
| def _apply(s=sample): | |
| return ( | |
| s["modality_label"], | |
| s["xy"], s["z"], | |
| s["spacing"][0], s["spacing"][1], s["spacing"][2], | |
| ) | |
| btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z]) | |
| return group, back_btn | |