| import random |
| import gradio as gr |
|
|
| import spaces |
|
|
| import torch |
| from diffusers import StableDiffusionXLPipeline |
|
|
|
|
| MODEL_ID = "glides/illustriousxl" |
| ADAPTER_BASE_PATH = "./creative-lora" |
|
|
| ALL_SEGMENTS = ["early", "mid", "late"] |
| NUM_INFERENCE_STEPS = 30 |
| EARLY_SEG = 10 |
| MID_SEG = 10 |
| _BOUNDARIES = [EARLY_SEG, EARLY_SEG + MID_SEG] |
| _BLEND_HALF = 2 |
|
|
|
|
| def _adapter_weights(step_index: int, strength: float) -> list[float]: |
| for i, boundary in enumerate(_BOUNDARIES): |
| dist = step_index - boundary |
| if abs(dist) <= _BLEND_HALF: |
| t = (dist + _BLEND_HALF) / (2 * _BLEND_HALF) |
| weights = [0.0, 0.0, 0.0] |
| weights[i] = (1.0 - t) * strength |
| weights[i + 1] = t * strength |
| return weights |
|
|
| weights = [0.0, 0.0, 0.0] |
| if step_index < _BOUNDARIES[0]: |
| weights[0] = strength |
| elif step_index < _BOUNDARIES[1]: |
| weights[1] = strength |
| else: |
| weights[2] = strength |
| return weights |
|
|
|
|
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| MODEL_ID, torch_dtype=torch.float16 |
| ).to("cuda") |
|
|
| for segment in ALL_SEGMENTS: |
| pipe.load_lora_weights( |
| ADAPTER_BASE_PATH, |
| weight_name=f"{segment}.safetensors", |
| adapter_name=segment, |
| ) |
|
|
|
|
| @spaces.GPU |
| def generate(prompt, negative_prompt, guidance, strength, seed, width, height): |
| seed = int(seed) |
| last_step = NUM_INFERENCE_STEPS - 1 |
|
|
| pipe.enable_lora() |
| pipe.set_adapters(ALL_SEGMENTS, _adapter_weights(0, strength)) |
|
|
| def callback(p, step_index, timestep, callback_kwargs): |
| if step_index == last_step: |
| p.set_adapters(ALL_SEGMENTS, [0.0, 0.0, 0.0]) |
| else: |
| p.set_adapters(ALL_SEGMENTS, _adapter_weights(step_index, strength)) |
| return callback_kwargs |
|
|
| generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
| try: |
| result = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| width=width, |
| height=height, |
| num_inference_steps=NUM_INFERENCE_STEPS, |
| guidance_scale=guidance, |
| generator=generator, |
| callback_on_step_end=callback, |
| callback_on_step_end_tensor_inputs=["latents"], |
| ) |
| image = result.images[0] |
| del result |
| return image |
| finally: |
| pipe.disable_lora() |
| torch.cuda.empty_cache() |
|
|
|
|
| with gr.Blocks() as interface: |
| with gr.Column(): |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox( |
| label="Prompt", |
| info="What do you want?", |
| value="A woman with long, wavy pink hair is shown in profile", |
| lines=4, |
| interactive=True, |
| ) |
| negative_prompt = gr.Textbox( |
| label="Negative Prompt", |
| info="What do you want to exclude from the image?", |
| value="ugly, low quality", |
| lines=4, |
| interactive=True, |
| ) |
| with gr.Column(): |
| generate_btn = gr.Button("Generate") |
| output = gr.Image() |
| with gr.Row(): |
| with gr.Accordion(label="Advanced Settings", open=False): |
| with gr.Row(): |
| with gr.Column(): |
| guidance = gr.Slider( |
| label="Guidance Scale", |
| value=7.0, |
| minimum=1.0, |
| maximum=15.0, |
| step=0.5, |
| interactive=True, |
| ) |
| width = gr.Slider( |
| label="Width", |
| info="The width in pixels of the generated image.", |
| value=1024, |
| minimum=128, |
| maximum=4096, |
| step=64, |
| interactive=True, |
| ) |
| height = gr.Slider( |
| label="Height", |
| info="The height in pixels of the generated image.", |
| value=1024, |
| minimum=128, |
| maximum=4096, |
| step=64, |
| interactive=True, |
| ) |
| with gr.Column(): |
| strength = gr.Slider( |
| label="LoRA Strength", |
| info="How strongly the LoRA influences output.", |
| value=1.0, |
| minimum=0.0, |
| maximum=1.5, |
| step=0.05, |
| interactive=True, |
| ) |
| seed = gr.Number( |
| label="Seed", |
| info="What initial image is passed to the model.", |
| value=43, |
| precision=0, |
| interactive=True, |
| ) |
|
|
| regen = gr.Button("\u21ba") |
|
|
| regen.click(fn=lambda: random.randint(0, 2**32 - 1), outputs=seed) |
|
|
| generate_btn.click( |
| fn=generate, |
| inputs=[prompt, negative_prompt, guidance, strength, seed, width, height], |
| outputs=[output], |
| ) |
|
|
| if __name__ == "__main__": |
| interface.queue() |
| interface.launch() |
|
|