primerz commited on
Commit
61380fb
·
verified ·
1 Parent(s): c82ccd6

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -4
model.py CHANGED
@@ -8,6 +8,11 @@ from diffusers import (
8
  ControlNetModel,
9
  LCMScheduler
10
  )
 
 
 
 
 
11
  # Import the custom pipeline from your local file
12
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
13
 
@@ -82,6 +87,13 @@ class ModelHandler:
82
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
83
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
84
 
 
 
 
 
 
 
 
85
  # 3. Load SDXL Pipeline
86
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
87
 
@@ -100,7 +112,7 @@ class ModelHandler:
100
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
101
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
102
  checkpoint_local_path, # Pass the local path
103
- controlnet=[cn_instantid, cn_zoe, cn_lineart],
104
  torch_dtype=Config.DTYPE,
105
  use_safetensors=True
106
  )
@@ -147,9 +159,7 @@ class ModelHandler:
147
 
148
  try:
149
  # Convert PIL to CV2
150
- # --- START FIX ---
151
- cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # <-- Corrected typo
152
- # --- END FIX ---
153
  faces = self.app.get(cv2_img)
154
 
155
  if len(faces) == 0:
 
8
  ControlNetModel,
9
  LCMScheduler
10
  )
11
+ # --- START FIX ---
12
+ # Import the MultiControlNetModel wrapper
13
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
14
+ # --- END FIX ---
15
+
16
  # Import the custom pipeline from your local file
17
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
18
 
 
87
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
88
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
89
 
90
+ # --- START FIX for AssertionError ---
91
+ # Manually wrap the list of models in a MultiControlNetModel
92
+ print("Wrapping ControlNets in MultiControlNetModel...")
93
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
94
+ controlnet = MultiControlNetModel(controlnet_list)
95
+ # --- END FIX ---
96
+
97
  # 3. Load SDXL Pipeline
98
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
99
 
 
112
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
113
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
114
  checkpoint_local_path, # Pass the local path
115
+ controlnet=controlnet, # <-- Pass the single, wrapped object
116
  torch_dtype=Config.DTYPE,
117
  use_safetensors=True
118
  )
 
159
 
160
  try:
161
  # Convert PIL to CV2
162
+ cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
 
163
  faces = self.app.get(cv2_img)
164
 
165
  if len(faces) == 0: