primerz commited on
Commit
16dc50a
·
verified ·
1 Parent(s): a882f77

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +23 -13
model.py CHANGED
@@ -16,16 +16,17 @@ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInst
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
  # --- MODIFIED: Import new detectors ---
19
- from controlnet_aux import LeresDetector, LineartAnimeDetector
20
  # --- END MODIFIED ---
21
 
22
  class ModelHandler:
23
  def __init__(self):
24
  self.pipeline = None
25
  self.app = None # InsightFace
26
- # --- MODIFIED: Rename detectors ---
27
  self.leres_detector = None
28
  self.lineart_anime_detector = None
 
29
  # --- END MODIFIED ---
30
  self.face_analysis_loaded = False
31
 
@@ -72,7 +73,8 @@ class ModelHandler:
72
  self.face_analysis_loaded = self.load_face_analysis()
73
 
74
  # 2. Load ControlNets
75
- print("Loading ControlNets (InstantID, Zoe, LineArt)...")
 
76
 
77
  # Load the InstantID ControlNet from the correct subfolder
78
  print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
@@ -84,13 +86,17 @@ class ModelHandler:
84
  print(" [OK] Loaded InstantID ControlNet.")
85
 
86
  # Load other ControlNets normally
87
- print("Loading Zoe and LineArt ControlNets...")
 
88
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
89
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
 
 
90
 
91
  # --- Manually wrap the list of models in a MultiControlNetModel ---
92
  print("Wrapping ControlNets in MultiControlNetModel...")
93
- controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
 
94
  controlnet = MultiControlNetModel(controlnet_list)
95
  # --- End wrapping ---
96
 
@@ -119,6 +125,14 @@ class ModelHandler:
119
 
120
  self.pipeline.to(Config.DEVICE)
121
 
 
 
 
 
 
 
 
 
122
  # 4. Set Scheduler
123
  self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config)
124
 
@@ -150,18 +164,13 @@ class ModelHandler:
150
  print(" [OK] LoRA fused.")
151
 
152
  # --- DISABLED torch.compile due to runtime errors ---
153
- # try:
154
- # print("Compiling UNet with torch.compile...")
155
- # self.pipeline.unet = torch.compile(self.pipeline.unet, mode="reduce-overhead", fullgraph=True)
156
- # print(" [OK] UNet compiled.")
157
- # except Exception as e:
158
- # print(f" [WARNING] torch.compile failed: {e}. Running without compilation.")
159
 
160
  # 6. Load Preprocessors
161
  # --- MODIFIED: Load new detectors ---
162
- print("Loading Preprocessors (LeReS, LineArtAnime)...")
163
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
164
  self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
 
165
  # --- END MODIFIED ---
166
 
167
  print("--- All models loaded successfully ---")
@@ -187,4 +196,5 @@ class ModelHandler:
187
  return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
188
  except Exception as e:
189
  print(f"Face embedding extraction failed: {e}")
190
- return None
 
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
  # --- MODIFIED: Import new detectors ---
19
+ from controlnet_aux import LeresDetector, LineartAnimeDetector, ColorDetector
20
  # --- END MODIFIED ---
21
 
22
  class ModelHandler:
23
  def __init__(self):
24
  self.pipeline = None
25
  self.app = None # InsightFace
26
+ # --- MODIFIED: Add new detector ---
27
  self.leres_detector = None
28
  self.lineart_anime_detector = None
29
+ self.color_detector = None
30
  # --- END MODIFIED ---
31
  self.face_analysis_loaded = False
32
 
 
73
  self.face_analysis_loaded = self.load_face_analysis()
74
 
75
  # 2. Load ControlNets
76
+ # --- MODIFIED: Updated print ---
77
+ print("Loading ControlNets (InstantID, Zoe, LineArt, Color)...")
78
 
79
  # Load the InstantID ControlNet from the correct subfolder
80
  print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
 
86
  print(" [OK] Loaded InstantID ControlNet.")
87
 
88
  # Load other ControlNets normally
89
+ # --- MODIFIED: Load Color CN ---
90
+ print("Loading Zoe, LineArt, and Color ControlNets...")
91
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
92
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
93
+ cn_color = ControlNetModel.from_pretrained(Config.CN_COLOR_REPO, torch_dtype=Config.DTYPE)
94
+ # --- END MODIFIED ---
95
 
96
  # --- Manually wrap the list of models in a MultiControlNetModel ---
97
  print("Wrapping ControlNets in MultiControlNetModel...")
98
+ # --- MODIFIED: Add Color CN to list ---
99
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart, cn_color]
100
  controlnet = MultiControlNetModel(controlnet_list)
101
  # --- End wrapping ---
102
 
 
125
 
126
  self.pipeline.to(Config.DEVICE)
127
 
128
+ # --- NEW: Enable xFormers ---
129
+ try:
130
+ self.pipeline.enable_xformers_memory_efficient_attention()
131
+ print(" [OK] xFormers memory efficient attention enabled.")
132
+ except Exception as e:
133
+ print(f" [WARNING] Failed to enable xFormers: {e}")
134
+ # --- END NEW ---
135
+
136
  # 4. Set Scheduler
137
  self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config)
138
 
 
164
  print(" [OK] LoRA fused.")
165
 
166
  # --- DISABLED torch.compile due to runtime errors ---
 
 
 
 
 
 
167
 
168
  # 6. Load Preprocessors
169
  # --- MODIFIED: Load new detectors ---
170
+ print("Loading Preprocessors (LeReS, LineArtAnime, Color)...")
171
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
172
  self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
173
+ self.color_detector = ColorDetector()
174
  # --- END MODIFIED ---
175
 
176
  print("--- All models loaded successfully ---")
 
196
  return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
197
  except Exception as e:
198
  print(f"Face embedding extraction failed: {e}")
199
+ return None
200
+ }