Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
| os.environ["TORCH_USE_CUDA_DSA"] = "1" | |
| import torch | |
| import traceback | |
| import gradio as gr | |
| # Import your inference functions and dataclasses | |
| # Adjust the import path if your file is located elsewhere | |
| from src.smc.inference import ( | |
| infer_pretrained, | |
| infer_smc_grad, | |
| infer_ft, | |
| PretrainedInferenceConfig, | |
| SMCGradInferenceConfig, | |
| FTInferenceConfig, | |
| ) | |
| from run_examples import get_out_if_exists | |
| GALLERY_HEIGHT = "224px" | |
| def get_device(): | |
| if not hasattr(get_device, "last_allocated"): | |
| get_device.last_allocated = -1 # type: ignore | |
| if not torch.cuda.is_available(): | |
| return "cuda" # GPU will be dynamically allocated later using spaces ZeroGPU | |
| # Round-robin allocation | |
| d = torch.cuda.device_count() | |
| i = (get_device.last_allocated + 1) % d # type: ignore | |
| get_device.last_allocated = i # type: ignore | |
| return f"cuda:{i}" | |
| examples = [ | |
| "A photo of a yellow bird and a black motorcycle", | |
| "A stylish dog wearing sunglasses", | |
| "A cat in the style of Van Gogh’s Starry Night", | |
| ] | |
| def _format_inference_output(out) -> str: | |
| """Return a short summary string for the UI""" | |
| if out is None: | |
| return "No output" | |
| try: | |
| rewards = out.image_rewards | |
| mem = out.gpu_mem_used | |
| return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}" | |
| except Exception: | |
| return "Could not parse inference output" | |
| def try_load_saved_outputs(prompt): | |
| """ | |
| Check for saved outputs for the given prompt for each method and return | |
| (pretrained_gallery, pretrained_info, smc_gallery, smc_info, ft_gallery, ft_info). | |
| If no saved output exists for a method, returns an empty gallery and | |
| \"No saved output\" for info for that method. | |
| """ | |
| try: | |
| # Pretrained | |
| pre_cfg = PretrainedInferenceConfig(prompt=prompt) | |
| pre_out = get_out_if_exists("pretrained", pre_cfg) | |
| if pre_out is not None: | |
| pre_gallery = pre_out.images | |
| pre_info = _format_inference_output(pre_out) | |
| else: | |
| pre_gallery, pre_info = [], "No saved output" | |
| # SMC-grad | |
| smc_cfg = SMCGradInferenceConfig(prompt=prompt) | |
| smc_out = get_out_if_exists("smc_grad", smc_cfg) | |
| if smc_out is not None: | |
| smc_gallery = smc_out.images | |
| smc_info = _format_inference_output(smc_out) | |
| else: | |
| smc_gallery, smc_info = [], "No saved output" | |
| # FT | |
| ft_cfg = FTInferenceConfig(prompt=prompt) | |
| ft_out = get_out_if_exists("ft", ft_cfg) | |
| if ft_out is not None: | |
| ft_gallery = ft_out.images | |
| ft_info = _format_inference_output(ft_out) | |
| else: | |
| ft_gallery, ft_info = [], "No saved output" | |
| return pre_gallery, pre_info, smc_gallery, smc_info, ft_gallery, ft_info | |
| except Exception as e: | |
| # Don't crash the UI; print the traceback and return empty placeholders | |
| traceback.print_exc() | |
| return [], "Error checking saved outputs", [], "Error checking saved outputs", [], "Error checking saved outputs" | |
| # --- Per-method runner functions --- | |
| def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps): | |
| """Run the pretrained inference method and return (gallery, info).""" | |
| try: | |
| pretrained_cfg = PretrainedInferenceConfig( | |
| prompt=prompt, | |
| negative_prompt=pretrained_negative_prompt or "", | |
| CFG=float(pretrained_CFG), | |
| steps=int(pretrained_steps), | |
| ) | |
| out = infer_pretrained(pretrained_cfg, device=get_device()) | |
| gallery = out.images | |
| info = _format_inference_output(out) | |
| return gallery, info | |
| except Exception as e: | |
| traceback.print_exc() | |
| err_msg = f"Pretrained inference error: {e}" | |
| # Return a simple textual error in the gallery and the info box | |
| return [err_msg], err_msg | |
| def run_smc_grad_ui( | |
| prompt, | |
| smc_grad_negative_prompt, | |
| smc_grad_CFG, | |
| smc_grad_steps, | |
| smc_grad_num_particles, | |
| smc_grad_ess_threshold, | |
| smc_grad_partial_resampling, | |
| smc_grad_resample_frequency, | |
| smc_grad_kl_weight, | |
| smc_grad_lambda_tempering, | |
| smc_grad_lambda_one_at, | |
| smc_grad_use_continuous_formulation, | |
| smc_grad_phi, | |
| smc_grad_tau, | |
| ): | |
| """Run the SMC-grad inference method and return (gallery, info).""" | |
| try: | |
| smc_grad_cfg = SMCGradInferenceConfig( | |
| prompt=prompt, | |
| negative_prompt=smc_grad_negative_prompt or "", | |
| ess_threshold=float(smc_grad_ess_threshold), | |
| partial_resampling=bool(smc_grad_partial_resampling), | |
| resample_frequency=int(smc_grad_resample_frequency), | |
| CFG=float(smc_grad_CFG), | |
| steps=int(smc_grad_steps), | |
| kl_weight=float(smc_grad_kl_weight), | |
| lambda_tempering=bool(smc_grad_lambda_tempering), | |
| lambda_one_at=float(smc_grad_lambda_one_at), | |
| num_particles=int(smc_grad_num_particles), | |
| use_continuous_formulation=bool(smc_grad_use_continuous_formulation), | |
| phi=int(smc_grad_phi), | |
| tau=float(smc_grad_tau), | |
| ) | |
| out = infer_smc_grad(smc_grad_cfg, device=get_device()) | |
| gallery = out.images | |
| info = _format_inference_output(out) | |
| return gallery, info | |
| except Exception as e: | |
| traceback.print_exc() | |
| err_msg = f"SMC-grad inference error: {e}" | |
| return [err_msg], err_msg | |
| def run_ft_ui(prompt, ft_negative_prompt, ft_CFG, ft_steps): | |
| """Run the finetuned model inference and return (gallery, info).""" | |
| try: | |
| ft_cfg = FTInferenceConfig( | |
| prompt=prompt, | |
| negative_prompt=ft_negative_prompt or "", | |
| CFG=float(ft_CFG), | |
| steps=int(ft_steps), | |
| ) | |
| out = infer_ft(ft_cfg, device=get_device()) | |
| gallery = out.images | |
| info = _format_inference_output(out) | |
| return gallery, info | |
| except Exception as e: | |
| traceback.print_exc() | |
| err_msg = f"FT inference error: {e}" | |
| # Return a simple textual error in the gallery and the info box | |
| return [err_msg], err_msg | |
| def mark_all_running(): | |
| """Quick lightweight callback to immediately mark UI components as running. | |
| This runs quickly and returns updates so the UI shows a "Running..." state | |
| while the heavy inference functions are queued/executed. | |
| """ | |
| running_info = gr.update(value="Running...", interactive=False) | |
| empty_gallery = gr.update(value=[]) | |
| # Return values must match the components this function is attached to (see below) | |
| return empty_gallery, running_info, empty_gallery, running_info, empty_gallery, running_info | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Prompt alignment for Meissonic using SMC") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1) | |
| run_button = gr.Button("Run", variant="primary") | |
| examples_widget = gr.Examples(examples=examples, inputs=prompt) | |
| # --- Pretrained method row --- | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280): | |
| with gr.Accordion("Pretrained model — settings", open=False): | |
| pretrained_negative_prompt = gr.Textbox( | |
| label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1 | |
| ) | |
| pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG") | |
| pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps") | |
| with gr.Column(scale=2): | |
| pretrained_gallery = gr.Gallery( | |
| label="Pretrained model outputs", show_label=True, elem_id="pretrained_gallery", height=GALLERY_HEIGHT, columns=4, | |
| object_fit="contain", | |
| ) | |
| pretrained_info = gr.Textbox(label="Pretrained info", interactive=False, visible=False) | |
| # --- SMC-grad method row --- | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280): | |
| with gr.Accordion("SMC-grad method — settings", open=False): | |
| smc_grad_negative_prompt = gr.Textbox( | |
| label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1 | |
| ) | |
| smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG") | |
| smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps") | |
| smc_grad_num_particles = gr.Slider( | |
| 1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles" | |
| ) | |
| smc_grad_ess_threshold = gr.Slider( | |
| 0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold" | |
| ) | |
| smc_grad_partial_resampling = gr.Checkbox( | |
| label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling | |
| ) | |
| smc_grad_resample_frequency = gr.Slider( | |
| 1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency" | |
| ) | |
| smc_grad_kl_weight = gr.Slider( | |
| 0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight" | |
| ) | |
| smc_grad_lambda_tempering = gr.Checkbox( | |
| label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering | |
| ) | |
| smc_grad_lambda_one_at = gr.Slider( | |
| 0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)" | |
| ) | |
| smc_grad_use_continuous_formulation = gr.Checkbox( | |
| label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation | |
| ) | |
| smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi") | |
| smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau") | |
| with gr.Column(scale=2): | |
| smc_grad_gallery = gr.Gallery( | |
| label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height=GALLERY_HEIGHT, columns=4, | |
| object_fit="contain", | |
| ) | |
| smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False, visible=False) | |
| # --- FT method row --- | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280): | |
| with gr.Accordion("Finetuned model — settings", open=False): | |
| ft_negative_prompt = gr.Textbox( | |
| label="Negative prompt", value=FTInferenceConfig.negative_prompt, lines=1 | |
| ) | |
| ft_CFG = gr.Slider(0.0, 30.0, step=0.1, value=FTInferenceConfig.CFG, label="CFG") | |
| ft_steps = gr.Slider(1, 200, step=1, value=FTInferenceConfig.steps, label="Steps") | |
| with gr.Column(scale=2): | |
| ft_gallery = gr.Gallery( | |
| label="Finetuned model outputs", show_label=True, elem_id="ft_gallery", height=GALLERY_HEIGHT, columns=4, | |
| object_fit="contain", | |
| ) | |
| ft_info = gr.Textbox(label="Finetuned info", interactive=False, visible=False) | |
| # --- Wiring --- | |
| # 1) Quick 'running' update attached to the button so the UI shows immediate feedback. | |
| run_button.click( | |
| fn=mark_all_running, | |
| inputs=[], | |
| outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info], | |
| ) | |
| # 2) Attach the per-method heavy functions separately. Gradio's queue() will allow | |
| # them to execute concurrently and update their respective outputs as they complete. | |
| run_button.click( | |
| fn=run_pretrained_ui, | |
| inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps], | |
| outputs=[pretrained_gallery, pretrained_info], | |
| ) | |
| run_button.click( | |
| fn=run_smc_grad_ui, | |
| inputs=[ | |
| prompt, | |
| smc_grad_negative_prompt, | |
| smc_grad_CFG, | |
| smc_grad_steps, | |
| smc_grad_num_particles, | |
| smc_grad_ess_threshold, | |
| smc_grad_partial_resampling, | |
| smc_grad_resample_frequency, | |
| smc_grad_kl_weight, | |
| smc_grad_lambda_tempering, | |
| smc_grad_lambda_one_at, | |
| smc_grad_use_continuous_formulation, | |
| smc_grad_phi, | |
| smc_grad_tau, | |
| ], | |
| outputs=[smc_grad_gallery, smc_grad_info], | |
| ) | |
| run_button.click( | |
| fn=run_ft_ui, | |
| inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps], | |
| outputs=[ft_gallery, ft_info], | |
| ) | |
| # Also allow pressing Enter in the prompt to trigger the same set of handlers | |
| prompt.submit( | |
| fn=mark_all_running, | |
| inputs=[], | |
| outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info], | |
| ) | |
| prompt.submit( | |
| fn=run_pretrained_ui, | |
| inputs=[prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps], | |
| outputs=[pretrained_gallery, pretrained_info], | |
| ) | |
| prompt.submit( | |
| fn=run_smc_grad_ui, | |
| inputs=[ | |
| prompt, | |
| smc_grad_negative_prompt, | |
| smc_grad_CFG, | |
| smc_grad_steps, | |
| smc_grad_num_particles, | |
| smc_grad_ess_threshold, | |
| smc_grad_partial_resampling, | |
| smc_grad_resample_frequency, | |
| smc_grad_kl_weight, | |
| smc_grad_lambda_tempering, | |
| smc_grad_lambda_one_at, | |
| smc_grad_use_continuous_formulation, | |
| smc_grad_phi, | |
| smc_grad_tau, | |
| ], | |
| outputs=[smc_grad_gallery, smc_grad_info], | |
| ) | |
| prompt.submit( | |
| fn=run_ft_ui, | |
| inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps], | |
| outputs=[ft_gallery, ft_info], | |
| ) | |
| # Trigger when an example is selected | |
| examples_widget.load_input_event.then( | |
| fn=try_load_saved_outputs, | |
| inputs=[prompt], | |
| outputs=[ | |
| pretrained_gallery, pretrained_info, | |
| smc_grad_gallery, smc_grad_info, | |
| ft_gallery, ft_info, | |
| ], | |
| ) | |
| # Trigger once on page load for the initial prompt value (so example[0] loads on startup) | |
| demo.load( | |
| fn=try_load_saved_outputs, | |
| inputs=[prompt], | |
| outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info], | |
| ) | |
| # Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency | |
| # to 2 (one per method) — increase if you add more methods. | |
| # You can fine-tune max_size / concurrency_count for your deployment. | |
| # Important: call queue() before launch() | |
| demo.queue(default_concurrency_limit=3) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |