primerz commited on
Commit
4a72459
·
verified ·
1 Parent(s): b225fb0

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -29
model.py CHANGED
@@ -15,14 +15,18 @@ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInst
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
- from processor import Processor
 
 
19
 
20
  class ModelHandler:
21
  def __init__(self):
22
  self.pipeline = None
23
  self.app = None # InsightFace
24
- self.zoe_detector = None
25
- self.lineart_detector = None
 
 
26
  self.face_analysis_loaded = False
27
 
28
  def load_face_analysis(self):
@@ -79,30 +83,14 @@ class ModelHandler:
79
  )
80
  print(" [OK] Loaded InstantID ControlNet.")
81
 
82
- print("Loading Depth and Edge ControlNets...")
83
-
84
- # --- FIX: Load the actual ControlNetModels, not the 'controlnet-aux' Processors ---
85
- # The 'Processor' class is for pre-processing (creating the maps),
86
- # not for the model itself. Your Config.DEPTH ("lesres++") is an ID
87
- # for the pre-processor, which is correctly used at the end of this file.
88
-
89
- cn_depth = ControlNetModel.from_pretrained(
90
- Config.DEPTH_REPO,
91
- torch_dtype=Config.DTYPE
92
- )
93
- print(f" [OK] Loaded Depth ControlNet ({DEPTH_MODEL_REPO}).")
94
-
95
- cn_edge = ControlNetModel.from_pretrained(
96
- Config.EDGE_REPO,
97
- torch_dtype=Config.DTYPE
98
- )
99
- print(f" [OK] Loaded Edge/LineArt ControlNet ({EDGE_MODEL_REPO}).")
100
-
101
- # --- End fix ---
102
 
103
  # --- Manually wrap the list of models in a MultiControlNetModel ---
104
  print("Wrapping ControlNets in MultiControlNetModel...")
105
- controlnet_list = [cn_instantid, cn_depth, cn_edge]
106
  controlnet = MultiControlNetModel(controlnet_list)
107
  # --- End wrapping ---
108
 
@@ -170,11 +158,11 @@ class ModelHandler:
170
  # print(f" [WARNING] torch.compile failed: {e}. Running without compilation.")
171
 
172
  # 6. Load Preprocessors
173
- # This is the CORRECT place to use the 'Processor' class
174
- # Config.DEPTH ("lesres++") and Config.EDGE are the annotator IDs
175
- print("Loading Preprocessors (Zoe, LineArt)...")
176
- self.depth_detector = Processor(Config.DEPTH)
177
- self.edge_detector = Processor(Config.EDGE)
178
 
179
  print("--- All models loaded successfully ---")
180
 
@@ -185,6 +173,7 @@ class ModelHandler:
185
 
186
  try:
187
  # Convert PIL to CV2
 
188
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
189
  faces = self.app.get(cv2_img)
190
 
 
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
+ # --- MODIFIED: Import new detectors ---
19
+ from controlnet_aux import LeresDetector, LineartAnimeDetector
20
+ # --- END MODIFIED ---
21
 
22
  class ModelHandler:
23
  def __init__(self):
24
  self.pipeline = None
25
  self.app = None # InsightFace
26
+ # --- MODIFIED: Rename detectors ---
27
+ self.leres_detector = None
28
+ self.lineart_anime_detector = None
29
+ # --- END MODIFIED ---
30
  self.face_analysis_loaded = False
31
 
32
  def load_face_analysis(self):
 
83
  )
84
  print(" [OK] Loaded InstantID ControlNet.")
85
 
86
+ # Load other ControlNets normally
87
+ print("Loading Zoe and LineArt ControlNets...")
88
+ cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
89
+ cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # --- Manually wrap the list of models in a MultiControlNetModel ---
92
  print("Wrapping ControlNets in MultiControlNetModel...")
93
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
94
  controlnet = MultiControlNetModel(controlnet_list)
95
  # --- End wrapping ---
96
 
 
158
  # print(f" [WARNING] torch.compile failed: {e}. Running without compilation.")
159
 
160
  # 6. Load Preprocessors
161
+ # --- MODIFIED: Load new detectors ---
162
+ print("Loading Preprocessors (LeReS, LineArtAnime)...")
163
+ self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
164
+ self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
165
+ # --- END MODIFIED ---
166
 
167
  print("--- All models loaded successfully ---")
168
 
 
173
 
174
  try:
175
  # Convert PIL to CV2
176
+ # --- FIX: Corrected OpenCV attribute ---
177
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
178
  faces = self.app.get(cv2_img)
179