primerz commited on
Commit
bfd74f2
·
verified ·
1 Parent(s): 432102c

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +215 -214
models.py CHANGED
@@ -1,30 +1,24 @@
1
  """
2
- Models.py - Following examplewithface.py EXACTLY
3
- NO MultiControlNetModel wrapper!
4
- Using diffusers LoRA system (examplewithface.py lines 266-267)
5
  """
6
  import torch
7
- torch.jit.script = lambda f: f # Critical: Disable JIT for compatibility
8
  import time
9
  import os
10
  from diffusers import (
11
- ControlNetModel,
12
- AutoencoderKL,
13
- DPMSolverMultistepScheduler,
14
- LCMScheduler,
15
- UNet2DConditionModel
16
  )
 
 
17
  from insightface.app import FaceAnalysis
18
  from controlnet_aux import ZoeDetector
19
- from huggingface_hub import hf_hub_download, snapshot_download
20
- from safetensors.torch import load_file
21
  from compel import Compel, ReturnedEmbeddingsType
22
 
23
- from pipeline_stable_diffusion_xl_instantid_img2img import (
24
- StableDiffusionXLInstantIDImg2ImgPipeline,
25
- draw_kps
26
- )
27
- from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
28
 
29
  from config import (
30
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
@@ -33,206 +27,237 @@ from config import (
33
 
34
 
35
  def download_model_with_retry(repo_id, filename, max_retries=None):
 
36
  if max_retries is None:
37
  max_retries = DOWNLOAD_CONFIG['max_retries']
38
 
39
  for attempt in range(max_retries):
40
  try:
 
 
41
  kwargs = {"repo_type": "model"}
42
  if HUGGINGFACE_TOKEN:
43
  kwargs["token"] = HUGGINGFACE_TOKEN
44
 
45
- path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
 
 
 
 
 
46
  return path
 
47
  except Exception as e:
 
 
48
  if attempt < max_retries - 1:
 
49
  time.sleep(DOWNLOAD_CONFIG['retry_delay'])
50
  else:
 
51
  raise
 
52
  return None
53
 
54
 
55
  def load_face_analysis():
56
- """examplewithface.py line 113"""
57
- print("Loading face analysis...")
58
  try:
59
- # Download antelopev2 model
60
- snapshot_download(
61
- repo_id="DIAMONIK7777/antelopev2",
62
- local_dir="/data/models/antelopev2"
63
  )
64
-
65
- # examplewithface.py line 113 pattern
66
- app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
67
- app.prepare(ctx_id=0, det_size=(640, 640))
68
-
69
- print(" [OK] Face analysis loaded")
70
- return app, True
71
  except Exception as e:
72
- print(f" [ERROR] Face analysis failed: {e}")
73
- import traceback
74
- traceback.print_exc()
75
  return None, False
76
 
77
 
78
  def load_depth_detector():
79
- """examplewithface.py line 151-155"""
80
- print("Loading Zoe Depth...")
81
  try:
82
- zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
83
- zoe.to(device) # examplewithface.py line 155
84
- print(" [OK] Zoe Depth loaded")
85
- return zoe, True
86
  except Exception as e:
87
- print(f" [WARNING] Zoe unavailable: {e}")
88
  return None, False
89
 
90
 
91
  def load_controlnets():
92
- """examplewithface.py lines 122-126"""
93
- print("Loading ControlNets...")
94
-
95
- # Load but don't move to device yet - pipe.to(device) will handle it
96
- identitynet = ControlNetModel.from_pretrained(
97
- "InstantX/InstantID",
98
- subfolder="ControlNetModel",
99
- torch_dtype=dtype
100
- )
101
- print(" [OK] InstantID ControlNet")
102
-
103
- zoedepthnet = ControlNetModel.from_pretrained(
104
  "diffusers/controlnet-zoe-depth-sdxl-1.0",
105
  torch_dtype=dtype
106
- )
107
- print(" [OK] Zoe Depth ControlNet")
108
 
109
- return identitynet, zoedepthnet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  def load_sdxl_pipeline(controlnets):
113
  """
114
- examplewithface.py lines 128-145
115
- CRITICAL: Pass controlnets as LIST - NO MultiControlNetModel!
 
 
 
116
  """
117
- print("Loading pipeline...")
118
-
119
- model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
120
-
121
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
122
- model_path,
123
- controlnet=controlnets,
124
- torch_dtype=dtype,
125
- use_safetensors=True
126
- );
127
-
128
- print(" [OK] Pipeline created with direct controlnet list")
129
-
130
- # LCM scheduler
131
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
132
- print(" [OK] LCM scheduler")
133
 
134
- # IP-Adapter (line 139)
135
- ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
136
- pipe.load_ip_adapter_instantid(ip_adapter_path)
137
- pipe.set_ip_adapter_scale(0.8)
138
- print(" [OK] IP-Adapter loaded")
139
-
140
- # DEBUG: Check UNet configuration
141
- print(f" [DEBUG] UNet cross_attention_dim: {pipe.unet.config.cross_attention_dim}")
142
- if hasattr(pipe, 'image_proj_model'):
143
- print(f" [DEBUG] Resampler output_dim: {pipe.image_proj_model.proj_out.out_features}")
144
- else:
145
- print(f" [DEBUG] WARNING: No image_proj_model found!")
 
 
 
 
 
 
 
 
 
 
146
 
147
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- # DEBUG: Check text_encoder type
150
- print(f" [DEBUG] type(pipe.text_encoder): {type(pipe.text_encoder)}")
151
- print(f" [DEBUG] isinstance(pipe.text_encoder, list): {isinstance(pipe.text_encoder, list)}")
152
- if hasattr(pipe, 'text_encoder_2'):
153
- print(f" [DEBUG] type(pipe.text_encoder_2): {type(pipe.text_encoder_2)}")
 
 
 
 
 
 
 
 
 
154
 
155
- print(" [OK] Pipeline ready (following examplewithface.py EXACTLY)")
156
- return pipe, True
157
-
158
-
159
- # Global LoRA state
160
- lora_path_cached = None
 
 
 
 
161
 
162
 
163
  def load_lora(pipe):
164
- """Download and store LoRA path - actual loading will be done by Kohya loader"""
165
- print("Downloading LoRA...")
166
- global lora_path_cached
167
-
168
  try:
169
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
170
- lora_path_cached = lora_path
171
-
172
- print(f" [OK] LoRA path stored (will be loaded with Kohya loader during generation)")
173
  return True
174
  except Exception as e:
175
- print(f" [WARNING] LoRA download failed: {e}")
176
- import traceback
177
- traceback.print_exc()
178
  return False
179
 
180
 
181
- def fuse_lora_with_scale(pipe, lora_scale):
182
  """
183
- Following examplewithface.py lines 266-267 EXACTLY:
184
- pipe.load_lora_weights(loaded_state_dict)
185
- pipe.fuse_lora(lora_scale)
186
-
187
- Uses DIFFUSERS built-in LoRA (NOT Kohya lora.py!)
188
  """
189
- global lora_path_cached
190
-
191
- if lora_path_cached is None:
192
- return False
193
-
194
  try:
195
- # Unfuse previous LoRA (example line 259)
196
- try:
197
- pipe.unfuse_lora()
198
- except:
199
- pass
200
-
201
- # Unload previous LoRA (example line 260)
202
- try:
203
- pipe.unload_lora_weights()
204
- except:
205
- pass
206
 
207
- print(f" [LORA] Loading state dict from file...")
208
- # Load state dict like example (lines 75-78)
209
- if lora_path_cached.endswith('.safetensors'):
210
- from safetensors.torch import load_file
211
- state_dict = load_file(lora_path_cached)
212
- else:
213
- state_dict = torch.load(lora_path_cached, map_location="cpu")
214
 
215
- print(f" [LORA] Loading weights into pipeline...")
216
- # examplewithface.py line 266
217
- pipe.load_lora_weights(state_dict)
218
 
219
- # examplewithface.py line 267
220
- print(f" [LORA] Fusing with scale {lora_scale}...")
221
- pipe.fuse_lora(lora_scale)
222
 
223
- print(f" [OK] LoRA fused into model (diffusers method)")
224
  return True
225
 
226
  except Exception as e:
227
- print(f" [ERROR] LoRA fusion failed: {e}")
228
  import traceback
229
  traceback.print_exc()
230
  return False
231
 
232
 
233
  def setup_compel(pipe):
234
- """examplewithface.py line 145"""
235
- print("Setting up Compel...")
236
  try:
237
  compel = Compel(
238
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
@@ -240,99 +265,75 @@ def setup_compel(pipe):
240
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
241
  requires_pooled=[False, True]
242
  )
243
- print(" [OK] Compel ready")
244
  return compel, True
245
  except Exception as e:
246
- print(f" [WARNING] Compel unavailable: {e}")
247
  return None, False
248
 
249
 
250
  def setup_scheduler(pipe):
251
- pass
 
 
 
252
 
253
 
254
  def optimize_pipeline(pipe):
 
 
255
  if device == "cuda":
256
  try:
257
  pipe.enable_xformers_memory_efficient_attention()
258
  print(" [OK] xformers enabled")
259
- except:
260
- pass
261
-
262
- if hasattr(pipe, 'enable_vae_slicing'):
263
- pipe.enable_vae_slicing()
264
- if hasattr(pipe, 'enable_vae_tiling'):
265
- pipe.enable_vae_tiling()
266
 
267
 
268
  def load_caption_model():
 
 
 
 
269
  print("Loading caption model...")
 
 
270
  try:
271
  from transformers import AutoProcessor, AutoModelForCausalLM
272
- processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
273
- model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu")
274
- print(" [OK] GIT-Large")
275
- return processor, model, True, 'git'
276
- except:
 
 
 
 
 
 
 
 
277
  try:
278
  from transformers import BlipProcessor, BlipForConditionalGeneration
279
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
280
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu")
281
- print(" [OK] BLIP")
282
- return processor, model, True, 'blip'
283
- except:
 
 
 
 
 
 
 
284
  return None, None, False, 'none'
285
 
286
 
287
  def set_clip_skip(pipe):
 
288
  if hasattr(pipe, 'text_encoder'):
289
- print(f" [OK] CLIP skip {CLIP_SKIP}")
290
-
291
-
292
- def load_image_encoder():
293
- """Load CLIP Image Encoder for IP-Adapter."""
294
- print("Loading CLIP Image Encoder for IP-Adapter...")
295
- try:
296
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
297
- "h94/IP-Adapter",
298
- subfolder="models/image_encoder",
299
- torch_dtype=dtype
300
- ).to(device)
301
- print(" [OK] CLIP Image Encoder loaded successfully")
302
- return image_encoder
303
- except Exception as e:
304
- print(f" [ERROR] Could not load image encoder: {e}")
305
- return None
306
-
307
- def setup_ip_adapter(pipe):
308
- """
309
- Setup IP-Adapter for InstantID - SIMPLIFIED VERSION.
310
- Uses the pipeline's built-in method like exampleapp.py.
311
- """
312
- print("Setting up IP-Adapter for InstantID face embeddings...")
313
- try:
314
- # Download InstantID weights
315
- face_adapter_path = download_model_with_retry(
316
- "InstantX/InstantID",
317
- "ip-adapter.bin"
318
- )
319
-
320
- # Use the pipeline's built-in method (like exampleapp.py line 139)
321
- pipe.load_ip_adapter_instantid(face_adapter_path)
322
-
323
- # Set initial scale (like exampleapp.py line 140)
324
- pipe.set_ip_adapter_scale(0.8)
325
-
326
- print(" [OK] IP-Adapter loaded successfully with built-in method")
327
- return True
328
-
329
- except Exception as e:
330
- print(f" [ERROR] Could not setup IP-Adapter: {e}")
331
- import traceback
332
- traceback.print_exc()
333
- return False
334
-
335
 
336
- __all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder', 'setup_ip_adapter']
337
 
338
- print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py")
 
1
  """
2
+ Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ HYBRID VERSION - Supports both local files and HuggingFace repos
 
4
  """
5
  import torch
 
6
  import time
7
  import os
8
  from diffusers import (
9
+ ControlNetModel,
10
+ AutoencoderKL,
11
+ LCMScheduler
 
 
12
  )
13
+ from diffusers.models.attention_processor import AttnProcessor2_0
14
+ from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
  from controlnet_aux import ZoeDetector
17
+ from huggingface_hub import hf_hub_download
 
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
20
+ # Import the custom pipeline that has load_ip_adapter_instantid method
21
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
 
 
 
22
 
23
  from config import (
24
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
 
27
 
28
 
29
  def download_model_with_retry(repo_id, filename, max_retries=None):
30
+ """Download model with retry logic and proper token handling."""
31
  if max_retries is None:
32
  max_retries = DOWNLOAD_CONFIG['max_retries']
33
 
34
  for attempt in range(max_retries):
35
  try:
36
+ print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
37
+
38
  kwargs = {"repo_type": "model"}
39
  if HUGGINGFACE_TOKEN:
40
  kwargs["token"] = HUGGINGFACE_TOKEN
41
 
42
+ path = hf_hub_download(
43
+ repo_id=repo_id,
44
+ filename=filename,
45
+ **kwargs
46
+ )
47
+ print(f" [OK] Downloaded: {filename}")
48
  return path
49
+
50
  except Exception as e:
51
+ print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
52
+
53
  if attempt < max_retries - 1:
54
+ print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
55
  time.sleep(DOWNLOAD_CONFIG['retry_delay'])
56
  else:
57
+ print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
58
  raise
59
+
60
  return None
61
 
62
 
63
  def load_face_analysis():
64
+ """Load face analysis model with proper error handling."""
65
+ print("Loading face analysis model...")
66
  try:
67
+ face_app = FaceAnalysis(
68
+ name=FACE_DETECTION_CONFIG['model_name'],
69
+ root='./models/insightface',
70
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
71
  )
72
+ face_app.prepare(
73
+ ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
74
+ det_size=FACE_DETECTION_CONFIG['det_size']
75
+ )
76
+ print(" [OK] Face analysis model loaded successfully")
77
+ return face_app, True
 
78
  except Exception as e:
79
+ print(f" [WARNING] Face detection not available: {e}")
 
 
80
  return None, False
81
 
82
 
83
  def load_depth_detector():
84
+ """Load Zoe Depth detector."""
85
+ print("Loading Zoe Depth detector...")
86
  try:
87
+ zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
88
+ zoe_depth.to(device)
89
+ print(" [OK] Zoe Depth loaded successfully")
90
+ return zoe_depth, True
91
  except Exception as e:
92
+ print(f" [WARNING] Zoe Depth not available: {e}")
93
  return None, False
94
 
95
 
96
  def load_controlnets():
97
+ """Load ControlNet models."""
98
+ print("Loading ControlNet Zoe Depth model...")
99
+ controlnet_depth = ControlNetModel.from_pretrained(
 
 
 
 
 
 
 
 
 
100
  "diffusers/controlnet-zoe-depth-sdxl-1.0",
101
  torch_dtype=dtype
102
+ ).to(device)
103
+ print(" [OK] ControlNet Depth loaded")
104
 
105
+ print("Loading InstantID ControlNet...")
106
+ try:
107
+ controlnet_instantid = ControlNetModel.from_pretrained(
108
+ "InstantX/InstantID",
109
+ subfolder="ControlNetModel",
110
+ torch_dtype=dtype
111
+ ).to(device)
112
+ print(" [OK] InstantID ControlNet loaded successfully")
113
+ return controlnet_depth, controlnet_instantid, True
114
+ except Exception as e:
115
+ print(f" [WARNING] InstantID ControlNet not available: {e}")
116
+ return controlnet_depth, None, False
117
+
118
+
119
+ def load_image_encoder():
120
+ """Load CLIP Image Encoder for IP-Adapter."""
121
+ print("Loading CLIP Image Encoder for IP-Adapter...")
122
+ try:
123
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
124
+ "h94/IP-Adapter",
125
+ subfolder="models/image_encoder",
126
+ torch_dtype=dtype
127
+ ).to(device)
128
+ print(" [OK] CLIP Image Encoder loaded successfully")
129
+ return image_encoder
130
+ except Exception as e:
131
+ print(f" [ERROR] Could not load image encoder: {e}")
132
+ return None
133
 
134
 
135
  def load_sdxl_pipeline(controlnets):
136
  """
137
+ Load SDXL checkpoint - HYBRID APPROACH.
138
+ Tries in order:
139
+ 1. Local file via from_single_file (like examplemodels.py)
140
+ 2. HuggingFace repo via from_pretrained (like exampleapp.py)
141
+ 3. Fallback to known working checkpoint
142
  """
143
+ print("Loading SDXL checkpoint (hybrid approach)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # ATTEMPT 1: Try loading from local file using from_single_file
146
+ # This is the examplemodels.py approach
147
+ if MODEL_FILES.get('checkpoint'):
148
+ try:
149
+ print(f" [Attempt 1] Loading from local file via from_single_file...")
150
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
151
+
152
+ # Check if file exists and is a safetensors file
153
+ if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'):
154
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
155
+ model_path,
156
+ controlnet=controlnets,
157
+ torch_dtype=dtype,
158
+ use_safetensors=True
159
+ ).to(device)
160
+ print(f" [OK] Checkpoint loaded from local file: {model_path}")
161
+ return pipe, True
162
+ else:
163
+ print(f" [INFO] Local file not found or invalid, trying next method...")
164
+ except Exception as e:
165
+ print(f" [WARNING] from_single_file failed: {e}")
166
+ print(f" [INFO] Trying from_pretrained approach...")
167
 
168
+ # ATTEMPT 2: Try loading from HuggingFace repo using from_pretrained
169
+ # This is the exampleapp.py approach
170
+ try:
171
+ print(f" [Attempt 2] Loading from HuggingFace repo via from_pretrained...")
172
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
173
+ MODEL_REPO,
174
+ controlnet=controlnets,
175
+ torch_dtype=dtype,
176
+ use_safetensors=True
177
+ ).to(device)
178
+ print(f" [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}")
179
+ return pipe, True
180
+ except Exception as e:
181
+ print(f" [WARNING] from_pretrained failed: {e}")
182
+ print(f" [INFO] Trying fallback checkpoint...")
183
 
184
+ # ATTEMPT 3: Fallback to known working checkpoint
185
+ try:
186
+ print(f" [Attempt 3] Loading fallback: frankjoshua/albedobaseXL_v21...")
187
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
188
+ "frankjoshua/albedobaseXL_v21",
189
+ controlnet=controlnets,
190
+ torch_dtype=dtype,
191
+ use_safetensors=True
192
+ ).to(device)
193
+ print(" [OK] Fallback checkpoint loaded successfully")
194
+ return pipe, False
195
+ except Exception as e:
196
+ print(f" [WARNING] Fallback also failed: {e}")
197
+ print(" [INFO] Trying SDXL base model...")
198
 
199
+ # ATTEMPT 4: Last resort - SDXL base
200
+ print(f" [Attempt 4] Loading base SDXL model...")
201
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
202
+ "stabilityai/stable-diffusion-xl-base-1.0",
203
+ controlnet=controlnets,
204
+ torch_dtype=dtype,
205
+ use_safetensors=True
206
+ ).to(device)
207
+ print(" [OK] Base SDXL model loaded")
208
+ return pipe, False
209
 
210
 
211
  def load_lora(pipe):
212
+ """Load LORA from HuggingFace Hub."""
213
+ print("Loading LORA (retroart) from HuggingFace Hub...")
 
 
214
  try:
215
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
216
+ pipe.load_lora_weights(lora_path, adapter_name="retroart")
217
+ print(f" [OK] LORA loaded successfully")
 
218
  return True
219
  except Exception as e:
220
+ print(f" [WARNING] Could not load LORA: {e}")
 
 
221
  return False
222
 
223
 
224
+ def setup_ip_adapter(pipe):
225
  """
226
+ Setup IP-Adapter for InstantID - SIMPLIFIED VERSION.
227
+ Uses pipeline's built-in method (like exampleapp.py lines 139-140).
228
+ This is much simpler and more reliable than manual Resampler setup.
 
 
229
  """
230
+ print("Setting up IP-Adapter for InstantID face embeddings...")
 
 
 
 
231
  try:
232
+ # Download InstantID IP-Adapter weights
233
+ face_adapter_path = download_model_with_retry(
234
+ "InstantX/InstantID",
235
+ "ip-adapter.bin"
236
+ )
 
 
 
 
 
 
237
 
238
+ # Use the pipeline's built-in method
239
+ # This handles all the complex Resampler setup automatically
240
+ pipe.load_ip_adapter_instantid(face_adapter_path)
 
 
 
 
241
 
242
+ # Set initial scale (can be adjusted later during generation)
243
+ pipe.set_ip_adapter_scale(0.8)
 
244
 
245
+ print(" [OK] IP-Adapter loaded successfully with built-in method")
246
+ print(" - Pipeline handles Resampler and attention processors automatically")
247
+ print(" - Face embeddings will be properly integrated during generation")
248
 
 
249
  return True
250
 
251
  except Exception as e:
252
+ print(f" [ERROR] Could not setup IP-Adapter: {e}")
253
  import traceback
254
  traceback.print_exc()
255
  return False
256
 
257
 
258
  def setup_compel(pipe):
259
+ """Setup Compel for better SDXL prompt handling."""
260
+ print("Setting up Compel for enhanced prompt processing...")
261
  try:
262
  compel = Compel(
263
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
 
265
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
266
  requires_pooled=[False, True]
267
  )
268
+ print(" [OK] Compel loaded successfully")
269
  return compel, True
270
  except Exception as e:
271
+ print(f" [WARNING] Compel not available: {e}")
272
  return None, False
273
 
274
 
275
  def setup_scheduler(pipe):
276
+ """Setup LCM scheduler."""
277
+ print("Setting up LCM scheduler...")
278
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
279
+ print(" [OK] LCM scheduler configured")
280
 
281
 
282
  def optimize_pipeline(pipe):
283
+ """Apply optimizations to pipeline."""
284
+ # Try to enable xformers
285
  if device == "cuda":
286
  try:
287
  pipe.enable_xformers_memory_efficient_attention()
288
  print(" [OK] xformers enabled")
289
+ except Exception as e:
290
+ print(f" [INFO] xformers not available: {e}")
 
 
 
 
 
291
 
292
 
293
  def load_caption_model():
294
+ """
295
+ Load caption model with proper error handling.
296
+ Tries multiple models in order of quality.
297
+ """
298
  print("Loading caption model...")
299
+
300
+ # Try GIT-Large first (good balance of quality and compatibility)
301
  try:
302
  from transformers import AutoProcessor, AutoModelForCausalLM
303
+
304
+ print(" Attempting GIT-Large (recommended)...")
305
+ caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
306
+ caption_model = AutoModelForCausalLM.from_pretrained(
307
+ "microsoft/git-large-coco",
308
+ torch_dtype=dtype
309
+ ).to(device)
310
+ print(" [OK] GIT-Large model loaded (produces detailed captions)")
311
+ return caption_processor, caption_model, True, 'git'
312
+ except Exception as e1:
313
+ print(f" [INFO] GIT-Large not available: {e1}")
314
+
315
+ # Try BLIP base as fallback
316
  try:
317
  from transformers import BlipProcessor, BlipForConditionalGeneration
318
+
319
+ print(" Attempting BLIP base (fallback)...")
320
+ caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
321
+ caption_model = BlipForConditionalGeneration.from_pretrained(
322
+ "Salesforce/blip-image-captioning-base",
323
+ torch_dtype=dtype
324
+ ).to(device)
325
+ print(" [OK] BLIP base model loaded (standard captions)")
326
+ return caption_processor, caption_model, True, 'blip'
327
+ except Exception as e2:
328
+ print(f" [WARNING] Caption models not available: {e2}")
329
+ print(" Caption generation will be disabled")
330
  return None, None, False, 'none'
331
 
332
 
333
  def set_clip_skip(pipe):
334
+ """Set CLIP skip value."""
335
  if hasattr(pipe, 'text_encoder'):
336
+ print(f" [OK] CLIP skip set to {CLIP_SKIP}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
 
338
 
339
+ print("[OK] Model loading functions ready (HYBRID VERSION)")