primerz commited on
Commit
70a37ed
·
verified ·
1 Parent(s): 82f7fe1

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +122 -168
models.py CHANGED
@@ -1,25 +1,34 @@
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
- FIXED VERSION with proper IP-Adapter and BLIP-2 support
 
4
  """
5
  import torch
6
  import time
 
7
  from diffusers import (
8
- StableDiffusionXLControlNetImg2ImgPipeline,
9
  ControlNetModel,
10
  AutoencoderKL,
11
- LCMScheduler
 
12
  )
13
  from diffusers.models.attention_processor import AttnProcessor2_0
14
- from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
- from controlnet_aux import ZoeDetector
17
  from huggingface_hub import hf_hub_download
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
20
- # Use reference implementation's attention processor
21
- from attention_processor import IPAttnProcessor2_0, AttnProcessor
22
- from resampler import Resampler
 
 
 
 
 
 
 
23
 
24
  from config import (
25
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
@@ -62,19 +71,19 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
62
 
63
 
64
  def load_face_analysis():
65
- """Load face analysis model with proper error handling."""
66
- print("Loading face analysis model...")
67
  try:
68
  face_app = FaceAnalysis(
69
- name=FACE_DETECTION_CONFIG['model_name'],
70
- root='./models/insightface',
71
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
72
  )
73
  face_app.prepare(
74
- ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
75
- det_size=FACE_DETECTION_CONFIG['det_size']
76
  )
77
- print(" [OK] Face analysis model loaded successfully")
78
  return face_app, True
79
  except Exception as e:
80
  print(f" [WARNING] Face detection not available: {e}")
@@ -82,89 +91,122 @@ def load_face_analysis():
82
 
83
 
84
  def load_depth_detector():
85
- """Load Zoe Depth detector."""
86
- print("Loading Zoe Depth detector...")
87
  try:
88
- zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
89
- zoe_depth.to(device)
90
- print(" [OK] Zoe Depth loaded successfully")
91
- return zoe_depth, True
92
  except Exception as e:
93
- print(f" [WARNING] Zoe Depth not available: {e}")
 
 
 
 
 
 
 
 
 
 
 
94
  return None, False
95
 
96
 
97
  def load_controlnets():
98
- """Load ControlNet models."""
99
- print("Loading ControlNet Zoe Depth model...")
100
  controlnet_depth = ControlNetModel.from_pretrained(
101
- "diffusers/controlnet-zoe-depth-sdxl-1.0",
102
  torch_dtype=dtype
103
  ).to(device)
104
  print(" [OK] ControlNet Depth loaded")
105
 
106
- print("Loading InstantID ControlNet...")
107
  try:
108
- controlnet_instantid = ControlNetModel.from_pretrained(
109
- "InstantX/InstantID",
110
- subfolder="ControlNetModel",
111
  torch_dtype=dtype
112
  ).to(device)
113
- print(" [OK] InstantID ControlNet loaded successfully")
114
- return controlnet_depth, controlnet_instantid, True
115
  except Exception as e:
116
- print(f" [WARNING] InstantID ControlNet not available: {e}")
117
  return controlnet_depth, None, False
118
 
119
 
120
  def load_image_encoder():
121
- """Load CLIP Image Encoder for IP-Adapter."""
122
- print("Loading CLIP Image Encoder for IP-Adapter...")
123
- try:
124
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
125
- "h94/IP-Adapter",
126
- subfolder="models/image_encoder",
127
- torch_dtype=dtype
128
- ).to(device)
129
- print(" [OK] CLIP Image Encoder loaded successfully")
130
- return image_encoder
131
- except Exception as e:
132
- print(f" [ERROR] Could not load image encoder: {e}")
133
- return None
134
 
135
 
136
  def load_sdxl_pipeline(controlnets):
137
- """Load SDXL checkpoint from HuggingFace Hub."""
138
- print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  try:
140
- model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
141
-
142
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
143
- model_path,
144
- controlnet=controlnets,
145
- torch_dtype=dtype,
146
- use_safetensors=True
147
  ).to(device)
148
- print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
149
  return pipe, True
150
  except Exception as e:
151
- print(f" [WARNING] Could not load custom checkpoint: {e}")
152
- print(" Using default SDXL base model")
153
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
154
- "stabilityai/stable-diffusion-xl-base-1.0",
155
- controlnet=controlnets,
156
- torch_dtype=dtype,
157
- use_safetensors=True
158
- ).to(device)
159
- return pipe, False
 
160
 
161
 
162
  def load_lora(pipe):
163
- """Load LORA from HuggingFace Hub."""
164
  print("Loading LORA (retroart) from HuggingFace Hub...")
165
  try:
166
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
167
- # **FIX 2: Add adapter_name="retroart"**
168
  pipe.load_lora_weights(lora_path, adapter_name="retroart")
169
  print(f" [OK] LORA loaded successfully")
170
  return True
@@ -173,113 +215,25 @@ def load_lora(pipe):
173
  return False
174
 
175
 
176
- def setup_ip_adapter(pipe, image_encoder):
177
  """
178
- Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
179
- Based on the reference InstantID pipeline.
180
  """
181
- if image_encoder is None:
182
- return None, False
183
-
184
- print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
185
  try:
186
- # Download InstantID weights
187
- ip_adapter_path = download_model_with_retry(
188
- "InstantX/InstantID",
189
- "ip-adapter.bin"
190
- )
191
-
192
- # Load full state dict
193
- state_dict = torch.load(ip_adapter_path, map_location="cpu")
194
-
195
- # Extract image_proj and ip_adapter weights
196
- image_proj_state_dict = {}
197
- ip_adapter_state_dict = {}
198
-
199
- for key, value in state_dict.items():
200
- if key.startswith("image_proj."):
201
- image_proj_state_dict[key.replace("image_proj.", "")] = value
202
- elif key.startswith("ip_adapter."):
203
- ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
204
-
205
- # Create Resampler (image projection model) with CORRECT parameters from reference
206
- print("Creating Resampler (Perceiver architecture)...")
207
- image_proj_model = Resampler(
208
- dim=1280, # Hidden dimension
209
- depth=4, # IMPORTANT: 4 layers (not 8!)
210
- dim_head=64, # Dimension per head
211
- heads=20, # Number of heads
212
- num_queries=16, # Number of output tokens
213
- embedding_dim=512, # InsightFace embedding dim
214
- output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048)
215
- ff_mult=4 # Feedforward multiplier
216
  )
217
 
218
- image_proj_model.eval()
219
- image_proj_model = image_proj_model.to(device, dtype=dtype)
220
-
221
- # Load image_proj weights
222
- if image_proj_state_dict:
223
- try:
224
- image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
225
- print(" [OK] Resampler loaded with pretrained weights")
226
- except Exception as e:
227
- print(f" [WARNING] Could not load Resampler weights: {e}")
228
- print(" Using randomly initialized Resampler")
229
- else:
230
- print(" [WARNING] No image_proj weights found, using random initialization")
231
-
232
- # Setup IP-Adapter attention processors
233
- print("Setting up IP-Adapter attention processors...")
234
- attn_procs = {}
235
- num_tokens = 16 # Match Resampler num_queries
236
-
237
- for name in pipe.unet.attn_processors.keys():
238
- cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
239
-
240
- if name.startswith("mid_block"):
241
- hidden_size = pipe.unet.config.block_out_channels[-1]
242
- elif name.startswith("up_blocks"):
243
- block_id = int(name[len("up_blocks.")])
244
- hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
245
- elif name.startswith("down_blocks"):
246
- block_id = int(name[len("down_blocks.")])
247
- hidden_size = pipe.unet.config.block_out_channels[block_id]
248
- else:
249
- hidden_size = pipe.unet.config.block_out_channels[-1]
250
-
251
- if cross_attention_dim is None:
252
- attn_procs[name] = AttnProcessor2_0()
253
- else:
254
- attn_procs[name] = IPAttnProcessor2_0(
255
- hidden_size=hidden_size,
256
- cross_attention_dim=cross_attention_dim,
257
- scale=1.0,
258
- num_tokens=num_tokens
259
- ).to(device, dtype=dtype)
260
-
261
- # Set attention processors
262
- pipe.unet.set_attn_processor(attn_procs)
263
-
264
- # Load IP-Adapter weights into attention processors
265
- if ip_adapter_state_dict:
266
- try:
267
- ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
268
- ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
269
- print(" [OK] IP-Adapter attention weights loaded")
270
- except Exception as e:
271
- print(f" [WARNING] Could not load IP-Adapter weights: {e}")
272
- else:
273
- print(" [WARNING] No ip_adapter weights found")
274
-
275
- # Store image encoder and projection model
276
- pipe.image_encoder = image_encoder
277
-
278
- print(" [OK] IP-Adapter fully loaded with InstantID architecture")
279
- print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
280
- print(f" - Face embeddings: 512D → 16x2048D")
281
 
282
- return image_proj_model, True
 
283
 
284
  except Exception as e:
285
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
@@ -369,4 +323,4 @@ def set_clip_skip(pipe):
369
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
370
 
371
 
372
- print("[OK] Model loading functions ready")
 
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ HYBRID VERSION - Supports both local files and HuggingFace repos
4
+ MODIFIED for IP-Adapter-FaceIDXL (non-plus) and LCM Scheduler
5
  """
6
  import torch
7
  import time
8
+ import os
9
  from diffusers import (
 
10
  ControlNetModel,
11
  AutoencoderKL,
12
+ LCMScheduler, # Changed back to LCM
13
+ StableDiffusionXLControlNetImg2ImgPipeline
14
  )
15
  from diffusers.models.attention_processor import AttnProcessor2_0
16
+ from transformers import CLIPVisionModelWithProjection, pipeline
17
  from insightface.app import FaceAnalysis
18
+ from controlnet_aux import LeresDetector, CannyDetector
19
  from huggingface_hub import hf_hub_download
20
  from compel import Compel, ReturnedEmbeddingsType
21
 
22
+ # Import the IP-Adapter wrapper classes
23
+ try:
24
+ # Import base class and the specific SDXL class
25
+ from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDXL
26
+ except ImportError:
27
+ print("="*80)
28
+ print("[FATAL ERROR] `ip_adapter` library not found.")
29
+ print("Please install it: pip install ip-adapter")
30
+ print("="*80)
31
+ raise
32
 
33
  from config import (
34
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
 
71
 
72
 
73
  def load_face_analysis():
74
+ """Load face analysis model (buffalo_l) with proper error handling."""
75
+ print("Loading face analysis model (buffalo_l)...")
76
  try:
77
  face_app = FaceAnalysis(
78
+ name='buffalo_l', # Changed from antelopev2
79
+ root='/data',
80
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
81
  )
82
  face_app.prepare(
83
+ ctx_id=0,
84
+ det_size=(640, 640)
85
  )
86
+ print(" [OK] Face analysis model (buffalo_l) loaded successfully")
87
  return face_app, True
88
  except Exception as e:
89
  print(f" [WARNING] Face detection not available: {e}")
 
91
 
92
 
93
  def load_depth_detector():
94
+ """Load LeReS++ Depth detector."""
95
+ print("Loading LeReS++ detector...")
96
  try:
97
+ leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
98
+ leres.to(device)
99
+ print(" [OK] LeReS++ loaded successfully")
100
+ return leres, True
101
  except Exception as e:
102
+ print(f" [WARNING] LeReS++ not available: {e}")
103
+ return None, False
104
+
105
+ def load_canny_detector():
106
+ """Load Canny detector."""
107
+ print("Loading Canny detector...")
108
+ try:
109
+ canny = CannyDetector()
110
+ print(" [OK] Canny loaded successfully")
111
+ return canny, True
112
+ except Exception as e:
113
+ print(f" [WARNING] Canny detector not available: {e}")
114
  return None, False
115
 
116
 
117
  def load_controlnets():
118
+ """Load ControlNet models for Depth and Canny."""
119
+ print("Loading ControlNet Depth model...")
120
  controlnet_depth = ControlNetModel.from_pretrained(
121
+ "diffusers/controlnet-depth-sdxl-1.0", # Standard depth model
122
  torch_dtype=dtype
123
  ).to(device)
124
  print(" [OK] ControlNet Depth loaded")
125
 
126
+ print("Loading ControlNet Canny model...")
127
  try:
128
+ controlnet_canny = ControlNetModel.from_pretrained(
129
+ "diffusers/controlnet-canny-sdxl-1.0",
 
130
  torch_dtype=dtype
131
  ).to(device)
132
+ print(" [OK] ControlNet Canny loaded successfully")
133
+ return controlnet_depth, controlnet_canny, True
134
  except Exception as e:
135
+ print(f" [WARNING] ControlNet Canny not available: {e}")
136
  return controlnet_depth, None, False
137
 
138
 
139
  def load_image_encoder():
140
+ """
141
+ [DEPRECATED] This function is no longer needed by IPAdapterFaceIDXL,
142
+ but we keep it here in case other components need it.
143
+ It will not be called by the generator.
144
+ """
145
+ print("Loading CLIP Image Encoder [SKIPPED - Not required by IPAdapterFaceIDXL]")
146
+ return None
 
 
 
 
 
 
147
 
148
 
149
  def load_sdxl_pipeline(controlnets):
150
+ """
151
+ Load SDXL checkpoint - MODIFIED for LCM and built-in VAE.
152
+ """
153
+
154
+ # --- VAE LOADING REMOVED ---
155
+ # We are using the VAE built into the "horizon" checkpoint.
156
+ print("Loading SDXL checkpoint (using built-in VAE)...")
157
+
158
+ pipeline_kwargs = {
159
+ "controlnet": controlnets,
160
+ "torch_dtype": dtype,
161
+ "use_safetensors": True,
162
+ # "vae": None, # <--- This line was correctly removed
163
+ }
164
+
165
+ # ATTEMPT 1: Try loading from local file (This should be your "horizon" checkpoint)
166
+ if MODEL_FILES.get('checkpoint'):
167
+ try:
168
+ print(f" [Attempt 1] Loading from local file: {MODEL_FILES['checkpoint']}...")
169
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
170
+
171
+ if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'):
172
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
173
+ model_path,
174
+ **pipeline_kwargs
175
+ ).to(device)
176
+ print(f" [OK] Checkpoint loaded from local file: {model_path}")
177
+ return pipe, True
178
+ else:
179
+ print(f" [INFO] Local file not found or invalid...")
180
+ except Exception as e:
181
+ print(f" [WARNING] from_single_file failed: {e}")
182
+
183
+ # ATTEMPT 2: Try loading from HuggingFace repo
184
  try:
185
+ print(f" [Attempt 2] Loading from HuggingFace repo: {MODEL_REPO}...")
186
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
187
+ MODEL_REPO,
188
+ **pipeline_kwargs
 
 
 
189
  ).to(device)
190
+ print(f" [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}")
191
  return pipe, True
192
  except Exception as e:
193
+ print(f" [WARNING] from_pretrained failed: {e}")
194
+
195
+ # ATTEMPT 3: Fallback (Base SDXL)
196
+ print(f" [Attempt 3] Loading base SDXL model...")
197
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
198
+ "stabilityai/stable-diffusion-xl-base-1.0",
199
+ **pipeline_kwargs
200
+ ).to(device)
201
+ print(" [OK] Base SDXL model loaded")
202
+ return pipe, False
203
 
204
 
205
  def load_lora(pipe):
206
+ """Load LORA (retroart) from HuggingFace Hub."""
207
  print("Loading LORA (retroart) from HuggingFace Hub...")
208
  try:
209
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
 
210
  pipe.load_lora_weights(lora_path, adapter_name="retroart")
211
  print(f" [OK] LORA loaded successfully")
212
  return True
 
215
  return False
216
 
217
 
218
+ def setup_ip_adapter(pipe):
219
  """
220
+ Setup IP-Adapter-FaceIDXL wrapper.
221
+ [FIXED] Does not take image_encoder_path.
222
  """
223
+ print("Setting up IP-Adapter-FaceIDXL...")
 
 
 
224
  try:
225
+ # Download the SDXL non-plus FaceID model
226
+ ip_ckpt_path = hf_hub_download(
227
+ repo_id="h94/IP-Adapter-FaceID",
228
+ filename="ip-adapter-faceid_sdxl.bin",
229
+ token=HUGGINGFACE_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
231
 
232
+ # --- [FIX] Instantiate without image_encoder_path ---
233
+ ip_model = IPAdapterFaceIDXL(pipe, ip_ckpt_path, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ print(" [OK] IPAdapterFaceIDXL wrapper initialized successfully.")
236
+ return ip_model, True
237
 
238
  except Exception as e:
239
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
 
323
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
324
 
325
 
326
+ print("[OK] Model loading functions ready (IP-Adapter-FaceIDXL / LCM VERSION)")