primerz commited on
Commit
1830448
·
verified ·
1 Parent(s): 09560c2

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +125 -170
models.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- Model loading and initialization for Pixagram AI Pixel Art Generator
3
- FIXED VERSION - Following examplewithface.py EXACTLY for LORA handling
4
  """
5
  import torch
6
  import time
@@ -13,10 +13,9 @@ from diffusers import (
13
  from insightface.app import FaceAnalysis
14
  from controlnet_aux import ZoeDetector
15
  from huggingface_hub import hf_hub_download, snapshot_download
16
- from safetensors.torch import load_file # CRITICAL: For loading state_dict
17
  from compel import Compel, ReturnedEmbeddingsType
18
 
19
- # Use InstantID pipeline
20
  from pipeline_stable_diffusion_xl_instantid_img2img import (
21
  StableDiffusionXLInstantIDImg2ImgPipeline,
22
  draw_kps
@@ -29,94 +28,71 @@ from config import (
29
 
30
 
31
  def download_model_with_retry(repo_id, filename, max_retries=None):
32
- """Download model with retry logic and proper token handling."""
33
  if max_retries is None:
34
  max_retries = DOWNLOAD_CONFIG['max_retries']
35
 
36
  for attempt in range(max_retries):
37
  try:
38
- print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
39
 
40
  kwargs = {"repo_type": "model"}
41
  if HUGGINGFACE_TOKEN:
42
  kwargs["token"] = HUGGINGFACE_TOKEN
43
 
44
- path = hf_hub_download(
45
- repo_id=repo_id,
46
- filename=filename,
47
- **kwargs
48
- )
49
  print(f" [OK] Downloaded: {filename}")
50
  return path
51
 
52
  except Exception as e:
53
- print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
54
-
55
  if attempt < max_retries - 1:
56
  print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
57
  time.sleep(DOWNLOAD_CONFIG['retry_delay'])
58
  else:
59
- print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
60
  raise
61
-
62
  return None
63
 
64
 
65
  def load_face_analysis():
66
- """
67
- Load face analysis model - SIMPLIFIED to match working example
68
- """
69
- print("Loading face analysis model...")
70
-
71
  try:
72
- # Download antelopev2 model files
73
- print(" Downloading antelopev2 model files...")
74
  snapshot_download(
75
  repo_id=FACE_DETECTION_CONFIG['download_repo'],
76
  local_dir=FACE_DETECTION_CONFIG['local_dir']
77
  )
78
  print(" [OK] Antelopev2 downloaded")
79
 
80
- # Initialize FaceAnalysis (like examplewithface.py line 113)
81
- face_app = FaceAnalysis(
82
- name='antelopev2',
83
- root='/data',
84
- providers=['CPUExecutionProvider']
85
- )
86
-
87
- # Prepare the model (like examplewithface.py line 114)
88
- face_app.prepare(ctx_id=0, det_size=(640, 640))
89
-
90
- print(f" [OK] Face analysis model loaded successfully")
91
- return face_app, True
92
 
 
 
93
  except Exception as e:
94
- print(f" [ERROR] Face analysis loading failed: {e}")
95
- import traceback
96
- traceback.print_exc()
97
  return None, False
98
 
99
 
100
  def load_depth_detector():
101
- """Load Zoe Depth detector (like examplewithface.py line 151)"""
102
- print("Loading Zoe Depth detector...")
103
  try:
104
- zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
105
- # Start on CPU to save memory
106
- zoe_depth = zoe_depth.to("cpu")
107
- print(" [OK] Zoe Depth loaded (on CPU, will move to GPU when needed)")
108
- return zoe_depth, True
109
  except Exception as e:
110
- print(f" [WARNING] Zoe Depth not available: {e}")
111
  return None, False
112
 
113
 
114
  def load_controlnets():
115
- """
116
- Load ControlNets for InstantID pipeline
117
- Following examplewithface.py lines 122-126
118
- """
119
- print("Loading InstantID ControlNet...")
120
  identitynet = ControlNetModel.from_pretrained(
121
  "InstantX/InstantID",
122
  subfolder="ControlNetModel",
@@ -124,7 +100,6 @@ def load_controlnets():
124
  )
125
  print(" [OK] InstantID ControlNet loaded")
126
 
127
- print("Loading Zoe Depth ControlNet...")
128
  zoedepthnet = ControlNetModel.from_pretrained(
129
  "diffusers/controlnet-zoe-depth-sdxl-1.0",
130
  torch_dtype=dtype
@@ -136,133 +111,132 @@ def load_controlnets():
136
 
137
  def load_sdxl_pipeline(controlnets):
138
  """
139
- Load SDXL pipeline with InstantID support
140
- Following examplewithface.py lines 128-145 EXACTLY
141
  """
142
- print("Loading SDXL checkpoint with InstantID pipeline...")
143
 
144
- try:
145
- # Load VAE first (line 128)
146
- print("Loading FP16-fixed VAE...")
147
- vae = AutoencoderKL.from_pretrained(
148
- "madebyollin/sdxl-vae-fp16-fix",
149
- torch_dtype=dtype
150
- )
151
- print(" [OK] VAE loaded")
152
-
153
- # Create pipeline with direct controlnet list (line 134)
154
- print("Creating InstantID pipeline...")
155
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
156
- "frankjoshua/albedobaseXL_v21",
157
- vae=vae,
158
- controlnet=controlnets, # Pass list directly [identitynet, zoedepthnet]
159
- torch_dtype=dtype
160
- )
161
-
162
- # Use LCM scheduler (user wants LCM, not DPM)
163
- print("Setting up LCM scheduler...")
164
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
165
- print(" [OK] LCM scheduler configured")
166
-
167
- # Load IP-Adapter (line 139)
168
- print("Loading IP-Adapter for InstantID...")
169
- ip_adapter_path = download_model_with_retry(
170
- "InstantX/InstantID",
171
- "ip-adapter.bin"
172
- )
173
- pipe.load_ip_adapter_instantid(ip_adapter_path)
174
-
175
- # Set initial scale (line 140)
176
- pipe.set_ip_adapter_scale(0.8)
177
- print(" [OK] IP-Adapter loaded with scale 0.8")
178
-
179
- # Move to device
180
- pipe = pipe.to(device)
181
-
182
- print(" [OK] InstantID pipeline loaded (following examplewithface.py)")
183
- return pipe, True
184
-
185
- except Exception as e:
186
- print(f" [ERROR] Could not load InstantID pipeline: {e}")
187
- import traceback
188
- traceback.print_exc()
189
- raise
190
 
191
 
192
- # Global variable to track LORA state (like examplewithface.py lines 158-159)
193
- last_lora_fused = False
194
  loaded_lora_state_dict = None
 
 
195
 
196
 
197
  def load_lora(pipe):
198
  """
199
- Load LORA state_dict into memory
200
- Following examplewithface.py lines 72-83 pattern
201
  """
202
- print("Loading LORA state_dict from HuggingFace Hub...")
203
  global loaded_lora_state_dict
204
 
205
  try:
206
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
207
 
208
- # Load state_dict like the working example (line 78)
209
  if lora_path.endswith('.safetensors'):
210
  loaded_lora_state_dict = load_file(lora_path)
211
  else:
212
  loaded_lora_state_dict = torch.load(lora_path)
213
 
214
- print(f" [OK] LORA state_dict loaded (will be fused before generation)")
215
  return True
216
-
217
  except Exception as e:
218
- print(f" [WARNING] Could not load LORA: {e}")
219
  loaded_lora_state_dict = None
220
  return False
221
 
222
 
223
  def fuse_lora_with_scale(pipe, lora_scale):
224
  """
225
- Fuse LORA with scale before generation
226
- Following examplewithface.py lines 256-271 pattern EXACTLY
 
 
227
  """
228
- global last_lora_fused, loaded_lora_state_dict
229
 
230
  if loaded_lora_state_dict is None:
231
- print(" [WARNING] No LORA state_dict loaded")
232
  return False
233
 
234
  try:
235
- # Unfuse if already fused (line 257-264)
236
- if last_lora_fused:
237
- print(f" [LORA] Unfusing previous...")
238
- pipe.unfuse_lora()
239
- pipe.unload_lora_weights()
 
 
 
 
 
 
240
 
241
- # Load and fuse with scale (lines 266-267)
242
- print(f" [LORA] Loading state_dict...")
243
- pipe.load_lora_weights(loaded_lora_state_dict)
244
 
245
- print(f" [LORA] Fusing with scale: {lora_scale}")
246
- pipe.fuse_lora(lora_scale)
 
 
 
 
 
247
 
248
- last_lora_fused = True
249
- print(f" [OK] LORA fused successfully")
 
250
 
 
 
251
  return True
252
 
253
  except Exception as e:
254
- print(f" [WARNING] Could not fuse LORA: {e}")
255
  import traceback
256
  traceback.print_exc()
257
  return False
258
 
259
 
260
  def setup_compel(pipe):
261
- """
262
- Setup Compel for SDXL prompt handling
263
- Following examplewithface.py line 145 pattern
264
- """
265
- print("Setting up Compel for enhanced prompt processing...")
266
  try:
267
  compel = Compel(
268
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
@@ -270,78 +244,59 @@ def setup_compel(pipe):
270
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
271
  requires_pooled=[False, True]
272
  )
273
- print(" [OK] Compel loaded successfully")
274
  return compel, True
275
  except Exception as e:
276
- print(f" [WARNING] Compel not available: {e}")
277
  return None, False
278
 
279
 
280
  def setup_scheduler(pipe):
281
- """Setup scheduler - already done in load_sdxl_pipeline()"""
282
  pass
283
 
284
 
285
  def optimize_pipeline(pipe):
286
- """Apply optimizations to pipeline"""
287
  if device == "cuda":
288
  try:
289
  pipe.enable_xformers_memory_efficient_attention()
290
  print(" [OK] xformers enabled")
291
- except Exception as e:
292
- print(f" [INFO] xformers not available: {e}")
293
 
294
- # VAE optimizations
295
  if hasattr(pipe, 'enable_vae_slicing'):
296
  pipe.enable_vae_slicing()
297
- print(" [OK] VAE slicing enabled")
298
-
299
  if hasattr(pipe, 'enable_vae_tiling'):
300
  pipe.enable_vae_tiling()
301
- print(" [OK] VAE tiling enabled")
302
 
303
 
304
  def load_caption_model():
305
- """Load caption model with proper error handling"""
306
  print("Loading caption model...")
307
-
308
  try:
309
  from transformers import AutoProcessor, AutoModelForCausalLM
310
-
311
- print(" Attempting GIT-Large...")
312
- caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
313
- caption_model = AutoModelForCausalLM.from_pretrained(
314
- "microsoft/git-large-coco",
315
- torch_dtype=dtype
316
- ).to("cpu")
317
- print(" [OK] GIT-Large model loaded")
318
- return caption_processor, caption_model, True, 'git'
319
- except Exception as e1:
320
- print(f" [INFO] GIT-Large not available: {e1}")
321
-
322
  try:
323
  from transformers import BlipProcessor, BlipForConditionalGeneration
324
-
325
- print(" Attempting BLIP base...")
326
- caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
327
- caption_model = BlipForConditionalGeneration.from_pretrained(
328
- "Salesforce/blip-image-captioning-base",
329
- torch_dtype=dtype
330
- ).to("cpu")
331
- print(" [OK] BLIP base model loaded")
332
- return caption_processor, caption_model, True, 'blip'
333
- except Exception as e2:
334
- print(f" [WARNING] Caption models not available: {e2}")
335
  return None, None, False, 'none'
336
 
337
 
338
  def set_clip_skip(pipe):
339
- """Set CLIP skip value"""
340
  if hasattr(pipe, 'text_encoder'):
341
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
342
 
343
 
344
- # Export for use in generator
345
- __all__ = ['draw_kps', 'fuse_lora_with_scale', 'loaded_lora_state_dict', 'last_lora_fused']
346
 
347
- print("[OK] Model loading functions ready (following examplewithface.py EXACTLY)")
 
1
  """
2
+ Model loading for Pixagram - Following examplewithface.py EXACTLY
3
+ Fixed for modern diffusers API (no scale argument to fuse_lora)
4
  """
5
  import torch
6
  import time
 
13
  from insightface.app import FaceAnalysis
14
  from controlnet_aux import ZoeDetector
15
  from huggingface_hub import hf_hub_download, snapshot_download
16
+ from safetensors.torch import load_file
17
  from compel import Compel, ReturnedEmbeddingsType
18
 
 
19
  from pipeline_stable_diffusion_xl_instantid_img2img import (
20
  StableDiffusionXLInstantIDImg2ImgPipeline,
21
  draw_kps
 
28
 
29
 
30
  def download_model_with_retry(repo_id, filename, max_retries=None):
31
+ """Download model with retry logic"""
32
  if max_retries is None:
33
  max_retries = DOWNLOAD_CONFIG['max_retries']
34
 
35
  for attempt in range(max_retries):
36
  try:
37
+ print(f" Attempting download {filename} (attempt {attempt + 1}/{max_retries})...")
38
 
39
  kwargs = {"repo_type": "model"}
40
  if HUGGINGFACE_TOKEN:
41
  kwargs["token"] = HUGGINGFACE_TOKEN
42
 
43
+ path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
 
 
 
 
44
  print(f" [OK] Downloaded: {filename}")
45
  return path
46
 
47
  except Exception as e:
48
+ print(f" [WARNING] Attempt {attempt + 1} failed: {e}")
 
49
  if attempt < max_retries - 1:
50
  print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
51
  time.sleep(DOWNLOAD_CONFIG['retry_delay'])
52
  else:
53
+ print(f" [ERROR] Failed after {max_retries} attempts")
54
  raise
 
55
  return None
56
 
57
 
58
  def load_face_analysis():
59
+ """Load face analysis - simplified to match examplewithface.py line 113"""
60
+ print("Loading face analysis...")
 
 
 
61
  try:
 
 
62
  snapshot_download(
63
  repo_id=FACE_DETECTION_CONFIG['download_repo'],
64
  local_dir=FACE_DETECTION_CONFIG['local_dir']
65
  )
66
  print(" [OK] Antelopev2 downloaded")
67
 
68
+ # Like examplewithface.py line 113
69
+ app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
70
+ app.prepare(ctx_id=0, det_size=(640, 640))
 
 
 
 
 
 
 
 
 
71
 
72
+ print(" [OK] Face analysis loaded")
73
+ return app, True
74
  except Exception as e:
75
+ print(f" [ERROR] Face analysis failed: {e}")
 
 
76
  return None, False
77
 
78
 
79
  def load_depth_detector():
80
+ """Load Zoe Depth - examplewithface.py line 151"""
81
+ print("Loading Zoe Depth...")
82
  try:
83
+ zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
84
+ zoe = zoe.to("cpu") # Start on CPU
85
+ print(" [OK] Zoe Depth loaded")
86
+ return zoe, True
 
87
  except Exception as e:
88
+ print(f" [WARNING] Zoe Depth unavailable: {e}")
89
  return None, False
90
 
91
 
92
  def load_controlnets():
93
+ """Load ControlNets - examplewithface.py lines 122-126"""
94
+ print("Loading ControlNets...")
95
+
 
 
96
  identitynet = ControlNetModel.from_pretrained(
97
  "InstantX/InstantID",
98
  subfolder="ControlNetModel",
 
100
  )
101
  print(" [OK] InstantID ControlNet loaded")
102
 
 
103
  zoedepthnet = ControlNetModel.from_pretrained(
104
  "diffusers/controlnet-zoe-depth-sdxl-1.0",
105
  torch_dtype=dtype
 
111
 
112
  def load_sdxl_pipeline(controlnets):
113
  """
114
+ Load pipeline - examplewithface.py lines 128-145
115
+ KEY: Pass controlnets as LIST directly, NO wrapper
116
  """
117
+ print("Loading SDXL pipeline...")
118
 
119
+ # Load VAE (line 128)
120
+ print(" Loading VAE...")
121
+ vae = AutoencoderKL.from_pretrained(
122
+ "madebyollin/sdxl-vae-fp16-fix",
123
+ torch_dtype=dtype
124
+ )
125
+ print(" [OK] VAE loaded")
126
+
127
+ # Load pipeline (line 134) - pass controlnets as list directly!
128
+ print(" Creating pipeline...")
129
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
130
+ "frankjoshua/albedobaseXL_v21",
131
+ vae=vae,
132
+ controlnet=controlnets, # Direct list [identitynet, zoedepthnet]
133
+ torch_dtype=dtype
134
+ )
135
+
136
+ # Setup LCM scheduler (USER WANTS LCM, not DPM!)
137
+ print(" Setting up LCM scheduler...")
138
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
139
+
140
+ # Load IP-Adapter (line 139)
141
+ print(" Loading IP-Adapter...")
142
+ ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
143
+ pipe.load_ip_adapter_instantid(ip_adapter_path)
144
+ pipe.set_ip_adapter_scale(0.8) # Default scale (line 140)
145
+
146
+ # Move to device
147
+ pipe = pipe.to(device)
148
+
149
+ print(" [OK] Pipeline loaded (following examplewithface.py)")
150
+ return pipe, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
+ # Global LoRA state (examplewithface.py lines 158-159, 243)
 
154
  loaded_lora_state_dict = None
155
+ last_lora = ""
156
+ last_fused = False
157
 
158
 
159
  def load_lora(pipe):
160
  """
161
+ Load LoRA state_dict - examplewithface.py lines 72-83
162
+ KEY: Load as state_dict, NOT path!
163
  """
164
+ print("Loading LoRA state_dict...")
165
  global loaded_lora_state_dict
166
 
167
  try:
168
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
169
 
170
+ # Load state_dict (line 78)
171
  if lora_path.endswith('.safetensors'):
172
  loaded_lora_state_dict = load_file(lora_path)
173
  else:
174
  loaded_lora_state_dict = torch.load(lora_path)
175
 
176
+ print(" [OK] LoRA state_dict loaded")
177
  return True
 
178
  except Exception as e:
179
+ print(f" [WARNING] LoRA load failed: {e}")
180
  loaded_lora_state_dict = None
181
  return False
182
 
183
 
184
  def fuse_lora_with_scale(pipe, lora_scale):
185
  """
186
+ Fuse LoRA with scale - Modern diffusers API
187
+
188
+ examplewithface.py calls fuse_lora(lora_scale) but that's old API.
189
+ Modern API: load → set_adapters → fuse
190
  """
191
+ global last_fused, loaded_lora_state_dict
192
 
193
  if loaded_lora_state_dict is None:
194
+ print(" [WARNING] No LoRA state_dict available")
195
  return False
196
 
197
  try:
198
+ # Unfuse if needed
199
+ if last_fused:
200
+ print(" [LORA] Unfusing previous...")
201
+ try:
202
+ pipe.unfuse_lora()
203
+ except:
204
+ pass
205
+ try:
206
+ pipe.unload_lora_weights()
207
+ except:
208
+ pass
209
 
210
+ # Load state_dict with adapter name
211
+ print(" [LORA] Loading state_dict...")
212
+ pipe.load_lora_weights(loaded_lora_state_dict, adapter_name="pixel_lora")
213
 
214
+ # Set scale using modern API
215
+ print(f" [LORA] Setting scale to {lora_scale}...")
216
+ try:
217
+ pipe.set_adapters(["pixel_lora"], adapter_weights=[lora_scale])
218
+ except AttributeError:
219
+ # If set_adapters doesn't exist, scale will be 1.0
220
+ print(" [INFO] set_adapters not available, using scale 1.0")
221
 
222
+ # Fuse - NO scale argument
223
+ print(f" [LORA] Fusing...")
224
+ pipe.fuse_lora()
225
 
226
+ last_fused = True
227
+ print(f" [OK] LoRA fused with scale {lora_scale}")
228
  return True
229
 
230
  except Exception as e:
231
+ print(f" [ERROR] LoRA fusion failed: {e}")
232
  import traceback
233
  traceback.print_exc()
234
  return False
235
 
236
 
237
  def setup_compel(pipe):
238
+ """Setup Compel - examplewithface.py line 145"""
239
+ print("Setting up Compel...")
 
 
 
240
  try:
241
  compel = Compel(
242
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
 
244
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
245
  requires_pooled=[False, True]
246
  )
247
+ print(" [OK] Compel loaded")
248
  return compel, True
249
  except Exception as e:
250
+ print(f" [WARNING] Compel unavailable: {e}")
251
  return None, False
252
 
253
 
254
  def setup_scheduler(pipe):
255
+ """Already done in load_sdxl_pipeline"""
256
  pass
257
 
258
 
259
  def optimize_pipeline(pipe):
260
+ """Apply optimizations"""
261
  if device == "cuda":
262
  try:
263
  pipe.enable_xformers_memory_efficient_attention()
264
  print(" [OK] xformers enabled")
265
+ except:
266
+ pass
267
 
 
268
  if hasattr(pipe, 'enable_vae_slicing'):
269
  pipe.enable_vae_slicing()
 
 
270
  if hasattr(pipe, 'enable_vae_tiling'):
271
  pipe.enable_vae_tiling()
 
272
 
273
 
274
  def load_caption_model():
275
+ """Load caption model"""
276
  print("Loading caption model...")
 
277
  try:
278
  from transformers import AutoProcessor, AutoModelForCausalLM
279
+ processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
280
+ model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu")
281
+ print(" [OK] GIT-Large loaded")
282
+ return processor, model, True, 'git'
283
+ except:
 
 
 
 
 
 
 
284
  try:
285
  from transformers import BlipProcessor, BlipForConditionalGeneration
286
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
287
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu")
288
+ print(" [OK] BLIP loaded")
289
+ return processor, model, True, 'blip'
290
+ except:
 
 
 
 
 
 
291
  return None, None, False, 'none'
292
 
293
 
294
  def set_clip_skip(pipe):
295
+ """Set CLIP skip"""
296
  if hasattr(pipe, 'text_encoder'):
297
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
298
 
299
 
300
+ __all__ = ['draw_kps', 'fuse_lora_with_scale', 'loaded_lora_state_dict', 'last_fused']
 
301
 
302
+ print("[OK] Models ready (examplewithface.py pattern + modern diffusers API)")