primerz commited on
Commit
f090813
·
verified ·
1 Parent(s): cb9063f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -65
app.py CHANGED
@@ -173,23 +173,6 @@ button.addEventListener('click', function() {
173
  '''
174
  lora_archive = "/data"
175
 
176
- def resize_image_aspect_ratio(img, max_dim=512):
177
- width, height = img.size
178
- aspect_ratio = width / height
179
-
180
- if aspect_ratio >= 1: # Landscape or square
181
- new_width = min(max_dim, width)
182
- new_height = int(new_width / aspect_ratio)
183
- else: # Portrait
184
- new_height = min(max_dim, height)
185
- new_width = int(new_height * aspect_ratio)
186
-
187
- new_width = (new_width // 8) * 8
188
- new_height = (new_height // 8) * 8
189
-
190
- return img.resize((new_width, new_height), Image.LANCZOS)
191
-
192
-
193
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
194
  lora_repo = sdxl_loras[selected_state.index]["repo"]
195
  new_placeholder = "Type a prompt to use your selected LoRA"
@@ -246,57 +229,22 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
246
  )
247
  del weights_sd
248
  del lora_model
 
 
 
249
 
250
- @spaces.GPU(duration=80)
251
- def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, face_detected):
252
- global last_fused, last_lora
253
-
254
- control_images = [face_kps, zoe(face_image)] if face_detected else [zoe(face_image)]
255
- control_scales = [face_strength, depth_control_scale] if face_detected else [depth_control_scale]
256
-
257
- if repo_name.startswith("https://huggingface.co"):
258
- repo_id = repo_name.split("huggingface.co/")[-1]
259
- weight_file = "pytorch_lora_weights.safetensors"
260
- full_path_lora = hf_hub_download(repo_id=repo_id, filename=weight_file, repo_type="model")
261
- loaded_state_dict = load_file(full_path_lora)
262
-
263
- if last_lora != repo_name:
264
- if last_fused:
265
- pipe.unfuse_lora()
266
- pipe.unload_lora_weights()
267
- pipe.unload_textual_inversion()
268
- pipe.load_lora_weights(loaded_state_dict)
269
- pipe.fuse_lora(lora_scale)
270
- last_fused = True
271
- is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
272
- if is_pivotal:
273
- text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
274
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
275
- state_dict_embedding = load_file(embedding_path)
276
- pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
277
- pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
278
 
279
- conditioning, pooled = compel(prompt)
280
- negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
281
 
282
- image = pipe(
283
- prompt_embeds=conditioning,
284
- pooled_prompt_embeds=pooled,
285
- negative_prompt_embeds=negative_conditioning,
286
- negative_pooled_prompt_embeds=negative_pooled,
287
- width=face_image.width,
288
- height=face_image.height,
289
- image_embeds=face_emb if face_detected else None,
290
- image=face_image,
291
- strength=1-image_strength,
292
- control_image=control_images,
293
- num_inference_steps=36,
294
- guidance_scale=guidance_scale,
295
- controlnet_conditioning_scale=control_scales,
296
- ).images[0]
297
 
298
- last_lora = repo_name
299
- return image
300
 
301
 
302
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)):
@@ -350,8 +298,77 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
350
 
351
  return (resized_image, image), gr.update(visible=True)
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- run_lora.zerogpu = True
 
355
 
356
  def shuffle_gallery(sdxl_loras):
357
  random.shuffle(sdxl_loras)
 
173
  '''
174
  lora_archive = "/data"
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
177
  lora_repo = sdxl_loras[selected_state.index]["repo"]
178
  new_placeholder = "Type a prompt to use your selected LoRA"
 
229
  )
230
  del weights_sd
231
  del lora_model
232
+ def resize_image_aspect_ratio(img, max_dim=512):
233
+ width, height = img.size
234
+ aspect_ratio = width / height
235
 
236
+ if aspect_ratio >= 1: # Landscape or square
237
+ new_width = min(max_dim, width)
238
+ new_height = int(new_width / aspect_ratio)
239
+ else: # Portrait
240
+ new_height = min(max_dim, height)
241
+ new_width = int(new_height * aspect_ratio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ new_width = (new_width // 8) * 8
244
+ new_height = (new_height // 8) * 8
245
 
246
+ return img.resize((new_width, new_height), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
 
 
248
 
249
 
250
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)):
 
298
 
299
  return (resized_image, image), gr.update(visible=True)
300
 
301
+ @spaces.GPU(duration=100)
302
+ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, face_detected):
303
+ global last_fused, last_lora
304
+
305
+ print("Loaded state dict:", loaded_state_dict)
306
+ print("Last LoRA:", last_lora, "| Current LoRA:", repo_name)
307
+
308
+ control_images = [face_kps, zoe(face_image)] if face_detected else [zoe(face_image)]
309
+ control_scales = [face_strength, depth_control_scale] if face_detected else [depth_control_scale]
310
+
311
+ # Handle Hugging Face URL-based LoRA
312
+ if repo_name.startswith("https://huggingface.co"):
313
+ repo_id = repo_name.split("huggingface.co/")[-1]
314
+ fs = HfFileSystem()
315
+ files = fs.ls(repo_id, detail=False)
316
+ safetensors_files = [f for f in files if f.endswith(".safetensors")]
317
+
318
+ if not safetensors_files:
319
+ raise gr.Error("No .safetensors file found in this Hugging Face repository.")
320
+
321
+ weight_file = safetensors_files[0] # Dynamically select the first available .safetensors file
322
+ full_path_lora = hf_hub_download(repo_id=repo_id, filename=weight_file, repo_type="model")
323
+ loaded_state_dict = load_file(full_path_lora)
324
+ else:
325
+ # Use the previously loaded state_dict if not using a Hugging Face URL
326
+ loaded_state_dict = state_dicts[repo_name]["state_dict"]
327
+
328
+ # Manage LoRA weights and textual inversion embeddings
329
+ if last_lora != repo_name:
330
+ if last_fused:
331
+ pipe.unfuse_lora()
332
+ pipe.unload_lora_weights()
333
+ pipe.unload_textual_inversion()
334
+ pipe.load_lora_weights(loaded_state_dict)
335
+ pipe.fuse_lora(lora_scale)
336
+ last_fused = True
337
+
338
+ # Handle pivotal tuning (textual inversion embeddings)
339
+ is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
340
+ if is_pivotal:
341
+ text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
342
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
343
+ state_dict_embedding = load_file(embedding_path)
344
+ pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
345
+ pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
346
+
347
+ # Prompt embeddings
348
+ print("Processing prompt...")
349
+ conditioning, pooled = compel(prompt)
350
+ negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
351
+
352
+ # Image generation
353
+ print("Generating image...")
354
+ image = pipe(
355
+ prompt_embeds=conditioning,
356
+ pooled_prompt_embeds=pooled,
357
+ negative_prompt_embeds=negative_conditioning,
358
+ negative_pooled_prompt_embeds=negative_pooled,
359
+ width=face_image.width,
360
+ height=face_image.height,
361
+ image_embeds=face_emb if face_detected else None,
362
+ image=face_image,
363
+ strength=1-image_strength,
364
+ control_image=control_images,
365
+ num_inference_steps=36,
366
+ guidance_scale=guidance_scale,
367
+ controlnet_conditioning_scale=control_scales,
368
+ ).images[0]
369
 
370
+ last_lora = repo_name
371
+ return image
372
 
373
  def shuffle_gallery(sdxl_loras):
374
  random.shuffle(sdxl_loras)