"""MR-Brain workspace — image-only T1/T2/FLAIR/SWI brain MRI.""" from __future__ import annotations from typing import Any import gradio as gr from pipelines import GenerationRequest from pipelines.mr_brain import generate as generate_mr_brain from viewer.niivue_embed import empty_html, render_viewer from .presets import XY_CHOICES, Z_CHOICES, MR_BRAIN_CONTRASTS, MR_BRAIN_SAMPLES def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]: with gr.Group(visible=False, elem_classes=["workspace"]) as group: with gr.Row(elem_classes=["workspace-header"]): back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0) gr.HTML( '
' '' 'NV-Generate' '/' 'MR Brain' '
' ) gr.HTML( """

NV-Generate · MR Brain

Multi-sequence brain MRI generation across T1, T2, FLAIR, and SWI — in both whole-brain and skull-stripped forms. Trained on the open MR-RATE dataset. Useful for downstream tumor and lesion segmentation studies that need controlled, synthetic data.

ArchitectureMAISI-v2 · Rectified Flow
SequencesT1 · T2 · FLAIR · SWI
Variantswhole-brain · skull-stripped
Trained onMR-RATE dataset (open)
Inference30 steps
Max volume512 × 512 × 256 vox
""" ) with gr.Row(elem_classes=["workspace-row"]): with gr.Column(scale=4, min_width=320, elem_classes=["controls"]): gr.Markdown("##### Quick presets") with gr.Row(): sample_btns = [gr.Button(s["label"], size="sm") for s in MR_BRAIN_SAMPLES] gr.Markdown("##### Contrast") contrast = gr.Radio( choices=MR_BRAIN_CONTRASTS, value="T1", label="Sequence", info="Skull-stripped variants supported.", ) gr.Markdown("##### Geometry") dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)") dim_z = gr.Radio(choices=Z_CHOICES, value=256, label="Z (voxels)") with gr.Row(equal_height=True): sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)") sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)") sp_z = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Z (mm)") gr.Markdown("##### Diffusion") with gr.Row(equal_height=True): seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"]) steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps") cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance") generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"]) status = gr.HTML('
Idle. Configure parameters and click Generate.
', elem_classes=["status"]) with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]): gr.HTML( '
' 'Viewport · Multiplanar' 'Axial · Coronal · Sagittal · 3D' '
' ) viewer = gr.HTML(empty_html(), elem_classes=["viewer"]) download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"]) legend = gr.HTML("", elem_classes=["legend-host"], visible=False) def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg): req = GenerationRequest( model="mr_brain", output_size=(int(dim_xy), int(dim_xy), int(dim_z)), spacing=(float(sp_x), float(sp_y), float(sp_z)), seed=int(seed), num_steps=int(steps), cfg_guidance_scale=float(cfg), contrast=contrast, ) try: result = generate_mr_brain(req) except Exception as e: return ( empty_html(f"Generation failed: {e}"), gr.update(visible=False, value=None), f'
✕ Generation failed ERR{e}
', ) html = render_viewer(volume_path=result.volume_path, colormap="gray") stat = ( '
' '' 'Generated' f'runtime{result.runtime_seconds:.1f}s' f'seed{result.seed}' f'steps{req.num_steps}' f'size{req.output_size[0]}³' '
' ) return html, gr.update(visible=True, value=result.volume_path), stat decorated = spaces_gpu(_generate) if spaces_gpu else _generate ( generate_btn.click( lambda: gr.update(value="Generating volume…", interactive=False), outputs=[generate_btn], ) .then( decorated, inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg], outputs=[viewer, download, status], show_progress="full", ) .then( lambda: gr.update(value="Generate volume", interactive=True), outputs=[generate_btn], ) ) for btn, sample in zip(sample_btns, MR_BRAIN_SAMPLES): def _apply(s=sample): return ( s["contrast"], s["xy"], s["z"], s["spacing"][0], s["spacing"][1], s["spacing"][2], ) btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z]) return group, back_btn