primerz commited on
Commit
a0ff65c
·
verified ·
1 Parent(s): 174c055

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +48 -84
models.py CHANGED
@@ -1,15 +1,12 @@
1
  """
2
- Model loading for Pixagram - WORKING VERSION
3
- Following examplewithface.py pattern with modern diffusers compatibility
 
4
  """
5
  import torch
6
  import time
7
  import os
8
- from diffusers import (
9
- ControlNetModel,
10
- AutoencoderKL,
11
- LCMScheduler
12
- )
13
  from insightface.app import FaceAnalysis
14
  from controlnet_aux import ZoeDetector
15
  from huggingface_hub import hf_hub_download, snapshot_download
@@ -28,7 +25,6 @@ from config import (
28
 
29
 
30
  def download_model_with_retry(repo_id, filename, max_retries=None):
31
- """Download model with retry logic"""
32
  if max_retries is None:
33
  max_retries = DOWNLOAD_CONFIG['max_retries']
34
 
@@ -40,7 +36,6 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
40
 
41
  path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
42
  return path
43
-
44
  except Exception as e:
45
  if attempt < max_retries - 1:
46
  time.sleep(DOWNLOAD_CONFIG['retry_delay'])
@@ -50,7 +45,7 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
50
 
51
 
52
  def load_face_analysis():
53
- """Load face analysis - examplewithface.py line 113"""
54
  print("Loading face analysis...")
55
  try:
56
  snapshot_download(
@@ -58,7 +53,6 @@ def load_face_analysis():
58
  local_dir=FACE_DETECTION_CONFIG['local_dir']
59
  )
60
 
61
- # examplewithface.py line 113
62
  app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
63
  app.prepare(ctx_id=0, det_size=(640, 640))
64
 
@@ -70,20 +64,19 @@ def load_face_analysis():
70
 
71
 
72
  def load_depth_detector():
73
- """Load Zoe Depth - examplewithface.py line 151"""
74
  print("Loading Zoe Depth...")
75
  try:
76
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
77
- zoe = zoe.to("cpu")
78
  print(" [OK] Zoe Depth loaded")
79
  return zoe, True
80
  except Exception as e:
81
- print(f" [WARNING] Zoe Depth unavailable: {e}")
82
  return None, False
83
 
84
 
85
  def load_controlnets():
86
- """Load ControlNets - examplewithface.py lines 122-126"""
87
  print("Loading ControlNets...")
88
 
89
  identitynet = ControlNetModel.from_pretrained(
@@ -91,23 +84,23 @@ def load_controlnets():
91
  subfolder="ControlNetModel",
92
  torch_dtype=dtype
93
  )
94
- print(" [OK] InstantID ControlNet loaded")
95
 
96
  zoedepthnet = ControlNetModel.from_pretrained(
97
  "diffusers/controlnet-zoe-depth-sdxl-1.0",
98
  torch_dtype=dtype
99
  )
100
- print(" [OK] Zoe Depth ControlNet loaded")
101
 
102
  return identitynet, zoedepthnet
103
 
104
 
105
  def load_sdxl_pipeline(controlnets):
106
  """
107
- Load pipeline - examplewithface.py lines 128-145
108
- KEY: Pass controlnets as LIST directly, NO wrapper
109
  """
110
- print("Loading SDXL pipeline...")
111
 
112
  # Load VAE (line 128)
113
  vae = AutoencoderKL.from_pretrained(
@@ -116,103 +109,78 @@ def load_sdxl_pipeline(controlnets):
116
  )
117
  print(" [OK] VAE loaded")
118
 
119
- # Load pipeline (line 134) - controlnets as list!
120
  pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
121
  "frankjoshua/albedobaseXL_v21",
122
  vae=vae,
123
- controlnet=controlnets, # Direct list!
124
  torch_dtype=dtype
125
  )
 
126
 
127
- # LCM scheduler (user requested LCM)
128
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
129
- print(" [OK] LCM scheduler set")
130
 
131
- # Load IP-Adapter (line 139)
132
  ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
133
  pipe.load_ip_adapter_instantid(ip_adapter_path)
134
  pipe.set_ip_adapter_scale(0.8)
135
  print(" [OK] IP-Adapter loaded")
136
 
137
- # Move to device
138
  pipe = pipe.to(device)
139
-
140
- print(" [OK] Pipeline ready")
141
  return pipe, True
142
 
143
 
144
- # Global LoRA tracking
145
- loaded_lora_path = None
146
- current_lora_scale = None
147
 
148
 
149
  def load_lora(pipe):
150
- """
151
- Load LoRA - Don't fuse yet, will fuse per-generation
152
- """
153
  print("Loading LoRA...")
154
- global loaded_lora_path
155
 
156
  try:
157
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
158
- loaded_lora_path = lora_path
159
-
160
- print(f" [OK] LoRA path stored: {lora_path}")
161
- print(f" [INFO] LoRA will be fused before each generation")
162
  return True
163
  except Exception as e:
164
- print(f" [WARNING] LoRA load failed: {e}")
165
- loaded_lora_path = None
166
  return False
167
 
168
 
169
  def fuse_lora_with_scale(pipe, lora_scale):
170
  """
171
- Fuse LoRA with scale for generation
172
- Modern approach: Don't fuse, use cross_attention_kwargs instead
173
  """
174
- global loaded_lora_path, current_lora_scale
175
 
176
- if loaded_lora_path is None:
177
- print(" [WARNING] No LoRA available")
178
  return False
179
 
180
  try:
181
- # Check if we need to reload
182
- if current_lora_scale is None or current_lora_scale != lora_scale:
183
- print(f" [LORA] Loading LoRA with scale {lora_scale}...")
184
-
185
- # Unload previous if exists
186
- try:
187
- pipe.unload_lora_weights()
188
- except:
189
- pass
190
-
191
- # Load LoRA weights from path
192
- pipe.load_lora_weights(loaded_lora_path)
193
- current_lora_scale = lora_scale
194
-
195
- print(f" [OK] LoRA loaded with scale {lora_scale}")
196
- print(f" [INFO] Scale will be applied via cross_attention_kwargs at inference")
197
- else:
198
- print(f" [INFO] LoRA already loaded with scale {lora_scale}")
199
 
200
- return True
 
 
 
201
 
 
202
  except Exception as e:
203
- print(f" [ERROR] LoRA loading failed: {e}")
204
- import traceback
205
- traceback.print_exc()
206
  return False
207
 
208
 
209
- def get_lora_scale():
210
- """Get current LoRA scale for cross_attention_kwargs"""
211
- return current_lora_scale if current_lora_scale is not None else 1.0
212
-
213
-
214
  def setup_compel(pipe):
215
- """Setup Compel - examplewithface.py line 145"""
216
  print("Setting up Compel...")
217
  try:
218
  compel = Compel(
@@ -221,7 +189,7 @@ def setup_compel(pipe):
221
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
222
  requires_pooled=[False, True]
223
  )
224
- print(" [OK] Compel loaded")
225
  return compel, True
226
  except Exception as e:
227
  print(f" [WARNING] Compel unavailable: {e}")
@@ -229,12 +197,10 @@ def setup_compel(pipe):
229
 
230
 
231
  def setup_scheduler(pipe):
232
- """Already done in load_sdxl_pipeline"""
233
  pass
234
 
235
 
236
  def optimize_pipeline(pipe):
237
- """Apply optimizations"""
238
  if device == "cuda":
239
  try:
240
  pipe.enable_xformers_memory_efficient_attention()
@@ -249,31 +215,29 @@ def optimize_pipeline(pipe):
249
 
250
 
251
  def load_caption_model():
252
- """Load caption model"""
253
  print("Loading caption model...")
254
  try:
255
  from transformers import AutoProcessor, AutoModelForCausalLM
256
  processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
257
  model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu")
258
- print(" [OK] GIT-Large loaded")
259
  return processor, model, True, 'git'
260
  except:
261
  try:
262
  from transformers import BlipProcessor, BlipForConditionalGeneration
263
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
264
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu")
265
- print(" [OK] BLIP loaded")
266
  return processor, model, True, 'blip'
267
  except:
268
  return None, None, False, 'none'
269
 
270
 
271
  def set_clip_skip(pipe):
272
- """Set CLIP skip"""
273
  if hasattr(pipe, 'text_encoder'):
274
- print(f" [OK] CLIP skip set to {CLIP_SKIP}")
275
 
276
 
277
- __all__ = ['draw_kps', 'fuse_lora_with_scale', 'get_lora_scale']
278
 
279
- print("[OK] Models ready (examplewithface.py pattern + modern API)")
 
1
  """
2
+ Models.py - Following examplewithface.py EXACTLY
3
+ NO MultiControlNetModel wrapper!
4
+ NO fuse_lora with scale!
5
  """
6
  import torch
7
  import time
8
  import os
9
+ from diffusers import ControlNetModel, AutoencoderKL, LCMScheduler
 
 
 
 
10
  from insightface.app import FaceAnalysis
11
  from controlnet_aux import ZoeDetector
12
  from huggingface_hub import hf_hub_download, snapshot_download
 
25
 
26
 
27
  def download_model_with_retry(repo_id, filename, max_retries=None):
 
28
  if max_retries is None:
29
  max_retries = DOWNLOAD_CONFIG['max_retries']
30
 
 
36
 
37
  path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
38
  return path
 
39
  except Exception as e:
40
  if attempt < max_retries - 1:
41
  time.sleep(DOWNLOAD_CONFIG['retry_delay'])
 
45
 
46
 
47
  def load_face_analysis():
48
+ """examplewithface.py line 113"""
49
  print("Loading face analysis...")
50
  try:
51
  snapshot_download(
 
53
  local_dir=FACE_DETECTION_CONFIG['local_dir']
54
  )
55
 
 
56
  app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
57
  app.prepare(ctx_id=0, det_size=(640, 640))
58
 
 
64
 
65
 
66
  def load_depth_detector():
67
+ """examplewithface.py line 151"""
68
  print("Loading Zoe Depth...")
69
  try:
70
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
 
71
  print(" [OK] Zoe Depth loaded")
72
  return zoe, True
73
  except Exception as e:
74
+ print(f" [WARNING] Zoe unavailable: {e}")
75
  return None, False
76
 
77
 
78
  def load_controlnets():
79
+ """examplewithface.py lines 122-126"""
80
  print("Loading ControlNets...")
81
 
82
  identitynet = ControlNetModel.from_pretrained(
 
84
  subfolder="ControlNetModel",
85
  torch_dtype=dtype
86
  )
87
+ print(" [OK] InstantID ControlNet")
88
 
89
  zoedepthnet = ControlNetModel.from_pretrained(
90
  "diffusers/controlnet-zoe-depth-sdxl-1.0",
91
  torch_dtype=dtype
92
  )
93
+ print(" [OK] Zoe Depth ControlNet")
94
 
95
  return identitynet, zoedepthnet
96
 
97
 
98
  def load_sdxl_pipeline(controlnets):
99
  """
100
+ examplewithface.py lines 128-145
101
+ CRITICAL: Pass controlnets as LIST - NO MultiControlNetModel!
102
  """
103
+ print("Loading pipeline...")
104
 
105
  # Load VAE (line 128)
106
  vae = AutoencoderKL.from_pretrained(
 
109
  )
110
  print(" [OK] VAE loaded")
111
 
112
+ # Create pipeline (line 134) - controlnets as LIST!
113
  pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
114
  "frankjoshua/albedobaseXL_v21",
115
  vae=vae,
116
+ controlnet=controlnets, # LIST [identitynet, zoedepthnet] - NO WRAPPER!
117
  torch_dtype=dtype
118
  )
119
+ print(" [OK] Pipeline created with direct controlnet list")
120
 
121
+ # LCM scheduler
122
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
123
+ print(" [OK] LCM scheduler")
124
 
125
+ # IP-Adapter (line 139)
126
  ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
127
  pipe.load_ip_adapter_instantid(ip_adapter_path)
128
  pipe.set_ip_adapter_scale(0.8)
129
  print(" [OK] IP-Adapter loaded")
130
 
 
131
  pipe = pipe.to(device)
132
+ print(" [OK] Pipeline ready (following examplewithface.py EXACTLY)")
 
133
  return pipe, True
134
 
135
 
136
+ # Global LoRA state
137
+ lora_path_cached = None
 
138
 
139
 
140
  def load_lora(pipe):
141
+ """Load LoRA - store path for later use"""
 
 
142
  print("Loading LoRA...")
143
+ global lora_path_cached
144
 
145
  try:
146
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
147
+ lora_path_cached = lora_path
148
+ print(f" [OK] LoRA path stored")
 
 
149
  return True
150
  except Exception as e:
151
+ print(f" [WARNING] LoRA failed: {e}")
 
152
  return False
153
 
154
 
155
  def fuse_lora_with_scale(pipe, lora_scale):
156
  """
157
+ Modern approach: Load LoRA and let cross_attention_kwargs apply scale
 
158
  """
159
+ global lora_path_cached
160
 
161
+ if lora_path_cached is None:
 
162
  return False
163
 
164
  try:
165
+ # Unload previous
166
+ try:
167
+ pipe.unload_lora_weights()
168
+ except:
169
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Load LoRA
172
+ print(f" [LORA] Loading with scale {lora_scale}...")
173
+ pipe.load_lora_weights(lora_path_cached)
174
+ print(f" [OK] LoRA loaded (scale will be applied via cross_attention_kwargs)")
175
 
176
+ return True
177
  except Exception as e:
178
+ print(f" [ERROR] LoRA failed: {e}")
 
 
179
  return False
180
 
181
 
 
 
 
 
 
182
  def setup_compel(pipe):
183
+ """examplewithface.py line 145"""
184
  print("Setting up Compel...")
185
  try:
186
  compel = Compel(
 
189
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
190
  requires_pooled=[False, True]
191
  )
192
+ print(" [OK] Compel ready")
193
  return compel, True
194
  except Exception as e:
195
  print(f" [WARNING] Compel unavailable: {e}")
 
197
 
198
 
199
  def setup_scheduler(pipe):
 
200
  pass
201
 
202
 
203
  def optimize_pipeline(pipe):
 
204
  if device == "cuda":
205
  try:
206
  pipe.enable_xformers_memory_efficient_attention()
 
215
 
216
 
217
  def load_caption_model():
 
218
  print("Loading caption model...")
219
  try:
220
  from transformers import AutoProcessor, AutoModelForCausalLM
221
  processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
222
  model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu")
223
+ print(" [OK] GIT-Large")
224
  return processor, model, True, 'git'
225
  except:
226
  try:
227
  from transformers import BlipProcessor, BlipForConditionalGeneration
228
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
229
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu")
230
+ print(" [OK] BLIP")
231
  return processor, model, True, 'blip'
232
  except:
233
  return None, None, False, 'none'
234
 
235
 
236
  def set_clip_skip(pipe):
 
237
  if hasattr(pipe, 'text_encoder'):
238
+ print(f" [OK] CLIP skip {CLIP_SKIP}")
239
 
240
 
241
+ __all__ = ['draw_kps', 'fuse_lora_with_scale']
242
 
243
+ print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py")