primerz commited on
Commit
b305aed
·
verified ·
1 Parent(s): 9e16fa1

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +45 -2
models.py CHANGED
@@ -218,20 +218,64 @@ def load_image_encoder():
218
  def load_sdxl_pipeline(controlnets):
219
  """Load SDXL checkpoint from HuggingFace Hub."""
220
  print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  try:
222
  model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
223
 
 
 
224
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
225
  model_path,
226
  controlnet=controlnets,
227
  torch_dtype=dtype,
228
- use_safetensors=True
 
 
 
 
 
 
 
229
  ).to(device) # This main pipe MUST be on device
 
 
230
  print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
231
  return pipe, True
 
232
  except Exception as e:
233
  print(f" [WARNING] Could not load custom checkpoint: {e}")
234
  print(" Using default SDXL base model")
 
 
235
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
236
  "stabilityai/stable-diffusion-xl-base-1.0",
237
  controlnet=controlnets,
@@ -240,7 +284,6 @@ def load_sdxl_pipeline(controlnets):
240
  ).to(device) # This main pipe MUST be on device
241
  return pipe, False
242
 
243
-
244
  def load_loras(pipe):
245
  """Load all LORAs from HuggingFace Hub."""
246
  print("Loading all LORAs from HuggingFace Hub...")
 
218
  def load_sdxl_pipeline(controlnets):
219
  """Load SDXL checkpoint from HuggingFace Hub."""
220
  print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
221
+
222
+ # --- START FIX ---
223
+ # Load tokenizers and text encoders from the base model first
224
+ # This guarantees they exist, even if the single file doesn't have them
225
+ print(" Loading base tokenizers and text encoders...")
226
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
227
+
228
+ try:
229
+ tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
230
+ tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2")
231
+
232
+ text_encoder = CLIPTextModel.from_pretrained(
233
+ BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype
234
+ ).to(device)
235
+
236
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
237
+ BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype
238
+ ).to(device)
239
+ print(" [OK] Base text/token models loaded")
240
+
241
+ except Exception as e:
242
+ print(f" [ERROR] Could not load base text models: {e}")
243
+ print(" Pipeline will likely fail. Check HF connection/model access.")
244
+ # Allow it to continue, but it will likely fail below
245
+ tokenizer = None
246
+ tokenizer_2 = None
247
+ text_encoder = None
248
+ text_encoder_2 = None
249
+ # --- END FIX ---
250
+
251
  try:
252
  model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
253
 
254
+ # --- START FIX ---
255
+ # Pass the pre-loaded models to from_single_file
256
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
257
  model_path,
258
  controlnet=controlnets,
259
  torch_dtype=dtype,
260
+ use_safetensors=True,
261
+
262
+ # Explicitly provide the models
263
+ tokenizer=tokenizer,
264
+ tokenizer_2=tokenizer_2,
265
+ text_encoder=text_encoder,
266
+ text_encoder_2=text_encoder_2,
267
+
268
  ).to(device) # This main pipe MUST be on device
269
+ # --- END FIX ---
270
+
271
  print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
272
  return pipe, True
273
+
274
  except Exception as e:
275
  print(f" [WARNING] Could not load custom checkpoint: {e}")
276
  print(" Using default SDXL base model")
277
+
278
+ # The fallback logic is already correct
279
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
280
  "stabilityai/stable-diffusion-xl-base-1.0",
281
  controlnet=controlnets,
 
284
  ).to(device) # This main pipe MUST be on device
285
  return pipe, False
286
 
 
287
  def load_loras(pipe):
288
  """Load all LORAs from HuggingFace Hub."""
289
  print("Loading all LORAs from HuggingFace Hub...")