smc_meissonic / app.py
cp524's picture
Add concurrent execution of methods
4991517
raw
history blame
10.7 kB
import torch
import traceback
import gradio as gr
# Import your inference functions and dataclasses
# Adjust the import path if your file is located elsewhere
from src.smc.inference import (
infer_pretrained,
infer_smc_grad,
PretrainedInferenceConfig,
SMCGradInferenceConfig,
)
def get_device():
if not hasattr(get_device, "last_allocated"):
get_device.last_allocated = -1 # type: ignore
if not torch.cuda.is_available():
return "cuda" # GPU will be dynamically allocated later using spaces ZeroGPU
# Round-robin allocation
d = torch.cuda.device_count()
i = (get_device.last_allocated + 1) % d # type: ignore
get_device.last_allocated = i # type: ignore
return f"cuda:{i}"
examples = [
"A dreamy Monet-style landscape with soft brush strokes",
"Vibrant city street at dawn in impressionist style",
]
def _format_inference_output(out) -> str:
"""Return a short summary string for the UI"""
if out is None:
return "No output"
try:
rewards = out.image_rewards
mem = out.gpu_mem_used
return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}"
except Exception:
return "Could not parse inference output"
# --- Per-method runner functions ---
def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
"""Run the pretrained inference method and return (gallery, info).
This function is designed to be attached directly to a Gradio event so it can
execute independently and return only the components it owns.
"""
try:
pretrained_cfg = PretrainedInferenceConfig(
prompt=prompt,
negative_prompt=pretrained_negative_prompt or "",
CFG=float(pretrained_CFG),
steps=int(pretrained_steps),
)
out = infer_pretrained(pretrained_cfg, device=get_device())
gallery = out.images
info = _format_inference_output(out)
return gallery, info
except Exception as e:
traceback.print_exc()
err_msg = f"Pretrained inference error: {e}"
# Return a simple textual error in the gallery and the info box
return [err_msg], err_msg
def run_smc_grad_ui(
prompt,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
):
"""Run the SMC-grad inference method and return (gallery, info)."""
try:
smc_grad_cfg = SMCGradInferenceConfig(
prompt=prompt,
negative_prompt=smc_grad_negative_prompt or "",
ess_threshold=float(smc_grad_ess_threshold),
partial_resampling=bool(smc_grad_partial_resampling),
resample_frequency=int(smc_grad_resample_frequency),
CFG=float(smc_grad_CFG),
steps=int(smc_grad_steps),
kl_weight=float(smc_grad_kl_weight),
lambda_tempering=bool(smc_grad_lambda_tempering),
lambda_one_at=float(smc_grad_lambda_one_at),
num_particles=int(smc_grad_num_particles),
use_continuous_formulation=bool(smc_grad_use_continuous_formulation),
phi=int(smc_grad_phi),
tau=float(smc_grad_tau),
)
out = infer_smc_grad(smc_grad_cfg, device=get_device())
gallery = out.images
info = _format_inference_output(out)
return gallery, info
except Exception as e:
traceback.print_exc()
err_msg = f"SMC-grad inference error: {e}"
return [err_msg], err_msg
def mark_all_running():
"""Quick lightweight callback to immediately mark UI components as running.
This runs quickly and returns updates so the UI shows a "Running..." state
while the heavy inference functions are queued/executed.
"""
running_info = gr.update(value="Running...", interactive=False)
empty_gallery = gr.update(value=[])
# Return values must match the components this function is attached to (see below)
return empty_gallery, running_info, empty_gallery, running_info
with gr.Blocks() as demo:
gr.Markdown("# Monetico — Multi-method Inference Playground")
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
run_button = gr.Button("Run", variant="primary")
gr.Examples(examples=examples, inputs=prompt)
# --- Pretrained method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("Pretrained method — settings", open=False):
pretrained_negative_prompt = gr.Textbox(
label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1
)
pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG")
pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps")
with gr.Column(scale=2):
pretrained_gallery = gr.Gallery(
label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="70vw", columns=4
)
pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
# --- SMC-grad method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("SMC-grad method — settings", open=False):
smc_grad_negative_prompt = gr.Textbox(
label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1
)
smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG")
smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps")
smc_grad_num_particles = gr.Slider(
1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles"
)
smc_grad_ess_threshold = gr.Slider(
0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold"
)
smc_grad_partial_resampling = gr.Checkbox(
label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling
)
smc_grad_resample_frequency = gr.Slider(
1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency"
)
smc_grad_kl_weight = gr.Slider(
0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight"
)
smc_grad_lambda_tempering = gr.Checkbox(
label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering
)
smc_grad_lambda_one_at = gr.Slider(
0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)"
)
smc_grad_use_continuous_formulation = gr.Checkbox(
label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation
)
smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi")
smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau")
with gr.Column(scale=2):
smc_grad_gallery = gr.Gallery(
label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height="70vw", columns=4
)
smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
# --- Wiring ---
# 1) Quick 'running' update attached to the button so the UI shows immediate feedback.
run_button.click(
fn=mark_all_running,
inputs=[],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
)
# 2) Attach the per-method heavy functions separately. Gradio's queue() will allow
# them to execute concurrently and update their respective outputs as they complete.
run_button.click(
fn=run_pretrained_ui,
inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps],
outputs=[pretrained_gallery, pretrained_info],
)
run_button.click(
fn=run_smc_grad_ui,
inputs=[
prompt,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
],
outputs=[smc_grad_gallery, smc_grad_info],
)
# Also allow pressing Enter in the prompt to trigger the same set of handlers
prompt.submit(
fn=mark_all_running,
inputs=[],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
)
prompt.submit(
fn=run_pretrained_ui,
inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps],
outputs=[pretrained_gallery, pretrained_info],
)
prompt.submit(
fn=run_smc_grad_ui,
inputs=[
prompt,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
],
outputs=[smc_grad_gallery, smc_grad_info],
)
# Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
# to 2 (one per method) — increase if you add more methods.
# You can fine-tune max_size / concurrency_count for your deployment.
# Important: call queue() before launch()
demo.queue(default_concurrency_limit=2)
if __name__ == "__main__":
demo.launch(share=True)