import os from typing import Optional import gradio as gr import lightning as L import numpy as np import spaces from PIL import Image from src.sw_sdthree_guidance import create_pipeline from src.sw_sdthree_guidance import run as sw_guidance_run preset_lrs = [1e-6, 1.0] log_lrs = np.log10(preset_lrs) models = { "stabilityai/stable-diffusion-3.5-large": { "num_inference_steps": 30, "guidance_scale": 7.0, "sw_u_lr": np.log10(3.2e-3), "sw_steps": 6, "cfg_rescale_phi": 0.7, }, "stabilityai/stable-diffusion-3.5-medium": { "num_inference_steps": 30, "guidance_scale": 3.5, "sw_u_lr": np.log10(3.2e-3), "sw_steps": 6, "cfg_rescale_phi": 0.7, }, "stabilityai/stable-diffusion-3.5-large-turbo": { "num_inference_steps": 4, "guidance_scale": 1.0, "sw_u_lr": np.log10(5e-3), "sw_steps": 6, "cfg_rescale_phi": 0.65, }, } def log_slider_to_lr(log_lr): return float(f"{10**log_lr:.1e}") def create_sw_guidance( model_name: str = "stabilityai/stable-diffusion-3.5-large" ): """ Creates the Gradio interface for SW guidance with SD3.5. Args: fabric: Lightning Fabric instance model_name: The model to use for guidance """ gr.Markdown( f""" # SW Guidance with {model_name.split('/')[-1]} Generates images using SW Guidance with a reference image and text prompt. """ ) pipe = create_pipeline( model_name, device="cuda", compile=False, ) model_config = models[model_name] @spaces.GPU def run_sw_guidance( num_inference_steps: int, num_guided_steps_perc: float, guidance_scale: float, sw_u_lr: float, sw_steps: int, num_projections: int, control_variates: str, distance: str, candidates_per_pass: int, subsampling_factor: int, sampling_mode: str, cfg_rescale_phi: float, prompt: str, reference_image: np.ndarray, seed: Optional[int] = None, ): """ Runs the SW guidance process with the given parameters. """ if reference_image is None: raise gr.Error("Please provide a reference image") if not prompt: raise gr.Error("Please provide a prompt") # Convert numpy array to PIL Image ref_img = Image.fromarray(reference_image) # Run SW guidance image = sw_guidance_run( prompt=prompt, reference_image=ref_img, model_path=model_name, num_inference_steps=num_inference_steps, num_guided_steps=int(num_guided_steps_perc * num_inference_steps), guidance_scale=guidance_scale, sw_u_lr=log_slider_to_lr(sw_u_lr), sw_steps=sw_steps, height=1024, width=1024, device="cuda", num_projections=num_projections, use_ucv=control_variates == "Upper", use_lcv=control_variates == "Lower", distance=distance, num_new_candidates=candidates_per_pass, subsampling_factor=subsampling_factor, sampling_mode=sampling_mode, pipe=pipe, compile=True, seed=seed, cfg_rescale_phi=cfg_rescale_phi, ) return np.array(image) gr.Markdown("## Input") with gr.Row(equal_height=True): with gr.Column(variant="panel"): prompt = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=2, ) reference_image = gr.Image(label="Reference Image", height=512) with gr.Column(variant="panel"): output_image = gr.Image(label="Generated Image", height=512) example_pairs = [ ("A diver discovering an underwater city", "sunken_boat.jpg"), ("A raccoon in a forest", "waterfall.jpg"), ("A cat detective solving a mystery", "building.jpg"), ("A raccoon reading a book by candlelight.", "food_stand.jpg"), ("A lion meditating on a mountain", "mountain.jpg"), ("A squirrel kayaking", "lake_green.jpg"), ("A young dragon roasting marshmallows", "canyon.jpg"), ("A family picnic beneath floating lanterns", "lake_sunset.jpg"), ("Elephants holding umbrellas in a rainstorm", "fisher.jpg"), ("A boy and his dog exploring a crystal cave", "lake.jpg"), ("A snowman sharing cocoa with woodland animals", "snow.jpg"), ("A kitten exploring an antique library", "ornament.jpg"), ("Children riding flying bicycles over mountains", "sky_pier.jpg"), ("An ancient tree whispering stories to deer", "greenhouse_2.jpg"), ("Owls with monocles in treetops", "path.jpg"), ] default_config = { "num_guided_steps_perc": 0.95, "num_projections": 32, "control_variates": "None", "distance": "l1", "candidates_per_pass": 8, "subsampling_factor": 1, "sampling_mode": "gaussian", } | model_config def run_example(prompt: str, reference_image: np.ndarray): return run_sw_guidance( **default_config, prompt=prompt, reference_image=reference_image, ) example_inputs = [ [prompt, os.path.join("example", "guidance", img_file)] for prompt, img_file in example_pairs ] run_button = gr.Button("Generate Image", variant="primary") gr.Examples( examples=example_inputs, inputs=[prompt, reference_image], outputs=[output_image], fn=run_example, label="Prompt + Reference Image Examples", examples_per_page=5, cache_examples=True, cache_mode="lazy", ) gr.Markdown( """ # Configuration Adjust the parameters for the SW guidance process. """ ) with gr.Accordion("Basic Config", open=True): num_inference_steps_slider = gr.Slider( 1, 100, value=model_config["num_inference_steps"], step=1, label="Number of inference steps", ) guidance_scale_slider = gr.Slider( 0.0, 20.0, value=model_config["guidance_scale"], step=0.1, label="Guidance scale", ) cfg_rescale_phi_slider = gr.Slider( 0.0, 1.0, value=model_config["cfg_rescale_phi"], step=0.05, label="CFG Rescale Phi", info="Controls the rescaling of classifier-free guidance", ) seed_input = gr.Number( value=lambda: None, label="Seed (leave empty for random)", precision=0, interactive=True, ) with gr.Row(variant="panel", equal_height=True): sw_u_lr_slider = gr.Slider( minimum=log_lrs.min(), maximum=log_lrs.max(), value=model_config["sw_u_lr"], step=0.05, label="SW guidance learning rate (Log scale)", interactive=True, scale=4, ) lr_display = gr.Textbox( label="Learning Rate", value=f"{log_slider_to_lr(model_config['sw_u_lr']):.1e}", interactive=False, scale=1, ) sw_u_lr_slider.change( lambda x: gr.update(value=f"{log_slider_to_lr(x):.1e}"), inputs=sw_u_lr_slider, outputs=lr_display, show_progress=False, ) with gr.Accordion("Advanced Config", open=False): num_guided_steps_perc_slider = gr.Slider( 0.0, 1.0, value=default_config["num_guided_steps_perc"], step=0.05, label="Percentage of steps to apply SW guidance", ) sw_steps_slider = gr.Slider( 0, 32, value=model_config["sw_steps"], step=1, label="Number of SW guidance steps, 0 means no SW guidance", ) num_projections_slider = gr.Slider( 16, 1024, value=default_config["num_projections"], step=16, label="Number of projections", ) control_variates_dropdown = gr.Dropdown( choices=["None", "Lower", "Upper"], value=default_config["control_variates"], label="Control Variates", info="Select which control variates to use for optimization", ) distance_dropdown = gr.Dropdown( choices=["l1", "l2"], value=default_config["distance"], label="Distance metric", info="Select which distance metric to use", ) candidates_per_pass_slider = gr.Slider( 0, 64, value=default_config["candidates_per_pass"], step=1, label="Number of new candidates per pass. 0 means no reservoir sampling", ) subsampling_factor_slider = gr.Slider( 1, 16, value=default_config["subsampling_factor"], step=1, label="Subsampling factor", ) sampling_mode_dropdown = gr.Dropdown( choices=["gaussian", "qmc"], value="qmc", label="Sampling Mode", info="Select which sampling mode to use for projections", ) run_button.click( run_sw_guidance, inputs=[ num_inference_steps_slider, num_guided_steps_perc_slider, guidance_scale_slider, sw_u_lr_slider, sw_steps_slider, num_projections_slider, control_variates_dropdown, distance_dropdown, candidates_per_pass_slider, subsampling_factor_slider, sampling_mode_dropdown, cfg_rescale_phi_slider, prompt, reference_image, seed_input, ], outputs=[output_image], ) clear_button = gr.ClearButton([prompt, reference_image, output_image, seed_input]) clear_button.click(lambda: None)