pva22
commited on
Commit
·
2de36a8
1
Parent(s):
7ad97f5
hw6
Browse files- app.py +139 -177
- methods.py +212 -0
app.py
CHANGED
|
@@ -1,110 +1,13 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
import spaces #[uncomment to use ZeroGPU]
|
| 6 |
-
from diffusers import (
|
| 7 |
-
DiffusionPipeline,
|
| 8 |
-
StableDiffusionPipeline,
|
| 9 |
-
StableDiffusionControlNetPipeline,
|
| 10 |
-
StableDiffusionControlNetImg2ImgPipeline,
|
| 11 |
-
DPMSolverMultistepScheduler,
|
| 12 |
-
PNDMScheduler,
|
| 13 |
-
ControlNetModel
|
| 14 |
-
)
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
from peft import PeftModel, LoraConfig
|
| 18 |
-
import os
|
| 19 |
-
|
| 20 |
-
def get_lora_sd_pipeline(
|
| 21 |
-
ckpt_dir='./content/lora',
|
| 22 |
-
base_model_name_or_path=None,
|
| 23 |
-
dtype=torch.float16,
|
| 24 |
-
device="cuda",
|
| 25 |
-
adapter_name="default"
|
| 26 |
-
):
|
| 27 |
-
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
| 28 |
-
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
| 29 |
-
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
| 30 |
-
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
| 31 |
-
base_model_name_or_path = config.base_model_name_or_path
|
| 32 |
-
|
| 33 |
-
if base_model_name_or_path is None:
|
| 34 |
-
raise ValueError("Please specify the base model name or path")
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
|
| 38 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
| 39 |
-
|
| 40 |
-
if os.path.exists(text_encoder_sub_dir):
|
| 41 |
-
pipe.text_encoder = PeftModel.from_pretrained(
|
| 42 |
-
pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
if dtype in (torch.float16, torch.bfloat16):
|
| 46 |
-
pipe.unet.half()
|
| 47 |
-
pipe.text_encoder.half()
|
| 48 |
-
|
| 49 |
-
return pipe
|
| 50 |
-
|
| 51 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
|
| 53 |
model_id_default = "sd-legacy/stable-diffusion-v1-5"
|
| 54 |
model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5']
|
| 55 |
|
| 56 |
-
model_lora_default = "lora"
|
| 57 |
-
|
| 58 |
-
if torch.cuda.is_available():
|
| 59 |
-
torch_dtype = torch.float16
|
| 60 |
-
else:
|
| 61 |
-
torch_dtype = torch.float32
|
| 62 |
-
|
| 63 |
MAX_SEED = np.iinfo(np.int32).max
|
| 64 |
MAX_IMAGE_SIZE = 1024
|
| 65 |
|
| 66 |
-
@spaces.GPU #[uncomment to use ZeroGPU]
|
| 67 |
-
def infer(
|
| 68 |
-
prompt,
|
| 69 |
-
negative_prompt,
|
| 70 |
-
randomize_seed,
|
| 71 |
-
width=512,
|
| 72 |
-
height=512,
|
| 73 |
-
model_repo_id=model_id_default,
|
| 74 |
-
seed=22,
|
| 75 |
-
guidance_scale=7,
|
| 76 |
-
num_inference_steps=50,
|
| 77 |
-
model_lora_id=model_lora_default,
|
| 78 |
-
progress=gr.Progress(track_tqdm=True),
|
| 79 |
-
):
|
| 80 |
-
|
| 81 |
-
if randomize_seed:
|
| 82 |
-
seed = random.randint(0, MAX_SEED)
|
| 83 |
-
|
| 84 |
-
generator = torch.Generator().manual_seed(seed)
|
| 85 |
-
|
| 86 |
-
# добавляем обновление pipe по условию
|
| 87 |
-
if model_repo_id != model_id_default:
|
| 88 |
-
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
|
| 89 |
-
pipe.safety_checker = None
|
| 90 |
-
else:
|
| 91 |
-
# добавляем lora
|
| 92 |
-
pipe = get_lora_sd_pipeline(ckpt_dir='./' + model_lora_id, base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
|
| 93 |
-
pipe.safety_checker = None
|
| 94 |
-
print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
|
| 95 |
-
|
| 96 |
-
params = {
|
| 97 |
-
'prompt': prompt,
|
| 98 |
-
'negative_prompt': negative_prompt,
|
| 99 |
-
'guidance_scale': guidance_scale,
|
| 100 |
-
'num_inference_steps': num_inference_steps,
|
| 101 |
-
'width': width,
|
| 102 |
-
'height': height,
|
| 103 |
-
'generator': generator,
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
return pipe(**params).images[0], seed
|
| 107 |
-
|
| 108 |
|
| 109 |
examples = [
|
| 110 |
"Cartoon sticker of sad Elon Musk",
|
|
@@ -115,91 +18,150 @@ examples = [
|
|
| 115 |
"A parody cartoon sticker of Elon Musk arm-wrestling a robotic version of himself. The robot Musk has glowing red eyes and mechanical arms, while the real Musk smirks confidently. Sparks fly from the table as the intense match unfolds, and the background features a neon sign that reads 'Tesla vs. AI: Ultimate Showdown'."
|
| 116 |
]
|
| 117 |
|
| 118 |
-
css = """
|
| 119 |
-
#col-container {
|
| 120 |
-
margin: 0 auto;
|
| 121 |
-
max-width: 640px;
|
| 122 |
-
}
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
with gr.Blocks(css=css) as demo:
|
| 126 |
-
with gr.Column(elem_id="col-container"):
|
| 127 |
-
gr.Markdown("# Generate LoRa stickers")
|
| 128 |
-
|
| 129 |
-
with gr.Row():
|
| 130 |
-
prompt = gr.Text(
|
| 131 |
-
label="Prompt",
|
| 132 |
-
show_label=False,
|
| 133 |
-
max_lines=1,
|
| 134 |
-
placeholder="Enter your prompt",
|
| 135 |
-
container=False,
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 139 |
-
|
| 140 |
-
result = gr.Image(label="Result", show_label=False)
|
| 141 |
-
|
| 142 |
-
with gr.Accordion("Advanced Settings", open=False):
|
| 143 |
-
|
| 144 |
-
model_repo_id = gr.Dropdown(
|
| 145 |
-
label="Model Id",
|
| 146 |
-
choices=model_dropdown,
|
| 147 |
-
info="Choose model",
|
| 148 |
-
visible=True,
|
| 149 |
-
allow_custom_value=True,
|
| 150 |
-
value=model_id_default,
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
negative_prompt = gr.Text(
|
| 154 |
-
label="Negative prompt",
|
| 155 |
-
max_lines=1,
|
| 156 |
-
placeholder="Enter a negative prompt",
|
| 157 |
-
visible=True,
|
| 158 |
-
value="monochrome, lowres, bad anatomy, worst quality, low quality, medical mask"
|
| 159 |
-
)
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
value=512, # Replace with defaults that work for your model
|
| 178 |
)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
label="
|
| 182 |
-
minimum=
|
| 183 |
-
maximum=
|
| 184 |
-
step=
|
| 185 |
-
value=
|
| 186 |
)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
if __name__ == "__main__":
|
| 205 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
+
from methods import infer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
model_id_default = "sd-legacy/stable-diffusion-v1-5"
|
| 6 |
model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5']
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
MAX_SEED = np.iinfo(np.int32).max
|
| 9 |
MAX_IMAGE_SIZE = 1024
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
examples = [
|
| 13 |
"Cartoon sticker of sad Elon Musk",
|
|
|
|
| 18 |
"A parody cartoon sticker of Elon Musk arm-wrestling a robotic version of himself. The robot Musk has glowing red eyes and mechanical arms, while the real Musk smirks confidently. Sparks fly from the table as the intense match unfolds, and the background features a neon sign that reads 'Tesla vs. AI: Ultimate Showdown'."
|
| 19 |
]
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
def on_checkbox_change(use_advanced):
|
| 23 |
+
visible = use_advanced
|
| 24 |
+
return (gr.update(visible=visible, interactive=visible),
|
| 25 |
+
gr.update(visible=visible, interactive=visible),
|
| 26 |
+
gr.update(visible=visible, interactive=visible))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
with gr.Blocks() as demo:
|
| 30 |
+
with gr.Row():
|
| 31 |
+
|
| 32 |
+
with gr.Column():
|
| 33 |
+
|
| 34 |
+
gr.Markdown("## ControlNet")
|
| 35 |
+
use_advanced_controlnet = gr.Checkbox(label="ControlNet Settings")
|
| 36 |
+
control_strength = gr.Slider(
|
| 37 |
+
label="control_strength",
|
| 38 |
+
minimum=0,
|
| 39 |
+
maximum=1,
|
| 40 |
+
step=0.01,
|
| 41 |
+
value=0.8,
|
| 42 |
+
visible=False)
|
| 43 |
+
mode = gr.Dropdown(["edge_detection", "pose_estimation"], label="Выбор режима", visible=False)
|
| 44 |
+
image_upload_cn = gr.Image(label="Загрузите изображение", visible=False)
|
| 45 |
+
use_advanced_controlnet.change(on_checkbox_change, use_advanced_controlnet, [control_strength, mode, image_upload_cn])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
gr.Markdown("## IPAdapter")
|
| 49 |
+
use_advanced_ip = gr.Checkbox(label="ControlNet Settings")
|
| 50 |
+
ip_adapter_scale = gr.Slider(
|
| 51 |
+
label="ip_adapter_scale",
|
| 52 |
+
minimum=0,
|
| 53 |
+
maximum=1,
|
| 54 |
+
step=0.01,
|
| 55 |
+
value=0.8,
|
| 56 |
+
visible=False)
|
| 57 |
+
image_upload_ip = gr.Image(label="Загрузите изображение", visible=False)
|
| 58 |
+
use_advanced_ip.change(on_checkbox_change, use_advanced_ip, [ip_adapter_scale, image_upload_ip])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
with gr.Column():
|
| 62 |
+
gr.Markdown("## Generate")
|
| 63 |
+
|
| 64 |
+
with gr.Row():
|
| 65 |
+
prompt = gr.Text(
|
| 66 |
+
label="Prompt",
|
| 67 |
+
show_label=False,
|
| 68 |
+
max_lines=1,
|
| 69 |
+
placeholder="Enter your prompt",
|
| 70 |
+
container=False,
|
| 71 |
+
)
|
| 72 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 73 |
+
|
| 74 |
+
result = gr.Image(label="Result", show_label=False)
|
| 75 |
+
|
| 76 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 77 |
+
|
| 78 |
+
model_repo_id = gr.Dropdown(
|
| 79 |
+
label="Model Id",
|
| 80 |
+
choices=model_dropdown,
|
| 81 |
+
info="Choose model",
|
| 82 |
+
visible=True,
|
| 83 |
+
allow_custom_value=True,
|
| 84 |
+
value=model_id_default,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
negative_prompt = gr.Text(
|
| 88 |
+
label="Negative prompt",
|
| 89 |
+
max_lines=1,
|
| 90 |
+
placeholder="Enter a negative prompt",
|
| 91 |
+
visible=True,
|
| 92 |
+
value="monochrome, lowres, bad anatomy, worst quality, low quality, medical mask"
|
| 93 |
+
)
|
| 94 |
|
| 95 |
+
seed = gr.Slider(
|
| 96 |
+
label="Seed",
|
| 97 |
+
minimum=0,
|
| 98 |
+
maximum=MAX_SEED,
|
| 99 |
+
step=1,
|
| 100 |
+
value=22,
|
| 101 |
+
)
|
| 102 |
|
| 103 |
+
guidance_scale = gr.Slider(
|
| 104 |
+
label="guidance_scale",
|
| 105 |
+
minimum=0,
|
| 106 |
+
maximum=100,
|
| 107 |
+
step=1,
|
| 108 |
+
value=7,
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
+
num_inference_steps = gr.Slider(
|
| 112 |
+
label="num_inference_steps",
|
| 113 |
+
minimum=0,
|
| 114 |
+
maximum=100,
|
| 115 |
+
step=1,
|
| 116 |
+
value=50,
|
| 117 |
)
|
| 118 |
|
| 119 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 120 |
+
|
| 121 |
+
with gr.Row():
|
| 122 |
+
width = gr.Slider(
|
| 123 |
+
label="Width",
|
| 124 |
+
minimum=256,
|
| 125 |
+
maximum=MAX_IMAGE_SIZE,
|
| 126 |
+
step=32,
|
| 127 |
+
value=512, # Replace with defaults that work for your model
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
height = gr.Slider(
|
| 131 |
+
label="Height",
|
| 132 |
+
minimum=256,
|
| 133 |
+
maximum=MAX_IMAGE_SIZE,
|
| 134 |
+
step=32,
|
| 135 |
+
value=512, # Replace with defaults that work for your model
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
with gr.Accordion("Prompt examples", open=False):
|
| 139 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
| 140 |
+
|
| 141 |
+
gr.on(
|
| 142 |
+
triggers=[run_button.click, prompt.submit],
|
| 143 |
+
fn=infer,
|
| 144 |
+
inputs=[
|
| 145 |
+
prompt,
|
| 146 |
+
negative_prompt,
|
| 147 |
+
randomize_seed,
|
| 148 |
+
width,
|
| 149 |
+
height,
|
| 150 |
+
model_repo_id,
|
| 151 |
+
seed,
|
| 152 |
+
guidance_scale,
|
| 153 |
+
num_inference_steps,
|
| 154 |
+
|
| 155 |
+
use_advanced_controlnet,
|
| 156 |
+
control_strength,
|
| 157 |
+
image_upload_cn,
|
| 158 |
+
|
| 159 |
+
use_advanced_ip,
|
| 160 |
+
ip_adapter_scale,
|
| 161 |
+
image_upload_ip
|
| 162 |
+
],
|
| 163 |
+
outputs=[result, seed],
|
| 164 |
+
)
|
| 165 |
|
| 166 |
if __name__ == "__main__":
|
| 167 |
+
demo.launch(share=False, debug=True)
|
methods.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PIL
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2 as cv
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
import spaces
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
from diffusers import DiffusionPipeline
|
| 13 |
+
from peft import PeftModel, LoraConfig
|
| 14 |
+
|
| 15 |
+
from diffusers import (
|
| 16 |
+
StableDiffusionPipeline,
|
| 17 |
+
StableDiffusionControlNetPipeline,
|
| 18 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
| 19 |
+
DPMSolverMultistepScheduler,
|
| 20 |
+
PNDMScheduler,
|
| 21 |
+
ControlNetModel
|
| 22 |
+
)
|
| 23 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps
|
| 25 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
from diffusers.utils import load_image, make_image_grid
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 31 |
+
MAX_IMAGE_SIZE = 1024
|
| 32 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
|
| 34 |
+
model_id_default = "sd-legacy/stable-diffusion-v1-5"
|
| 35 |
+
model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5']
|
| 36 |
+
model_lora_default = "lora"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_lora_sd_pipeline(
|
| 40 |
+
ckpt_dir='./' + model_lora_default,
|
| 41 |
+
base_model_name_or_path=None,
|
| 42 |
+
dtype=torch.float16,
|
| 43 |
+
device=DEVICE,
|
| 44 |
+
adapter_name="default",
|
| 45 |
+
controlnet=None,
|
| 46 |
+
ip_adapter=None
|
| 47 |
+
):
|
| 48 |
+
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
| 49 |
+
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
| 50 |
+
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
| 51 |
+
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
| 52 |
+
base_model_name_or_path = config.base_model_name_or_path
|
| 53 |
+
|
| 54 |
+
if base_model_name_or_path is None:
|
| 55 |
+
raise ValueError("Please specify the base model name or path")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if controlnet and ip_adapter:
|
| 59 |
+
print('Pipe with ControlNet and IpAdapter')
|
| 60 |
+
|
| 61 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 62 |
+
"lllyasviel/sd-controlnet-canny",
|
| 63 |
+
cache_dir="./models_cache",
|
| 64 |
+
torch_dtype=torch.float16
|
| 65 |
+
)
|
| 66 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 67 |
+
base_model_name_or_path,
|
| 68 |
+
torch_dtype=dtype,
|
| 69 |
+
controlnet=controlnet).to(device)
|
| 70 |
+
|
| 71 |
+
pipe.load_ip_adapter(
|
| 72 |
+
"h94/IP-Adapter",
|
| 73 |
+
subfolder="models",
|
| 74 |
+
weight_name="ip-adapter-plus_sd15.bin",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
elif controlnet:
|
| 79 |
+
print('Pipe with ControlNet')
|
| 80 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 81 |
+
"lllyasviel/sd-controlnet-canny",
|
| 82 |
+
cache_dir="./models_cache",
|
| 83 |
+
torch_dtype=torch.float16)
|
| 84 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype, controlnet=controlnet)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
elif ip_adapter:
|
| 88 |
+
print('Pipe with IpAdapter')
|
| 89 |
+
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
|
| 90 |
+
pipe.load_ip_adapter(
|
| 91 |
+
"h94/IP-Adapter",
|
| 92 |
+
subfolder="models",
|
| 93 |
+
weight_name="ip-adapter-plus_sd15.bin")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
print('Pipe with only SD')
|
| 98 |
+
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
| 102 |
+
if os.path.exists(text_encoder_sub_dir):
|
| 103 |
+
pipe.text_encoder = PeftModel.from_pretrained(
|
| 104 |
+
pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 108 |
+
pipe.unet.half()
|
| 109 |
+
pipe.text_encoder.half()
|
| 110 |
+
|
| 111 |
+
pipe.safety_checker = None
|
| 112 |
+
pipe.to(device)
|
| 113 |
+
return pipe
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@spaces.GPU
|
| 117 |
+
def infer(
|
| 118 |
+
prompt,
|
| 119 |
+
negative_prompt,
|
| 120 |
+
randomize_seed,
|
| 121 |
+
width=512,
|
| 122 |
+
height=512,
|
| 123 |
+
model_repo_id=model_id_default, # в get_lora_sd_pipeline - base_model_name_or_path
|
| 124 |
+
seed=22,
|
| 125 |
+
guidance_scale=7,
|
| 126 |
+
num_inference_steps=50,
|
| 127 |
+
|
| 128 |
+
use_advanced_controlnet=None,
|
| 129 |
+
control_strength=None,
|
| 130 |
+
image_upload_cn=None,
|
| 131 |
+
|
| 132 |
+
use_advanced_ip=None,
|
| 133 |
+
ip_adapter_scale=None,
|
| 134 |
+
image_upload_ip=None,
|
| 135 |
+
|
| 136 |
+
model_lora_id=model_lora_default,
|
| 137 |
+
progress=gr.Progress(track_tqdm=True),
|
| 138 |
+
dtype=torch.float16,
|
| 139 |
+
device=DEVICE,
|
| 140 |
+
):
|
| 141 |
+
|
| 142 |
+
if randomize_seed:
|
| 143 |
+
seed = random.randint(0, MAX_SEED)
|
| 144 |
+
generator = torch.Generator().manual_seed(seed)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
#1. SD 1.5 + Lora
|
| 148 |
+
if use_advanced_controlnet == None and use_advanced_ip == None:
|
| 149 |
+
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id,
|
| 150 |
+
dtype=dtype).to(device)
|
| 151 |
+
|
| 152 |
+
image = pipe(prompt,
|
| 153 |
+
num_inference_steps=num_inference_steps,
|
| 154 |
+
guidance_scale=guidance_scale,
|
| 155 |
+
negative_prompt=negative_prompt,
|
| 156 |
+
width=width,
|
| 157 |
+
heigth=height,
|
| 158 |
+
generator=generator).images[0]
|
| 159 |
+
|
| 160 |
+
elif use_advanced_controlnet != None and use_advanced_ip == None:
|
| 161 |
+
#2. SD 1.5 + Lora + Controlnet
|
| 162 |
+
edges = cv.Canny(np.array(image_upload_cn))
|
| 163 |
+
edges = np.repeat(edges[:, :, None], 3, axis=2)
|
| 164 |
+
edges = Image.fromarray(edges)
|
| 165 |
+
|
| 166 |
+
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id,
|
| 167 |
+
controlnet=True,
|
| 168 |
+
dtype=dtype).to(device)
|
| 169 |
+
|
| 170 |
+
image = pipe(prompt,
|
| 171 |
+
edges,
|
| 172 |
+
num_inference_steps = num_inference_steps,
|
| 173 |
+
controlnet_conditioning_scale=control_strength,
|
| 174 |
+
negative_prompt=negative_prompt,
|
| 175 |
+
generator=generator).images[0]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
elif use_advanced_ip != None and use_advanced_controlnet == None:
|
| 179 |
+
#3. SD 1.5 + Lora + IpAdapter
|
| 180 |
+
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id,
|
| 181 |
+
ip_adapter=True,
|
| 182 |
+
dtype=dtype).to(device)
|
| 183 |
+
pipe.set_ip_adapter_scale(ip_adapter_scale)
|
| 184 |
+
|
| 185 |
+
image = pipe(
|
| 186 |
+
prompt,
|
| 187 |
+
edges,
|
| 188 |
+
ip_adapter_image=image_upload_ip,
|
| 189 |
+
num_inference_steps=num_inference_steps,
|
| 190 |
+
guidance_scale=guidance_scale,
|
| 191 |
+
generator=generator).images[0]
|
| 192 |
+
|
| 193 |
+
elif use_advanced_ip != None and use_advanced_controlnet != None:
|
| 194 |
+
#4. SD 1.5 + Lora + IpAdapter + ControlNet
|
| 195 |
+
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_repo_id,
|
| 196 |
+
ip_adapter=True,
|
| 197 |
+
controlnet=True,
|
| 198 |
+
dtype=dtype).to(device)
|
| 199 |
+
|
| 200 |
+
pipe.set_ip_adapter_scale(ip_adapter_scale)
|
| 201 |
+
image = pipe(prompt,
|
| 202 |
+
edges,
|
| 203 |
+
ip_adapter_image=image_upload_ip,
|
| 204 |
+
num_inference_steps=num_inference_steps,
|
| 205 |
+
guidance_scale=guidance_scale,
|
| 206 |
+
controlnet_conditioning_scale=control_strength,
|
| 207 |
+
height=height,
|
| 208 |
+
width=width,
|
| 209 |
+
generator=generator,
|
| 210 |
+
).images[0]
|
| 211 |
+
|
| 212 |
+
return image, seed
|