primerz commited on
Commit
75da5ce
·
verified ·
1 Parent(s): 867d605

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +168 -122
models.py CHANGED
@@ -1,34 +1,25 @@
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
- HYBRID VERSION - Supports both local files and HuggingFace repos
4
- MODIFIED for IP-Adapter-FaceIDXL (non-plus) and LCM Scheduler
5
  """
6
  import torch
7
  import time
8
- import os
9
  from diffusers import (
 
10
  ControlNetModel,
11
  AutoencoderKL,
12
- LCMScheduler, # Changed back to LCM
13
- StableDiffusionXLControlNetImg2ImgPipeline
14
  )
15
  from diffusers.models.attention_processor import AttnProcessor2_0
16
- from transformers import CLIPVisionModelWithProjection, pipeline
17
  from insightface.app import FaceAnalysis
18
- from controlnet_aux import LeresDetector, CannyDetector
19
  from huggingface_hub import hf_hub_download
20
  from compel import Compel, ReturnedEmbeddingsType
21
 
22
- # Import the IP-Adapter wrapper classes
23
- try:
24
- # Import base class and the specific SDXL class
25
- from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDXL
26
- except ImportError:
27
- print("="*80)
28
- print("[FATAL ERROR] `ip_adapter` library not found.")
29
- print("Please install it: pip install ip-adapter")
30
- print("="*80)
31
- raise
32
 
33
  from config import (
34
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
@@ -71,19 +62,19 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
71
 
72
 
73
  def load_face_analysis():
74
- """Load face analysis model (buffalo_l) with proper error handling."""
75
- print("Loading face analysis model (buffalo_l)...")
76
  try:
77
  face_app = FaceAnalysis(
78
- name='buffalo_l', # Changed from antelopev2
79
- root='/data',
80
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
81
  )
82
  face_app.prepare(
83
- ctx_id=0,
84
- det_size=(640, 640)
85
  )
86
- print(" [OK] Face analysis model (buffalo_l) loaded successfully")
87
  return face_app, True
88
  except Exception as e:
89
  print(f" [WARNING] Face detection not available: {e}")
@@ -91,122 +82,89 @@ def load_face_analysis():
91
 
92
 
93
  def load_depth_detector():
94
- """Load LeReS++ Depth detector."""
95
- print("Loading LeReS++ detector...")
96
  try:
97
- leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
98
- leres.to(device)
99
- print(" [OK] LeReS++ loaded successfully")
100
- return leres, True
101
  except Exception as e:
102
- print(f" [WARNING] LeReS++ not available: {e}")
103
- return None, False
104
-
105
- def load_canny_detector():
106
- """Load Canny detector."""
107
- print("Loading Canny detector...")
108
- try:
109
- canny = CannyDetector()
110
- print(" [OK] Canny loaded successfully")
111
- return canny, True
112
- except Exception as e:
113
- print(f" [WARNING] Canny detector not available: {e}")
114
  return None, False
115
 
116
 
117
  def load_controlnets():
118
- """Load ControlNet models for Depth and Canny."""
119
- print("Loading ControlNet Depth model...")
120
  controlnet_depth = ControlNetModel.from_pretrained(
121
- "diffusers/controlnet-depth-sdxl-1.0", # Standard depth model
122
  torch_dtype=dtype
123
  ).to(device)
124
  print(" [OK] ControlNet Depth loaded")
125
 
126
- print("Loading ControlNet Canny model...")
127
  try:
128
- controlnet_canny = ControlNetModel.from_pretrained(
129
- "diffusers/controlnet-canny-sdxl-1.0",
 
130
  torch_dtype=dtype
131
  ).to(device)
132
- print(" [OK] ControlNet Canny loaded successfully")
133
- return controlnet_depth, controlnet_canny, True
134
  except Exception as e:
135
- print(f" [WARNING] ControlNet Canny not available: {e}")
136
  return controlnet_depth, None, False
137
 
138
 
139
  def load_image_encoder():
140
- """
141
- [DEPRECATED] This function is no longer needed by IPAdapterFaceIDXL,
142
- but we keep it here in case other components need it.
143
- It will not be called by the generator.
144
- """
145
- print("Loading CLIP Image Encoder [SKIPPED - Not required by IPAdapterFaceIDXL]")
146
- return None
 
 
 
 
 
 
147
 
148
 
149
  def load_sdxl_pipeline(controlnets):
150
- """
151
- Load SDXL checkpoint - MODIFIED for LCM and built-in VAE.
152
- """
153
-
154
- # --- VAE LOADING REMOVED ---
155
- # We are using the VAE built into the "horizon" checkpoint.
156
- print("Loading SDXL checkpoint (using built-in VAE)...")
157
-
158
- pipeline_kwargs = {
159
- "controlnet": controlnets,
160
- "torch_dtype": dtype,
161
- "use_safetensors": True,
162
- # "vae": None, # <--- This line was correctly removed
163
- }
164
-
165
- # ATTEMPT 1: Try loading from local file (This should be your "horizon" checkpoint)
166
- if MODEL_FILES.get('checkpoint'):
167
- try:
168
- print(f" [Attempt 1] Loading from local file: {MODEL_FILES['checkpoint']}...")
169
- model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
170
-
171
- if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'):
172
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
173
- model_path,
174
- **pipeline_kwargs
175
- ).to(device)
176
- print(f" [OK] Checkpoint loaded from local file: {model_path}")
177
- return pipe, True
178
- else:
179
- print(f" [INFO] Local file not found or invalid...")
180
- except Exception as e:
181
- print(f" [WARNING] from_single_file failed: {e}")
182
-
183
- # ATTEMPT 2: Try loading from HuggingFace repo
184
  try:
185
- print(f" [Attempt 2] Loading from HuggingFace repo: {MODEL_REPO}...")
186
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
187
- MODEL_REPO,
188
- **pipeline_kwargs
 
 
 
189
  ).to(device)
190
- print(f" [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}")
191
  return pipe, True
192
  except Exception as e:
193
- print(f" [WARNING] from_pretrained failed: {e}")
194
-
195
- # ATTEMPT 3: Fallback (Base SDXL)
196
- print(f" [Attempt 3] Loading base SDXL model...")
197
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
198
- "stabilityai/stable-diffusion-xl-base-1.0",
199
- **pipeline_kwargs
200
- ).to(device)
201
- print(" [OK] Base SDXL model loaded")
202
- return pipe, False
203
 
204
 
205
  def load_lora(pipe):
206
- """Load LORA (retroart) from HuggingFace Hub."""
207
  print("Loading LORA (retroart) from HuggingFace Hub...")
208
  try:
209
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
 
210
  pipe.load_lora_weights(lora_path, adapter_name="retroart")
211
  print(f" [OK] LORA loaded successfully")
212
  return True
@@ -215,25 +173,113 @@ def load_lora(pipe):
215
  return False
216
 
217
 
218
- def setup_ip_adapter(pipe):
219
  """
220
- Setup IP-Adapter-FaceIDXL wrapper.
221
- [FIXED] Does not take image_encoder_path.
222
  """
223
- print("Setting up IP-Adapter-FaceIDXL...")
 
 
 
224
  try:
225
- # Download the SDXL non-plus FaceID model
226
- ip_ckpt_path = hf_hub_download(
227
- repo_id="h94/IP-Adapter-FaceID",
228
- filename="ip-adapter-faceid_sdxl.bin",
229
- token=HUGGINGFACE_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
231
 
232
- # --- [FIX] Instantiate without image_encoder_path ---
233
- ip_model = IPAdapterFaceIDXL(pipe, ip_ckpt_path, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- print(" [OK] IPAdapterFaceIDXL wrapper initialized successfully.")
236
- return ip_model, True
237
 
238
  except Exception as e:
239
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
@@ -323,4 +369,4 @@ def set_clip_skip(pipe):
323
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
324
 
325
 
326
- print("[OK] Model loading functions ready (IP-Adapter-FaceIDXL / LCM VERSION)")
 
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ FIXED VERSION with proper IP-Adapter and BLIP-2 support
 
4
  """
5
  import torch
6
  import time
 
7
  from diffusers import (
8
+ StableDiffusionXLControlNetImg2ImgPipeline,
9
  ControlNetModel,
10
  AutoencoderKL,
11
+ LCMScheduler
 
12
  )
13
  from diffusers.models.attention_processor import AttnProcessor2_0
14
+ from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
+ from controlnet_aux import ZoeDetector
17
  from huggingface_hub import hf_hub_download
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
20
+ # Use reference implementation's attention processor
21
+ from attention_processor import IPAttnProcessor2_0, AttnProcessor
22
+ from resampler import Resampler
 
 
 
 
 
 
 
23
 
24
  from config import (
25
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
 
62
 
63
 
64
  def load_face_analysis():
65
+ """Load face analysis model with proper error handling."""
66
+ print("Loading face analysis model...")
67
  try:
68
  face_app = FaceAnalysis(
69
+ name=FACE_DETECTION_CONFIG['model_name'],
70
+ root='./models/insightface',
71
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
72
  )
73
  face_app.prepare(
74
+ ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
75
+ det_size=FACE_DETECTION_CONFIG['det_size']
76
  )
77
+ print(" [OK] Face analysis model loaded successfully")
78
  return face_app, True
79
  except Exception as e:
80
  print(f" [WARNING] Face detection not available: {e}")
 
82
 
83
 
84
  def load_depth_detector():
85
+ """Load Zoe Depth detector."""
86
+ print("Loading Zoe Depth detector...")
87
  try:
88
+ zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
89
+ zoe_depth.to(device)
90
+ print(" [OK] Zoe Depth loaded successfully")
91
+ return zoe_depth, True
92
  except Exception as e:
93
+ print(f" [WARNING] Zoe Depth not available: {e}")
 
 
 
 
 
 
 
 
 
 
 
94
  return None, False
95
 
96
 
97
  def load_controlnets():
98
+ """Load ControlNet models."""
99
+ print("Loading ControlNet Zoe Depth model...")
100
  controlnet_depth = ControlNetModel.from_pretrained(
101
+ "diffusers/controlnet-zoe-depth-sdxl-1.0",
102
  torch_dtype=dtype
103
  ).to(device)
104
  print(" [OK] ControlNet Depth loaded")
105
 
106
+ print("Loading InstantID ControlNet...")
107
  try:
108
+ controlnet_instantid = ControlNetModel.from_pretrained(
109
+ "InstantX/InstantID",
110
+ subfolder="ControlNetModel",
111
  torch_dtype=dtype
112
  ).to(device)
113
+ print(" [OK] InstantID ControlNet loaded successfully")
114
+ return controlnet_depth, controlnet_instantid, True
115
  except Exception as e:
116
+ print(f" [WARNING] InstantID ControlNet not available: {e}")
117
  return controlnet_depth, None, False
118
 
119
 
120
  def load_image_encoder():
121
+ """Load CLIP Image Encoder for IP-Adapter."""
122
+ print("Loading CLIP Image Encoder for IP-Adapter...")
123
+ try:
124
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
125
+ "h94/IP-Adapter",
126
+ subfolder="models/image_encoder",
127
+ torch_dtype=dtype
128
+ ).to(device)
129
+ print(" [OK] CLIP Image Encoder loaded successfully")
130
+ return image_encoder
131
+ except Exception as e:
132
+ print(f" [ERROR] Could not load image encoder: {e}")
133
+ return None
134
 
135
 
136
  def load_sdxl_pipeline(controlnets):
137
+ """Load SDXL checkpoint from HuggingFace Hub."""
138
+ print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  try:
140
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
141
+
142
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
143
+ model_path,
144
+ controlnet=controlnets,
145
+ torch_dtype=dtype,
146
+ use_safetensors=True
147
  ).to(device)
148
+ print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
149
  return pipe, True
150
  except Exception as e:
151
+ print(f" [WARNING] Could not load custom checkpoint: {e}")
152
+ print(" Using default SDXL base model")
153
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
154
+ "stabilityai/stable-diffusion-xl-base-1.0",
155
+ controlnet=controlnets,
156
+ torch_dtype=dtype,
157
+ use_safetensors=True
158
+ ).to(device)
159
+ return pipe, False
 
160
 
161
 
162
  def load_lora(pipe):
163
+ """Load LORA from HuggingFace Hub."""
164
  print("Loading LORA (retroart) from HuggingFace Hub...")
165
  try:
166
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
167
+ # **FIX 2: Add adapter_name="retroart"**
168
  pipe.load_lora_weights(lora_path, adapter_name="retroart")
169
  print(f" [OK] LORA loaded successfully")
170
  return True
 
173
  return False
174
 
175
 
176
+ def setup_ip_adapter(pipe, image_encoder):
177
  """
178
+ Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
179
+ Based on the reference InstantID pipeline.
180
  """
181
+ if image_encoder is None:
182
+ return None, False
183
+
184
+ print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
185
  try:
186
+ # Download InstantID weights
187
+ ip_adapter_path = download_model_with_retry(
188
+ "InstantX/InstantID",
189
+ "ip-adapter.bin"
190
+ )
191
+
192
+ # Load full state dict
193
+ state_dict = torch.load(ip_adapter_path, map_location="cpu")
194
+
195
+ # Extract image_proj and ip_adapter weights
196
+ image_proj_state_dict = {}
197
+ ip_adapter_state_dict = {}
198
+
199
+ for key, value in state_dict.items():
200
+ if key.startswith("image_proj."):
201
+ image_proj_state_dict[key.replace("image_proj.", "")] = value
202
+ elif key.startswith("ip_adapter."):
203
+ ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
204
+
205
+ # Create Resampler (image projection model) with CORRECT parameters from reference
206
+ print("Creating Resampler (Perceiver architecture)...")
207
+ image_proj_model = Resampler(
208
+ dim=1280, # Hidden dimension
209
+ depth=4, # IMPORTANT: 4 layers (not 8!)
210
+ dim_head=64, # Dimension per head
211
+ heads=20, # Number of heads
212
+ num_queries=16, # Number of output tokens
213
+ embedding_dim=512, # InsightFace embedding dim
214
+ output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048)
215
+ ff_mult=4 # Feedforward multiplier
216
  )
217
 
218
+ image_proj_model.eval()
219
+ image_proj_model = image_proj_model.to(device, dtype=dtype)
220
+
221
+ # Load image_proj weights
222
+ if image_proj_state_dict:
223
+ try:
224
+ image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
225
+ print(" [OK] Resampler loaded with pretrained weights")
226
+ except Exception as e:
227
+ print(f" [WARNING] Could not load Resampler weights: {e}")
228
+ print(" Using randomly initialized Resampler")
229
+ else:
230
+ print(" [WARNING] No image_proj weights found, using random initialization")
231
+
232
+ # Setup IP-Adapter attention processors
233
+ print("Setting up IP-Adapter attention processors...")
234
+ attn_procs = {}
235
+ num_tokens = 16 # Match Resampler num_queries
236
+
237
+ for name in pipe.unet.attn_processors.keys():
238
+ cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
239
+
240
+ if name.startswith("mid_block"):
241
+ hidden_size = pipe.unet.config.block_out_channels[-1]
242
+ elif name.startswith("up_blocks"):
243
+ block_id = int(name[len("up_blocks.")])
244
+ hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
245
+ elif name.startswith("down_blocks"):
246
+ block_id = int(name[len("down_blocks.")])
247
+ hidden_size = pipe.unet.config.block_out_channels[block_id]
248
+ else:
249
+ hidden_size = pipe.unet.config.block_out_channels[-1]
250
+
251
+ if cross_attention_dim is None:
252
+ attn_procs[name] = AttnProcessor2_0()
253
+ else:
254
+ attn_procs[name] = IPAttnProcessor2_0(
255
+ hidden_size=hidden_size,
256
+ cross_attention_dim=cross_attention_dim,
257
+ scale=1.0,
258
+ num_tokens=num_tokens
259
+ ).to(device, dtype=dtype)
260
+
261
+ # Set attention processors
262
+ pipe.unet.set_attn_processor(attn_procs)
263
+
264
+ # Load IP-Adapter weights into attention processors
265
+ if ip_adapter_state_dict:
266
+ try:
267
+ ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
268
+ ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
269
+ print(" [OK] IP-Adapter attention weights loaded")
270
+ except Exception as e:
271
+ print(f" [WARNING] Could not load IP-Adapter weights: {e}")
272
+ else:
273
+ print(" [WARNING] No ip_adapter weights found")
274
+
275
+ # Store image encoder and projection model
276
+ pipe.image_encoder = image_encoder
277
+
278
+ print(" [OK] IP-Adapter fully loaded with InstantID architecture")
279
+ print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
280
+ print(f" - Face embeddings: 512D → 16x2048D")
281
 
282
+ return image_proj_model, True
 
283
 
284
  except Exception as e:
285
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
 
369
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
370
 
371
 
372
+ print("[OK] Model loading functions ready")