smc_meissonic / app.py
cp524's picture
Fix decoding batch size bug
05f9f55
raw
history blame
8.87 kB
import traceback
import gradio as gr
# Import your inference functions and dataclasses
# Adjust the import path if your file is located elsewhere
from src.smc.inference import infer_pretrained, infer_smc_grad, PretrainedInferenceConfig, SMCGradInferenceConfig
DEVICE = "cuda"
examples = [
"A dreamy Monet-style landscape with soft brush strokes",
"Vibrant city street at dawn in impressionist style",
]
def _format_inference_output(out) -> str:
"""Return a short summary string for the UI"""
if out is None:
return "No output"
try:
rewards = out.image_rewards
mem = out.gpu_mem_used
return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}"
except Exception:
return "Could not parse inference output"
def run_inference_all(
prompt,
# Pretrained method controls
pretrained_negative_prompt,
pretrained_CFG,
pretrained_steps,
# SMC-grad method controls
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
):
"""Wrapper that runs both inference methods and returns UI-friendly outputs.
Returns:
pretrained_images, pretrained_info, smc_grad_images, smc_grad_info
"""
# --- Pretrained ---
pretrained_output = None
pretrained_images = []
try:
pretrained_cfg = PretrainedInferenceConfig(
prompt=prompt,
negative_prompt=pretrained_negative_prompt or "",
CFG=float(pretrained_CFG),
steps=int(pretrained_steps),
)
pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE)
pretrained_images = pretrained_output.images
except Exception as e:
traceback.print_exc()
pretrained_images = []
pretrained_output = None
pretrained_error = f"Pretrained inference error: {e}"
pretrained_images = [pretrained_error]
# --- SMC-grad ---
smc_grad_output = None
smc_grad_images = []
try:
smc_grad_cfg = SMCGradInferenceConfig(
prompt=prompt,
negative_prompt=smc_grad_negative_prompt or "",
ess_threshold=float(smc_grad_ess_threshold),
partial_resampling=bool(smc_grad_partial_resampling),
resample_frequency=int(smc_grad_resample_frequency),
CFG=float(smc_grad_CFG),
steps=int(smc_grad_steps),
kl_weight=float(smc_grad_kl_weight),
lambda_tempering=bool(smc_grad_lambda_tempering),
lambda_one_at=float(smc_grad_lambda_one_at),
num_particles=int(smc_grad_num_particles),
use_continuous_formulation=bool(smc_grad_use_continuous_formulation),
phi=int(smc_grad_phi),
tau=float(smc_grad_tau),
)
smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE)
# The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided.
except Exception as e:
traceback.print_exc()
smc_grad_images = []
smc_grad_output = None
smc_grad_error = f"SMC inference error: {e}"
smc_grad_images = [smc_grad_error]
# If outputs are dataclasses with PIL images, gr.Gallery accepts lists of PIL images.
pretrained_gallery = pretrained_images if isinstance(pretrained_images, list) else [pretrained_images]
smc_grad_gallery = smc_grad_output.images if smc_grad_output is not None else smc_grad_images
pretrained_info = _format_inference_output(pretrained_output)
smc_grad_info = _format_inference_output(smc_grad_output)
return pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info
with gr.Blocks() as demo:
gr.Markdown("# Monetico — Multi-method Inference Playground")
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
run_button = gr.Button("Run", variant="primary")
gr.Examples(examples=examples, inputs=prompt)
# --- Pretrained method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("Pretrained method — settings", open=False):
pretrained_negative_prompt = gr.Textbox(label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1)
pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG")
pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps")
with gr.Column(scale=2):
pretrained_gallery = gr.Gallery(label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="70vw", columns=4)
pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
# --- SMC-grad method row ---
with gr.Row():
with gr.Column(scale=1, min_width=280):
with gr.Accordion("SMC-grad method — settings", open=False):
smc_grad_negative_prompt = gr.Textbox(label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1)
smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG")
smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps")
smc_grad_num_particles = gr.Slider(1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles")
smc_grad_ess_threshold = gr.Slider(0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold")
smc_grad_partial_resampling = gr.Checkbox(label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling)
smc_grad_resample_frequency = gr.Slider(1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency")
smc_grad_kl_weight = gr.Slider(0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight")
smc_grad_lambda_tempering = gr.Checkbox(label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering)
smc_grad_lambda_one_at = gr.Slider(0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)")
smc_grad_use_continuous_formulation = gr.Checkbox(label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation)
smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi")
smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau")
with gr.Column(scale=2):
smc_grad_gallery = gr.Gallery(label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height="70vw", columns=4)
smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
# Wire up the run button and prompt submit to the same runner
run_button.click(
fn=run_inference_all,
inputs=[
prompt,
pretrained_negative_prompt,
pretrained_CFG,
pretrained_steps,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
)
# Also allow pressing Enter in the prompt to trigger
prompt.submit(
fn=run_inference_all,
inputs=[
prompt,
pretrained_negative_prompt,
pretrained_CFG,
pretrained_steps,
smc_grad_negative_prompt,
smc_grad_CFG,
smc_grad_steps,
smc_grad_num_particles,
smc_grad_ess_threshold,
smc_grad_partial_resampling,
smc_grad_resample_frequency,
smc_grad_kl_weight,
smc_grad_lambda_tempering,
smc_grad_lambda_one_at,
smc_grad_use_continuous_formulation,
smc_grad_phi,
smc_grad_tau,
],
outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
)
if __name__ == "__main__":
demo.launch(share=True)