primerz commited on
Commit
5d3624b
·
verified ·
1 Parent(s): fae896d

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +105 -20
models.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
- UPDATED VERSION with proper InstantID pipeline support
4
  """
5
  import torch
6
  import time
@@ -12,11 +12,13 @@ from diffusers import (
12
  )
13
  from insightface.app import FaceAnalysis
14
  from controlnet_aux import ZoeDetector
15
- from huggingface_hub import hf_hub_download
16
  from compel import Compel, ReturnedEmbeddingsType
17
 
18
  # Use InstantID pipeline
19
- from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
 
 
20
 
21
  from config import (
22
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
@@ -59,18 +61,79 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
59
 
60
 
61
  def load_face_analysis():
62
- """Load face analysis model with intelligent provider selection."""
 
 
 
63
  print("Loading face analysis model...")
64
- face_app = FaceAnalysis(
65
- name=FACE_DETECTION_CONFIG['model_name'],
66
- root='/data',
67
- providers=['CPUExecutionProvider']
68
- )
69
- face_app.prepare(
70
- ctx_id=ctx_id,
71
- det_size=FACE_DETECTION_CONFIG['det_size']
72
- )
73
- return face_app, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def load_depth_detector():
76
  """Load Zoe Depth detector with optimized memory management."""
@@ -110,23 +173,44 @@ def load_controlnets():
110
 
111
 
112
  def load_sdxl_pipeline(controlnets):
113
- """Load SDXL checkpoint from HuggingFace Hub."""
114
- print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
 
 
 
115
  try:
116
  model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
117
 
 
118
  pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
119
  model_path,
120
  controlnet=controlnets,
121
  torch_dtype=dtype,
122
  use_safetensors=True
123
  ).to(device)
124
- print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
 
 
 
 
 
 
 
 
 
 
 
125
  return pipe, True
 
126
  except Exception as e:
127
- print(f" [WARNING] Could not load custom checkpoint: {e}")
128
- print(" Using default SDXL base model")
129
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
 
 
 
 
 
130
  "stabilityai/stable-diffusion-xl-base-1.0",
131
  controlnet=controlnets,
132
  torch_dtype=dtype,
@@ -134,6 +218,7 @@ def load_sdxl_pipeline(controlnets):
134
  ).to(device)
135
  return pipe, False
136
 
 
137
  def load_lora(pipe):
138
  """Load LORA from HuggingFace Hub."""
139
  print("Loading LORA (retroart) from HuggingFace Hub...")
 
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ CORRECTED VERSION with proper face analysis loading
4
  """
5
  import torch
6
  import time
 
12
  )
13
  from insightface.app import FaceAnalysis
14
  from controlnet_aux import ZoeDetector
15
+ from huggingface_hub import hf_hub_download, snapshot_download
16
  from compel import Compel, ReturnedEmbeddingsType
17
 
18
  # Use InstantID pipeline
19
+ from pipeline_stable_diffusion_xl_instantid_img2img import (
20
+ StableDiffusionXLInstantIDImg2ImgPipeline
21
+ )
22
 
23
  from config import (
24
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
 
61
 
62
 
63
  def load_face_analysis():
64
+ """
65
+ Load face analysis model using the correct approach.
66
+ Downloads antelopev2 model and initializes FaceAnalysis.
67
+ """
68
  print("Loading face analysis model...")
69
+
70
+ try:
71
+ # Download antelopev2 model using snapshot_download (like working example)
72
+ print(" Downloading antelopev2 model files...")
73
+ antelope_path = snapshot_download(
74
+ repo_id=FACE_DETECTION_CONFIG['download_repo'],
75
+ local_dir=FACE_DETECTION_CONFIG['local_dir']
76
+ )
77
+ print(f" [OK] Antelopev2 downloaded to: {antelope_path}")
78
+
79
+ # Initialize FaceAnalysis with the correct root path
80
+ # Use CPU provider for memory efficiency (can be changed in config)
81
+ providers = FACE_DETECTION_CONFIG.get('providers', ['CPUExecutionProvider'])
82
+
83
+ print(f" Initializing FaceAnalysis with providers: {providers}")
84
+ face_app = FaceAnalysis(
85
+ name=FACE_DETECTION_CONFIG['model_name'],
86
+ root=FACE_DETECTION_CONFIG['root'],
87
+ providers=providers
88
+ )
89
+
90
+ # Prepare the model
91
+ face_app.prepare(
92
+ ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
93
+ det_size=FACE_DETECTION_CONFIG['det_size']
94
+ )
95
+
96
+ # Test the model to ensure it works
97
+ import numpy as np
98
+ test_img = np.zeros((640, 640, 3), dtype=np.uint8)
99
+ _ = face_app.get(test_img)
100
+
101
+ print(f" [OK] Face analysis model loaded successfully")
102
+ print(f" [INFO] Using providers: {providers}")
103
+ return face_app, True
104
+
105
+ except Exception as e:
106
+ print(f" [ERROR] Face analysis loading failed: {e}")
107
+ import traceback
108
+ traceback.print_exc()
109
+
110
+ # Try fallback with different providers
111
+ try:
112
+ print(" [INFO] Trying fallback with auto-detect providers...")
113
+ face_app = FaceAnalysis(
114
+ name=FACE_DETECTION_CONFIG['model_name'],
115
+ root=FACE_DETECTION_CONFIG['root']
116
+ )
117
+ face_app.prepare(
118
+ ctx_id=0,
119
+ det_size=FACE_DETECTION_CONFIG['det_size']
120
+ )
121
+
122
+ # Test
123
+ import numpy as np
124
+ test_img = np.zeros((640, 640, 3), dtype=np.uint8)
125
+ _ = face_app.get(test_img)
126
+
127
+ print(" [OK] Face analysis loaded with auto-detect providers")
128
+ return face_app, True
129
+
130
+ except Exception as e2:
131
+ print(f" [WARNING] Face detection not available: {e2}")
132
+ print(" [INFO] Generation will continue without face preservation")
133
+ print(" [TIP] Check that onnxruntime is properly installed:")
134
+ print(" pip install onnxruntime --break-system-packages")
135
+ return None, False
136
+
137
 
138
  def load_depth_detector():
139
  """Load Zoe Depth detector with optimized memory management."""
 
173
 
174
 
175
  def load_sdxl_pipeline(controlnets):
176
+ """
177
+ Load SDXL pipeline with InstantID support.
178
+ controlnets MUST be a list: [identitynet, depthnet]
179
+ """
180
+ print("Loading SDXL checkpoint with InstantID pipeline...")
181
  try:
182
  model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
183
 
184
+ # CRITICAL: Use InstantID-enabled pipeline (not standard ControlNet pipeline)
185
  pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
186
  model_path,
187
  controlnet=controlnets,
188
  torch_dtype=dtype,
189
  use_safetensors=True
190
  ).to(device)
191
+
192
+ # Load IP-Adapter weights for InstantID
193
+ print("Loading IP-Adapter for InstantID...")
194
+ ip_adapter_path = download_model_with_retry(
195
+ "InstantX/InstantID",
196
+ "ip-adapter.bin"
197
+ )
198
+ pipe.load_ip_adapter_instantid(ip_adapter_path)
199
+ # Don't set default scale - will be set dynamically based on face detection
200
+ print(" [OK] IP-Adapter loaded (scale will be set dynamically)")
201
+
202
+ print(" [OK] InstantID pipeline loaded successfully")
203
  return pipe, True
204
+
205
  except Exception as e:
206
+ print(f" [ERROR] Could not load InstantID pipeline: {e}")
207
+ import traceback
208
+ traceback.print_exc()
209
+
210
+ # Fallback to standard pipeline
211
+ print(" [WARNING] Falling back to standard SDXL pipeline (no InstantID)")
212
+ from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
213
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
214
  "stabilityai/stable-diffusion-xl-base-1.0",
215
  controlnet=controlnets,
216
  torch_dtype=dtype,
 
218
  ).to(device)
219
  return pipe, False
220
 
221
+
222
  def load_lora(pipe):
223
  """Load LORA from HuggingFace Hub."""
224
  print("Loading LORA (retroart) from HuggingFace Hub...")