File size: 15,125 Bytes
01f8311
 
 
 
4991517
05f9f55
640a1bf
cbf192b
 
 
f580729
 
 
3e0672b
f580729
 
3e0672b
f580729
2fe9c08
cbf192b
ff1e7f5
 
4991517
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
 
e335e34
9c41927
 
cbf192b
 
f580729
cbf192b
 
 
 
 
 
 
 
 
 
 
2fe9c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
f580729
 
3e0672b
f580729
 
 
 
 
 
 
4991517
f580729
 
 
 
 
 
 
 
cbf192b
f580729
 
 
204cd3a
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
f580729
cbf192b
204cd3a
cbf192b
204cd3a
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
4991517
f580729
 
 
cbf192b
05f9f55
f580729
 
3e0672b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
 
f580729
 
cbf192b
f580729
 
 
 
 
 
3e0672b
cbf192b
 
 
e335e34
cbf192b
 
 
 
 
2fe9c08
cbf192b
 
 
 
3e0672b
f580729
 
 
204cd3a
 
cbf192b
 
f580729
ff1e7f5
e335e34
f580729
ff1e7f5
cbf192b
 
 
 
 
f580729
 
 
204cd3a
 
f580729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204cd3a
 
cbf192b
 
f580729
ff1e7f5
e335e34
f580729
ff1e7f5
3e0672b
 
 
 
 
 
 
 
 
 
 
 
 
ff1e7f5
3e0672b
 
ff1e7f5
cbf192b
f580729
 
cbf192b
f580729
 
3e0672b
f580729
cbf192b
f580729
 
 
 
 
 
 
204cd3a
f580729
 
 
 
204cd3a
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
f580729
cbf192b
3e0672b
 
 
 
 
 
cbf192b
f580729
 
 
 
3e0672b
f580729
cbf192b
f580729
 
 
 
 
 
cbf192b
 
204cd3a
 
 
 
 
 
 
 
 
 
 
 
 
cbf192b
f580729
cbf192b
3e0672b
 
 
 
 
2fe9c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed17232
f580729
 
 
 
 
642e49d
ed17232
cbf192b
204cd3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

import torch
import traceback
import gradio as gr

# Import your inference functions and dataclasses
# Adjust the import path if your file is located elsewhere
from src.smc.inference import (
    infer_pretrained,
    infer_smc_grad,
    infer_ft,
    PretrainedInferenceConfig,
    SMCGradInferenceConfig,
    FTInferenceConfig,
)
from run_examples import get_out_if_exists

GALLERY_HEIGHT = "224px"

def get_device():
    if not hasattr(get_device, "last_allocated"):
        get_device.last_allocated = -1 # type: ignore

    if not torch.cuda.is_available():
        return "cuda"  # GPU will be dynamically allocated later using spaces ZeroGPU

    # Round-robin allocation
    d = torch.cuda.device_count()
    i = (get_device.last_allocated + 1) % d # type: ignore
    get_device.last_allocated = i # type: ignore
    return f"cuda:{i}"


examples = [
    "A photo of a yellow bird and a black motorcycle",
    "A stylish dog wearing sunglasses",
    "A cat in the style of Van Gogh’s Starry Night",
]


def _format_inference_output(out) -> str:
    """Return a short summary string for the UI"""
    if out is None:
        return "No output"
    try:
        rewards = out.image_rewards
        mem = out.gpu_mem_used
        return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}"
    except Exception:
        return "Could not parse inference output"

def try_load_saved_outputs(prompt):
    """
    Check for saved outputs for the given prompt for each method and return
    (pretrained_gallery, pretrained_info, smc_gallery, smc_info, ft_gallery, ft_info).

    If no saved output exists for a method, returns an empty gallery and
    \"No saved output\" for info for that method.
    """
    try:
        # Pretrained
        pre_cfg = PretrainedInferenceConfig(prompt=prompt)
        pre_out = get_out_if_exists("pretrained", pre_cfg)
        if pre_out is not None:
            pre_gallery = pre_out.images
            pre_info = _format_inference_output(pre_out)
        else:
            pre_gallery, pre_info = [], "No saved output"

        # SMC-grad
        smc_cfg = SMCGradInferenceConfig(prompt=prompt)
        smc_out = get_out_if_exists("smc_grad", smc_cfg)
        if smc_out is not None:
            smc_gallery = smc_out.images
            smc_info = _format_inference_output(smc_out)
        else:
            smc_gallery, smc_info = [], "No saved output"

        # FT
        ft_cfg = FTInferenceConfig(prompt=prompt)
        ft_out = get_out_if_exists("ft", ft_cfg)
        if ft_out is not None:
            ft_gallery = ft_out.images
            ft_info = _format_inference_output(ft_out)
        else:
            ft_gallery, ft_info = [], "No saved output"

        return pre_gallery, pre_info, smc_gallery, smc_info, ft_gallery, ft_info

    except Exception as e:
        # Don't crash the UI; print the traceback and return empty placeholders
        traceback.print_exc()
        return [], "Error checking saved outputs", [], "Error checking saved outputs", [], "Error checking saved outputs"


# --- Per-method runner functions ---
def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
    """Run the pretrained inference method and return (gallery, info)."""
    try:
        pretrained_cfg = PretrainedInferenceConfig(
            prompt=prompt,
            negative_prompt=pretrained_negative_prompt or "",
            CFG=float(pretrained_CFG),
            steps=int(pretrained_steps),
        )
        out = infer_pretrained(pretrained_cfg, device=get_device())
        gallery = out.images
        info = _format_inference_output(out)
        return gallery, info
    except Exception as e:
        traceback.print_exc()
        err_msg = f"Pretrained inference error: {e}"
        # Return a simple textual error in the gallery and the info box
        return [err_msg], err_msg


def run_smc_grad_ui(
    prompt,
    smc_grad_negative_prompt,
    smc_grad_CFG,
    smc_grad_steps,
    smc_grad_num_particles,
    smc_grad_ess_threshold,
    smc_grad_partial_resampling,
    smc_grad_resample_frequency,
    smc_grad_kl_weight,
    smc_grad_lambda_tempering,
    smc_grad_lambda_one_at,
    smc_grad_use_continuous_formulation,
    smc_grad_phi,
    smc_grad_tau,
):
    """Run the SMC-grad inference method and return (gallery, info)."""
    try:
        smc_grad_cfg = SMCGradInferenceConfig(
            prompt=prompt,
            negative_prompt=smc_grad_negative_prompt or "",
            ess_threshold=float(smc_grad_ess_threshold),
            partial_resampling=bool(smc_grad_partial_resampling),
            resample_frequency=int(smc_grad_resample_frequency),
            CFG=float(smc_grad_CFG),
            steps=int(smc_grad_steps),
            kl_weight=float(smc_grad_kl_weight),
            lambda_tempering=bool(smc_grad_lambda_tempering),
            lambda_one_at=float(smc_grad_lambda_one_at),
            num_particles=int(smc_grad_num_particles),
            use_continuous_formulation=bool(smc_grad_use_continuous_formulation),
            phi=int(smc_grad_phi),
            tau=float(smc_grad_tau),
        )
        out = infer_smc_grad(smc_grad_cfg, device=get_device())
        gallery = out.images
        info = _format_inference_output(out)
        return gallery, info
    except Exception as e:
        traceback.print_exc()
        err_msg = f"SMC-grad inference error: {e}"
        return [err_msg], err_msg
    
def run_ft_ui(prompt, ft_negative_prompt, ft_CFG, ft_steps):
    """Run the finetuned model inference and return (gallery, info)."""
    try:
        ft_cfg = FTInferenceConfig(
            prompt=prompt,
            negative_prompt=ft_negative_prompt or "",
            CFG=float(ft_CFG),
            steps=int(ft_steps),
        )
        out = infer_ft(ft_cfg, device=get_device())
        gallery = out.images
        info = _format_inference_output(out)
        return gallery, info
    except Exception as e:
        traceback.print_exc()
        err_msg = f"FT inference error: {e}"
        # Return a simple textual error in the gallery and the info box
        return [err_msg], err_msg


def mark_all_running():
    """Quick lightweight callback to immediately mark UI components as running.

    This runs quickly and returns updates so the UI shows a "Running..." state
    while the heavy inference functions are queued/executed.
    """
    running_info = gr.update(value="Running...", interactive=False)
    empty_gallery = gr.update(value=[])
    # Return values must match the components this function is attached to (see below)
    return empty_gallery, running_info, empty_gallery, running_info, empty_gallery, running_info


with gr.Blocks() as demo:
    gr.Markdown("# Prompt alignment for Meissonic using SMC")

    with gr.Row():
        prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
        run_button = gr.Button("Run", variant="primary")

    examples_widget = gr.Examples(examples=examples, inputs=prompt)

    # --- Pretrained method row ---
    with gr.Row():
        with gr.Column(scale=1, min_width=280):
            with gr.Accordion("Pretrained model — settings", open=False):
                pretrained_negative_prompt = gr.Textbox(
                    label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
                )
                pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG")
                pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps")

        with gr.Column(scale=2):
            pretrained_gallery = gr.Gallery(
                label="Pretrained model outputs", show_label=True, elem_id="pretrained_gallery", height=GALLERY_HEIGHT, columns=4,
                object_fit="contain",
            )
            pretrained_info = gr.Textbox(label="Pretrained info", interactive=False, visible=False)

    # --- SMC-grad method row ---
    with gr.Row():
        with gr.Column(scale=1, min_width=280):
            with gr.Accordion("SMC-grad method — settings", open=False):
                smc_grad_negative_prompt = gr.Textbox(
                    label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1
                )
                smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG")
                smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps")
                smc_grad_num_particles = gr.Slider(
                    1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles"
                )
                smc_grad_ess_threshold = gr.Slider(
                    0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold"
                )
                smc_grad_partial_resampling = gr.Checkbox(
                    label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling
                )
                smc_grad_resample_frequency = gr.Slider(
                    1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency"
                )
                smc_grad_kl_weight = gr.Slider(
                    0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight"
                )
                smc_grad_lambda_tempering = gr.Checkbox(
                    label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering
                )
                smc_grad_lambda_one_at = gr.Slider(
                    0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)"
                )
                smc_grad_use_continuous_formulation = gr.Checkbox(
                    label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation
                )
                smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi")
                smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau")

        with gr.Column(scale=2):
            smc_grad_gallery = gr.Gallery(
                label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height=GALLERY_HEIGHT, columns=4,
                object_fit="contain",
            )
            smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False, visible=False)
            
    # --- FT method row ---
    with gr.Row():
        with gr.Column(scale=1, min_width=280):
            with gr.Accordion("Finetuned model — settings", open=False):
                ft_negative_prompt = gr.Textbox(
                    label="Negative prompt", value=FTInferenceConfig.negative_prompt, lines=1
                )
                ft_CFG = gr.Slider(0.0, 30.0, step=0.1, value=FTInferenceConfig.CFG, label="CFG")
                ft_steps = gr.Slider(1, 200, step=1, value=FTInferenceConfig.steps, label="Steps")

        with gr.Column(scale=2):
            ft_gallery = gr.Gallery(
                label="Finetuned model outputs", show_label=True, elem_id="ft_gallery", height=GALLERY_HEIGHT, columns=4,
                object_fit="contain",
            )
            ft_info = gr.Textbox(label="Finetuned info", interactive=False, visible=False)

    # --- Wiring ---
    # 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
    run_button.click(
        fn=mark_all_running,
        inputs=[],
        outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
    )

    # 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
    # them to execute concurrently and update their respective outputs as they complete.
    run_button.click(
        fn=run_pretrained_ui,
        inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps],
        outputs=[pretrained_gallery, pretrained_info],
    )

    run_button.click(
        fn=run_smc_grad_ui,
        inputs=[
            prompt,
            smc_grad_negative_prompt,
            smc_grad_CFG,
            smc_grad_steps,
            smc_grad_num_particles,
            smc_grad_ess_threshold,
            smc_grad_partial_resampling,
            smc_grad_resample_frequency,
            smc_grad_kl_weight,
            smc_grad_lambda_tempering,
            smc_grad_lambda_one_at,
            smc_grad_use_continuous_formulation,
            smc_grad_phi,
            smc_grad_tau,
        ],
        outputs=[smc_grad_gallery, smc_grad_info],
    )
    
    run_button.click(
        fn=run_ft_ui,
        inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
        outputs=[ft_gallery, ft_info],
    )

    # Also allow pressing Enter in the prompt to trigger the same set of handlers
    prompt.submit(
        fn=mark_all_running,
        inputs=[],
        outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
    )
    prompt.submit(
        fn=run_pretrained_ui,
        inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps],
        outputs=[pretrained_gallery, pretrained_info],
    )
    prompt.submit(
        fn=run_smc_grad_ui,
        inputs=[
            prompt,
            smc_grad_negative_prompt,
            smc_grad_CFG,
            smc_grad_steps,
            smc_grad_num_particles,
            smc_grad_ess_threshold,
            smc_grad_partial_resampling,
            smc_grad_resample_frequency,
            smc_grad_kl_weight,
            smc_grad_lambda_tempering,
            smc_grad_lambda_one_at,
            smc_grad_use_continuous_formulation,
            smc_grad_phi,
            smc_grad_tau,
        ],
        outputs=[smc_grad_gallery, smc_grad_info],
    )
    prompt.submit(
        fn=run_ft_ui,
        inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
        outputs=[ft_gallery, ft_info],
    )
    
    # Trigger when an example is selected
    examples_widget.load_input_event.then(
        fn=try_load_saved_outputs,
        inputs=[prompt],
        outputs=[
            pretrained_gallery, pretrained_info,
            smc_grad_gallery, smc_grad_info,
            ft_gallery, ft_info,
        ],
    )
    
    # Trigger once on page load for the initial prompt value (so example[0] loads on startup)
    demo.load(
        fn=try_load_saved_outputs,
        inputs=[prompt],
        outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
    )

# Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
# to 2 (one per method) — increase if you add more methods.
# You can fine-tune max_size / concurrency_count for your deployment.

# Important: call queue() before launch()
demo.queue(default_concurrency_limit=3)

if __name__ == "__main__":
    demo.launch(share=True)