File size: 5,592 Bytes
815e698 f152a6a 815e698 b3eef45 815e698 cd3ae09 815e698 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import random
import gradio as gr
import spaces
import torch
from diffusers import StableDiffusionXLPipeline
MODEL_ID = "glides/illustriousxl"
ADAPTER_BASE_PATH = "./creative-lora"
ALL_SEGMENTS = ["early", "mid", "late"]
NUM_INFERENCE_STEPS = 30
EARLY_SEG = 10
MID_SEG = 10
_BOUNDARIES = [EARLY_SEG, EARLY_SEG + MID_SEG]
_BLEND_HALF = 2
def _adapter_weights(step_index: int, strength: float) -> list[float]:
for i, boundary in enumerate(_BOUNDARIES):
dist = step_index - boundary
if abs(dist) <= _BLEND_HALF:
t = (dist + _BLEND_HALF) / (2 * _BLEND_HALF)
weights = [0.0, 0.0, 0.0]
weights[i] = (1.0 - t) * strength
weights[i + 1] = t * strength
return weights
weights = [0.0, 0.0, 0.0]
if step_index < _BOUNDARIES[0]:
weights[0] = strength
elif step_index < _BOUNDARIES[1]:
weights[1] = strength
else:
weights[2] = strength
return weights
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID, torch_dtype=torch.float16
).to("cuda")
for segment in ALL_SEGMENTS:
pipe.load_lora_weights(
ADAPTER_BASE_PATH,
weight_name=f"{segment}.safetensors",
adapter_name=segment,
)
@spaces.GPU
def generate(prompt, negative_prompt, guidance, strength, seed, width, height):
seed = int(seed)
last_step = NUM_INFERENCE_STEPS - 1
pipe.enable_lora()
pipe.set_adapters(ALL_SEGMENTS, _adapter_weights(0, strength))
def callback(p, step_index, timestep, callback_kwargs):
if step_index == last_step:
p.set_adapters(ALL_SEGMENTS, [0.0, 0.0, 0.0])
else:
p.set_adapters(ALL_SEGMENTS, _adapter_weights(step_index, strength))
return callback_kwargs
generator = torch.Generator(device="cuda").manual_seed(seed)
try:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=guidance,
generator=generator,
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["latents"],
)
image = result.images[0]
del result
return image
finally:
pipe.disable_lora()
torch.cuda.empty_cache()
with gr.Blocks() as interface:
with gr.Column():
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
info="What do you want?",
value="A woman with long, wavy pink hair is shown in profile",
lines=4,
interactive=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
info="What do you want to exclude from the image?",
value="ugly, low quality",
lines=4,
interactive=True,
)
with gr.Column():
generate_btn = gr.Button("Generate")
output = gr.Image()
with gr.Row():
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
with gr.Column():
guidance = gr.Slider(
label="Guidance Scale",
value=7.0,
minimum=1.0,
maximum=15.0,
step=0.5,
interactive=True,
)
width = gr.Slider(
label="Width",
info="The width in pixels of the generated image.",
value=1024,
minimum=128,
maximum=4096,
step=64,
interactive=True,
)
height = gr.Slider(
label="Height",
info="The height in pixels of the generated image.",
value=1024,
minimum=128,
maximum=4096,
step=64,
interactive=True,
)
with gr.Column():
strength = gr.Slider(
label="LoRA Strength",
info="How strongly the LoRA influences output.",
value=1.0,
minimum=0.0,
maximum=1.5,
step=0.05,
interactive=True,
)
seed = gr.Number(
label="Seed",
info="What initial image is passed to the model.",
value=43,
precision=0,
interactive=True,
)
regen = gr.Button("\u21ba")
regen.click(fn=lambda: random.randint(0, 2**32 - 1), outputs=seed)
generate_btn.click(
fn=generate,
inputs=[prompt, negative_prompt, guidance, strength, seed, width, height],
outputs=[output],
)
if __name__ == "__main__":
interface.queue()
interface.launch()
|