Spaces:
Running on Zero
Running on Zero
File size: 8,554 Bytes
ab1db83 2e66cee ab1db83 7fb7cc3 ab1db83 3188a6e ab1db83 407e4a9 ab1db83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """MR workspace — image-only multi-contrast MR generation."""
from __future__ import annotations
from typing import Any
import gradio as gr
from pipelines import GenerationRequest
from pipelines.mr import generate as generate_mr
from viewer.niivue_embed import empty_html, render_viewer
from .presets import XY_CHOICES, Z_CHOICES, MR_CONTRAST_CHOICES, MR_SAMPLES
CONTRAST_LABELS = [c[0] for c in MR_CONTRAST_CHOICES]
LABEL_TO_MODALITY = {c[0]: c[1] for c in MR_CONTRAST_CHOICES}
def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]:
with gr.Group(visible=False, elem_classes=["workspace"]) as group:
with gr.Row(elem_classes=["workspace-header"]):
back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0)
gr.HTML(
'<div class="workspace-title">'
'<span class="ws-dot" style="background:var(--mr);color:var(--mr)"></span>'
'<span class="ws-crumb">NV-Generate</span>'
'<span class="ws-crumb-sep">/</span>'
'<span class="ws-active">MR</span>'
'</div>'
)
gr.HTML(
"""
<div class="ws-intro ws-intro-mr">
<div class="ws-intro-left">
<h2 class="ws-intro-title">NV-Generate · MR</h2>
<p class="ws-intro-desc">
Multi-contrast MRI across brain, prostate, breast, and abdominal anatomy.
Drive contrast through a modality embedding — T1, T2, FLAIR — at
variable resolution and voxel spacing. Fine-tune on your own MRI data
to extend to new modalities and regions.
</p>
</div>
<div class="ws-intro-facts">
<div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div>
<div class="ws-fact"><span class="ws-fact-k">Contrasts</span><span class="ws-fact-v">T1 · T2 · FLAIR</span></div>
<div class="ws-fact"><span class="ws-fact-k">Regions</span><span class="ws-fact-v">brain · prostate · breast · abdomen</span></div>
<div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 steps</span></div>
<div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 128 vox</span></div>
<div class="ws-fact"><span class="ws-fact-k">License</span><span class="ws-fact-v ws-fact-warn">NVIDIA Non-Commercial</span></div>
</div>
</div>
"""
)
gr.HTML(
'<div class="license-banner">'
'<strong>Non-commercial license.</strong> '
'NV-Generate-MR weights are released under the <a href="https://developer.download.nvidia.com/licenses/NVIDIA-OneWay-Noncommercial-License-22Mar2022.pdf" target="_blank">NVIDIA OneWay Non-Commercial License</a>. '
'Use is permitted for academic research only.'
'</div>'
)
with gr.Row(elem_classes=["workspace-row"]):
with gr.Column(scale=4, min_width=320, elem_classes=["controls"]):
gr.Markdown("##### Quick presets")
with gr.Row():
sample_btns = [gr.Button(s["label"], size="sm") for s in MR_SAMPLES]
gr.Markdown("##### Conditioning")
contrast = gr.Dropdown(
choices=CONTRAST_LABELS,
value="T2 prostate",
label="Contrast & anatomy",
info="Drives the modality embedding.",
)
gr.Markdown("##### Geometry")
dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)")
dim_z = gr.Radio(choices=Z_CHOICES, value=128, label="Z (voxels)")
with gr.Row(equal_height=True):
sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)")
sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)")
sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)")
gr.Markdown("##### Diffusion")
with gr.Row(equal_height=True):
seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"])
steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps")
cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance")
generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"])
status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"])
with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]):
gr.HTML(
'<div class="viewer-strip">'
'<span class="viewer-strip-left">Viewport · Multiplanar</span>'
'<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>'
'</div>'
)
viewer = gr.HTML(empty_html(), elem_classes=["viewer"])
download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"])
# MR has no mask, but keep a legend slot so all workspaces share structure
legend = gr.HTML("", elem_classes=["legend-host"], visible=False)
def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg):
req = GenerationRequest(
model="mr",
output_size=(int(dim_xy), int(dim_xy), int(dim_z)),
spacing=(float(sp_x), float(sp_y), float(sp_z)),
seed=int(seed),
num_steps=int(steps),
cfg_guidance_scale=float(cfg),
modality_class=LABEL_TO_MODALITY.get(contrast, 9),
)
try:
result = generate_mr(req)
except Exception as e:
return (
empty_html(f"Generation failed: {e}"),
gr.update(visible=False, value=None),
f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>',
)
html = render_viewer(volume_path=result.volume_path, colormap="gray")
stat = (
'<div class="stat-line">'
'<span class="stat-mark"></span>'
'<span class="stat-label">Generated</span>'
f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>'
f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>'
f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>'
f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>'
'</div>'
)
return html, gr.update(visible=True, value=result.volume_path), stat
decorated = spaces_gpu(_generate) if spaces_gpu else _generate
(
generate_btn.click(
lambda: gr.update(value="Generating volume…", interactive=False),
outputs=[generate_btn],
)
.then(
decorated,
inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg],
outputs=[viewer, download, status],
show_progress="full",
)
.then(
lambda: gr.update(value="Generate volume", interactive=True),
outputs=[generate_btn],
)
)
for btn, sample in zip(sample_btns, MR_SAMPLES):
def _apply(s=sample):
return (
s["modality_label"],
s["xy"], s["z"],
s["spacing"][0], s["spacing"][1], s["spacing"][2],
)
btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z])
return group, back_btn
|