primerz commited on
Commit
171e0fc
·
verified ·
1 Parent(s): 8064305

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +46 -45
models.py CHANGED
@@ -13,7 +13,7 @@ from diffusers import (
13
  from diffusers.models.attention_processor import AttnProcessor2_0
14
  from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
- from controlnet_aux import MidasDetector, LeresDetector
17
  from huggingface_hub import hf_hub_download
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
@@ -82,26 +82,30 @@ def load_face_analysis():
82
 
83
 
84
  def load_depth_detector():
85
- """Load LeRes++ Depth detector (superior to Midas/Zoe for detailed depth estimation)."""
86
- print("Loading LeRes++ Depth detector...")
87
  try:
88
- from controlnet_aux import LeresDetector
89
- leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators")
90
- leres_depth.to(device)
91
- print(" [OK] LeRes++ Depth loaded successfully (+15-20% accuracy over Midas/Zoe)")
92
- return leres_depth, True
93
  except Exception as e:
94
- print(f" [WARNING] LeRes++ Depth not available: {e}")
95
- print(" Attempting fallback to Midas Depth...")
96
- try:
97
- midas_depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
98
- midas_depth.to(device)
99
- print(" [OK] Midas Depth loaded as fallback")
100
- return midas_depth, True
101
- except Exception as e2:
102
- print(f" [ERROR] All depth detectors failed: {e2}")
103
- return None, False
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def load_controlnets():
107
  """Load ControlNet models."""
@@ -111,6 +115,19 @@ def load_controlnets():
111
  torch_dtype=dtype
112
  ).to(device)
113
  print(" [OK] ControlNet Depth loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  print("Loading InstantID ControlNet...")
116
  try:
@@ -120,10 +137,12 @@ def load_controlnets():
120
  torch_dtype=dtype
121
  ).to(device)
122
  print(" [OK] InstantID ControlNet loaded successfully")
123
- return controlnet_depth, controlnet_instantid, True
 
124
  except Exception as e:
125
  print(f" [WARNING] InstantID ControlNet not available: {e}")
126
- return controlnet_depth, None, False
 
127
 
128
 
129
  def load_image_encoder():
@@ -150,7 +169,7 @@ def load_sdxl_pipeline(controlnets):
150
 
151
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
152
  model_path,
153
- controlnet=controlnets,
154
  torch_dtype=dtype,
155
  use_safetensors=True
156
  ).to(device)
@@ -161,7 +180,7 @@ def load_sdxl_pipeline(controlnets):
161
  print(" Using default SDXL base model")
162
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
163
  "stabilityai/stable-diffusion-xl-base-1.0",
164
- controlnet=controlnets,
165
  torch_dtype=dtype,
166
  use_safetensors=True
167
  ).to(device)
@@ -173,7 +192,7 @@ def load_lora(pipe):
173
  print("Loading LORA (retroart) from HuggingFace Hub...")
174
  try:
175
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
176
- pipe.load_lora_weights(lora_path)
177
  print(f" [OK] LORA loaded successfully")
178
  return True
179
  except Exception as e:
@@ -285,7 +304,7 @@ def setup_ip_adapter(pipe, image_encoder):
285
 
286
  print(" [OK] IP-Adapter fully loaded with InstantID architecture")
287
  print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
288
- print(f" - Face embeddings: 512D → 16x2048D")
289
 
290
  return image_proj_model, True
291
 
@@ -297,37 +316,19 @@ def setup_ip_adapter(pipe, image_encoder):
297
 
298
 
299
  def setup_compel(pipe):
300
- """Setup Compel for better SDXL prompt handling with robust error handling."""
301
  print("Setting up Compel for enhanced prompt processing...")
302
  try:
303
- # FIXED: Handle SDXL dual tokenizer setup more carefully
304
  compel = Compel(
305
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
306
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
307
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
308
- requires_pooled=[False, True],
309
- padding_get_round_multiple=False # Disable padding that might cause mismatches
310
  )
311
- print(" [OK] Compel loaded successfully with SDXL dual tokenizers")
312
  return compel, True
313
- except TypeError:
314
- # Fallback for older Compel versions without padding parameter
315
- try:
316
- compel = Compel(
317
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
318
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
319
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
320
- requires_pooled=[False, True]
321
- )
322
- print(" [OK] Compel loaded (standard config)")
323
- return compel, True
324
- except Exception as e:
325
- print(f" [WARNING] Compel not available: {e}")
326
- print(" [INFO] Will use standard prompt encoding instead")
327
- return None, False
328
  except Exception as e:
329
  print(f" [WARNING] Compel not available: {e}")
330
- print(" [INFO] Will use standard prompt encoding instead")
331
  return None, False
332
 
333
 
 
13
  from diffusers.models.attention_processor import AttnProcessor2_0
14
  from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
+ from controlnet_aux import ZoeDetector, OpenposeDetector # <-- NEW
17
  from huggingface_hub import hf_hub_download
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
 
82
 
83
 
84
  def load_depth_detector():
85
+ """Load Zoe Depth detector."""
86
+ print("Loading Zoe Depth detector...")
87
  try:
88
+ zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
89
+ zoe_depth.to(device)
90
+ print(" [OK] Zoe Depth loaded successfully")
91
+ return zoe_depth, True
 
92
  except Exception as e:
93
+ print(f" [WARNING] Zoe Depth not available: {e}")
94
+ return None, False
 
 
 
 
 
 
 
 
95
 
96
+ # --- NEW FUNCTION ---
97
+ def load_openpose_detector():
98
+ """Load OpenPose detector."""
99
+ print("Loading OpenPose detector...")
100
+ try:
101
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
102
+ openpose.to(device)
103
+ print(" [OK] OpenPose loaded successfully")
104
+ return openpose, True
105
+ except Exception as e:
106
+ print(f" [WARNING] OpenPose not available: {e}")
107
+ return None, False
108
+ # --- END NEW FUNCTION ---
109
 
110
  def load_controlnets():
111
  """Load ControlNet models."""
 
115
  torch_dtype=dtype
116
  ).to(device)
117
  print(" [OK] ControlNet Depth loaded")
118
+
119
+ # --- NEW: Load OpenPose ControlNet ---
120
+ print("Loading ControlNet OpenPose model...")
121
+ try:
122
+ controlnet_openpose = ControlNetModel.from_pretrained(
123
+ "diffusers/controlnet-openpose-sdxl-1.0",
124
+ torch_dtype=dtype
125
+ ).to(device)
126
+ print(" [OK] ControlNet OpenPose loaded")
127
+ except Exception as e:
128
+ print(f" [WARNING] ControlNet OpenPose not available: {e}")
129
+ controlnet_openpose = None
130
+ # --- END NEW ---
131
 
132
  print("Loading InstantID ControlNet...")
133
  try:
 
137
  torch_dtype=dtype
138
  ).to(device)
139
  print(" [OK] InstantID ControlNet loaded successfully")
140
+ # Return all three models
141
+ return controlnet_depth, controlnet_instantid, controlnet_openpose, True
142
  except Exception as e:
143
  print(f" [WARNING] InstantID ControlNet not available: {e}")
144
+ # Return models, indicating InstantID failure
145
+ return controlnet_depth, None, controlnet_openpose, False
146
 
147
 
148
  def load_image_encoder():
 
169
 
170
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
171
  model_path,
172
+ controlnet=controlnets, # Pass the list of 3 controlnets
173
  torch_dtype=dtype,
174
  use_safetensors=True
175
  ).to(device)
 
180
  print(" Using default SDXL base model")
181
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
182
  "stabilityai/stable-diffusion-xl-base-1.0",
183
+ controlnet=controlnets, # Pass the list of 3 controlnets
184
  torch_dtype=dtype,
185
  use_safetensors=True
186
  ).to(device)
 
192
  print("Loading LORA (retroart) from HuggingFace Hub...")
193
  try:
194
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
195
+ pipe.load_lora_weights(lora_path, adapter_name="retroart")
196
  print(f" [OK] LORA loaded successfully")
197
  return True
198
  except Exception as e:
 
304
 
305
  print(" [OK] IP-Adapter fully loaded with InstantID architecture")
306
  print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
307
+ print(f" - Face embeddings: 512D -> 16x2048D")
308
 
309
  return image_proj_model, True
310
 
 
316
 
317
 
318
  def setup_compel(pipe):
319
+ """Setup Compel for better SDXL prompt handling."""
320
  print("Setting up Compel for enhanced prompt processing...")
321
  try:
 
322
  compel = Compel(
323
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
324
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
325
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
326
+ requires_pooled=[False, True]
 
327
  )
328
+ print(" [OK] Compel loaded successfully")
329
  return compel, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  except Exception as e:
331
  print(f" [WARNING] Compel not available: {e}")
 
332
  return None, False
333
 
334