"""MR workspace — image-only multi-contrast MR generation.""" from __future__ import annotations from typing import Any import gradio as gr from pipelines import GenerationRequest from pipelines.mr import generate as generate_mr from viewer.niivue_embed import empty_html, render_viewer from .presets import XY_CHOICES, Z_CHOICES, MR_CONTRAST_CHOICES, MR_SAMPLES CONTRAST_LABELS = [c[0] for c in MR_CONTRAST_CHOICES] LABEL_TO_MODALITY = {c[0]: c[1] for c in MR_CONTRAST_CHOICES} def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]: with gr.Group(visible=False, elem_classes=["workspace"]) as group: with gr.Row(elem_classes=["workspace-header"]): back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0) gr.HTML( '
' '' 'NV-Generate' '/' 'MR' '
' ) gr.HTML( """

NV-Generate · MR

Multi-contrast MRI across brain, prostate, breast, and abdominal anatomy. Drive contrast through a modality embedding — T1, T2, FLAIR — at variable resolution and voxel spacing. Fine-tune on your own MRI data to extend to new modalities and regions.

ArchitectureMAISI-v2 · Rectified Flow
ContrastsT1 · T2 · FLAIR
Regionsbrain · prostate · breast · abdomen
Inference30 steps
Max volume512 × 512 × 128 vox
LicenseNVIDIA Non-Commercial
""" ) gr.HTML( '
' 'Non-commercial license. ' 'NV-Generate-MR weights are released under the NVIDIA OneWay Non-Commercial License. ' 'Use is permitted for academic research only.' '
' ) with gr.Row(elem_classes=["workspace-row"]): with gr.Column(scale=4, min_width=320, elem_classes=["controls"]): gr.Markdown("##### Quick presets") with gr.Row(): sample_btns = [gr.Button(s["label"], size="sm") for s in MR_SAMPLES] gr.Markdown("##### Conditioning") contrast = gr.Dropdown( choices=CONTRAST_LABELS, value="T2 prostate", label="Contrast & anatomy", info="Drives the modality embedding.", ) gr.Markdown("##### Geometry") dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)") dim_z = gr.Radio(choices=Z_CHOICES, value=128, label="Z (voxels)") with gr.Row(equal_height=True): sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)") sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)") sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)") gr.Markdown("##### Diffusion") with gr.Row(equal_height=True): seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"]) steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps") cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance") generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"]) status = gr.HTML('
Idle. Configure parameters and click Generate.
', elem_classes=["status"]) with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]): gr.HTML( '
' 'Viewport · Multiplanar' 'Axial · Coronal · Sagittal · 3D' '
' ) viewer = gr.HTML(empty_html(), elem_classes=["viewer"]) download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"]) # MR has no mask, but keep a legend slot so all workspaces share structure legend = gr.HTML("", elem_classes=["legend-host"], visible=False) def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg): req = GenerationRequest( model="mr", output_size=(int(dim_xy), int(dim_xy), int(dim_z)), spacing=(float(sp_x), float(sp_y), float(sp_z)), seed=int(seed), num_steps=int(steps), cfg_guidance_scale=float(cfg), modality_class=LABEL_TO_MODALITY.get(contrast, 9), ) try: result = generate_mr(req) except Exception as e: return ( empty_html(f"Generation failed: {e}"), gr.update(visible=False, value=None), f'
✕ Generation failed ERR{e}
', ) html = render_viewer(volume_path=result.volume_path, colormap="gray") stat = ( '
' '' 'Generated' f'runtime{result.runtime_seconds:.1f}s' f'seed{result.seed}' f'steps{req.num_steps}' f'size{req.output_size[0]}³' '
' ) return html, gr.update(visible=True, value=result.volume_path), stat decorated = spaces_gpu(_generate) if spaces_gpu else _generate ( generate_btn.click( lambda: gr.update(value="Generating volume…", interactive=False), outputs=[generate_btn], ) .then( decorated, inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg], outputs=[viewer, download, status], show_progress="full", ) .then( lambda: gr.update(value="Generate volume", interactive=True), outputs=[generate_btn], ) ) for btn, sample in zip(sample_btns, MR_SAMPLES): def _apply(s=sample): return ( s["modality_label"], s["xy"], s["z"], s["spacing"][0], s["spacing"][1], s["spacing"][2], ) btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z]) return group, back_btn