Spaces:
Running on Zero
Running on Zero
| """MR-Brain workspace — image-only T1/T2/FLAIR/SWI brain MRI.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| import gradio as gr | |
| from pipelines import GenerationRequest | |
| from pipelines.mr_brain import generate as generate_mr_brain | |
| from viewer.niivue_embed import empty_html, render_viewer | |
| from .presets import XY_CHOICES, Z_CHOICES, MR_BRAIN_CONTRASTS, MR_BRAIN_SAMPLES | |
| def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]: | |
| with gr.Group(visible=False, elem_classes=["workspace"]) as group: | |
| with gr.Row(elem_classes=["workspace-header"]): | |
| back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0) | |
| gr.HTML( | |
| '<div class="workspace-title">' | |
| '<span class="ws-dot" style="background:var(--mrb);color:var(--mrb)"></span>' | |
| '<span class="ws-crumb">NV-Generate</span>' | |
| '<span class="ws-crumb-sep">/</span>' | |
| '<span class="ws-active">MR Brain</span>' | |
| '</div>' | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="ws-intro ws-intro-mrb"> | |
| <div class="ws-intro-left"> | |
| <h2 class="ws-intro-title">NV-Generate · MR Brain</h2> | |
| <p class="ws-intro-desc"> | |
| Multi-sequence brain MRI generation across T1, T2, FLAIR, and SWI — | |
| in both whole-brain and skull-stripped forms. Trained on the open | |
| MR-RATE dataset. Useful for downstream tumor and lesion segmentation | |
| studies that need controlled, synthetic data. | |
| </p> | |
| </div> | |
| <div class="ws-intro-facts"> | |
| <div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Sequences</span><span class="ws-fact-v">T1 · T2 · FLAIR · SWI</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Variants</span><span class="ws-fact-v">whole-brain · skull-stripped</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Trained on</span><span class="ws-fact-v">MR-RATE dataset (open)</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 steps</span></div> | |
| <div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 256 vox</span></div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(elem_classes=["workspace-row"]): | |
| with gr.Column(scale=4, min_width=320, elem_classes=["controls"]): | |
| gr.Markdown("##### Quick presets") | |
| with gr.Row(): | |
| sample_btns = [gr.Button(s["label"], size="sm") for s in MR_BRAIN_SAMPLES] | |
| gr.Markdown("##### Contrast") | |
| contrast = gr.Radio( | |
| choices=MR_BRAIN_CONTRASTS, | |
| value="T1", | |
| label="Sequence", | |
| info="Skull-stripped variants supported.", | |
| ) | |
| gr.Markdown("##### Geometry") | |
| dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)") | |
| dim_z = gr.Radio(choices=Z_CHOICES, value=256, label="Z (voxels)") | |
| with gr.Row(equal_height=True): | |
| sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)") | |
| sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)") | |
| sp_z = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Z (mm)") | |
| gr.Markdown("##### Diffusion") | |
| with gr.Row(equal_height=True): | |
| seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"]) | |
| steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps") | |
| cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance") | |
| generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"]) | |
| status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"]) | |
| with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]): | |
| gr.HTML( | |
| '<div class="viewer-strip">' | |
| '<span class="viewer-strip-left">Viewport · Multiplanar</span>' | |
| '<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>' | |
| '</div>' | |
| ) | |
| viewer = gr.HTML(empty_html(), elem_classes=["viewer"]) | |
| download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"]) | |
| legend = gr.HTML("", elem_classes=["legend-host"], visible=False) | |
| def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg): | |
| req = GenerationRequest( | |
| model="mr_brain", | |
| output_size=(int(dim_xy), int(dim_xy), int(dim_z)), | |
| spacing=(float(sp_x), float(sp_y), float(sp_z)), | |
| seed=int(seed), | |
| num_steps=int(steps), | |
| cfg_guidance_scale=float(cfg), | |
| contrast=contrast, | |
| ) | |
| try: | |
| result = generate_mr_brain(req) | |
| except Exception as e: | |
| return ( | |
| empty_html(f"Generation failed: {e}"), | |
| gr.update(visible=False, value=None), | |
| f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>', | |
| ) | |
| html = render_viewer(volume_path=result.volume_path, colormap="gray") | |
| stat = ( | |
| '<div class="stat-line">' | |
| '<span class="stat-mark"></span>' | |
| '<span class="stat-label">Generated</span>' | |
| f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>' | |
| f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>' | |
| '</div>' | |
| ) | |
| return html, gr.update(visible=True, value=result.volume_path), stat | |
| decorated = spaces_gpu(_generate) if spaces_gpu else _generate | |
| ( | |
| generate_btn.click( | |
| lambda: gr.update(value="Generating volume…", interactive=False), | |
| outputs=[generate_btn], | |
| ) | |
| .then( | |
| decorated, | |
| inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg], | |
| outputs=[viewer, download, status], | |
| show_progress="full", | |
| ) | |
| .then( | |
| lambda: gr.update(value="Generate volume", interactive=True), | |
| outputs=[generate_btn], | |
| ) | |
| ) | |
| for btn, sample in zip(sample_btns, MR_BRAIN_SAMPLES): | |
| def _apply(s=sample): | |
| return ( | |
| s["contrast"], | |
| s["xy"], s["z"], | |
| s["spacing"][0], s["spacing"][1], s["spacing"][2], | |
| ) | |
| btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z]) | |
| return group, back_btn | |