File size: 5,592 Bytes
815e698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f152a6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815e698
 
 
 
 
 
 
 
 
b3eef45
815e698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3ae09
815e698
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import random
import gradio as gr

import spaces

import torch
from diffusers import StableDiffusionXLPipeline


MODEL_ID = "glides/illustriousxl"
ADAPTER_BASE_PATH = "./creative-lora"

ALL_SEGMENTS = ["early", "mid", "late"]
NUM_INFERENCE_STEPS = 30
EARLY_SEG = 10
MID_SEG = 10
_BOUNDARIES = [EARLY_SEG, EARLY_SEG + MID_SEG]
_BLEND_HALF = 2


def _adapter_weights(step_index: int, strength: float) -> list[float]:
    for i, boundary in enumerate(_BOUNDARIES):
        dist = step_index - boundary
        if abs(dist) <= _BLEND_HALF:
            t = (dist + _BLEND_HALF) / (2 * _BLEND_HALF)
            weights = [0.0, 0.0, 0.0]
            weights[i] = (1.0 - t) * strength
            weights[i + 1] = t * strength
            return weights

    weights = [0.0, 0.0, 0.0]
    if step_index < _BOUNDARIES[0]:
        weights[0] = strength
    elif step_index < _BOUNDARIES[1]:
        weights[1] = strength
    else:
        weights[2] = strength
    return weights


pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID, torch_dtype=torch.float16
).to("cuda")

for segment in ALL_SEGMENTS:
    pipe.load_lora_weights(
        ADAPTER_BASE_PATH,
        weight_name=f"{segment}.safetensors",
        adapter_name=segment,
    )


@spaces.GPU
def generate(prompt, negative_prompt, guidance, strength, seed, width, height):
    seed = int(seed)
    last_step = NUM_INFERENCE_STEPS - 1

    pipe.enable_lora()
    pipe.set_adapters(ALL_SEGMENTS, _adapter_weights(0, strength))

    def callback(p, step_index, timestep, callback_kwargs):
        if step_index == last_step:
            p.set_adapters(ALL_SEGMENTS, [0.0, 0.0, 0.0])
        else:
            p.set_adapters(ALL_SEGMENTS, _adapter_weights(step_index, strength))
        return callback_kwargs

    generator = torch.Generator(device="cuda").manual_seed(seed)

    try:
        result = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            num_inference_steps=NUM_INFERENCE_STEPS,
            guidance_scale=guidance,
            generator=generator,
            callback_on_step_end=callback,
            callback_on_step_end_tensor_inputs=["latents"],
        )
        image = result.images[0]
        del result
        return image
    finally:
        pipe.disable_lora()
        torch.cuda.empty_cache()


with gr.Blocks() as interface:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(
                    label="Prompt",
                    info="What do you want?",
                    value="A woman with long, wavy pink hair is shown in profile",
                    lines=4,
                    interactive=True,
                )
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    info="What do you want to exclude from the image?",
                    value="ugly, low quality",
                    lines=4,
                    interactive=True,
                )
            with gr.Column():
                generate_btn = gr.Button("Generate")
                output = gr.Image()
        with gr.Row():
            with gr.Accordion(label="Advanced Settings", open=False):
                with gr.Row():
                    with gr.Column():
                        guidance = gr.Slider(
                            label="Guidance Scale",
                            value=7.0,
                            minimum=1.0,
                            maximum=15.0,
                            step=0.5,
                            interactive=True,
                        )
                        width = gr.Slider(
                            label="Width",
                            info="The width in pixels of the generated image.",
                            value=1024,
                            minimum=128,
                            maximum=4096,
                            step=64,
                            interactive=True,
                        )
                        height = gr.Slider(
                            label="Height",
                            info="The height in pixels of the generated image.",
                            value=1024,
                            minimum=128,
                            maximum=4096,
                            step=64,
                            interactive=True,
                        )
                    with gr.Column():
                        strength = gr.Slider(
                            label="LoRA Strength",
                            info="How strongly the LoRA influences output.",
                            value=1.0,
                            minimum=0.0,
                            maximum=1.5,
                            step=0.05,
                            interactive=True,
                        )
                        seed = gr.Number(
                            label="Seed",
                            info="What initial image is passed to the model.",
                            value=43,
                            precision=0,
                            interactive=True,
                        )

                        regen = gr.Button("\u21ba")

    regen.click(fn=lambda: random.randint(0, 2**32 - 1), outputs=seed)

    generate_btn.click(
        fn=generate,
        inputs=[prompt, negative_prompt, guidance, strength, seed, width, height],
        outputs=[output],
    )

if __name__ == "__main__":
    interface.queue()
    interface.launch()