import random import gradio as gr import spaces import torch from diffusers import StableDiffusionXLPipeline MODEL_ID = "glides/illustriousxl" ADAPTER_BASE_PATH = "./creative-lora" ALL_SEGMENTS = ["early", "mid", "late"] NUM_INFERENCE_STEPS = 30 EARLY_SEG = 10 MID_SEG = 10 _BOUNDARIES = [EARLY_SEG, EARLY_SEG + MID_SEG] _BLEND_HALF = 2 def _adapter_weights(step_index: int, strength: float) -> list[float]: for i, boundary in enumerate(_BOUNDARIES): dist = step_index - boundary if abs(dist) <= _BLEND_HALF: t = (dist + _BLEND_HALF) / (2 * _BLEND_HALF) weights = [0.0, 0.0, 0.0] weights[i] = (1.0 - t) * strength weights[i + 1] = t * strength return weights weights = [0.0, 0.0, 0.0] if step_index < _BOUNDARIES[0]: weights[0] = strength elif step_index < _BOUNDARIES[1]: weights[1] = strength else: weights[2] = strength return weights pipe = StableDiffusionXLPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16 ).to("cuda") for segment in ALL_SEGMENTS: pipe.load_lora_weights( ADAPTER_BASE_PATH, weight_name=f"{segment}.safetensors", adapter_name=segment, ) @spaces.GPU def generate(prompt, negative_prompt, guidance, strength, seed, width, height): seed = int(seed) last_step = NUM_INFERENCE_STEPS - 1 pipe.enable_lora() pipe.set_adapters(ALL_SEGMENTS, _adapter_weights(0, strength)) def callback(p, step_index, timestep, callback_kwargs): if step_index == last_step: p.set_adapters(ALL_SEGMENTS, [0.0, 0.0, 0.0]) else: p.set_adapters(ALL_SEGMENTS, _adapter_weights(step_index, strength)) return callback_kwargs generator = torch.Generator(device="cuda").manual_seed(seed) try: result = pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=NUM_INFERENCE_STEPS, guidance_scale=guidance, generator=generator, callback_on_step_end=callback, callback_on_step_end_tensor_inputs=["latents"], ) image = result.images[0] del result return image finally: pipe.disable_lora() torch.cuda.empty_cache() with gr.Blocks() as interface: with gr.Column(): with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", info="What do you want?", value="A woman with long, wavy pink hair is shown in profile", lines=4, interactive=True, ) negative_prompt = gr.Textbox( label="Negative Prompt", info="What do you want to exclude from the image?", value="ugly, low quality", lines=4, interactive=True, ) with gr.Column(): generate_btn = gr.Button("Generate") output = gr.Image() with gr.Row(): with gr.Accordion(label="Advanced Settings", open=False): with gr.Row(): with gr.Column(): guidance = gr.Slider( label="Guidance Scale", value=7.0, minimum=1.0, maximum=15.0, step=0.5, interactive=True, ) width = gr.Slider( label="Width", info="The width in pixels of the generated image.", value=1024, minimum=128, maximum=4096, step=64, interactive=True, ) height = gr.Slider( label="Height", info="The height in pixels of the generated image.", value=1024, minimum=128, maximum=4096, step=64, interactive=True, ) with gr.Column(): strength = gr.Slider( label="LoRA Strength", info="How strongly the LoRA influences output.", value=1.0, minimum=0.0, maximum=1.5, step=0.05, interactive=True, ) seed = gr.Number( label="Seed", info="What initial image is passed to the model.", value=43, precision=0, interactive=True, ) regen = gr.Button("\u21ba") regen.click(fn=lambda: random.randint(0, 2**32 - 1), outputs=seed) generate_btn.click( fn=generate, inputs=[prompt, negative_prompt, guidance, strength, seed, width, height], outputs=[output], ) if __name__ == "__main__": interface.queue() interface.launch()