primerz commited on
Commit
f4d10c2
·
verified ·
1 Parent(s): a6caf7b

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +100 -54
models.py CHANGED
@@ -5,6 +5,7 @@ FIXED VERSION with proper IP-Adapter and BLIP-2 support
5
  import torch
6
  import time
7
  import os
 
8
  from diffusers import (
9
  StableDiffusionXLControlNetImg2ImgPipeline,
10
  ControlNetModel,
@@ -14,8 +15,8 @@ from diffusers import (
14
  from diffusers.models.attention_processor import AttnProcessor2_0
15
  from transformers import CLIPVisionModelWithProjection
16
  from insightface.app import FaceAnalysis
17
- from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector
18
- from huggingface_hub import hf_hub_download, snapshot_download
19
  from compel import Compel, ReturnedEmbeddingsType
20
 
21
  # Use reference implementation's attention processor
@@ -28,26 +29,24 @@ from config import (
28
  )
29
 
30
 
31
- def download_model_with_retry(repo_id, filename, max_retries=None):
32
  """Download model with retry logic and proper token handling."""
33
  if max_retries is None:
34
  max_retries = DOWNLOAD_CONFIG['max_retries']
35
 
 
 
 
 
36
  for attempt in range(max_retries):
37
  try:
38
  print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
39
 
40
- kwargs = {"repo_type": "model"}
41
- if HUGGINGFACE_TOKEN:
42
- kwargs["token"] = HUGGINGFACE_TOKEN
43
-
44
- path = hf_hub_download(
45
  repo_id=repo_id,
46
  filename=filename,
47
  **kwargs
48
  )
49
- print(f" [OK] Downloaded: {filename}")
50
- return path
51
 
52
  except Exception as e:
53
  print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
@@ -63,46 +62,90 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
63
 
64
 
65
  def load_face_analysis():
66
- """Load face analysis model with proper error handling."""
 
 
 
 
67
  print("Loading face analysis model...")
68
 
69
- model_name = FACE_DETECTION_CONFIG['model_name'] # "antelopev2"
70
- local_model_root = '.' # We want files to be in ./antelopev2
 
71
  local_model_path = os.path.join(local_model_root, model_name)
72
 
73
  try:
74
- # --- NEW: Download models from HF Hub ---
75
- print(f" Ensuring insightface models are downloaded to {local_model_path}...")
 
76
 
77
- # Check if files exist before downloading
78
- required_files = ["1k3d68.onnx", "2d106det.onnx", "genderage.onnx", "glintr100.onnx", "scrfd_10g_bnkps.onnx"]
79
- files_exist = all(os.path.exists(os.path.join(local_model_path, f)) for f in required_files)
 
 
 
 
80
 
81
- if not files_exist:
82
- print(f" Models not found locally, downloading from {MODEL_REPO}...")
83
- try:
84
- snapshot_kwargs = {
85
- "repo_id": MODEL_REPO,
86
- "allow_patterns": [f"{model_name}/*.onnx"],
87
- "local_dir": local_model_root,
88
- "local_dir_use_symlinks": False,
89
- "repo_type": "model"
90
- }
91
- if HUGGINGFACE_TOKEN:
92
- snapshot_kwargs["token"] = HUGGINGFACE_TOKEN
93
 
94
- snapshot_download(**snapshot_kwargs)
95
- print(" [OK] Downloaded insightface models.")
96
- except Exception as e:
97
- print(f" [ERROR] Failed to download insightface models: {e}")
98
- raise # Re-raise exception to be caught by the outer try-except
99
- else:
100
- print(" [OK] Insightface models found locally.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # --- END NEW ---
102
 
103
  face_app = FaceAnalysis(
104
- name=model_name,
105
- root=local_model_root, # This will look in ./antelopev2
106
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
107
  )
108
  face_app.prepare(
@@ -114,6 +157,8 @@ def load_face_analysis():
114
 
115
  except Exception as e:
116
  print(f" [WARNING] Face detection not available: {e}")
 
 
117
  return None, False
118
 
119
 
@@ -242,11 +287,11 @@ def load_sdxl_pipeline(controlnets):
242
  """Load SDXL checkpoint from HuggingFace Hub."""
243
  print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
244
  try:
245
- model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
246
 
247
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
248
  model_path,
249
- controlnet=controlnets, # Pass the list of 3 controlnets
250
  torch_dtype=dtype,
251
  use_safetensors=True
252
  ).to(device)
@@ -257,7 +302,7 @@ def load_sdxl_pipeline(controlnets):
257
  print(" Using default SDXL base model")
258
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
259
  "stabilityai/stable-diffusion-xl-base-1.0",
260
- controlnet=controlnets, # Pass the list of 3 controlnets
261
  torch_dtype=dtype,
262
  use_safetensors=True
263
  ).to(device)
@@ -282,7 +327,7 @@ def load_loras(pipe):
282
  continue
283
 
284
  try:
285
- lora_path = download_model_with_retry(MODEL_REPO, filename)
286
  pipe.load_lora_weights(lora_path, adapter_name=adapter_name)
287
  print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'")
288
  loaded_loras[adapter_name] = True
@@ -310,7 +355,8 @@ def setup_ip_adapter(pipe, image_encoder):
310
  # Download InstantID weights
311
  ip_adapter_path = download_model_with_retry(
312
  "InstantX/InstantID",
313
- "ip-adapter.bin"
 
314
  )
315
 
316
  # Load full state dict
@@ -329,14 +375,14 @@ def setup_ip_adapter(pipe, image_encoder):
329
  # Create Resampler (image projection model) with CORRECT parameters from reference
330
  print("Creating Resampler (Perceiver architecture)...")
331
  image_proj_model = Resampler(
332
- dim=1280, # Hidden dimension
333
- depth=4, # IMPORTANT: 4 layers (not 8!)
334
- dim_head=64, # Dimension per head
335
- heads=20, # Number of heads
336
- num_queries=16, # Number of output tokens
337
- embedding_dim=512, # InsightFace embedding dim
338
- output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048)
339
- ff_mult=4 # Feedforward multiplier
340
  )
341
 
342
  image_proj_model.eval()
@@ -356,7 +402,7 @@ def setup_ip_adapter(pipe, image_encoder):
356
  # Setup IP-Adapter attention processors
357
  print("Setting up IP-Adapter attention processors...")
358
  attn_procs = {}
359
- num_tokens = 16 # Match Resampler num_queries
360
 
361
  for name in pipe.unet.attn_processors.keys():
362
  cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
@@ -444,7 +490,7 @@ def optimize_pipeline(pipe):
444
  pipe.enable_xformers_memory_efficient_attention()
445
  print(" [OK] xformers enabled")
446
  except Exception as e:
447
- print(f" [INFO] xformers not available: {e}")
448
 
449
 
450
  def load_caption_model():
 
5
  import torch
6
  import time
7
  import os
8
+ import shutil
9
  from diffusers import (
10
  StableDiffusionXLControlNetImg2ImgPipeline,
11
  ControlNetModel,
 
15
  from diffusers.models.attention_processor import AttnProcessor2_0
16
  from transformers import CLIPVisionModelWithProjection
17
  from insightface.app import FaceAnalysis
18
+ from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MedipeFaceDetector
19
+ from huggingface_hub import hf_hub_download, HfHubDownloadConfig
20
  from compel import Compel, ReturnedEmbeddingsType
21
 
22
  # Use reference implementation's attention processor
 
29
  )
30
 
31
 
32
+ def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs):
33
  """Download model with retry logic and proper token handling."""
34
  if max_retries is None:
35
  max_retries = DOWNLOAD_CONFIG['max_retries']
36
 
37
+ # Ensure token is passed if available
38
+ if HUGGINGFACE_TOKEN and "token" not in kwargs:
39
+ kwargs["token"] = HUGGINGFACE_TOKEN
40
+
41
  for attempt in range(max_retries):
42
  try:
43
  print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
44
 
45
+ return hf_hub_download(
 
 
 
 
46
  repo_id=repo_id,
47
  filename=filename,
48
  **kwargs
49
  )
 
 
50
 
51
  except Exception as e:
52
  print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
 
62
 
63
 
64
  def load_face_analysis():
65
+ """
66
+ Load face analysis model with proper error handling.
67
+ This version downloads files manually to a custom folder
68
+ to bypass the insightface hard-coded zip download.
69
+ """
70
  print("Loading face analysis model...")
71
 
72
+ # Use a custom model name to prevent insightface from auto-downloading a zip
73
+ model_name = FACE_DETECTION_CONFIG['model_name'] # "pixagram_face_models"
74
+ local_model_root = '.' # We want files to be in ./pixagram_face_models
75
  local_model_path = os.path.join(local_model_root, model_name)
76
 
77
  try:
78
+ # --- NEW: Manual download logic ---
79
+ print(f" Ensuring insightface models are present in {local_model_path}...")
80
+ os.makedirs(local_model_path, exist_ok=True)
81
 
82
+ required_files = [
83
+ "1k3d68.onnx",
84
+ "2d106det.onnx",
85
+ "genderage.onnx",
86
+ "glintr100.onnx",
87
+ "scrfd_10g_bnkps.onnx"
88
+ ]
89
 
90
+ # Download config to control download location
91
+ download_config = HfHubDownloadConfig(
92
+ local_dir=local_model_path,
93
+ local_dir_use_symlinks=False,
94
+ resume_download=True
95
+ )
96
+
97
+ for file_name in required_files:
98
+ local_file_path = os.path.join(local_model_path, file_name)
99
+
100
+ if not os.path.exists(local_file_path):
101
+ print(f" Downloading {file_name}...")
102
 
103
+ # Path to the file in the HF model repo
104
+ repo_file_path = f"antelopev2/{file_name}"
105
+
106
+ try:
107
+ # Download the file directly into our target folder
108
+ downloaded_path = download_model_with_retry(
109
+ repo_id=MODEL_REPO,
110
+ filename=repo_file_path,
111
+ local_dir=local_model_path,
112
+ local_dir_use_symlinks=False,
113
+ resume_download=True,
114
+ repo_type="model"
115
+ )
116
+
117
+ # hf_hub_download *might* preserve folder structure,
118
+ # e.g., saving to ./pixagram_face_models/antelopev2/genderage.onnx
119
+ # We must move it if that happens.
120
+
121
+ expected_download_path = os.path.join(local_model_path, *repo_file_path.split('/'))
122
+
123
+ if os.path.exists(expected_download_path) and expected_download_path != local_file_path:
124
+ print(f" Moving {expected_download_path} to {local_file_path}")
125
+ shutil.move(expected_download_path, local_file_path)
126
+
127
+ # Clean up empty antelopev2 folder if it was created
128
+ try:
129
+ os.rmdir(os.path.dirname(expected_download_path))
130
+ except OSError:
131
+ pass # Not empty, which is fine
132
+
133
+ elif not os.path.exists(local_file_path):
134
+ # Fallback in case logic is wrong, just check the returned path
135
+ if downloaded_path != local_file_path:
136
+ print(f" Moving {downloaded_path} to {local_file_path}")
137
+ shutil.move(downloaded_path, local_file_path)
138
+
139
+ except Exception as e:
140
+ print(f" [ERROR] Failed to download {file_name}: {e}")
141
+ raise # Re-raise to stop startup
142
+
143
+ print(" [OK] All insightface models are present locally.")
144
  # --- END NEW ---
145
 
146
  face_app = FaceAnalysis(
147
+ name=model_name, # "pixagram_face_models" (custom name)
148
+ root=local_model_root, # "." (looks in ./pixagram_face_models)
149
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
150
  )
151
  face_app.prepare(
 
157
 
158
  except Exception as e:
159
  print(f" [WARNING] Face detection not available: {e}")
160
+ import traceback
161
+ traceback.print_exc()
162
  return None, False
163
 
164
 
 
287
  """Load SDXL checkpoint from HuggingFace Hub."""
288
  print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
289
  try:
290
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
291
 
292
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
293
  model_path,
294
+ controlnet=controlnets,
295
  torch_dtype=dtype,
296
  use_safetensors=True
297
  ).to(device)
 
302
  print(" Using default SDXL base model")
303
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
304
  "stabilityai/stable-diffusion-xl-base-1.0",
305
+ controlnet=controlnets,
306
  torch_dtype=dtype,
307
  use_safetensors=True
308
  ).to(device)
 
327
  continue
328
 
329
  try:
330
+ lora_path = download_model_with_retry(MODEL_REPO, filename, repo_type="model")
331
  pipe.load_lora_weights(lora_path, adapter_name=adapter_name)
332
  print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'")
333
  loaded_loras[adapter_name] = True
 
355
  # Download InstantID weights
356
  ip_adapter_path = download_model_with_retry(
357
  "InstantX/InstantID",
358
+ "ip-adapter.bin",
359
+ repo_type="model"
360
  )
361
 
362
  # Load full state dict
 
375
  # Create Resampler (image projection model) with CORRECT parameters from reference
376
  print("Creating Resampler (Perceiver architecture)...")
377
  image_proj_model = Resampler(
378
+ dim=1280,
379
+ depth=4,
380
+ dim_head=64,
381
+ heads=20,
382
+ num_queries=16,
383
+ embedding_dim=512,
384
+ output_dim=pipe.unet.config.cross_attention_dim,
385
+ ff_mult=4
386
  )
387
 
388
  image_proj_model.eval()
 
402
  # Setup IP-Adapter attention processors
403
  print("Setting up IP-Adapter attention processors...")
404
  attn_procs = {}
405
+ num_tokens = 16
406
 
407
  for name in pipe.unet.attn_processors.keys():
408
  cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
 
490
  pipe.enable_xformers_memory_efficient_attention()
491
  print(" [OK] xformers enabled")
492
  except Exception as e:
493
+ print(f" [INFO] xformformers not available: {e}")
494
 
495
 
496
  def load_caption_model():