nroggendorff's picture
Update app.py
cd3ae09 verified
import random
import gradio as gr
import spaces
import torch
from diffusers import StableDiffusionXLPipeline
MODEL_ID = "glides/illustriousxl"
ADAPTER_BASE_PATH = "./creative-lora"
ALL_SEGMENTS = ["early", "mid", "late"]
NUM_INFERENCE_STEPS = 30
EARLY_SEG = 10
MID_SEG = 10
_BOUNDARIES = [EARLY_SEG, EARLY_SEG + MID_SEG]
_BLEND_HALF = 2
def _adapter_weights(step_index: int, strength: float) -> list[float]:
for i, boundary in enumerate(_BOUNDARIES):
dist = step_index - boundary
if abs(dist) <= _BLEND_HALF:
t = (dist + _BLEND_HALF) / (2 * _BLEND_HALF)
weights = [0.0, 0.0, 0.0]
weights[i] = (1.0 - t) * strength
weights[i + 1] = t * strength
return weights
weights = [0.0, 0.0, 0.0]
if step_index < _BOUNDARIES[0]:
weights[0] = strength
elif step_index < _BOUNDARIES[1]:
weights[1] = strength
else:
weights[2] = strength
return weights
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID, torch_dtype=torch.float16
).to("cuda")
for segment in ALL_SEGMENTS:
pipe.load_lora_weights(
ADAPTER_BASE_PATH,
weight_name=f"{segment}.safetensors",
adapter_name=segment,
)
@spaces.GPU
def generate(prompt, negative_prompt, guidance, strength, seed, width, height):
seed = int(seed)
last_step = NUM_INFERENCE_STEPS - 1
pipe.enable_lora()
pipe.set_adapters(ALL_SEGMENTS, _adapter_weights(0, strength))
def callback(p, step_index, timestep, callback_kwargs):
if step_index == last_step:
p.set_adapters(ALL_SEGMENTS, [0.0, 0.0, 0.0])
else:
p.set_adapters(ALL_SEGMENTS, _adapter_weights(step_index, strength))
return callback_kwargs
generator = torch.Generator(device="cuda").manual_seed(seed)
try:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=guidance,
generator=generator,
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["latents"],
)
image = result.images[0]
del result
return image
finally:
pipe.disable_lora()
torch.cuda.empty_cache()
with gr.Blocks() as interface:
with gr.Column():
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
info="What do you want?",
value="A woman with long, wavy pink hair is shown in profile",
lines=4,
interactive=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
info="What do you want to exclude from the image?",
value="ugly, low quality",
lines=4,
interactive=True,
)
with gr.Column():
generate_btn = gr.Button("Generate")
output = gr.Image()
with gr.Row():
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
with gr.Column():
guidance = gr.Slider(
label="Guidance Scale",
value=7.0,
minimum=1.0,
maximum=15.0,
step=0.5,
interactive=True,
)
width = gr.Slider(
label="Width",
info="The width in pixels of the generated image.",
value=1024,
minimum=128,
maximum=4096,
step=64,
interactive=True,
)
height = gr.Slider(
label="Height",
info="The height in pixels of the generated image.",
value=1024,
minimum=128,
maximum=4096,
step=64,
interactive=True,
)
with gr.Column():
strength = gr.Slider(
label="LoRA Strength",
info="How strongly the LoRA influences output.",
value=1.0,
minimum=0.0,
maximum=1.5,
step=0.05,
interactive=True,
)
seed = gr.Number(
label="Seed",
info="What initial image is passed to the model.",
value=43,
precision=0,
interactive=True,
)
regen = gr.Button("\u21ba")
regen.click(fn=lambda: random.randint(0, 2**32 - 1), outputs=seed)
generate_btn.click(
fn=generate,
inputs=[prompt, negative_prompt, guidance, strength, seed, width, height],
outputs=[output],
)
if __name__ == "__main__":
interface.queue()
interface.launch()