"""MR workspace — image-only multi-contrast MR generation."""
from __future__ import annotations
from typing import Any
import gradio as gr
from pipelines import GenerationRequest
from pipelines.mr import generate as generate_mr
from viewer.niivue_embed import empty_html, render_viewer
from .presets import XY_CHOICES, Z_CHOICES, MR_CONTRAST_CHOICES, MR_SAMPLES
CONTRAST_LABELS = [c[0] for c in MR_CONTRAST_CHOICES]
LABEL_TO_MODALITY = {c[0]: c[1] for c in MR_CONTRAST_CHOICES}
def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]:
with gr.Group(visible=False, elem_classes=["workspace"]) as group:
with gr.Row(elem_classes=["workspace-header"]):
back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0)
gr.HTML(
'
'
''
'NV-Generate'
'/'
'MR'
'
'
)
gr.HTML(
"""
NV-Generate · MR
Multi-contrast MRI across brain, prostate, breast, and abdominal anatomy.
Drive contrast through a modality embedding — T1, T2, FLAIR — at
variable resolution and voxel spacing. Fine-tune on your own MRI data
to extend to new modalities and regions.
ArchitectureMAISI-v2 · Rectified Flow
ContrastsT1 · T2 · FLAIR
Regionsbrain · prostate · breast · abdomen
Inference30 steps
Max volume512 × 512 × 128 vox
LicenseNVIDIA Non-Commercial
"""
)
gr.HTML(
''
)
with gr.Row(elem_classes=["workspace-row"]):
with gr.Column(scale=4, min_width=320, elem_classes=["controls"]):
gr.Markdown("##### Quick presets")
with gr.Row():
sample_btns = [gr.Button(s["label"], size="sm") for s in MR_SAMPLES]
gr.Markdown("##### Conditioning")
contrast = gr.Dropdown(
choices=CONTRAST_LABELS,
value="T2 prostate",
label="Contrast & anatomy",
info="Drives the modality embedding.",
)
gr.Markdown("##### Geometry")
dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)")
dim_z = gr.Radio(choices=Z_CHOICES, value=128, label="Z (voxels)")
with gr.Row(equal_height=True):
sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)")
sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)")
sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)")
gr.Markdown("##### Diffusion")
with gr.Row(equal_height=True):
seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"])
steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps")
cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance")
generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"])
status = gr.HTML('Idle. Configure parameters and click Generate.
', elem_classes=["status"])
with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]):
gr.HTML(
''
'Viewport · Multiplanar'
'Axial · Coronal · Sagittal · 3D'
'
'
)
viewer = gr.HTML(empty_html(), elem_classes=["viewer"])
download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"])
# MR has no mask, but keep a legend slot so all workspaces share structure
legend = gr.HTML("", elem_classes=["legend-host"], visible=False)
def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg):
req = GenerationRequest(
model="mr",
output_size=(int(dim_xy), int(dim_xy), int(dim_z)),
spacing=(float(sp_x), float(sp_y), float(sp_z)),
seed=int(seed),
num_steps=int(steps),
cfg_guidance_scale=float(cfg),
modality_class=LABEL_TO_MODALITY.get(contrast, 9),
)
try:
result = generate_mr(req)
except Exception as e:
return (
empty_html(f"Generation failed: {e}"),
gr.update(visible=False, value=None),
f'✕ Generation failed ERR{e}
',
)
html = render_viewer(volume_path=result.volume_path, colormap="gray")
stat = (
''
''
'Generated'
f'runtime{result.runtime_seconds:.1f}s'
f'seed{result.seed}'
f'steps{req.num_steps}'
f'size{req.output_size[0]}³'
'
'
)
return html, gr.update(visible=True, value=result.volume_path), stat
decorated = spaces_gpu(_generate) if spaces_gpu else _generate
(
generate_btn.click(
lambda: gr.update(value="Generating volume…", interactive=False),
outputs=[generate_btn],
)
.then(
decorated,
inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg],
outputs=[viewer, download, status],
show_progress="full",
)
.then(
lambda: gr.update(value="Generate volume", interactive=True),
outputs=[generate_btn],
)
)
for btn, sample in zip(sample_btns, MR_SAMPLES):
def _apply(s=sample):
return (
s["modality_label"],
s["xy"], s["z"],
s["spacing"][0], s["spacing"][1], s["spacing"][2],
)
btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z])
return group, back_btn