primerz commited on
Commit
6daa11a
·
verified ·
1 Parent(s): 481c3ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -661
app.py CHANGED
@@ -1,26 +1,9 @@
1
- import os
2
- # Fix transformers deprecation warning - use HF_HOME instead of TRANSFORMERS_CACHE
3
- if 'TRANSFORMERS_CACHE' in os.environ and 'HF_HOME' not in os.environ:
4
- os.environ['HF_HOME'] = os.environ['TRANSFORMERS_CACHE']
5
- if 'HF_HOME' not in os.environ:
6
- os.environ['HF_HOME'] = os.environ.get('TRANSFORMERS_CACHE', '/data/.cache/huggingface')
7
-
8
- # Suppress ONNX Runtime GPU discovery warnings (we use CPU for face detection)
9
- os.environ['ORT_LOGGING_LEVEL'] = '3' # Only show errors, not warnings
10
-
11
- import warnings
12
- warnings.filterwarnings('ignore', category=UserWarning, module='onnxruntime')
13
-
14
  import gradio as gr
15
  import torch
16
  import spaces
17
- import time
18
- import traceback
19
- from typing import Optional, List
20
- import numpy as np
21
- from PIL import Image, ImageEnhance
22
  torch.jit.script = lambda f: f
23
  import timm
 
24
 
25
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
26
  from safetensors.torch import load_file
@@ -45,6 +28,7 @@ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditio
45
  import cv2
46
  import torch
47
  import numpy as np
 
48
 
49
  from insightface.app import FaceAnalysis
50
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
@@ -80,248 +64,174 @@ with open("defaults_data.json", "r") as file:
80
 
81
  device = "cuda"
82
 
83
- print("=" * 80)
84
- print("🚀 Starting LucasArts Style App (VRAM Optimized)")
85
- print("=" * 80)
86
- print()
87
-
88
  # Cache for LoRA state dicts
89
- print("📦 Loading LoRA configurations...")
90
  state_dicts = {}
91
- try:
92
- for item in sdxl_loras_raw:
93
- saved_name = hf_hub_download(item["repo"], item["weights"])
94
-
95
- if not saved_name.endswith('.safetensors'):
96
- state_dict = torch.load(saved_name)
97
- else:
98
- state_dict = load_file(saved_name)
99
-
100
- state_dicts[item["repo"]] = {
101
- "saved_name": saved_name,
102
- "state_dict": state_dict
103
- }
104
- print(f"✅ Loaded {len(state_dicts)} LoRA configurations")
105
- except Exception as e:
106
- print(f"❌ Error loading LoRAs: {e}")
107
- raise
108
 
109
  sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
110
 
111
  # Download models
112
- print()
113
- print("📥 Downloading required models...")
114
- try:
115
- hf_hub_download(
116
- repo_id="InstantX/InstantID",
117
- filename="ControlNetModel/config.json",
118
- local_dir="/data/checkpoints",
119
- )
120
- hf_hub_download(
121
- repo_id="InstantX/InstantID",
122
- filename="ControlNetModel/diffusion_pytorch_model.safetensors",
123
- local_dir="/data/checkpoints",
124
- )
125
- hf_hub_download(
126
- repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
127
- )
128
- hf_hub_download(
129
- repo_id="latent-consistency/lcm-lora-sdxl",
130
- filename="pytorch_lora_weights.safetensors",
131
- local_dir="/data/checkpoints",
132
- )
133
- print("✅ Model checkpoints downloaded")
134
- except Exception as e:
135
- print(f"❌ Error downloading models: {e}")
136
- raise
137
 
138
  # Download antelopev2
139
- print()
140
- print("📥 Downloading face detection model...")
141
- try:
142
- antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
143
- print(f"✅ Face detection model: {antelope_download}")
144
- except Exception as e:
145
- print(f"❌ Error downloading face model: {e}")
146
- raise
147
-
148
- print()
149
- print("🔧 Initializing face detection...")
150
- # VRAM OPTIMIZED: Standard 768x768 for better memory usage
151
- try:
152
- app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
153
- app.prepare(ctx_id=0, det_size=(768, 768))
154
- print("✅ Face detection initialized at 768x768")
155
- except Exception as e:
156
- print(f"❌ Error initializing face detection: {e}")
157
- raise
158
 
159
  # Prepare models
160
  face_adapter = f'/data/checkpoints/ip-adapter.bin'
161
  controlnet_path = f'/data/checkpoints/ControlNetModel'
162
 
163
- print()
164
- print("🔧 Loading ControlNets...")
165
  st = time.time()
166
- try:
167
- identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
168
- zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
169
- et = time.time()
170
- print(f'✅ ControlNet loaded in {et - st:.2f} seconds')
171
- except Exception as e:
172
- print(f"❌ Error loading ControlNet: {e}")
173
- raise
174
 
175
- print()
176
- print("🔧 Loading VAE...")
177
  st = time.time()
178
- try:
179
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
180
- et = time.time()
181
- print(f'✅ VAE loaded in {et - st:.2f} seconds')
182
- except Exception as e:
183
- print(f"❌ Error loading VAE: {e}")
184
- raise
185
 
186
- print()
187
- print("🔧 Loading main pipeline...")
188
  st = time.time()
189
- try:
190
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
191
- "frankjoshua/albedobaseXL_v21",
192
- vae=vae,
193
- controlnet=[identitynet, zoedepthnet],
194
- torch_dtype=torch.float16
195
- )
 
 
 
 
 
196
 
197
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
198
- pipe.load_ip_adapter_instantid(face_adapter)
199
- pipe.set_ip_adapter_scale(1.0)
200
- et = time.time()
201
- print(f'✅ Pipeline loaded in {et - st:.2f} seconds')
202
- except Exception as e:
203
- print(f"❌ Error loading pipeline: {e}")
204
- raise
205
-
206
- print()
207
- print("🔧 Loading Compel (prompt processor)...")
208
  st = time.time()
209
- try:
210
- compel = Compel(
211
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
212
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
213
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
214
- requires_pooled=[False, True]
215
- )
216
- et = time.time()
217
- print(f'✅ Compel loaded in {et - st:.2f} seconds')
218
- except Exception as e:
219
- print(f"❌ Error loading Compel: {e}")
220
- raise
221
 
222
- print()
223
- print("🔧 Loading Zoe (depth detector)...")
224
  st = time.time()
225
- try:
226
- zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
227
- et = time.time()
228
- print(f'✅ Zoe loaded in {et - st:.2f} seconds')
229
- except Exception as e:
230
- print(f"❌ Error loading Zoe: {e}")
231
- raise
232
-
233
- print()
234
- print("🔧 Moving models to GPU...")
235
- try:
236
- zoe.to(device)
237
- pipe.to(device)
238
- print("✅ Models moved to GPU")
239
- except Exception as e:
240
- print(f"❌ Error moving to GPU: {e}")
241
- raise
242
-
243
- print()
244
- print("=" * 80)
245
- print("✅ All models loaded successfully!")
246
- print("âš¡ VRAM Optimized Configuration:")
247
- print(" • 768x768 face detection")
248
- print(" • 768px max output resolution")
249
- print(" • 35 inference steps")
250
- print(" • Enhanced error reporting")
251
- print(" • Fixed version compatibility (diffusers 0.21.4)")
252
- print("=" * 80)
253
- print()
254
 
255
  last_lora = ""
256
  last_fused = False
257
  lora_archive = "/data"
258
 
259
-
260
- def enhance_details(image, strength=1.15):
261
- """Post-process to enhance details"""
 
 
 
262
  try:
263
- sharpener = ImageEnhance.Sharpness(image)
264
- image = sharpener.enhance(strength)
265
-
266
- contrast = ImageEnhance.Contrast(image)
267
- image = contrast.enhance(1.08)
268
 
269
- return image
270
- except Exception as e:
271
- print(f"Warning: Detail enhancement failed: {e}")
272
- return image
273
-
274
-
275
- def enhanced_depth_map(image, face_detected=False):
276
- """Better depth map generation"""
277
- try:
278
- original_size = image.size
279
 
280
- # Skip upscaling for VRAM optimization
281
- depth = zoe(image)
 
 
 
 
282
 
283
- return depth
 
 
 
 
 
 
284
  except Exception as e:
285
- print(f"Error in depth map generation: {e}")
286
- # Return a blank depth map as fallback
287
- return Image.new('L', image.size, color=128)
288
-
289
 
290
- def process_face_embeddings_separately(face_info_list):
291
- """Process face embeddings separately for multi-face generation"""
 
 
292
  if not face_info_list:
293
- return []
294
 
 
 
 
 
295
  embeddings = [face_info['embedding'] for face_info in face_info_list]
296
- return embeddings
297
-
298
 
299
  def create_face_kps_image(face_image, face_info_list):
300
- """Create keypoints image from face info with enhanced visibility"""
 
 
301
  if not face_info_list:
302
  return face_image
303
 
 
304
  if len(face_info_list) > 1:
305
  return draw_multiple_kps(face_image, [f['kps'] for f in face_info_list])
306
  else:
307
  return draw_kps(face_image, face_info_list[0]['kps'])
308
 
309
-
310
  def draw_multiple_kps(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
311
- """Draw keypoints for multiple faces with enhanced visibility"""
 
 
312
  stickwidth = 4
313
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
314
 
315
  w, h = image_pil.size
316
  out_img = np.zeros([h, w, 3])
317
 
318
- for idx, kps in enumerate(kps_list):
319
  kps = np.array(kps)
320
- color_offset = idx % len(color_list)
321
 
322
  for i in range(len(limbSeq)):
323
  index = limbSeq[i]
324
- color = color_list[(index[0] + color_offset) % len(color_list)]
325
 
326
  x = kps[index][:, 0]
327
  y = kps[index][:, 1]
@@ -335,25 +245,24 @@ def draw_multiple_kps(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0),
335
  out_img = (out_img * 0.6).astype(np.uint8)
336
 
337
  for idx_kp, kp in enumerate(kps):
338
- color = color_list[(idx_kp + color_offset) % len(color_list)]
339
  x, y = kp
340
  out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
341
 
342
  out_img_pil = Image.fromarray(out_img.astype(np.uint8))
343
  return out_img_pil
344
 
345
-
346
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
347
  lora_repo = sdxl_loras[selected_state.index]["repo"]
348
  new_placeholder = "Type a prompt to use your selected LoRA"
349
  weight_name = sdxl_loras[selected_state.index]["weights"]
350
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
351
 
352
  for lora_list in lora_defaults:
353
  if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
354
- face_strength = lora_list.get("face_strength", 1.0)
355
- image_strength = lora_list.get("image_strength", 0.15)
356
- weight = lora_list.get("weight", 1.0)
357
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
358
  negative = lora_list.get("negative", "")
359
 
@@ -374,33 +283,230 @@ def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, i
374
  selected_state
375
  )
376
 
377
-
378
  def check_selected(selected_state, custom_lora):
379
  if not selected_state and not custom_lora:
380
  raise gr.Error("You must select a style")
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  def shuffle_gallery(sdxl_loras):
384
  random.shuffle(sdxl_loras)
385
  return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
386
 
387
-
388
  def classify_gallery(sdxl_loras):
389
  sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
390
  return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
391
 
392
-
393
  def swap_gallery(order, sdxl_loras):
394
  if(order == "random"):
395
  return shuffle_gallery(sdxl_loras)
396
  else:
397
  return classify_gallery(sdxl_loras)
398
 
399
-
400
  def deselect():
401
  return gr.Gallery(selected_index=None)
402
 
403
-
404
  def get_huggingface_safetensors(link):
405
  split_link = link.split("/")
406
  if(len(split_link) == 2):
@@ -424,7 +530,6 @@ def get_huggingface_safetensors(link):
424
  raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
425
  return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
426
 
427
-
428
  def get_civitai_safetensors(link):
429
  link_split = link.split("civitai.com/")
430
  pattern = re.compile(r'models\/(\d+)')
@@ -469,7 +574,6 @@ def get_civitai_safetensors(link):
469
  raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
470
  return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
471
 
472
-
473
  def check_custom_model(link):
474
  if(link.startswith("https://")):
475
  if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
@@ -480,7 +584,6 @@ def check_custom_model(link):
480
  else:
481
  return get_huggingface_safetensors(link)
482
 
483
-
484
  def load_custom_lora(link):
485
  if(link):
486
  try:
@@ -504,419 +607,20 @@ def load_custom_lora(link):
504
  else:
505
  return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
506
 
507
-
508
  def remove_custom_lora():
509
  return "", gr.update(visible=False), gr.update(visible=False), None
510
 
511
-
512
- @spaces.GPU(duration=120)
513
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength,
514
- guidance_scale, depth_control_scale, sdxl_loras, custom_lora, use_multiple_faces=False,
515
- progress=gr.Progress(track_tqdm=True)):
516
- """
517
- VRAM optimized generation with enhanced error reporting
518
- """
519
- print("=" * 80)
520
- print("🚀 FUNCTION CALLED: run_lora")
521
- print(f"📊 Inputs received:")
522
- print(f" - face_image: {type(face_image)} - {face_image.size if face_image else 'None'}")
523
- print(f" - prompt: '{prompt}'")
524
- print(f" - selected_state: {type(selected_state)} - {selected_state}")
525
- print(f" - custom_lora: {custom_lora}")
526
- print(f" - use_multiple_faces: {use_multiple_faces}")
527
- print("=" * 80)
528
-
529
- try:
530
- print("Starting generation...")
531
- print("Custom LoRA:", custom_lora)
532
- custom_lora_path = custom_lora[0] if custom_lora else None
533
-
534
- # Extract index from selected_state (handle Gradio SelectData object)
535
- if selected_state:
536
- print(f" selected_state exists: {selected_state}")
537
- print(f" selected_state type: {type(selected_state)}")
538
- print(f" selected_state dir: {dir(selected_state)}")
539
- if hasattr(selected_state, 'index'):
540
- selected_state_index = selected_state.index
541
- print(f" ✓ Extracted index: {selected_state_index}")
542
- else:
543
- selected_state_index = -1
544
- print(f" ❌ No index attribute, using -1")
545
- else:
546
- selected_state_index = -1
547
- print(f" ❌ selected_state is None or False")
548
-
549
- print(f"🔍 VALIDATION CHECK:")
550
- print(f" - selected_state_index: {selected_state_index}")
551
- print(f" - custom_lora_path: {custom_lora_path}")
552
- print(f" - len(sdxl_loras): {len(sdxl_loras)}")
553
-
554
- # Validate selection immediately
555
- if (selected_state_index is None or selected_state_index < 0) and not custom_lora_path:
556
- error_msg = "❌ You must select a style before generating"
557
- print(error_msg)
558
- return gr.update(), gr.update(visible=False), gr.update(visible=True, value=error_msg)
559
-
560
- # Validate selected_state_index is valid (only check positive indices)
561
- if not custom_lora_path and selected_state_index >= 0 and selected_state_index >= len(sdxl_loras):
562
- error_msg = f"❌ Invalid style selection (index: {selected_state_index}, available: {len(sdxl_loras)})"
563
- print(error_msg)
564
- return gr.update(), gr.update(visible=False), gr.update(visible=True, value=error_msg)
565
-
566
- st = time.time()
567
-
568
- pipe.to(device)
569
- zoe.to(device)
570
-
571
- # VRAM OPTIMIZED: Reduced max dimension to 768
572
- face_image = resize_image_aspect_ratio(face_image)
573
- print(f"Resized image to {face_image.size}")
574
-
575
- # Face detection with better error handling
576
- try:
577
- face_info_list = detect_faces(face_image, use_multiple_faces)
578
- face_detected = len(face_info_list) > 0
579
- except Exception as e:
580
- print(f"Face detection error: {e}")
581
- face_detected = False
582
- face_info_list = []
583
-
584
- if face_detected:
585
- face_embeddings = process_face_embeddings_separately(face_info_list)
586
- face_kps = create_face_kps_image(face_image, face_info_list)
587
- print(f"Processing with {len(face_info_list)} face(s) detected")
588
- face_emb = face_embeddings[0]
589
- else:
590
- face_emb = None
591
- face_kps = face_image
592
- print("No faces detected - landscape mode")
593
-
594
- et = time.time()
595
- print(f'Face processing took: {et - st:.2f}s')
596
-
597
- st = time.time()
598
-
599
- # Enhanced prompt processing
600
- if custom_lora_path and custom_lora[1]:
601
- prompt = f"{prompt} {custom_lora[1]}"
602
- else:
603
- for lora_list in lora_defaults:
604
- if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
605
- prompt_full = lora_list.get("prompt", None)
606
- if prompt_full:
607
- prompt = prompt_full.replace("<subject>", prompt)
608
-
609
- if "lucasarts artstyle" not in prompt.lower():
610
- prompt = f"{prompt}, lucasarts artstyle"
611
-
612
- print("Prompt:", prompt)
613
- if prompt == "":
614
- prompt = "a beautiful cinematic scene" if not face_detected else "a person in cinematic lighting"
615
- print(f"Executing prompt: {prompt}")
616
-
617
- if negative == "":
618
- if not face_detected:
619
- negative = "worst quality, low quality, blurry, distorted, deformed, ugly, bad anatomy"
620
- else:
621
- negative = "worst quality, low quality, blurry, distorted, deformed, ugly, bad anatomy, bad proportions"
622
-
623
- print("Custom Loaded LoRA:", custom_lora_path)
624
-
625
- if custom_lora_path:
626
- repo_name = custom_lora_path
627
- full_path_lora = custom_lora_path
628
- else:
629
- repo_name = sdxl_loras[selected_state_index]["repo"]
630
- if repo_name not in state_dicts:
631
- error_msg = f"❌ LoRA not loaded: {repo_name}\nAvailable: {list(state_dicts.keys())[:5]}"
632
- print(error_msg)
633
- return gr.update(), gr.update(visible=False), gr.update(visible=True, value=error_msg)
634
- full_path_lora = state_dicts[repo_name]["saved_name"]
635
-
636
- repo_name = repo_name.rstrip("/").lower()
637
-
638
- et = time.time()
639
- print(f'Prompt processing took: {et - st:.2f}s')
640
-
641
- # Optimized parameters based on mode
642
- if not face_detected:
643
- face_strength = 0.0
644
- depth_control_scale = 1.0
645
- image_strength = 0.25
646
- guidance_scale = max(guidance_scale, 8.5)
647
- print("Optimized for landscape mode")
648
- else:
649
- face_strength = max(face_strength, 1.0)
650
- depth_control_scale = max(depth_control_scale, 0.8)
651
- guidance_scale = max(guidance_scale, 7.5)
652
- print("Optimized for face preservation")
653
-
654
- st = time.time()
655
-
656
- image = generate_image_inline(
657
- prompt, negative, face_emb, face_image, face_kps, image_strength,
658
- guidance_scale, face_strength, depth_control_scale, repo_name,
659
- full_path_lora, lora_scale, sdxl_loras, selected_state_index, face_detected, st
660
- )
661
-
662
- torch.cuda.empty_cache()
663
-
664
- print("Generation complete!")
665
- print("=" * 80)
666
- return (face_image, image), gr.update(visible=True), gr.update(visible=False)
667
-
668
- except torch.cuda.OutOfMemoryError as e:
669
- error_msg = (
670
- "GPU OUT OF MEMORY!\n\n"
671
- "Your image is too large for available VRAM.\n\n"
672
- "Solutions:\n"
673
- "1. Try a smaller image (current max: 768px)\n"
674
- "2. Upgrade to A10G GPU in Space settings\n"
675
- "3. Reduce image strength parameter\n\n"
676
- f"Technical details: {str(e)}"
677
- )
678
- print("=" * 80)
679
- print(error_msg)
680
- print("=" * 80)
681
- torch.cuda.empty_cache()
682
- return gr.update(), gr.update(visible=False), gr.update(visible=True, value=error_msg)
683
- except RuntimeError as e:
684
- if "out of memory" in str(e).lower():
685
- error_msg = (
686
- "GPU OUT OF MEMORY!\n\n"
687
- f"Error: {str(e)}\n\n"
688
- "Solutions:\n"
689
- "1. Upload smaller image\n"
690
- "2. Upgrade GPU in Settings\n"
691
- "3. Reduce parameters"
692
- )
693
- else:
694
- error_msg = f"Runtime error: {str(e)}\n\nFull trace:\n{traceback.format_exc()}"
695
- print("=" * 80)
696
- print(error_msg)
697
- print("=" * 80)
698
- torch.cuda.empty_cache()
699
- return gr.update(), gr.update(visible=False), gr.update(visible=True, value=error_msg)
700
- except Exception as e:
701
- error_msg = f"Generation failed: {str(e)}\n\nFull error:\n{traceback.format_exc()}"
702
- print("=" * 80)
703
- print("ERROR:")
704
- print(error_msg)
705
- print("=" * 80)
706
- torch.cuda.empty_cache()
707
- return gr.update(), gr.update(visible=False), gr.update(visible=True, value=error_msg)
708
-
709
-
710
- def generate_image_inline(prompt, negative, face_emb, face_image, face_kps, image_strength,
711
- guidance_scale, face_strength, depth_control_scale, repo_name,
712
- loaded_state_dict, lora_scale, sdxl_loras, selected_state_index,
713
- face_detected, st):
714
- """Generation with VRAM optimization"""
715
- global last_fused, last_lora
716
-
717
- try:
718
- print("Loaded state dict:", loaded_state_dict)
719
- print("Last LoRA:", last_lora, "| Current LoRA:", repo_name)
720
-
721
- # Enhanced depth map generation
722
- depth_image = enhanced_depth_map(face_image, face_detected)
723
-
724
- # CRITICAL FIX: Pipeline has 2 controlnets, must always pass 2 control images
725
- if face_detected:
726
- control_images = [face_kps, depth_image]
727
- control_scales = [face_strength, depth_control_scale]
728
- else:
729
- # When no face detected, pass dummy black image for face controlnet with scale 0.0
730
- dummy_face = Image.new('RGB', face_image.size, color=(0, 0, 0))
731
- control_images = [dummy_face, depth_image]
732
- control_scales = [0.0, depth_control_scale] # Face control disabled
733
-
734
- # Handle custom LoRA from HuggingFace
735
- if repo_name.startswith("https://huggingface.co"):
736
- repo_id = repo_name.split("huggingface.co/")[-1]
737
- fs = HfFileSystem()
738
- files = fs.ls(repo_id, detail=False)
739
- safetensors_files = [f for f in files if f.endswith(".safetensors")]
740
-
741
- if not safetensors_files:
742
- raise Exception("No .safetensors file found in this Hugging Face repository.")
743
-
744
- weight_file = safetensors_files[0]
745
- full_path_lora = hf_hub_download(repo_id=repo_id, filename=weight_file, repo_type="model")
746
- else:
747
- full_path_lora = loaded_state_dict
748
-
749
- # LoRA loading
750
- if last_lora != repo_name:
751
- if last_fused:
752
- pipe.unfuse_lora()
753
- pipe.unload_lora_weights()
754
- pipe.unload_textual_inversion()
755
-
756
- try:
757
- pipe.load_lora_weights(full_path_lora)
758
- pipe.fuse_lora(lora_scale)
759
- last_fused = True
760
-
761
- is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
762
- if is_pivotal:
763
- text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
764
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
765
- state_dict_embedding = load_file(embedding_path)
766
- pipe.load_textual_inversion(
767
- state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"],
768
- token=["<s0>", "<s1>"],
769
- text_encoder=pipe.text_encoder,
770
- tokenizer=pipe.tokenizer
771
- )
772
- pipe.load_textual_inversion(
773
- state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"],
774
- token=["<s0>", "<s1>"],
775
- text_encoder=pipe.text_encoder_2,
776
- tokenizer=pipe.tokenizer_2
777
- )
778
- except Exception as e:
779
- raise Exception(f"Failed to load LoRA: {str(e)}")
780
-
781
- print("✓ Processing embeddings...")
782
- conditioning, pooled = compel(prompt)
783
- negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
784
-
785
- # VRAM OPTIMIZED: Reduced to 35 steps
786
- num_inference_steps = 35
787
-
788
- print(f"✓ Generating image ({num_inference_steps} steps)...")
789
- print(f" Image size: {face_image.width}x{face_image.height}")
790
- print(f" Face detected: {face_detected}")
791
- print(f" Control images: {len(control_images)}")
792
- print(f" GPU Memory before generation: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
793
-
794
- try:
795
- image = pipe(
796
- prompt_embeds=conditioning,
797
- pooled_prompt_embeds=pooled,
798
- negative_prompt_embeds=negative_conditioning,
799
- negative_pooled_prompt_embeds=negative_pooled,
800
- width=face_image.width,
801
- height=face_image.height,
802
- image_embeds=face_emb if face_detected else None,
803
- image=face_image,
804
- strength=1-image_strength,
805
- control_image=control_images,
806
- num_inference_steps=num_inference_steps,
807
- guidance_scale=guidance_scale,
808
- controlnet_conditioning_scale=control_scales,
809
- ).images[0]
810
- except torch.cuda.OutOfMemoryError as e:
811
- print(f" GPU Memory at error: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
812
- raise Exception(f"CUDA out of memory during generation. Image size {face_image.width}x{face_image.height} is too large. Error: {str(e)}")
813
- except RuntimeError as e:
814
- if "out of memory" in str(e).lower():
815
- print(f" GPU Memory at error: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
816
- raise Exception(f"GPU out of memory. Try smaller image or upgrade GPU. Error: {str(e)}")
817
- else:
818
- raise Exception(f"Runtime error during generation: {str(e)}")
819
- except Exception as e:
820
- raise Exception(f"Pipeline generation failed: {str(e)}")
821
-
822
- # Post-processing detail enhancement
823
- print("✓ Enhancing details...")
824
- image = enhance_details(image, strength=1.15)
825
-
826
- print(f"✓ Generation complete! GPU Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
827
-
828
- last_lora = repo_name
829
- return image
830
-
831
- except Exception as e:
832
- raise Exception(f"Image generation failed: {str(e)}\n{traceback.format_exc()}")
833
-
834
-
835
- def detect_faces(face_image, use_multiple_faces=False):
836
- """
837
- Enhanced face detection with better filtering
838
- """
839
- try:
840
- face_info_list = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
841
-
842
- if not face_info_list or len(face_info_list) == 0:
843
- print("No faces detected")
844
- return []
845
-
846
- # Enhanced: Stricter quality filtering
847
- filtered_faces = []
848
- for face_info in face_info_list:
849
- # Higher confidence threshold (0.6 instead of 0.5)
850
- if 'det_score' in face_info and face_info['det_score'] > 0.6:
851
- # Check minimum face size (80x80 instead of default)
852
- bbox = face_info['bbox']
853
- width = bbox[2] - bbox[0]
854
- height = bbox[3] - bbox[1]
855
-
856
- if width >= 80 and height >= 80:
857
- # Check reasonable aspect ratio
858
- aspect_ratio = width / height
859
- if 0.6 <= aspect_ratio <= 1.4:
860
- filtered_faces.append(face_info)
861
- elif 'det_score' not in face_info:
862
- filtered_faces.append(face_info)
863
-
864
- if not filtered_faces:
865
- print("No high-quality faces detected (strict filtering)")
866
- return []
867
-
868
- # Sort by size (largest first)
869
- filtered_faces = sorted(
870
- filtered_faces,
871
- key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]),
872
- reverse=True
873
- )
874
-
875
- if use_multiple_faces:
876
- print(f"✓ Detected {len(filtered_faces)} high-quality faces")
877
- return filtered_faces
878
- else:
879
- print(f"✓ Using largest face (detected {len(filtered_faces)} total)")
880
- return [filtered_faces[0]]
881
-
882
- except Exception as e:
883
- print(f"Face detection error: {e}")
884
- return []
885
-
886
-
887
- def resize_image_aspect_ratio(img, max_dim=768):
888
- """
889
- VRAM OPTIMIZED: Reduced max dimension to 768 to prevent CUDA OOM errors
890
- """
891
- width, height = img.size
892
- aspect_ratio = width / height
893
-
894
- if aspect_ratio >= 1:
895
- new_width = min(max_dim, width)
896
- new_height = int(new_width / aspect_ratio)
897
- else:
898
- new_height = min(max_dim, height)
899
- new_width = int(new_height * aspect_ratio)
900
-
901
- new_width = (new_width // 8) * 8
902
- new_height = (new_height // 8) * 8
903
-
904
- return img.resize((new_width, new_height), Image.LANCZOS)
905
-
906
-
907
  # Build Gradio interface
908
  with gr.Blocks(css="custom.css") as demo:
909
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
910
  title = gr.HTML(
911
  """<h1><img src="https://i.imgur.com/DVoGw04.png">
912
- <span>LucasArts Style - VRAM Optimized âš¡<br><small style="
913
  font-size: 13px;
914
  display: block;
915
  font-weight: normal;
916
  opacity: 0.75;
917
- ">🔥 Fixed: Version compatibility + CUDA OOM errors resolved<br>
918
- ✨ 768px max output | 35 inference steps | Detailed error messages<br>
919
- AlbedoBase XL v2.1 + InstantID + ControlNet</small></span></h1>""",
920
  elem_id="title",
921
  )
922
  selected_state = gr.State()
@@ -928,7 +632,7 @@ AlbedoBase XL v2.1 + InstantID + ControlNet</small></span></h1>""",
928
  photo = gr.Image(label="Upload a picture (with or without faces)", interactive=True, type="pil", height=300)
929
  selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected")
930
  gallery = gr.Gallery(
931
- label="LucasArts Style",
932
  allow_preview=False,
933
  columns=4,
934
  elem_id="gallery",
@@ -949,55 +653,24 @@ AlbedoBase XL v2.1 + InstantID + ControlNet</small></span></h1>""",
949
  interactive=False, label="Generated Image", elem_id="result-image", position=0.1
950
  )
951
 
952
- error_message = gr.Textbox(
953
- label="Error Details",
954
- visible=False,
955
- elem_id="error-message",
956
- lines=5,
957
- max_lines=10
958
- )
959
-
960
  with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
961
  community_icon = gr.HTML(community_icon_html)
962
  loading_icon = gr.HTML(loading_icon_html)
963
  share_button = gr.Button("Share to community", elem_id="share-btn")
964
 
965
  with gr.Accordion("Advanced options", open=False):
966
- gr.Markdown("""
967
- ### VRAM Optimizations Active âš¡
968
- - 🎯 768x768 face detection
969
- - 📐 768px max output resolution (reduced for stability)
970
- - âš¡ 35 inference steps (balanced quality/speed)
971
- - 🔍 Detailed error messages with solutions
972
- - 💾 Reduced memory usage
973
- - ✅ Fixed version compatibility issues
974
- """)
975
- use_multiple_faces = gr.Checkbox(
976
- label="Process multiple faces separately",
977
- value=False,
978
- info="Generate separate outputs for each detected face"
979
- )
980
  negative = gr.Textbox(label="Negative Prompt")
981
- weight = gr.Slider(0, 10, value=1.0, step=0.1, label="LoRA weight")
982
- face_strength = gr.Slider(
983
- 0, 2, value=1.0, step=0.01, label="Face identity strength",
984
- info="Higher = stronger face preservation (auto-adjusted)"
985
- )
986
- image_strength = gr.Slider(
987
- 0, 1, value=0.15, step=0.01, label="Image structure strength",
988
- info="Lower = more transformation"
989
- )
990
- guidance_scale = gr.Slider(
991
- 0, 50, value=7.5, step=0.1, label="Guidance Scale",
992
- info="Auto-optimized per mode (7.5 faces, 8.5 landscapes)"
993
- )
994
- depth_control_scale = gr.Slider(
995
- 0, 1, value=0.8, step=0.01, label="Depth ControlNet strength",
996
- info="3D structure preservation (auto-optimized)"
997
- )
998
 
999
  prompt_title = gr.Markdown(
1000
- value="### Click 'Run' to generate (VRAM optimized) âš¡",
1001
  visible=True,
1002
  elem_id="selected_lora",
1003
  )
@@ -1029,7 +702,7 @@ AlbedoBase XL v2.1 + InstantID + ControlNet</small></span></h1>""",
1029
  fn=run_lora,
1030
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
1031
  guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
1032
- outputs=[result, share_group, error_message],
1033
  )
1034
 
1035
  button.click(
@@ -1040,7 +713,7 @@ AlbedoBase XL v2.1 + InstantID + ControlNet</small></span></h1>""",
1040
  fn=run_lora,
1041
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
1042
  guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
1043
- outputs=[result, share_group, error_message],
1044
  )
1045
 
1046
  share_button.click(None, [], [], js=share_js)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import spaces
 
 
 
 
 
4
  torch.jit.script = lambda f: f
5
  import timm
6
+ import time
7
 
8
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
9
  from safetensors.torch import load_file
 
28
  import cv2
29
  import torch
30
  import numpy as np
31
+ from PIL import Image
32
 
33
  from insightface.app import FaceAnalysis
34
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
 
64
 
65
  device = "cuda"
66
 
 
 
 
 
 
67
  # Cache for LoRA state dicts
 
68
  state_dicts = {}
69
+ for item in sdxl_loras_raw:
70
+ saved_name = hf_hub_download(item["repo"], item["weights"])
71
+
72
+ if not saved_name.endswith('.safetensors'):
73
+ state_dict = torch.load(saved_name)
74
+ else:
75
+ state_dict = load_file(saved_name)
76
+
77
+ state_dicts[item["repo"]] = {
78
+ "saved_name": saved_name,
79
+ "state_dict": state_dict
80
+ }
 
 
 
 
 
81
 
82
  sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
83
 
84
  # Download models
85
+ hf_hub_download(
86
+ repo_id="InstantX/InstantID",
87
+ filename="ControlNetModel/config.json",
88
+ local_dir="/data/checkpoints",
89
+ )
90
+ hf_hub_download(
91
+ repo_id="InstantX/InstantID",
92
+ filename="ControlNetModel/diffusion_pytorch_model.safetensors",
93
+ local_dir="/data/checkpoints",
94
+ )
95
+ hf_hub_download(
96
+ repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
97
+ )
98
+ hf_hub_download(
99
+ repo_id="latent-consistency/lcm-lora-sdxl",
100
+ filename="pytorch_lora_weights.safetensors",
101
+ local_dir="/data/checkpoints",
102
+ )
 
 
 
 
 
 
 
103
 
104
  # Download antelopev2
105
+ antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
106
+ print(antelope_download)
107
+ app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
108
+ app.prepare(ctx_id=0, det_size=(768, 768))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # Prepare models
111
  face_adapter = f'/data/checkpoints/ip-adapter.bin'
112
  controlnet_path = f'/data/checkpoints/ControlNetModel'
113
 
 
 
114
  st = time.time()
115
+ identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
116
+ zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
117
+ et = time.time()
118
+ print('Loading ControlNet took: ', et - st, 'seconds')
 
 
 
 
119
 
 
 
120
  st = time.time()
121
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
122
+ et = time.time()
123
+ print('Loading VAE took: ', et - st, 'seconds')
 
 
 
 
124
 
 
 
125
  st = time.time()
126
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
127
+ "SG161222/RealVisXL_V5.0",
128
+ vae=vae,
129
+ controlnet=[identitynet, zoedepthnet],
130
+ torch_dtype=torch.float16
131
+ )
132
+
133
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
134
+ pipe.load_ip_adapter_instantid(face_adapter)
135
+ pipe.set_ip_adapter_scale(0.9)
136
+ et = time.time()
137
+ print('Loading pipeline took: ', et - st, 'seconds')
138
 
 
 
 
 
 
 
 
 
 
 
 
139
  st = time.time()
140
+ compel = Compel(
141
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
142
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
143
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
144
+ requires_pooled=[False, True]
145
+ )
146
+ et = time.time()
147
+ print('Loading Compel took: ', et - st, 'seconds')
 
 
 
 
148
 
 
 
149
  st = time.time()
150
+ zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
151
+ et = time.time()
152
+ print('Loading Zoe took: ', et - st, 'seconds')
153
+ zoe.to(device)
154
+ pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  last_lora = ""
157
  last_fused = False
158
  lora_archive = "/data"
159
 
160
+ # Improved face detection with multi-face support
161
+ def detect_faces(face_image, use_multiple_faces=False):
162
+ """
163
+ Detect faces in the image
164
+ Returns: list of face info dictionaries, or empty list if no faces
165
+ """
166
  try:
167
+ face_info_list = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
 
 
 
 
168
 
169
+ if not face_info_list or len(face_info_list) == 0:
170
+ print("No faces detected")
171
+ return []
 
 
 
 
 
 
 
172
 
173
+ # Sort faces by size (largest first)
174
+ face_info_list = sorted(
175
+ face_info_list,
176
+ key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]),
177
+ reverse=True
178
+ )
179
 
180
+ if use_multiple_faces:
181
+ print(f"Detected {len(face_info_list)} faces")
182
+ return face_info_list
183
+ else:
184
+ print(f"Using largest face (detected {len(face_info_list)} total)")
185
+ return [face_info_list[0]]
186
+
187
  except Exception as e:
188
+ print(f"Face detection error: {e}")
189
+ return []
 
 
190
 
191
+ def process_face_embeddings(face_info_list):
192
+ """
193
+ Process face embeddings - average multiple faces or return single face
194
+ """
195
  if not face_info_list:
196
+ return None
197
 
198
+ if len(face_info_list) == 1:
199
+ return face_info_list[0]['embedding']
200
+
201
+ # Average embeddings for multiple faces
202
  embeddings = [face_info['embedding'] for face_info in face_info_list]
203
+ avg_embedding = np.mean(embeddings, axis=0)
204
+ return avg_embedding
205
 
206
  def create_face_kps_image(face_image, face_info_list):
207
+ """
208
+ Create keypoints image from face info
209
+ """
210
  if not face_info_list:
211
  return face_image
212
 
213
+ # For multiple faces, draw all keypoints
214
  if len(face_info_list) > 1:
215
  return draw_multiple_kps(face_image, [f['kps'] for f in face_info_list])
216
  else:
217
  return draw_kps(face_image, face_info_list[0]['kps'])
218
 
 
219
  def draw_multiple_kps(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
220
+ """
221
+ Draw keypoints for multiple faces
222
+ """
223
  stickwidth = 4
224
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
225
 
226
  w, h = image_pil.size
227
  out_img = np.zeros([h, w, 3])
228
 
229
+ for kps in kps_list:
230
  kps = np.array(kps)
 
231
 
232
  for i in range(len(limbSeq)):
233
  index = limbSeq[i]
234
+ color = color_list[index[0]]
235
 
236
  x = kps[index][:, 0]
237
  y = kps[index][:, 1]
 
245
  out_img = (out_img * 0.6).astype(np.uint8)
246
 
247
  for idx_kp, kp in enumerate(kps):
248
+ color = color_list[idx_kp]
249
  x, y = kp
250
  out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
251
 
252
  out_img_pil = Image.fromarray(out_img.astype(np.uint8))
253
  return out_img_pil
254
 
 
255
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
256
  lora_repo = sdxl_loras[selected_state.index]["repo"]
257
  new_placeholder = "Type a prompt to use your selected LoRA"
258
  weight_name = sdxl_loras[selected_state.index]["weights"]
259
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
260
 
261
  for lora_list in lora_defaults:
262
  if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
263
+ face_strength = lora_list.get("face_strength", 0.9)
264
+ image_strength = lora_list.get("image_strength", 0.2)
265
+ weight = lora_list.get("weight", 0.95)
266
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
267
  negative = lora_list.get("negative", "")
268
 
 
283
  selected_state
284
  )
285
 
 
286
  def check_selected(selected_state, custom_lora):
287
  if not selected_state and not custom_lora:
288
  raise gr.Error("You must select a style")
289
 
290
+ def resize_image_aspect_ratio(img, max_dim=1280):
291
+ width, height = img.size
292
+ aspect_ratio = width / height
293
+
294
+ if aspect_ratio >= 1: # Landscape or square
295
+ new_width = min(max_dim, width)
296
+ new_height = int(new_width / aspect_ratio)
297
+ else: # Portrait
298
+ new_height = min(max_dim, height)
299
+ new_width = int(new_height * aspect_ratio)
300
+
301
+ new_width = (new_width // 8) * 8
302
+ new_height = (new_height // 8) * 8
303
+
304
+ return img.resize((new_width, new_height), Image.LANCZOS)
305
+
306
+
307
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength,
308
+ guidance_scale, depth_control_scale, sdxl_loras, custom_lora, use_multiple_faces=False,
309
+ progress=gr.Progress(track_tqdm=True)):
310
+ """
311
+ Enhanced run_lora with support for:
312
+ - No faces (landscape mode)
313
+ - Multiple faces
314
+ - Improved results
315
+ """
316
+ print("Custom LoRA:", custom_lora)
317
+ custom_lora_path = custom_lora[0] if custom_lora else None
318
+ selected_state_index = selected_state.index if selected_state else -1
319
+
320
+ st = time.time()
321
+ face_image = resize_image_aspect_ratio(face_image)
322
+
323
+ # Enhanced face detection
324
+ face_info_list = detect_faces(face_image, use_multiple_faces)
325
+ face_detected = len(face_info_list) > 0
326
+
327
+ if face_detected:
328
+ face_emb = process_face_embeddings(face_info_list)
329
+ face_kps = create_face_kps_image(face_image, face_info_list)
330
+ print(f"Processing with {len(face_info_list)} face(s)")
331
+ else:
332
+ face_emb = None
333
+ face_kps = face_image
334
+ print("No faces detected - using landscape/depth mode only")
335
+
336
+ et = time.time()
337
+ print('Face processing took:', et - st, 'seconds')
338
+
339
+ st = time.time()
340
+
341
+ # Enhanced prompt processing
342
+ if custom_lora_path and custom_lora[1]:
343
+ prompt = f"{prompt} {custom_lora[1]}"
344
+ else:
345
+ for lora_list in lora_defaults:
346
+ if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
347
+ prompt_full = lora_list.get("prompt", None)
348
+ if prompt_full:
349
+ prompt = prompt_full.replace("<subject>", prompt)
350
+
351
+ print("Prompt:", prompt)
352
+ if prompt == "":
353
+ prompt = "a beautiful scene" if not face_detected else "a person"
354
+ print(f"Executing prompt: {prompt}")
355
+
356
+ if negative == "":
357
+ # Enhanced negative prompt for better quality
358
+ negative = "worst quality, low quality, blurry, distorted, deformed" if not face_detected else None
359
+
360
+ print("Custom Loaded LoRA:", custom_lora_path)
361
+
362
+ if not selected_state and not custom_lora_path:
363
+ raise gr.Error("You must select a style")
364
+ elif custom_lora_path:
365
+ repo_name = custom_lora_path
366
+ full_path_lora = custom_lora_path
367
+ else:
368
+ repo_name = sdxl_loras[selected_state_index]["repo"]
369
+ full_path_lora = state_dicts[repo_name]["saved_name"]
370
+
371
+ repo_name = repo_name.rstrip("/").lower()
372
+
373
+ print("Full path LoRA", full_path_lora)
374
+
375
+ et = time.time()
376
+ print('Prompt processing took:', et - st, 'seconds')
377
+
378
+ # Adjust parameters based on face detection
379
+ if not face_detected:
380
+ # For landscape/no face mode, reduce face strength and increase depth control
381
+ face_strength = 0.0
382
+ depth_control_scale = max(depth_control_scale, 0.9)
383
+ image_strength = min(image_strength, 0.4)
384
+ print("Adjusted parameters for no-face mode")
385
+
386
+ st = time.time()
387
+ image = generate_image(
388
+ prompt, negative, face_emb, face_image, face_kps, image_strength,
389
+ guidance_scale, face_strength, depth_control_scale, repo_name,
390
+ full_path_lora, lora_scale, sdxl_loras, selected_state_index, face_detected, st
391
+ )
392
+
393
+ return (face_image, image), gr.update(visible=True)
394
+
395
+ run_lora.zerogpu = True
396
+
397
+
398
+ @spaces.GPU(duration=75)
399
+ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale,
400
+ face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale,
401
+ sdxl_loras, selected_state_index, face_detected, st):
402
+ global last_fused, last_lora
403
+
404
+ print("Loaded state dict:", loaded_state_dict)
405
+ print("Last LoRA:", last_lora, "| Current LoRA:", repo_name)
406
+
407
+ # Prepare control images and scales based on face detection
408
+ if face_detected:
409
+ control_images = [face_kps, zoe(face_image)]
410
+ control_scales = [face_strength, depth_control_scale]
411
+ else:
412
+ # Only use depth control for landscapes
413
+ control_images = [zoe(face_image)]
414
+ control_scales = [depth_control_scale]
415
+
416
+ # Handle custom LoRA from HuggingFace
417
+ if repo_name.startswith("https://huggingface.co"):
418
+ repo_id = repo_name.split("huggingface.co/")[-1]
419
+ fs = HfFileSystem()
420
+ files = fs.ls(repo_id, detail=False)
421
+ safetensors_files = [f for f in files if f.endswith(".safetensors")]
422
+
423
+ if not safetensors_files:
424
+ raise gr.Error("No .safetensors file found in this Hugging Face repository.")
425
+
426
+ weight_file = safetensors_files[0]
427
+ full_path_lora = hf_hub_download(repo_id=repo_id, filename=weight_file, repo_type="model")
428
+ else:
429
+ full_path_lora = loaded_state_dict
430
+
431
+ # Improved LoRA loading and caching
432
+ if last_lora != repo_name:
433
+ if last_fused:
434
+ pipe.unfuse_lora()
435
+ pipe.unload_lora_weights()
436
+ pipe.unload_textual_inversion()
437
+
438
+ # Load LoRA with better error handling
439
+ try:
440
+ pipe.load_lora_weights(full_path_lora)
441
+ pipe.fuse_lora(lora_scale)
442
+ last_fused = True
443
+
444
+ # Handle pivotal tuning embeddings
445
+ is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
446
+ if is_pivotal:
447
+ text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
448
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
449
+ state_dict_embedding = load_file(embedding_path)
450
+ pipe.load_textual_inversion(
451
+ state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"],
452
+ token=["<s0>", "<s1>"],
453
+ text_encoder=pipe.text_encoder,
454
+ tokenizer=pipe.tokenizer
455
+ )
456
+ pipe.load_textual_inversion(
457
+ state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"],
458
+ token=["<s0>", "<s1>"],
459
+ text_encoder=pipe.text_encoder_2,
460
+ tokenizer=pipe.tokenizer_2
461
+ )
462
+ except Exception as e:
463
+ print(f"Error loading LoRA: {e}")
464
+ raise gr.Error(f"Failed to load LoRA: {str(e)}")
465
+
466
+ print("Processing prompt...")
467
+ conditioning, pooled = compel(prompt)
468
+ negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
469
+
470
+ # Enhanced generation parameters
471
+ num_inference_steps = 40 # Increased for better quality
472
+
473
+ print("Generating image...")
474
+ image = pipe(
475
+ prompt_embeds=conditioning,
476
+ pooled_prompt_embeds=pooled,
477
+ negative_prompt_embeds=negative_conditioning,
478
+ negative_pooled_prompt_embeds=negative_pooled,
479
+ width=face_image.width,
480
+ height=face_image.height,
481
+ image_embeds=face_emb if face_detected else None,
482
+ image=face_image,
483
+ strength=1-image_strength,
484
+ control_image=control_images,
485
+ num_inference_steps=num_inference_steps,
486
+ guidance_scale=guidance_scale,
487
+ controlnet_conditioning_scale=control_scales,
488
+ ).images[0]
489
+
490
+ last_lora = repo_name
491
+ return image
492
 
493
  def shuffle_gallery(sdxl_loras):
494
  random.shuffle(sdxl_loras)
495
  return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
496
 
 
497
  def classify_gallery(sdxl_loras):
498
  sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
499
  return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
500
 
 
501
  def swap_gallery(order, sdxl_loras):
502
  if(order == "random"):
503
  return shuffle_gallery(sdxl_loras)
504
  else:
505
  return classify_gallery(sdxl_loras)
506
 
 
507
  def deselect():
508
  return gr.Gallery(selected_index=None)
509
 
 
510
  def get_huggingface_safetensors(link):
511
  split_link = link.split("/")
512
  if(len(split_link) == 2):
 
530
  raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
531
  return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
532
 
 
533
  def get_civitai_safetensors(link):
534
  link_split = link.split("civitai.com/")
535
  pattern = re.compile(r'models\/(\d+)')
 
574
  raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
575
  return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
576
 
 
577
  def check_custom_model(link):
578
  if(link.startswith("https://")):
579
  if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
 
584
  else:
585
  return get_huggingface_safetensors(link)
586
 
 
587
  def load_custom_lora(link):
588
  if(link):
589
  try:
 
607
  else:
608
  return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
609
 
 
610
  def remove_custom_lora():
611
  return "", gr.update(visible=False), gr.update(visible=False), None
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  # Build Gradio interface
614
  with gr.Blocks(css="custom.css") as demo:
615
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
616
  title = gr.HTML(
617
  """<h1><img src="https://i.imgur.com/DVoGw04.png">
618
+ <span>Face to All - Enhanced<br><small style="
619
  font-size: 13px;
620
  display: block;
621
  font-weight: normal;
622
  opacity: 0.75;
623
+ ">🔥 Supports: No faces (landscape), Multiple faces, Improved quality, Custom LoRAs<br> diffusers InstantID + ControlNet</small></span></h1>""",
 
 
624
  elem_id="title",
625
  )
626
  selected_state = gr.State()
 
632
  photo = gr.Image(label="Upload a picture (with or without faces)", interactive=True, type="pil", height=300)
633
  selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected")
634
  gallery = gr.Gallery(
635
+ label="Pick a style from the gallery",
636
  allow_preview=False,
637
  columns=4,
638
  elem_id="gallery",
 
653
  interactive=False, label="Generated Image", elem_id="result-image", position=0.1
654
  )
655
 
 
 
 
 
 
 
 
 
656
  with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
657
  community_icon = gr.HTML(community_icon_html)
658
  loading_icon = gr.HTML(loading_icon_html)
659
  share_button = gr.Button("Share to community", elem_id="share-btn")
660
 
661
  with gr.Accordion("Advanced options", open=False):
662
+ use_multiple_faces = gr.Checkbox(label="Use multiple faces (if detected)", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  negative = gr.Textbox(label="Negative Prompt")
664
+ weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight")
665
+ face_strength = gr.Slider(0, 2, value=0.9, step=0.01, label="Face strength",
666
+ info="Higher values increase face likeness (auto-adjusted for no-face images)")
667
+ image_strength = gr.Slider(0, 1, value=0.20, step=0.01, label="Image strength",
668
+ info="Higher values increase similarity with original structure/colors")
669
+ guidance_scale = gr.Slider(0, 50, value=8, step=0.1, label="Guidance Scale")
670
+ depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strength")
 
 
 
 
 
 
 
 
 
 
671
 
672
  prompt_title = gr.Markdown(
673
+ value="### Click on a LoRA in the gallery to select it",
674
  visible=True,
675
  elem_id="selected_lora",
676
  )
 
702
  fn=run_lora,
703
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
704
  guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
705
+ outputs=[result, share_group],
706
  )
707
 
708
  button.click(
 
713
  fn=run_lora,
714
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
715
  guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
716
+ outputs=[result, share_group],
717
  )
718
 
719
  share_button.click(None, [], [], js=share_js)