smc_meissonic / app.py
cp524's picture
Add more examples
9c41927
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
import torch
import traceback
import gradio as gr
# Import your inference functions and dataclasses
# Adjust the import path if your file is located elsewhere
from src.smc.inference import (
infer_pretrained,
infer_smc_grad,
infer_ft,
PretrainedInferenceConfig,
SMCGradInferenceConfig,
FTInferenceConfig,
)
from run_examples import get_out_if_exists
GALLERY_HEIGHT = "224px"
def get_device():
if not hasattr(get_device, "last_allocated"):
get_device.last_allocated = -1 # type: ignore
if not torch.cuda.is_available():
return "cuda" # GPU will be dynamically allocated later using spaces ZeroGPU
# Round-robin allocation
d = torch.cuda.device_count()
i = (get_device.last_allocated + 1) % d # type: ignore
get_device.last_allocated = i # type: ignore
return f"cuda:{i}"
examples = [
"A photo of a yellow bird and a black motorcycle",
"A stylish dog wearing sunglasses",
"A cat in the style of Van Gogh’s Starry Night",
]
def _format_inference_output(out) -> str:
"""Return a short summary string for the UI"""
if out is None:
return "No output"
try:
rewards = out.image_rewards
mem = out.gpu_mem_used
return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}"
except Exception:
return "Could not parse inference output"
def try_load_saved_outputs(prompt):
"""
Check for saved outputs for the given prompt for each method and return
(pretrained_gallery, pretrained_info, smc_gallery, smc_info, ft_gallery, ft_info).
If no saved output exists for a method, returns an empty gallery and
\"No saved output\" for info for that method.
"""
try:
# Pretrained
pre_cfg = PretrainedInferenceConfig(prompt=prompt)
pre_out = get_out_if_exists("pretrained", pre_cfg)
if pre_out is not None:
pre_gallery = pre_out.images
pre_info = _format_inference_output(pre_out)
else:
pre_gallery, pre_info = [], "No saved output"
# SMC-grad
smc_cfg = SMCGradInferenceConfig(prompt=prompt)
smc_out = get_out_if_exists("smc_grad", smc_cfg)
if smc_out is not None:
smc_gallery = smc_out.images
smc_info = _format_inference_output(smc_out)
else:
smc_gallery, smc_info = [], "No saved output"
# FT
ft_cfg = FTInferenceConfig(prompt=prompt)
ft_out = get_out_if_exists("ft", ft_cfg)
if ft_out is not None:
ft_gallery = ft_out.images
ft_info = _format_inference_output(ft_out)
else:
ft_gallery, ft_info = [], "No saved output"
return pre_gallery, pre_info, smc_gallery, smc_info, ft_gallery, ft_info
except Exception as e:
# Don't crash the UI; print the traceback and return empty placeholders
traceback.print_exc()
return [], "Error checking saved outputs", [], "Error checking saved outputs", [], "Error checking saved outputs"
# --- Per-method runner functions ---
def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
"""Run the pretrained inference method and return (gallery, info)."""
try:
pretrained_cfg = PretrainedInferenceConfig(
prompt=prompt,
negative_prompt=pretrained_negative_prompt or "",
CFG=float(pretrained_CFG),
steps=int(pretrained_steps),
)
out = infer_pretrained(pretrained_cfg, device=get_device())
gallery = out.images
info = _format_inference_output(out)
return gallery, info
except Exception as e:
traceback.print_exc()
err_msg = f"Pretrained inference error: {e}"
# Return a simple textual error in the gallery and the info box
return [err_msg], err_msg
def run_smc_grad_ui(
prompt,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
):
"""Run the SMC-grad inference method and return (gallery, info)."""
try:
smc_grad_cfg = SMCGradInferenceConfig(
prompt=prompt,
negative_prompt=smc_grad_negative_prompt or "",
ess_threshold=float(smc_grad_ess_threshold),
partial_resampling=bool(smc_grad_partial_resampling),
resample_frequency=int(smc_grad_resample_frequency),
CFG=float(smc_grad_CFG),
steps=int(smc_grad_steps),
kl_weight=float(smc_grad_kl_weight),
lambda_tempering=bool(smc_grad_lambda_tempering),
lambda_one_at=float(smc_grad_lambda_one_at),
num_particles=int(smc_grad_num_particles),
use_continuous_formulation=bool(smc_grad_use_continuous_formulation),
phi=int(smc_grad_phi),
tau=float(smc_grad_tau),
)
out = infer_smc_grad(smc_grad_cfg, device=get_device())
gallery = out.images
info = _format_inference_output(out)
return gallery, info
except Exception as e:
traceback.print_exc()
err_msg = f"SMC-grad inference error: {e}"
return [err_msg], err_msg
def run_ft_ui(prompt, ft_negative_prompt, ft_CFG, ft_steps):
"""Run the finetuned model inference and return (gallery, info)."""
try:
ft_cfg = FTInferenceConfig(
prompt=prompt,
negative_prompt=ft_negative_prompt or "",
CFG=float(ft_CFG),
steps=int(ft_steps),
)
out = infer_ft(ft_cfg, device=get_device())
gallery = out.images
info = _format_inference_output(out)
return gallery, info
except Exception as e:
traceback.print_exc()
err_msg = f"FT inference error: {e}"
# Return a simple textual error in the gallery and the info box
return [err_msg], err_msg
def mark_all_running():
"""Quick lightweight callback to immediately mark UI components as running.
This runs quickly and returns updates so the UI shows a "Running..." state
while the heavy inference functions are queued/executed.
"""
running_info = gr.update(value="Running...", interactive=False)
empty_gallery = gr.update(value=[])
# Return values must match the components this function is attached to (see below)
return empty_gallery, running_info, empty_gallery, running_info, empty_gallery, running_info
with gr.Blocks() as demo:
gr.Markdown("# Prompt alignment for Meissonic using SMC")
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
run_button = gr.Button("Run", variant="primary")
examples_widget = gr.Examples(examples=examples, inputs=prompt)
# --- Pretrained method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("Pretrained model — settings", open=False):
pretrained_negative_prompt = gr.Textbox(
label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
)
pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG")
pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps")
with gr.Column(scale=2):
pretrained_gallery = gr.Gallery(
label="Pretrained model outputs", show_label=True, elem_id="pretrained_gallery", height=GALLERY_HEIGHT, columns=4,
object_fit="contain",
)
pretrained_info = gr.Textbox(label="Pretrained info", interactive=False, visible=False)
# --- SMC-grad method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("SMC-grad method — settings", open=False):
smc_grad_negative_prompt = gr.Textbox(
label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1
)
smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG")
smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps")
smc_grad_num_particles = gr.Slider(
1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles"
)
smc_grad_ess_threshold = gr.Slider(
0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold"
)
smc_grad_partial_resampling = gr.Checkbox(
label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling
)
smc_grad_resample_frequency = gr.Slider(
1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency"
)
smc_grad_kl_weight = gr.Slider(
0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight"
)
smc_grad_lambda_tempering = gr.Checkbox(
label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering
)
smc_grad_lambda_one_at = gr.Slider(
0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)"
)
smc_grad_use_continuous_formulation = gr.Checkbox(
label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation
)
smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi")
smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau")
with gr.Column(scale=2):
smc_grad_gallery = gr.Gallery(
label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height=GALLERY_HEIGHT, columns=4,
object_fit="contain",
)
smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False, visible=False)
# --- FT method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("Finetuned model — settings", open=False):
ft_negative_prompt = gr.Textbox(
label="Negative prompt", value=FTInferenceConfig.negative_prompt, lines=1
)
ft_CFG = gr.Slider(0.0, 30.0, step=0.1, value=FTInferenceConfig.CFG, label="CFG")
ft_steps = gr.Slider(1, 200, step=1, value=FTInferenceConfig.steps, label="Steps")
with gr.Column(scale=2):
ft_gallery = gr.Gallery(
label="Finetuned model outputs", show_label=True, elem_id="ft_gallery", height=GALLERY_HEIGHT, columns=4,
object_fit="contain",
)
ft_info = gr.Textbox(label="Finetuned info", interactive=False, visible=False)
# --- Wiring ---
# 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
run_button.click(
fn=mark_all_running,
inputs=[],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
)
# 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
# them to execute concurrently and update their respective outputs as they complete.
run_button.click(
fn=run_pretrained_ui,
inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps],
outputs=[pretrained_gallery, pretrained_info],
)
run_button.click(
fn=run_smc_grad_ui,
inputs=[
prompt,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
],
outputs=[smc_grad_gallery, smc_grad_info],
)
run_button.click(
fn=run_ft_ui,
inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
outputs=[ft_gallery, ft_info],
)
# Also allow pressing Enter in the prompt to trigger the same set of handlers
prompt.submit(
fn=mark_all_running,
inputs=[],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
)
prompt.submit(
fn=run_pretrained_ui,
inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps],
outputs=[pretrained_gallery, pretrained_info],
)
prompt.submit(
fn=run_smc_grad_ui,
inputs=[
prompt,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
],
outputs=[smc_grad_gallery, smc_grad_info],
)
prompt.submit(
fn=run_ft_ui,
inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
outputs=[ft_gallery, ft_info],
)
# Trigger when an example is selected
examples_widget.load_input_event.then(
fn=try_load_saved_outputs,
inputs=[prompt],
outputs=[
pretrained_gallery, pretrained_info,
smc_grad_gallery, smc_grad_info,
ft_gallery, ft_info,
],
)
# Trigger once on page load for the initial prompt value (so example[0] loads on startup)
demo.load(
fn=try_load_saved_outputs,
inputs=[prompt],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
)
# Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
# to 2 (one per method) — increase if you add more methods.
# You can fine-tune max_size / concurrency_count for your deployment.
# Important: call queue() before launch()
demo.queue(default_concurrency_limit=3)
if __name__ == "__main__":
demo.launch(share=True)