ReSWD / src /gradio_demo /sw_guidance.py
mboss's picture
Updates for arXiv release
d0dc81a
import os
from typing import Optional
import gradio as gr
import lightning as L
import numpy as np
import spaces
from PIL import Image
from src.sw_sdthree_guidance import create_pipeline
from src.sw_sdthree_guidance import run as sw_guidance_run
preset_lrs = [1e-6, 1.0]
log_lrs = np.log10(preset_lrs)
models = {
"stabilityai/stable-diffusion-3.5-large": {
"num_inference_steps": 30,
"guidance_scale": 7.0,
"sw_u_lr": np.log10(3.2e-3),
"sw_steps": 6,
"cfg_rescale_phi": 0.7,
},
"stabilityai/stable-diffusion-3.5-medium": {
"num_inference_steps": 30,
"guidance_scale": 3.5,
"sw_u_lr": np.log10(3.2e-3),
"sw_steps": 6,
"cfg_rescale_phi": 0.7,
},
"stabilityai/stable-diffusion-3.5-large-turbo": {
"num_inference_steps": 4,
"guidance_scale": 1.0,
"sw_u_lr": np.log10(5e-3),
"sw_steps": 6,
"cfg_rescale_phi": 0.65,
},
}
def log_slider_to_lr(log_lr):
return float(f"{10**log_lr:.1e}")
def create_sw_guidance(
model_name: str = "stabilityai/stable-diffusion-3.5-large"
):
"""
Creates the Gradio interface for SW guidance with SD3.5.
Args:
fabric: Lightning Fabric instance
model_name: The model to use for guidance
"""
gr.Markdown(
f"""
# SW Guidance with {model_name.split('/')[-1]}
Generates images using SW Guidance with a reference image and text prompt.
"""
)
pipe = create_pipeline(
model_name,
device="cuda",
compile=False,
)
model_config = models[model_name]
@spaces.GPU
def run_sw_guidance(
num_inference_steps: int,
num_guided_steps_perc: float,
guidance_scale: float,
sw_u_lr: float,
sw_steps: int,
num_projections: int,
control_variates: str,
distance: str,
candidates_per_pass: int,
subsampling_factor: int,
sampling_mode: str,
cfg_rescale_phi: float,
prompt: str,
reference_image: np.ndarray,
seed: Optional[int] = None,
):
"""
Runs the SW guidance process with the given parameters.
"""
if reference_image is None:
raise gr.Error("Please provide a reference image")
if not prompt:
raise gr.Error("Please provide a prompt")
# Convert numpy array to PIL Image
ref_img = Image.fromarray(reference_image)
# Run SW guidance
image = sw_guidance_run(
prompt=prompt,
reference_image=ref_img,
model_path=model_name,
num_inference_steps=num_inference_steps,
num_guided_steps=int(num_guided_steps_perc * num_inference_steps),
guidance_scale=guidance_scale,
sw_u_lr=log_slider_to_lr(sw_u_lr),
sw_steps=sw_steps,
height=1024,
width=1024,
device="cuda",
num_projections=num_projections,
use_ucv=control_variates == "Upper",
use_lcv=control_variates == "Lower",
distance=distance,
num_new_candidates=candidates_per_pass,
subsampling_factor=subsampling_factor,
sampling_mode=sampling_mode,
pipe=pipe,
compile=True,
seed=seed,
cfg_rescale_phi=cfg_rescale_phi,
)
return np.array(image)
gr.Markdown("## Input")
with gr.Row(equal_height=True):
with gr.Column(variant="panel"):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=2,
)
reference_image = gr.Image(label="Reference Image", height=512)
with gr.Column(variant="panel"):
output_image = gr.Image(label="Generated Image", height=512)
example_pairs = [
("A diver discovering an underwater city", "sunken_boat.jpg"),
("A raccoon in a forest", "waterfall.jpg"),
("A cat detective solving a mystery", "building.jpg"),
("A raccoon reading a book by candlelight.", "food_stand.jpg"),
("A lion meditating on a mountain", "mountain.jpg"),
("A squirrel kayaking", "lake_green.jpg"),
("A young dragon roasting marshmallows", "canyon.jpg"),
("A family picnic beneath floating lanterns", "lake_sunset.jpg"),
("Elephants holding umbrellas in a rainstorm", "fisher.jpg"),
("A boy and his dog exploring a crystal cave", "lake.jpg"),
("A snowman sharing cocoa with woodland animals", "snow.jpg"),
("A kitten exploring an antique library", "ornament.jpg"),
("Children riding flying bicycles over mountains", "sky_pier.jpg"),
("An ancient tree whispering stories to deer", "greenhouse_2.jpg"),
("Owls with monocles in treetops", "path.jpg"),
]
default_config = {
"num_guided_steps_perc": 0.95,
"num_projections": 32,
"control_variates": "None",
"distance": "l1",
"candidates_per_pass": 8,
"subsampling_factor": 1,
"sampling_mode": "gaussian",
} | model_config
def run_example(prompt: str, reference_image: np.ndarray):
return run_sw_guidance(
**default_config,
prompt=prompt,
reference_image=reference_image,
)
example_inputs = [
[prompt, os.path.join("example", "guidance", img_file)]
for prompt, img_file in example_pairs
]
run_button = gr.Button("Generate Image", variant="primary")
gr.Examples(
examples=example_inputs,
inputs=[prompt, reference_image],
outputs=[output_image],
fn=run_example,
label="Prompt + Reference Image Examples",
examples_per_page=5,
cache_examples=True,
cache_mode="lazy",
)
gr.Markdown(
"""
# Configuration
Adjust the parameters for the SW guidance process.
"""
)
with gr.Accordion("Basic Config", open=True):
num_inference_steps_slider = gr.Slider(
1,
100,
value=model_config["num_inference_steps"],
step=1,
label="Number of inference steps",
)
guidance_scale_slider = gr.Slider(
0.0,
20.0,
value=model_config["guidance_scale"],
step=0.1,
label="Guidance scale",
)
cfg_rescale_phi_slider = gr.Slider(
0.0,
1.0,
value=model_config["cfg_rescale_phi"],
step=0.05,
label="CFG Rescale Phi",
info="Controls the rescaling of classifier-free guidance",
)
seed_input = gr.Number(
value=lambda: None,
label="Seed (leave empty for random)",
precision=0,
interactive=True,
)
with gr.Row(variant="panel", equal_height=True):
sw_u_lr_slider = gr.Slider(
minimum=log_lrs.min(),
maximum=log_lrs.max(),
value=model_config["sw_u_lr"],
step=0.05,
label="SW guidance learning rate (Log scale)",
interactive=True,
scale=4,
)
lr_display = gr.Textbox(
label="Learning Rate",
value=f"{log_slider_to_lr(model_config['sw_u_lr']):.1e}",
interactive=False,
scale=1,
)
sw_u_lr_slider.change(
lambda x: gr.update(value=f"{log_slider_to_lr(x):.1e}"),
inputs=sw_u_lr_slider,
outputs=lr_display,
show_progress=False,
)
with gr.Accordion("Advanced Config", open=False):
num_guided_steps_perc_slider = gr.Slider(
0.0,
1.0,
value=default_config["num_guided_steps_perc"],
step=0.05,
label="Percentage of steps to apply SW guidance",
)
sw_steps_slider = gr.Slider(
0,
32,
value=model_config["sw_steps"],
step=1,
label="Number of SW guidance steps, 0 means no SW guidance",
)
num_projections_slider = gr.Slider(
16,
1024,
value=default_config["num_projections"],
step=16,
label="Number of projections",
)
control_variates_dropdown = gr.Dropdown(
choices=["None", "Lower", "Upper"],
value=default_config["control_variates"],
label="Control Variates",
info="Select which control variates to use for optimization",
)
distance_dropdown = gr.Dropdown(
choices=["l1", "l2"],
value=default_config["distance"],
label="Distance metric",
info="Select which distance metric to use",
)
candidates_per_pass_slider = gr.Slider(
0,
64,
value=default_config["candidates_per_pass"],
step=1,
label="Number of new candidates per pass. 0 means no reservoir sampling",
)
subsampling_factor_slider = gr.Slider(
1,
16,
value=default_config["subsampling_factor"],
step=1,
label="Subsampling factor",
)
sampling_mode_dropdown = gr.Dropdown(
choices=["gaussian", "qmc"],
value="qmc",
label="Sampling Mode",
info="Select which sampling mode to use for projections",
)
run_button.click(
run_sw_guidance,
inputs=[
num_inference_steps_slider,
num_guided_steps_perc_slider,
guidance_scale_slider,
sw_u_lr_slider,
sw_steps_slider,
num_projections_slider,
control_variates_dropdown,
distance_dropdown,
candidates_per_pass_slider,
subsampling_factor_slider,
sampling_mode_dropdown,
cfg_rescale_phi_slider,
prompt,
reference_image,
seed_input,
],
outputs=[output_image],
)
clear_button = gr.ClearButton([prompt, reference_image, output_image, seed_input])
clear_button.click(lambda: None)