nv-generate / ui /workspace_mr.py
zephyrie's picture
Move Quick Presets to the top of the controls sidebar
7fb7cc3
"""MR workspace — image-only multi-contrast MR generation."""
from __future__ import annotations
from typing import Any
import gradio as gr
from pipelines import GenerationRequest
from pipelines.mr import generate as generate_mr
from viewer.niivue_embed import empty_html, render_viewer
from .presets import XY_CHOICES, Z_CHOICES, MR_CONTRAST_CHOICES, MR_SAMPLES
CONTRAST_LABELS = [c[0] for c in MR_CONTRAST_CHOICES]
LABEL_TO_MODALITY = {c[0]: c[1] for c in MR_CONTRAST_CHOICES}
def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]:
with gr.Group(visible=False, elem_classes=["workspace"]) as group:
with gr.Row(elem_classes=["workspace-header"]):
back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0)
gr.HTML(
'<div class="workspace-title">'
'<span class="ws-dot" style="background:var(--mr);color:var(--mr)"></span>'
'<span class="ws-crumb">NV-Generate</span>'
'<span class="ws-crumb-sep">/</span>'
'<span class="ws-active">MR</span>'
'</div>'
)
gr.HTML(
"""
<div class="ws-intro ws-intro-mr">
<div class="ws-intro-left">
<h2 class="ws-intro-title">NV-Generate · MR</h2>
<p class="ws-intro-desc">
Multi-contrast MRI across brain, prostate, breast, and abdominal anatomy.
Drive contrast through a modality embedding — T1, T2, FLAIR — at
variable resolution and voxel spacing. Fine-tune on your own MRI data
to extend to new modalities and regions.
</p>
</div>
<div class="ws-intro-facts">
<div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div>
<div class="ws-fact"><span class="ws-fact-k">Contrasts</span><span class="ws-fact-v">T1 · T2 · FLAIR</span></div>
<div class="ws-fact"><span class="ws-fact-k">Regions</span><span class="ws-fact-v">brain · prostate · breast · abdomen</span></div>
<div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 steps</span></div>
<div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 128 vox</span></div>
<div class="ws-fact"><span class="ws-fact-k">License</span><span class="ws-fact-v ws-fact-warn">NVIDIA Non-Commercial</span></div>
</div>
</div>
"""
)
gr.HTML(
'<div class="license-banner">'
'<strong>Non-commercial license.</strong> '
'NV-Generate-MR weights are released under the <a href="https://developer.download.nvidia.com/licenses/NVIDIA-OneWay-Noncommercial-License-22Mar2022.pdf" target="_blank">NVIDIA OneWay Non-Commercial License</a>. '
'Use is permitted for academic research only.'
'</div>'
)
with gr.Row(elem_classes=["workspace-row"]):
with gr.Column(scale=4, min_width=320, elem_classes=["controls"]):
gr.Markdown("##### Quick presets")
with gr.Row():
sample_btns = [gr.Button(s["label"], size="sm") for s in MR_SAMPLES]
gr.Markdown("##### Conditioning")
contrast = gr.Dropdown(
choices=CONTRAST_LABELS,
value="T2 prostate",
label="Contrast & anatomy",
info="Drives the modality embedding.",
)
gr.Markdown("##### Geometry")
dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)")
dim_z = gr.Radio(choices=Z_CHOICES, value=128, label="Z (voxels)")
with gr.Row(equal_height=True):
sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)")
sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)")
sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)")
gr.Markdown("##### Diffusion")
with gr.Row(equal_height=True):
seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"])
steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps")
cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance")
generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"])
status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"])
with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]):
gr.HTML(
'<div class="viewer-strip">'
'<span class="viewer-strip-left">Viewport · Multiplanar</span>'
'<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>'
'</div>'
)
viewer = gr.HTML(empty_html(), elem_classes=["viewer"])
download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"])
# MR has no mask, but keep a legend slot so all workspaces share structure
legend = gr.HTML("", elem_classes=["legend-host"], visible=False)
def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg):
req = GenerationRequest(
model="mr",
output_size=(int(dim_xy), int(dim_xy), int(dim_z)),
spacing=(float(sp_x), float(sp_y), float(sp_z)),
seed=int(seed),
num_steps=int(steps),
cfg_guidance_scale=float(cfg),
modality_class=LABEL_TO_MODALITY.get(contrast, 9),
)
try:
result = generate_mr(req)
except Exception as e:
return (
empty_html(f"Generation failed: {e}"),
gr.update(visible=False, value=None),
f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>',
)
html = render_viewer(volume_path=result.volume_path, colormap="gray")
stat = (
'<div class="stat-line">'
'<span class="stat-mark"></span>'
'<span class="stat-label">Generated</span>'
f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>'
f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>'
f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>'
f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>'
'</div>'
)
return html, gr.update(visible=True, value=result.volume_path), stat
decorated = spaces_gpu(_generate) if spaces_gpu else _generate
(
generate_btn.click(
lambda: gr.update(value="Generating volume…", interactive=False),
outputs=[generate_btn],
)
.then(
decorated,
inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg],
outputs=[viewer, download, status],
show_progress="full",
)
.then(
lambda: gr.update(value="Generate volume", interactive=True),
outputs=[generate_btn],
)
)
for btn, sample in zip(sample_btns, MR_SAMPLES):
def _apply(s=sample):
return (
s["modality_label"],
s["xy"], s["z"],
s["spacing"][0], s["spacing"][1], s["spacing"][2],
)
btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z])
return group, back_btn