pva22 commited on
Commit
39aef48
·
1 Parent(s): e977ca2

fix lora and sd

Browse files
Files changed (1) hide show
  1. app.py +166 -61
app.py CHANGED
@@ -2,99 +2,204 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- import spaces # [uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
7
  import torch
8
 
9
  from peft import PeftModel, LoraConfig
10
  import os
11
 
12
  def get_lora_sd_pipeline(
13
- ckpt_dir='./lora',
14
- base_model_name_or_path=None,
15
- dtype=torch.float16,
 
16
  adapter_name="default"
17
- ):
18
-
19
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
20
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
21
-
22
  if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
23
  config = LoraConfig.from_pretrained(text_encoder_sub_dir)
24
  base_model_name_or_path = config.base_model_name_or_path
25
-
26
  if base_model_name_or_path is None:
27
  raise ValueError("Please specify the base model name or path")
28
-
29
- pipe = DiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
 
30
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
31
- pipe.unet.set_adapter(adapter_name)
32
-
33
  if os.path.exists(text_encoder_sub_dir):
34
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
35
-
 
 
36
  if dtype in (torch.float16, torch.bfloat16):
37
  pipe.unet.half()
38
  pipe.text_encoder.half()
39
-
40
  return pipe
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def infer(
43
  prompt,
44
  negative_prompt,
45
  randomize_seed,
46
  width=512,
47
  height=512,
48
- model_repo_id="sd-legacy/stable-diffusion-v1-5",
49
- seed=42,
50
  guidance_scale=7,
51
- num_inference_steps=20,
52
- model_lora_id="lora",
53
- lora_scale=0.5,
54
- use_controlnet=False,
55
- controlnet_image=None,
56
- control_strength=0.5,
57
- control_mode="edge_detection",
58
- use_ip_adapter=False,
59
- ip_adapter_image=None,
60
- ip_adapter_scale=0.5
61
  ):
62
-
63
  if randomize_seed:
64
- seed = random.randint(0, 1000)
65
 
66
  generator = torch.Generator().manual_seed(seed)
67
- pipe = get_lora_sd_pipeline(ckpt_dir=f'./{model_lora_id}', base_model_name_or_path=model_repo_id).to("cuda")
68
-
69
- if use_controlnet and controlnet_image is not None:
70
- pipe.enable_controlnet(control_mode, controlnet_image, control_strength)
71
-
72
- if use_ip_adapter and ip_adapter_image is not None:
73
- pipe.enable_ip_adapter(ip_adapter_image, ip_adapter_scale)
74
 
75
- return pipe(prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator).images[0], seed
 
 
 
 
 
 
 
 
76
 
77
- with gr.Blocks() as demo:
78
- gr.Markdown("# Generate LoRa stickers with ControlNet & IP-Adapter")
79
-
80
- prompt = gr.Text(label="Prompt", placeholder="Enter your prompt")
81
- negative_prompt = gr.Text(label="Negative Prompt", placeholder="Enter a negative prompt")
82
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
83
- width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=512)
84
- height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=512)
85
-
86
- use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
87
- controlnet_image = gr.Image(label="ControlNet Image")
88
- control_strength = gr.Slider(label="ControlNet Strength", minimum=0, maximum=1, step=0.1, value=0.5)
89
- control_mode = gr.Dropdown(label="ControlNet Mode", choices=["edge_detection", "pose_estimation"], value="edge_detection")
90
-
91
- use_ip_adapter = gr.Checkbox(label="Use IP-Adapter", value=False)
92
- ip_adapter_image = gr.Image(label="IP-Adapter Image")
93
- ip_adapter_scale = gr.Slider(label="IP-Adapter Scale", minimum=0, maximum=1, step=0.1, value=0.5)
94
-
95
- run_button = gr.Button("Run")
96
- result = gr.Image(label="Result")
97
-
98
- run_button.click(infer, inputs=[prompt, negative_prompt, randomize_seed, width, height, use_controlnet, controlnet_image, control_strength, control_mode, use_ip_adapter, ip_adapter_image, ip_adapter_scale], outputs=[result])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- demo.launch()
 
 
2
  import numpy as np
3
  import random
4
 
5
+ import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import (
7
+ DiffusionPipeline,
8
+ StableDiffusionPipeline,
9
+ StableDiffusionControlNetPipeline,
10
+ StableDiffusionControlNetImg2ImgPipeline,
11
+ DPMSolverMultistepScheduler,
12
+ PNDMScheduler,
13
+ ControlNetModel
14
+ )
15
  import torch
16
 
17
  from peft import PeftModel, LoraConfig
18
  import os
19
 
20
  def get_lora_sd_pipeline(
21
+ ckpt_dir='./content/lora',
22
+ base_model_name_or_path=None,
23
+ dtype=torch.float16,
24
+ device="cuda",
25
  adapter_name="default"
26
+ ):
 
27
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
28
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
 
29
  if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
30
  config = LoraConfig.from_pretrained(text_encoder_sub_dir)
31
  base_model_name_or_path = config.base_model_name_or_path
32
+
33
  if base_model_name_or_path is None:
34
  raise ValueError("Please specify the base model name or path")
35
+
36
+
37
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
38
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
39
+
 
40
  if os.path.exists(text_encoder_sub_dir):
41
+ pipe.text_encoder = PeftModel.from_pretrained(
42
+ pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
43
+ )
44
+
45
  if dtype in (torch.float16, torch.bfloat16):
46
  pipe.unet.half()
47
  pipe.text_encoder.half()
48
+
49
  return pipe
50
 
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+ model_id_default = "sd-legacy/stable-diffusion-v1-5"
54
+ model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4', 'sd-legacy/stable-diffusion-v1-5']
55
+
56
+ model_lora_default = "lora"
57
+
58
+ if torch.cuda.is_available():
59
+ torch_dtype = torch.float16
60
+ else:
61
+ torch_dtype = torch.float32
62
+
63
+ MAX_SEED = np.iinfo(np.int32).max
64
+ MAX_IMAGE_SIZE = 1024
65
+
66
+ @spaces.GPU #[uncomment to use ZeroGPU]
67
  def infer(
68
  prompt,
69
  negative_prompt,
70
  randomize_seed,
71
  width=512,
72
  height=512,
73
+ model_repo_id=model_id_default,
74
+ seed=22,
75
  guidance_scale=7,
76
+ num_inference_steps=50,
77
+ model_lora_id=model_lora_default,
78
+ progress=gr.Progress(track_tqdm=True),
 
 
 
 
 
 
 
79
  ):
80
+
81
  if randomize_seed:
82
+ seed = random.randint(0, MAX_SEED)
83
 
84
  generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
85
 
86
+ # добавляем обновление pipe по условию
87
+ if model_repo_id != model_id_default:
88
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
89
+ pipe.safety_checker = None
90
+ else:
91
+ # добавляем lora
92
+ pipe = get_lora_sd_pipeline(ckpt_dir='./' + model_lora_id, base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
93
+ pipe.safety_checker = None
94
+ print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
95
 
96
+ # на вызов pipe с эмбеддингами
97
+ params = {
98
+ 'prompt_embeds': prompt,
99
+ 'negative_prompt_embeds': negative_prompt,
100
+ 'guidance_scale': guidance_scale,
101
+ 'num_inference_steps': num_inference_steps,
102
+ 'width': width,
103
+ 'height': height,
104
+ 'generator': generator,
105
+ }
106
+
107
+ return pipe(**params).images[0], seed
108
+
109
+
110
+ examples = [
111
+ "A cartoon-style sticker of Elon Musk shaking hands with Donald Trump. Both figures have exaggerated facial expressions, with Musk grinning confidently and Trump giving a signature thumbs-up. The background features a patriotic red, white, and blue color scheme with fireworks exploding behind them.",
112
+ "A cyberpunk-themed cartoon sticker of Elon Musk standing atop a futuristic Tesla spaceship. He wears a sleek, neon-lit jacket with glowing circuits, while the city skyline behind him is filled with holographic billboards displaying SpaceX and Neuralink logos. His sunglasses reflect the distant stars, adding to the sci-fi aesthetic.",
113
+ "A medieval fantasy sticker of Elon Musk depicted as a wizard. He holds a glowing blue orb in one hand and a spellbook in the other, wearing a long, starry robe with intricate golden details. His expression is both wise and mischievous, as if he's about to reveal the secrets of the universe. The background features a mystical castle and a dragon flying in the sky.",
114
+ "A sticker of Elon Musk dressed as a cowboy in the Wild West. He wears a wide-brimmed hat, leather boots, and a long trench coat, standing in front of a saloon with a SpaceX rocket docked nearby instead of a horse. A wanted poster on the wall reads 'Wanted: Mars Pioneer', adding to the playful western theme.",
115
+ "A parody cartoon sticker of Elon Musk arm-wrestling a robotic version of himself. The robot Musk has glowing red eyes and mechanical arms, while the real Musk smirks confidently. Sparks fly from the table as the intense match unfolds, and the background features a neon sign that reads 'Tesla vs. AI: Ultimate Showdown'."
116
+ ]
117
+
118
+ css = """
119
+ #col-container {
120
+ margin: 0 auto;
121
+ max-width: 640px;
122
+ }
123
+ """
124
+
125
+ with gr.Blocks(css=css) as demo:
126
+ with gr.Column(elem_id="col-container"):
127
+ gr.Markdown("# Generate LoRa stickers")
128
+
129
+ with gr.Row():
130
+ prompt = gr.Text(
131
+ label="Prompt",
132
+ show_label=False,
133
+ max_lines=1,
134
+ placeholder="Enter your prompt",
135
+ container=False,
136
+ )
137
+
138
+ run_button = gr.Button("Run", scale=0, variant="primary")
139
+
140
+ result = gr.Image(label="Result", show_label=False)
141
+
142
+ with gr.Accordion("Advanced Settings", open=False):
143
+
144
+ model_repo_id = gr.Dropdown(
145
+ label="Model Id",
146
+ choices=model_dropdown,
147
+ info="Choose model",
148
+ visible=True,
149
+ allow_custom_value=True,
150
+ value=model_id_default,
151
+ )
152
+
153
+ negative_prompt = gr.Text(
154
+ label="Negative prompt",
155
+ max_lines=1,
156
+ placeholder="Enter a negative prompt",
157
+ visible=True,
158
+ value="bad face, bad quality, artifacts, low-res, black and white, blurry, low quality, distorted, low resolution, medical mask"
159
+ )
160
+
161
+ seed = gr.Slider(
162
+ label="Seed",
163
+ minimum=0,
164
+ maximum=1000,
165
+ step=1,
166
+ value=42,
167
+ )
168
+
169
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
170
+
171
+ with gr.Row():
172
+ width = gr.Slider(
173
+ label="Width",
174
+ minimum=256,
175
+ maximum=MAX_IMAGE_SIZE,
176
+ step=32,
177
+ value=512, # Replace with defaults that work for your model
178
+ )
179
+
180
+ height = gr.Slider(
181
+ label="Height",
182
+ minimum=256,
183
+ maximum=MAX_IMAGE_SIZE,
184
+ step=32,
185
+ value=512, # Replace with defaults that work for your model
186
+ )
187
+
188
+ gr.Examples(examples=examples, inputs=[prompt])
189
+ gr.on(
190
+ triggers=[run_button.click, prompt.submit],
191
+ fn=infer,
192
+ inputs=[
193
+ prompt,
194
+ negative_prompt,
195
+ randomize_seed,
196
+ width,
197
+ height,
198
+ model_repo_id,
199
+ seed
200
+ ],
201
+ outputs=[result, seed],
202
+ )
203
 
204
+ if __name__ == "__main__":
205
+ demo.launch()