nv-generate / ui /workspace_mr_brain.py
zephyrie's picture
Move Quick Presets to the top of the controls sidebar
7fb7cc3
"""MR-Brain workspace — image-only T1/T2/FLAIR/SWI brain MRI."""
from __future__ import annotations
from typing import Any
import gradio as gr
from pipelines import GenerationRequest
from pipelines.mr_brain import generate as generate_mr_brain
from viewer.niivue_embed import empty_html, render_viewer
from .presets import XY_CHOICES, Z_CHOICES, MR_BRAIN_CONTRASTS, MR_BRAIN_SAMPLES
def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]:
with gr.Group(visible=False, elem_classes=["workspace"]) as group:
with gr.Row(elem_classes=["workspace-header"]):
back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0)
gr.HTML(
'<div class="workspace-title">'
'<span class="ws-dot" style="background:var(--mrb);color:var(--mrb)"></span>'
'<span class="ws-crumb">NV-Generate</span>'
'<span class="ws-crumb-sep">/</span>'
'<span class="ws-active">MR Brain</span>'
'</div>'
)
gr.HTML(
"""
<div class="ws-intro ws-intro-mrb">
<div class="ws-intro-left">
<h2 class="ws-intro-title">NV-Generate · MR Brain</h2>
<p class="ws-intro-desc">
Multi-sequence brain MRI generation across T1, T2, FLAIR, and SWI —
in both whole-brain and skull-stripped forms. Trained on the open
MR-RATE dataset. Useful for downstream tumor and lesion segmentation
studies that need controlled, synthetic data.
</p>
</div>
<div class="ws-intro-facts">
<div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div>
<div class="ws-fact"><span class="ws-fact-k">Sequences</span><span class="ws-fact-v">T1 · T2 · FLAIR · SWI</span></div>
<div class="ws-fact"><span class="ws-fact-k">Variants</span><span class="ws-fact-v">whole-brain · skull-stripped</span></div>
<div class="ws-fact"><span class="ws-fact-k">Trained on</span><span class="ws-fact-v">MR-RATE dataset (open)</span></div>
<div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 steps</span></div>
<div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 256 vox</span></div>
</div>
</div>
"""
)
with gr.Row(elem_classes=["workspace-row"]):
with gr.Column(scale=4, min_width=320, elem_classes=["controls"]):
gr.Markdown("##### Quick presets")
with gr.Row():
sample_btns = [gr.Button(s["label"], size="sm") for s in MR_BRAIN_SAMPLES]
gr.Markdown("##### Contrast")
contrast = gr.Radio(
choices=MR_BRAIN_CONTRASTS,
value="T1",
label="Sequence",
info="Skull-stripped variants supported.",
)
gr.Markdown("##### Geometry")
dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)")
dim_z = gr.Radio(choices=Z_CHOICES, value=256, label="Z (voxels)")
with gr.Row(equal_height=True):
sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)")
sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)")
sp_z = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Z (mm)")
gr.Markdown("##### Diffusion")
with gr.Row(equal_height=True):
seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"])
steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps")
cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance")
generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"])
status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"])
with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]):
gr.HTML(
'<div class="viewer-strip">'
'<span class="viewer-strip-left">Viewport · Multiplanar</span>'
'<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>'
'</div>'
)
viewer = gr.HTML(empty_html(), elem_classes=["viewer"])
download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"])
legend = gr.HTML("", elem_classes=["legend-host"], visible=False)
def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg):
req = GenerationRequest(
model="mr_brain",
output_size=(int(dim_xy), int(dim_xy), int(dim_z)),
spacing=(float(sp_x), float(sp_y), float(sp_z)),
seed=int(seed),
num_steps=int(steps),
cfg_guidance_scale=float(cfg),
contrast=contrast,
)
try:
result = generate_mr_brain(req)
except Exception as e:
return (
empty_html(f"Generation failed: {e}"),
gr.update(visible=False, value=None),
f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>',
)
html = render_viewer(volume_path=result.volume_path, colormap="gray")
stat = (
'<div class="stat-line">'
'<span class="stat-mark"></span>'
'<span class="stat-label">Generated</span>'
f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>'
f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>'
f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>'
f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>'
'</div>'
)
return html, gr.update(visible=True, value=result.volume_path), stat
decorated = spaces_gpu(_generate) if spaces_gpu else _generate
(
generate_btn.click(
lambda: gr.update(value="Generating volume…", interactive=False),
outputs=[generate_btn],
)
.then(
decorated,
inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg],
outputs=[viewer, download, status],
show_progress="full",
)
.then(
lambda: gr.update(value="Generate volume", interactive=True),
outputs=[generate_btn],
)
)
for btn, sample in zip(sample_btns, MR_BRAIN_SAMPLES):
def _apply(s=sample):
return (
s["contrast"],
s["xy"], s["z"],
s["spacing"][0], s["spacing"][1], s["spacing"][2],
)
btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z])
return group, back_btn