Spaces:
Sleeping
Sleeping
| import traceback | |
| import gradio as gr | |
| # Import your inference functions and dataclasses | |
| # Adjust the import path if your file is located elsewhere | |
| from src.smc.inference import infer_pretrained, infer_smc_grad, PretrainedInferenceConfig, SMCGradInferenceConfig | |
| DEVICE = "cuda" | |
| examples = [ | |
| "A dreamy Monet-style landscape with soft brush strokes", | |
| "Vibrant city street at dawn in impressionist style", | |
| ] | |
| def _format_inference_output(out) -> str: | |
| """Return a short summary string for the UI""" | |
| if out is None: | |
| return "No output" | |
| try: | |
| rewards = out.image_rewards | |
| mem = out.gpu_mem_used | |
| return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}" | |
| except Exception: | |
| return "Could not parse inference output" | |
| def run_inference_all( | |
| prompt, | |
| # Pretrained method controls | |
| pretrained_negative_prompt, | |
| pretrained_CFG, | |
| pretrained_steps, | |
| # SMC-grad method controls | |
| smc_grad_negative_prompt, | |
| smc_grad_CFG, | |
| smc_grad_steps, | |
| smc_grad_num_particles, | |
| smc_grad_ess_threshold, | |
| smc_grad_partial_resampling, | |
| smc_grad_resample_frequency, | |
| smc_grad_kl_weight, | |
| smc_grad_lambda_tempering, | |
| smc_grad_lambda_one_at, | |
| smc_grad_use_continuous_formulation, | |
| smc_grad_phi, | |
| smc_grad_tau, | |
| ): | |
| """Wrapper that runs both inference methods and returns UI-friendly outputs. | |
| Returns: | |
| pretrained_images, pretrained_info, smc_grad_images, smc_grad_info | |
| """ | |
| # --- Pretrained --- | |
| pretrained_output = None | |
| pretrained_images = [] | |
| try: | |
| pretrained_cfg = PretrainedInferenceConfig( | |
| prompt=prompt, | |
| negative_prompt=pretrained_negative_prompt or "", | |
| CFG=float(pretrained_CFG), | |
| steps=int(pretrained_steps), | |
| ) | |
| pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE) | |
| pretrained_images = pretrained_output.images | |
| except Exception as e: | |
| traceback.print_exc() | |
| pretrained_images = [] | |
| pretrained_output = None | |
| pretrained_error = f"Pretrained inference error: {e}" | |
| pretrained_images = [pretrained_error] | |
| # --- SMC-grad --- | |
| smc_grad_output = None | |
| smc_grad_images = [] | |
| try: | |
| smc_grad_cfg = SMCGradInferenceConfig( | |
| prompt=prompt, | |
| negative_prompt=smc_grad_negative_prompt or "", | |
| ess_threshold=float(smc_grad_ess_threshold), | |
| partial_resampling=bool(smc_grad_partial_resampling), | |
| resample_frequency=int(smc_grad_resample_frequency), | |
| CFG=float(smc_grad_CFG), | |
| steps=int(smc_grad_steps), | |
| kl_weight=float(smc_grad_kl_weight), | |
| lambda_tempering=bool(smc_grad_lambda_tempering), | |
| lambda_one_at=float(smc_grad_lambda_one_at), | |
| num_particles=int(smc_grad_num_particles), | |
| use_continuous_formulation=bool(smc_grad_use_continuous_formulation), | |
| phi=int(smc_grad_phi), | |
| tau=float(smc_grad_tau), | |
| ) | |
| smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE) | |
| # The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided. | |
| except Exception as e: | |
| traceback.print_exc() | |
| smc_grad_images = [] | |
| smc_grad_output = None | |
| smc_grad_error = f"SMC inference error: {e}" | |
| smc_grad_images = [smc_grad_error] | |
| # If outputs are dataclasses with PIL images, gr.Gallery accepts lists of PIL images. | |
| pretrained_gallery = pretrained_images if isinstance(pretrained_images, list) else [pretrained_images] | |
| smc_grad_gallery = smc_grad_output.images if smc_grad_output is not None else smc_grad_images | |
| pretrained_info = _format_inference_output(pretrained_output) | |
| smc_grad_info = _format_inference_output(smc_grad_output) | |
| return pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Monetico — Multi-method Inference Playground") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1) | |
| run_button = gr.Button("Run", variant="primary") | |
| gr.Examples(examples=examples, inputs=prompt) | |
| # --- Pretrained method row --- | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280): | |
| with gr.Accordion("Pretrained method — settings", open=False): | |
| pretrained_negative_prompt = gr.Textbox(label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1) | |
| pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG") | |
| pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps") | |
| with gr.Column(scale=2): | |
| pretrained_gallery = gr.Gallery(label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="70vw", columns=4) | |
| pretrained_info = gr.Textbox(label="Pretrained info", interactive=False) | |
| # --- SMC-grad method row --- | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280): | |
| with gr.Accordion("SMC-grad method — settings", open=False): | |
| smc_grad_negative_prompt = gr.Textbox(label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1) | |
| smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG") | |
| smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps") | |
| smc_grad_num_particles = gr.Slider(1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles") | |
| smc_grad_ess_threshold = gr.Slider(0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold") | |
| smc_grad_partial_resampling = gr.Checkbox(label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling) | |
| smc_grad_resample_frequency = gr.Slider(1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency") | |
| smc_grad_kl_weight = gr.Slider(0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight") | |
| smc_grad_lambda_tempering = gr.Checkbox(label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering) | |
| smc_grad_lambda_one_at = gr.Slider(0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)") | |
| smc_grad_use_continuous_formulation = gr.Checkbox(label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation) | |
| smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi") | |
| smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau") | |
| with gr.Column(scale=2): | |
| smc_grad_gallery = gr.Gallery(label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height="70vw", columns=4) | |
| smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False) | |
| # Wire up the run button and prompt submit to the same runner | |
| run_button.click( | |
| fn=run_inference_all, | |
| inputs=[ | |
| prompt, | |
| pretrained_negative_prompt, | |
| pretrained_CFG, | |
| pretrained_steps, | |
| smc_grad_negative_prompt, | |
| smc_grad_CFG, | |
| smc_grad_steps, | |
| smc_grad_num_particles, | |
| smc_grad_ess_threshold, | |
| smc_grad_partial_resampling, | |
| smc_grad_resample_frequency, | |
| smc_grad_kl_weight, | |
| smc_grad_lambda_tempering, | |
| smc_grad_lambda_one_at, | |
| smc_grad_use_continuous_formulation, | |
| smc_grad_phi, | |
| smc_grad_tau, | |
| ], | |
| outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info], | |
| ) | |
| # Also allow pressing Enter in the prompt to trigger | |
| prompt.submit( | |
| fn=run_inference_all, | |
| inputs=[ | |
| prompt, | |
| pretrained_negative_prompt, | |
| pretrained_CFG, | |
| pretrained_steps, | |
| smc_grad_negative_prompt, | |
| smc_grad_CFG, | |
| smc_grad_steps, | |
| smc_grad_num_particles, | |
| smc_grad_ess_threshold, | |
| smc_grad_partial_resampling, | |
| smc_grad_resample_frequency, | |
| smc_grad_kl_weight, | |
| smc_grad_lambda_tempering, | |
| smc_grad_lambda_one_at, | |
| smc_grad_use_continuous_formulation, | |
| smc_grad_phi, | |
| smc_grad_tau, | |
| ], | |
| outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |