File size: 8,554 Bytes
ab1db83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e66cee
ab1db83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fb7cc3
 
 
 
ab1db83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3188a6e
ab1db83
407e4a9
ab1db83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""MR workspace — image-only multi-contrast MR generation."""
from __future__ import annotations

from typing import Any

import gradio as gr

from pipelines import GenerationRequest
from pipelines.mr import generate as generate_mr
from viewer.niivue_embed import empty_html, render_viewer
from .presets import XY_CHOICES, Z_CHOICES, MR_CONTRAST_CHOICES, MR_SAMPLES


CONTRAST_LABELS = [c[0] for c in MR_CONTRAST_CHOICES]
LABEL_TO_MODALITY = {c[0]: c[1] for c in MR_CONTRAST_CHOICES}


def build(spaces_gpu: Any) -> tuple[gr.Group, gr.Button]:
    with gr.Group(visible=False, elem_classes=["workspace"]) as group:
        with gr.Row(elem_classes=["workspace-header"]):
            back_btn = gr.Button("← Back", elem_classes=["back-btn"], scale=0)
            gr.HTML(
                '<div class="workspace-title">'
                '<span class="ws-dot" style="background:var(--mr);color:var(--mr)"></span>'
                '<span class="ws-crumb">NV-Generate</span>'
                '<span class="ws-crumb-sep">/</span>'
                '<span class="ws-active">MR</span>'
                '</div>'
            )
        gr.HTML(
            """
            <div class="ws-intro ws-intro-mr">
              <div class="ws-intro-left">
                <h2 class="ws-intro-title">NV-Generate · MR</h2>
                <p class="ws-intro-desc">
                  Multi-contrast MRI across brain, prostate, breast, and abdominal anatomy.
                  Drive contrast through a modality embedding — T1, T2, FLAIR — at
                  variable resolution and voxel spacing. Fine-tune on your own MRI data
                  to extend to new modalities and regions.
                </p>
              </div>
              <div class="ws-intro-facts">
                <div class="ws-fact"><span class="ws-fact-k">Architecture</span><span class="ws-fact-v">MAISI-v2 · Rectified Flow</span></div>
                <div class="ws-fact"><span class="ws-fact-k">Contrasts</span><span class="ws-fact-v">T1 · T2 · FLAIR</span></div>
                <div class="ws-fact"><span class="ws-fact-k">Regions</span><span class="ws-fact-v">brain · prostate · breast · abdomen</span></div>
                <div class="ws-fact"><span class="ws-fact-k">Inference</span><span class="ws-fact-v">30 steps</span></div>
                <div class="ws-fact"><span class="ws-fact-k">Max volume</span><span class="ws-fact-v">512 × 512 × 128 vox</span></div>
                <div class="ws-fact"><span class="ws-fact-k">License</span><span class="ws-fact-v ws-fact-warn">NVIDIA Non-Commercial</span></div>
              </div>
            </div>
            """
        )

        gr.HTML(
            '<div class="license-banner">'
            '<strong>Non-commercial license.</strong> '
            'NV-Generate-MR weights are released under the <a href="https://developer.download.nvidia.com/licenses/NVIDIA-OneWay-Noncommercial-License-22Mar2022.pdf" target="_blank">NVIDIA OneWay Non-Commercial License</a>. '
            'Use is permitted for academic research only.'
            '</div>'
        )

        with gr.Row(elem_classes=["workspace-row"]):
            with gr.Column(scale=4, min_width=320, elem_classes=["controls"]):
                gr.Markdown("##### Quick presets")
                with gr.Row():
                    sample_btns = [gr.Button(s["label"], size="sm") for s in MR_SAMPLES]

                gr.Markdown("##### Conditioning")
                contrast = gr.Dropdown(
                    choices=CONTRAST_LABELS,
                    value="T2 prostate",
                    label="Contrast & anatomy",
                    info="Drives the modality embedding.",
                )

                gr.Markdown("##### Geometry")
                dim_xy = gr.Radio(choices=XY_CHOICES, value=256, label="X / Y (voxels)")
                dim_z = gr.Radio(choices=Z_CHOICES, value=128, label="Z (voxels)")
                with gr.Row(equal_height=True):
                    sp_x = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing X (mm)")
                    sp_y = gr.Slider(0.5, 5.0, value=1.0, step=0.05, label="Spacing Y (mm)")
                    sp_z = gr.Slider(0.5, 5.0, value=1.5, step=0.05, label="Spacing Z (mm)")

                gr.Markdown("##### Diffusion")
                with gr.Row(equal_height=True):
                    seed = gr.Number(value=0, label="Seed", precision=0, elem_classes=["seed-field"])
                    steps = gr.Slider(10, 60, value=30, step=1, label="Inference steps")
                    cfg = gr.Slider(0.0, 20.0, value=10.0, step=0.5, label="CFG guidance")

                generate_btn = gr.Button("Generate volume", variant="primary", elem_classes=["primary-cta"])
                status = gr.HTML('<div class="stat-line"><span class="stat-label" style="color:var(--muted)">Idle. Configure parameters and click Generate.</span></div>', elem_classes=["status"])

            with gr.Column(scale=8, min_width=520, elem_classes=["viewer-col"]):
                gr.HTML(
                    '<div class="viewer-strip">'
                    '<span class="viewer-strip-left">Viewport · Multiplanar</span>'
                    '<span class="viewer-strip-right">Axial · Coronal · Sagittal · 3D</span>'
                    '</div>'
                )
                viewer = gr.HTML(empty_html(), elem_classes=["viewer"])
                download = gr.File(label="Download generated NIfTI", visible=False, elem_classes=["nv-download"])
                # MR has no mask, but keep a legend slot so all workspaces share structure
                legend = gr.HTML("", elem_classes=["legend-host"], visible=False)

        def _generate(contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg):
            req = GenerationRequest(
                model="mr",
                output_size=(int(dim_xy), int(dim_xy), int(dim_z)),
                spacing=(float(sp_x), float(sp_y), float(sp_z)),
                seed=int(seed),
                num_steps=int(steps),
                cfg_guidance_scale=float(cfg),
                modality_class=LABEL_TO_MODALITY.get(contrast, 9),
            )
            try:
                result = generate_mr(req)
            except Exception as e:
                return (
                    empty_html(f"Generation failed: {e}"),
                    gr.update(visible=False, value=None),
                    f'<div class="stat-line"><span class="stat-err">✕ Generation failed</span> <span class="stat-chip"><span class="stat-k">ERR</span><span class="stat-v">{e}</span></span></div>',
                )
            html = render_viewer(volume_path=result.volume_path, colormap="gray")
            stat = (
                '<div class="stat-line">'
                '<span class="stat-mark"></span>'
                '<span class="stat-label">Generated</span>'
                f'<span class="stat-chip"><span class="stat-k">runtime</span><span class="stat-v">{result.runtime_seconds:.1f}s</span></span>'
                f'<span class="stat-chip"><span class="stat-k">seed</span><span class="stat-v">{result.seed}</span></span>'
                f'<span class="stat-chip"><span class="stat-k">steps</span><span class="stat-v">{req.num_steps}</span></span>'
                f'<span class="stat-chip"><span class="stat-k">size</span><span class="stat-v">{req.output_size[0]}³</span></span>'
                '</div>'
            )
            return html, gr.update(visible=True, value=result.volume_path), stat

        decorated = spaces_gpu(_generate) if spaces_gpu else _generate
        (
            generate_btn.click(
                lambda: gr.update(value="Generating volume…", interactive=False),
                outputs=[generate_btn],
            )
            .then(
                decorated,
                inputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z, seed, steps, cfg],
                outputs=[viewer, download, status],
                show_progress="full",
            )
            .then(
                lambda: gr.update(value="Generate volume", interactive=True),
                outputs=[generate_btn],
            )
        )

        for btn, sample in zip(sample_btns, MR_SAMPLES):
            def _apply(s=sample):
                return (
                    s["modality_label"],
                    s["xy"], s["z"],
                    s["spacing"][0], s["spacing"][1], s["spacing"][2],
                )
            btn.click(_apply, outputs=[contrast, dim_xy, dim_z, sp_x, sp_y, sp_z])

    return group, back_btn