primerz commited on
Commit
22858c3
·
verified ·
1 Parent(s): fc3355b

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +62 -179
models.py CHANGED
@@ -1,40 +1,33 @@
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
- FIXED VERSION with proper IP-Adapter and BLIP-2 support
4
  """
5
  import torch
6
  import time
7
  import os
8
- import shutil
9
  from diffusers import (
10
- StableDiffusionXLControlNetImg2ImgPipeline,
11
  ControlNetModel,
12
  AutoencoderKL,
13
  LCMScheduler
14
  )
15
- from diffusers.models.attention_processor import AttnProcessor2_0
16
- from transformers import (
17
- CLIPVisionModelWithProjection, CLIPTokenizer,
18
- CLIPTextModel, CLIPTextModelWithProjection
19
- )
20
  from insightface.app import FaceAnalysis
21
  from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector
22
  from huggingface_hub import hf_hub_download, snapshot_download
23
 
24
- # --- START FIX: Import our new Cappella module ---
25
- from cappella import Cappella
 
26
  # --- END FIX ---
27
 
28
- # Use reference implementation's attention processor
29
- from attention_processor import IPAttnProcessor2_0, AttnProcessor
30
- from resampler import Resampler
31
-
32
  from config import (
33
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
34
  FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
35
  )
36
 
37
-
 
 
38
  def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs):
39
  """Download model with retry logic and proper token handling."""
40
  if max_retries is None:
@@ -200,93 +193,67 @@ def load_controlnets():
200
  # Return models, indicating InstantID failure
201
  return controlnet_depth, None, controlnet_openpose, False
202
 
203
-
204
- def load_image_encoder():
205
- """Load CLIP Image Encoder for IP-Adapter."""
206
- print("Loading CLIP Image Encoder for IP-Adapter...")
207
- try:
208
- # --- FIX: Load core models on GPU ---
209
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
210
- "h94/IP-Adapter",
211
- subfolder="models/image_encoder",
212
- torch_dtype=dtype
213
- ).to(device)
214
- print(" [OK] CLIP Image Encoder loaded successfully (on GPU)")
215
- return image_encoder
216
- except Exception as e:
217
- print(f" [ERROR] Could not load image encoder: {e}")
218
- return None
219
-
220
 
221
  def load_sdxl_pipeline(controlnets):
222
  """Load SDXL checkpoint from HuggingFace Hub."""
223
  print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
224
-
225
- # --- START FIX ---
226
- # Load tokenizers and text encoders from the base model first
227
- # This guarantees they exist, even if the single file doesn't have them
228
  print(" Loading base tokenizers and text encoders...")
229
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
230
-
231
- try:
232
- tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
233
- tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2")
234
-
235
- text_encoder = CLIPTextModel.from_pretrained(
236
- BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype
237
- ).to(device)
238
-
239
- text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
240
- BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype
241
- ).to(device)
242
- print(" [OK] Base text/token models loaded")
243
-
244
- except Exception as e:
245
- print(f" [ERROR] Could not load base text models: {e}")
246
- print(" Pipeline will likely fail. Check HF connection/model access.")
247
- # Allow it to continue, but it will likely fail below
248
- tokenizer = None
249
- tokenizer_2 = None
250
- text_encoder = None
251
- text_encoder_2 = None
252
  # --- END FIX ---
253
 
254
  try:
255
  model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
256
 
257
- # --- START FIX ---
258
- # Pass the pre-loaded models to from_single_file
259
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
260
  model_path,
261
  controlnet=controlnets,
262
  torch_dtype=dtype,
263
  use_safetensors=True,
264
-
265
- # Explicitly provide the models
266
  tokenizer=tokenizer,
267
  tokenizer_2=tokenizer_2,
268
  text_encoder=text_encoder,
269
  text_encoder_2=text_encoder_2,
270
-
271
- ).to(device) # This main pipe MUST be on device
272
  # --- END FIX ---
273
-
274
  print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
275
  return pipe, True
276
-
277
  except Exception as e:
278
  print(f" [WARNING] Could not load custom checkpoint: {e}")
279
  print(" Using default SDXL base model")
280
 
281
- # The fallback logic is already correct
282
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
283
  "stabilityai/stable-diffusion-xl-base-1.0",
284
  controlnet=controlnets,
285
  torch_dtype=dtype,
286
- use_safetensors=True
287
- ).to(device) # This main pipe MUST be on device
 
 
 
 
 
 
288
  return pipe, False
289
 
 
290
  def load_loras(pipe):
291
  """Load all LORAs from HuggingFace Hub."""
292
  print("Loading all LORAs from HuggingFace Hub...")
@@ -320,14 +287,11 @@ def load_loras(pipe):
320
  return loaded_loras, success
321
 
322
 
323
- def setup_ip_adapter(pipe, image_encoder):
 
324
  """
325
- Setup IP-Adapter for InstantID face embeddings.
326
- This is CRITICAL for face preservation.
327
  """
328
- if image_encoder is None:
329
- return None, False
330
-
331
  print("Setting up IP-Adapter for InstantID face embeddings...")
332
  try:
333
  # Download InstantID weights
@@ -337,110 +301,35 @@ def setup_ip_adapter(pipe, image_encoder):
337
  repo_type="model"
338
  )
339
 
340
- # Load full state dict
341
- state_dict = torch.load(ip_adapter_path, map_location="cpu")
342
 
343
- # Extract image_proj and ip_adapter weights
344
- image_proj_state_dict = {}
345
- ip_adapter_state_dict = {}
346
-
347
- for key, value in state_dict.items():
348
- if key.startswith("image_proj."):
349
- image_proj_state_dict[key.replace("image_proj.", "")] = value
350
- elif key.startswith("ip_adapter."):
351
- ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
352
-
353
- # Create Resampler with CORRECT parameters
354
- print("Creating Resampler (Perceiver architecture)...")
355
- image_proj_model = Resampler(
356
- dim=1280,
357
- depth=4,
358
- dim_head=64,
359
- heads=20,
360
- num_queries=16,
361
- embedding_dim=512, # CRITICAL: Must match InsightFace embedding size
362
- output_dim=pipe.unet.config.cross_attention_dim,
363
- ff_mult=4
364
- )
365
-
366
- image_proj_model.eval()
367
- image_proj_model = image_proj_model.to(device, dtype=dtype)
368
-
369
- # Load image_proj weights
370
- if image_proj_state_dict:
371
- try:
372
- image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
373
- print(" [OK] Resampler loaded with pretrained weights")
374
- except Exception as e:
375
- print(f" [WARNING] Could not load Resampler weights: {e}")
376
-
377
- # Setup IP-Adapter attention processors
378
- print("Setting up IP-Adapter attention processors...")
379
- attn_procs = {}
380
- num_tokens = 16
381
-
382
- for name in pipe.unet.attn_processors.keys():
383
- cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
384
-
385
- if name.startswith("mid_block"):
386
- hidden_size = pipe.unet.config.block_out_channels[-1]
387
- elif name.startswith("up_blocks"):
388
- block_id = int(name[len("up_blocks.")])
389
- hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
390
- elif name.startswith("down_blocks"):
391
- block_id = int(name[len("down_blocks.")])
392
- hidden_size = pipe.unet.config.block_out_channels[block_id]
393
- else:
394
- hidden_size = pipe.unet.config.block_out_channels[-1]
395
-
396
- if cross_attention_dim is None:
397
- attn_procs[name] = AttnProcessor2_0()
398
- else:
399
- attn_procs[name] = IPAttnProcessor2_0(
400
- hidden_size=hidden_size,
401
- cross_attention_dim=cross_attention_dim,
402
- scale=1.0,
403
- num_tokens=num_tokens
404
- ).to(device, dtype=dtype)
405
-
406
- # Set attention processors
407
- pipe.unet.set_attn_processor(attn_procs)
408
-
409
- # Load IP-Adapter weights
410
- if ip_adapter_state_dict:
411
- try:
412
- ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
413
- ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
414
- print(" [OK] IP-Adapter attention weights loaded")
415
- except Exception as e:
416
- print(f" [WARNING] Could not load IP-Adapter weights: {e}")
417
-
418
- # Store image encoder
419
- pipe.image_encoder = image_encoder
420
-
421
- print(" [OK] IP-Adapter fully loaded with InstantID architecture")
422
- print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
423
- print(f" - Face embeddings: 512D -> 16x{pipe.unet.config.cross_attention_dim}D")
424
-
425
- return image_proj_model, True
426
 
427
  except Exception as e:
428
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
429
  import traceback
430
  traceback.print_exc()
431
  return None, False
 
432
 
433
 
434
- # --- START FIX: Use our new Cappella module ---
435
- def setup_cappella(pipe):
436
- """Setup Cappella for our custom prompt encoding."""
437
- print("Setting up Cappella (custom prompt encoder)...")
438
  try:
439
- cappella = Cappella(pipe, device)
440
- print(" [OK] Cappella loaded successfully.")
441
- return cappella, True
 
 
 
 
 
442
  except Exception as e:
443
- print(f" [WARNING] Cappella not available: {e}")
444
  return None, False
445
  # --- END FIX ---
446
 
@@ -454,10 +343,6 @@ def setup_scheduler(pipe):
454
 
455
  def optimize_pipeline(pipe):
456
  """Apply optimizations to pipeline."""
457
-
458
- # --- FIX: Removed enable_model_cpu_offload() ---
459
-
460
- # Try to enable xformers
461
  if device == "cuda":
462
  try:
463
  pipe.enable_xformers_memory_efficient_attention()
@@ -479,11 +364,10 @@ def load_caption_model():
479
 
480
  print(" Attempting GIT-Large (recommended)...")
481
  caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
482
- # --- FIX: Load on CPU ---
483
  caption_model = AutoModelForCausalLM.from_pretrained(
484
  "microsoft/git-large-coco",
485
  torch_dtype=dtype
486
- ) # .to(device) removed
487
  print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)")
488
  return caption_processor, caption_model, True, 'git'
489
  except Exception as e1:
@@ -495,11 +379,10 @@ def load_caption_model():
495
 
496
  print(" Attempting BLIP base (fallback)...")
497
  caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
498
- # --- FIX: Load on CPU ---
499
  caption_model = BlipForConditionalGeneration.from_pretrained(
500
  "Salesforce/blip-image-captioning-base",
501
  torch_dtype=dtype
502
- ) # .to(device) removed
503
  print(" [OK] BLIP base model loaded (standard captions, on CPU)")
504
  return caption_processor, caption_model, True, 'blip'
505
  except Exception as e2:
@@ -514,4 +397,4 @@ def set_clip_skip(pipe):
514
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
515
 
516
 
517
- print("[OK] Model loading functions ready")
 
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ FIXED VERSION - Uses correct InstantID pipeline and Compel encoder
4
  """
5
  import torch
6
  import time
7
  import os
 
8
  from diffusers import (
 
9
  ControlNetModel,
10
  AutoencoderKL,
11
  LCMScheduler
12
  )
13
+ from transformers import CLIPVisionModelWithProjection
 
 
 
 
14
  from insightface.app import FaceAnalysis
15
  from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector
16
  from huggingface_hub import hf_hub_download, snapshot_download
17
 
18
+ # --- START FIX: Import correct pipeline and Compel ---
19
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
20
+ from compel import Compel, ReturnedEmbeddingsType
21
  # --- END FIX ---
22
 
 
 
 
 
23
  from config import (
24
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
25
  FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
26
  )
27
 
28
+ # (We keep download_model_with_retry, load_face_analysis, load_depth_detector,
29
+ # load_openpose_detector, and load_mediapipe_face_detector as they were)
30
+ # ... (Keep all original functions from line 25 down to line 180) ...
31
  def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs):
32
  """Download model with retry logic and proper token handling."""
33
  if max_retries is None:
 
193
  # Return models, indicating InstantID failure
194
  return controlnet_depth, None, controlnet_openpose, False
195
 
196
+ # --- START: REMOVED load_image_encoder ---
197
+ # (The new pipeline handles this internally)
198
+ # --- END: REMOVED load_image_encoder ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  def load_sdxl_pipeline(controlnets):
201
  """Load SDXL checkpoint from HuggingFace Hub."""
202
  print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
203
+
204
+ # --- START FIX: Load base text models for Compel (from previous fix) ---
 
 
205
  print(" Loading base tokenizers and text encoders...")
206
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
207
+ tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
208
+ tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2")
209
+ text_encoder = CLIPTextModel.from_pretrained(
210
+ BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype
211
+ ).to(device)
212
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
213
+ BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype
214
+ ).to(device)
215
+ print(" [OK] Base text/token models loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # --- END FIX ---
217
 
218
  try:
219
  model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
220
 
221
+ # --- START FIX: Load the CORRECT pipeline ---
222
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
 
223
  model_path,
224
  controlnet=controlnets,
225
  torch_dtype=dtype,
226
  use_safetensors=True,
227
+ # Pass components
 
228
  tokenizer=tokenizer,
229
  tokenizer_2=tokenizer_2,
230
  text_encoder=text_encoder,
231
  text_encoder_2=text_encoder_2,
232
+ ).to(device)
 
233
  # --- END FIX ---
234
+
235
  print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
236
  return pipe, True
 
237
  except Exception as e:
238
  print(f" [WARNING] Could not load custom checkpoint: {e}")
239
  print(" Using default SDXL base model")
240
 
241
+ # --- START FIX: Fallback to the CORRECT pipeline ---
242
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
243
  "stabilityai/stable-diffusion-xl-base-1.0",
244
  controlnet=controlnets,
245
  torch_dtype=dtype,
246
+ use_safetensors=True,
247
+ # Pass components
248
+ tokenizer=tokenizer,
249
+ tokenizer_2=tokenizer_2,
250
+ text_encoder=text_encoder,
251
+ text_encoder_2=text_encoder_2,
252
+ ).to(device)
253
+ # --- END FIX ---
254
  return pipe, False
255
 
256
+
257
  def load_loras(pipe):
258
  """Load all LORAs from HuggingFace Hub."""
259
  print("Loading all LORAs from HuggingFace Hub...")
 
287
  return loaded_loras, success
288
 
289
 
290
+ # --- START FIX: Replace setup_ip_adapter ---
291
+ def setup_ip_adapter(pipe):
292
  """
293
+ Setup IP-Adapter for InstantID face embeddings using the pipeline's method.
 
294
  """
 
 
 
295
  print("Setting up IP-Adapter for InstantID face embeddings...")
296
  try:
297
  # Download InstantID weights
 
301
  repo_type="model"
302
  )
303
 
304
+ # Use the pipeline's built-in loader
305
+ pipe.load_ip_adapter_instantid(ip_adapter_path)
306
 
307
+ print(" [OK] IP-Adapter fully loaded via pipeline")
308
+ return None, True # We don't need to return a model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  except Exception as e:
311
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
312
  import traceback
313
  traceback.print_exc()
314
  return None, False
315
+ # --- END FIX ---
316
 
317
 
318
+ # --- START FIX: Replace setup_cappella with setup_compel ---
319
+ def setup_compel(pipe):
320
+ """Setup Compel for robust prompt encoding."""
321
+ print("Setting up Compel (prompt encoder)...")
322
  try:
323
+ compel = Compel(
324
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
325
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
326
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
327
+ requires_pooled=[False, True]
328
+ )
329
+ print(" [OK] Compel loaded successfully.")
330
+ return compel, True
331
  except Exception as e:
332
+ print(f" [WARNING] Compel not available: {e}")
333
  return None, False
334
  # --- END FIX ---
335
 
 
343
 
344
  def optimize_pipeline(pipe):
345
  """Apply optimizations to pipeline."""
 
 
 
 
346
  if device == "cuda":
347
  try:
348
  pipe.enable_xformers_memory_efficient_attention()
 
364
 
365
  print(" Attempting GIT-Large (recommended)...")
366
  caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
 
367
  caption_model = AutoModelForCausalLM.from_pretrained(
368
  "microsoft/git-large-coco",
369
  torch_dtype=dtype
370
+ )
371
  print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)")
372
  return caption_processor, caption_model, True, 'git'
373
  except Exception as e1:
 
379
 
380
  print(" Attempting BLIP base (fallback)...")
381
  caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
382
  caption_model = BlipForConditionalGeneration.from_pretrained(
383
  "Salesforce/blip-image-captioning-base",
384
  torch_dtype=dtype
385
+ )
386
  print(" [OK] BLIP base model loaded (standard captions, on CPU)")
387
  return caption_processor, caption_model, True, 'blip'
388
  except Exception as e2:
 
397
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
398
 
399
 
400
+ print("[OK] Model loading functions ready")