primerz commited on
Commit
cc0ae1f
·
verified ·
1 Parent(s): 61380fb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +5 -6
model.py CHANGED
@@ -9,8 +9,8 @@ from diffusers import (
9
  LCMScheduler
10
  )
11
  # --- START FIX ---
12
- # Import the MultiControlNetModel wrapper
13
- from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
14
  # --- END FIX ---
15
 
16
  # Import the custom pipeline from your local file
@@ -87,12 +87,11 @@ class ModelHandler:
87
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
88
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
89
 
90
- # --- START FIX for AssertionError ---
91
- # Manually wrap the list of models in a MultiControlNetModel
92
  print("Wrapping ControlNets in MultiControlNetModel...")
93
  controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
94
  controlnet = MultiControlNetModel(controlnet_list)
95
- # --- END FIX ---
96
 
97
  # 3. Load SDXL Pipeline
98
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
@@ -112,7 +111,7 @@ class ModelHandler:
112
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
113
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
114
  checkpoint_local_path, # Pass the local path
115
- controlnet=controlnet, # <-- Pass the single, wrapped object
116
  torch_dtype=Config.DTYPE,
117
  use_safetensors=True
118
  )
 
9
  LCMScheduler
10
  )
11
  # --- START FIX ---
12
+ # Import MultiControlNetModel from its new, correct location
13
+ from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
14
  # --- END FIX ---
15
 
16
  # Import the custom pipeline from your local file
 
87
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
88
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
89
 
90
+ # --- Manually wrap the list of models in a MultiControlNetModel ---
 
91
  print("Wrapping ControlNets in MultiControlNetModel...")
92
  controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
93
  controlnet = MultiControlNetModel(controlnet_list)
94
+ # --- End wrapping ---
95
 
96
  # 3. Load SDXL Pipeline
97
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
 
111
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
112
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
113
  checkpoint_local_path, # Pass the local path
114
+ controlnet=controlnet, # Pass the single, wrapped object
115
  torch_dtype=Config.DTYPE,
116
  use_safetensors=True
117
  )