primerz commited on
Commit
c19d329
·
verified ·
1 Parent(s): 8ac26c4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +90 -33
model.py CHANGED
@@ -6,8 +6,7 @@ from config import Config
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
- TCDScheduler,
10
- AutoencoderKL # <-- ADDED: Import AutoencoderKL
11
  )
12
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
13
 
@@ -16,15 +15,17 @@ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInst
16
 
17
  from huggingface_hub import snapshot_download, hf_hub_download
18
  from insightface.app import FaceAnalysis
19
- from controlnet_aux import LeresDetector, LineartAnimeDetector
20
 
21
  class ModelHandler:
22
  def __init__(self):
23
  self.pipeline = None
24
- self.app = None # InsightFace
25
  self.leres_detector = None
26
  self.lineart_anime_detector = None
 
27
  self.face_analysis_loaded = False
 
28
 
29
  def load_face_analysis(self):
30
  """
@@ -40,7 +41,7 @@ class ModelHandler:
40
  try:
41
  snapshot_download(
42
  repo_id=Config.ANTELOPEV2_REPO,
43
- local_dir=model_path, # Download to the correct expected path
44
  )
45
  except Exception as e:
46
  print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
@@ -60,35 +61,65 @@ class ModelHandler:
60
  print(f" [WARNING] Face detection system failed to initialize: {e}")
61
  return False
62
 
63
- def load_models(self):
 
 
 
 
 
 
 
 
64
  # 1. Load Face Analysis
65
  self.face_analysis_loaded = self.load_face_analysis()
66
 
67
- # 2. Load ControlNets
68
- print("Loading ControlNets (InstantID, Zoe, LineArt)...")
69
  cn_instantid = ControlNetModel.from_pretrained(
70
  Config.INSTANTID_REPO,
71
  subfolder="ControlNetModel",
72
  torch_dtype=Config.DTYPE
73
  )
74
- cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
75
- cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  print("Wrapping ControlNets in MultiControlNetModel...")
78
- controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
79
  controlnet = MultiControlNetModel(controlnet_list)
80
 
81
- # --- ADDED: Load FP16 Fixed VAE ---
82
- # This prevents NaN errors and black images in SDXL fp16 mode
83
- vae_repo = getattr(Config, "VAE_REPO", "madebyollin/sdxl-vae-fp16-fix")
84
- print(f"Loading VAE ({vae_repo})...")
85
- vae = AutoencoderKL.from_pretrained(
86
- vae_repo,
87
- torch_dtype=Config.DTYPE
88
- )
89
- # ----------------------------------
90
-
91
- # 3. Load SDXL Pipeline (Now from 'reality.safetensors')
92
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
93
 
94
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
@@ -104,7 +135,6 @@ class ModelHandler:
104
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
105
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
106
  checkpoint_local_path,
107
- vae=vae, # <-- ADDED: Inject the fixed VAE here
108
  controlnet=controlnet,
109
  torch_dtype=Config.DTYPE,
110
  use_safetensors=True
@@ -118,15 +148,15 @@ class ModelHandler:
118
  except Exception as e:
119
  print(f" [WARNING] Failed to enable xFormers: {e}")
120
 
121
- # 4. Set TCD Scheduler (Sanitized Config)
122
  print("Configuring TCDScheduler...")
123
  self.pipeline.scheduler = TCDScheduler.from_config(self.pipeline.scheduler.config)
124
- print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
125
 
126
  # 5. Load Adapters
127
  print("Loading Adapters...")
128
 
129
- # 5b. Load and Fuse Style LoRA (lucasart/retroart)
130
  print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
131
  style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
132
  if not os.path.exists(style_lora_path):
@@ -140,7 +170,7 @@ class ModelHandler:
140
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
141
  print(" [OK] Style LoRA fused.")
142
 
143
- # 5c. Load IP-Adapter (for InstantID) - *Must be loaded AFTER fusing*
144
  ip_adapter_filename = "ip-adapter.bin"
145
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
146
  if not os.path.exists(ip_adapter_local_path):
@@ -151,12 +181,19 @@ class ModelHandler:
151
  local_dir_use_symlinks=False
152
  )
153
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
154
- print(" [OK] IP-Adapter loaded.")
155
 
156
- # 7. Load Preprocessors
157
- print("Loading Preprocessors (LeReS, LineArtAnime)...")
158
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
159
- self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
 
 
 
 
 
 
 
160
 
161
  print("--- All models loaded successfully ---")
162
 
@@ -169,8 +206,28 @@ class ModelHandler:
169
  faces = self.app.get(cv2_img)
170
  if len(faces) == 0:
171
  return None
172
- faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
 
 
 
 
173
  return faces[0]
174
  except Exception as e:
175
  print(f"Face embedding extraction failed: {e}")
176
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
+ TCDScheduler,
 
10
  )
11
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
12
 
 
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
+ from controlnet_aux import LeresDetector, LineartAnimeDetector, CannyDetector
19
 
20
  class ModelHandler:
21
  def __init__(self):
22
  self.pipeline = None
23
+ self.app = None # InsightFace
24
  self.leres_detector = None
25
  self.lineart_anime_detector = None
26
+ self.canny_detector = None
27
  self.face_analysis_loaded = False
28
+ self.edge_type = Config.DEFAULT_EDGE_TYPE
29
 
30
  def load_face_analysis(self):
31
  """
 
41
  try:
42
  snapshot_download(
43
  repo_id=Config.ANTELOPEV2_REPO,
44
+ local_dir=model_path,
45
  )
46
  except Exception as e:
47
  print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
 
61
  print(f" [WARNING] Face detection system failed to initialize: {e}")
62
  return False
63
 
64
+ def load_models(self, edge_type="canny"):
65
+ """
66
+ Load all models with support for different edge detection types.
67
+
68
+ Args:
69
+ edge_type: "canny", "lineart", or "both"
70
+ """
71
+ self.edge_type = edge_type
72
+
73
  # 1. Load Face Analysis
74
  self.face_analysis_loaded = self.load_face_analysis()
75
 
76
+ # 2. Load ControlNets based on edge_type
77
+ print(f"Loading ControlNets (InstantID, Zoe, {edge_type.upper()})...")
78
  cn_instantid = ControlNetModel.from_pretrained(
79
  Config.INSTANTID_REPO,
80
  subfolder="ControlNetModel",
81
  torch_dtype=Config.DTYPE
82
  )
83
+ cn_zoe = ControlNetModel.from_pretrained(
84
+ Config.CN_ZOE_REPO,
85
+ torch_dtype=Config.DTYPE
86
+ )
87
+
88
+ # Load edge ControlNet(s)
89
+ controlnet_list = [cn_instantid, cn_zoe]
90
+
91
+ if edge_type == "canny":
92
+ cn_canny = ControlNetModel.from_pretrained(
93
+ Config.CN_CANNY_REPO,
94
+ torch_dtype=Config.DTYPE
95
+ )
96
+ controlnet_list.append(cn_canny)
97
+ print(" [OK] Loaded Canny ControlNet")
98
+
99
+ elif edge_type == "lineart":
100
+ cn_lineart = ControlNetModel.from_pretrained(
101
+ Config.CN_LINEART_REPO,
102
+ torch_dtype=Config.DTYPE
103
+ )
104
+ controlnet_list.append(cn_lineart)
105
+ print(" [OK] Loaded LineArt ControlNet")
106
+
107
+ elif edge_type == "both":
108
+ cn_canny = ControlNetModel.from_pretrained(
109
+ Config.CN_CANNY_REPO,
110
+ torch_dtype=Config.DTYPE
111
+ )
112
+ cn_lineart = ControlNetModel.from_pretrained(
113
+ Config.CN_LINEART_REPO,
114
+ torch_dtype=Config.DTYPE
115
+ )
116
+ controlnet_list.extend([cn_canny, cn_lineart])
117
+ print(" [OK] Loaded both Canny and LineArt ControlNets")
118
 
119
  print("Wrapping ControlNets in MultiControlNetModel...")
 
120
  controlnet = MultiControlNetModel(controlnet_list)
121
 
122
+ # 3. Load SDXL Pipeline
 
 
 
 
 
 
 
 
 
 
123
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
124
 
125
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
 
135
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
136
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
137
  checkpoint_local_path,
 
138
  controlnet=controlnet,
139
  torch_dtype=Config.DTYPE,
140
  use_safetensors=True
 
148
  except Exception as e:
149
  print(f" [WARNING] Failed to enable xFormers: {e}")
150
 
151
+ # 4. Set TCD Scheduler
152
  print("Configuring TCDScheduler...")
153
  self.pipeline.scheduler = TCDScheduler.from_config(self.pipeline.scheduler.config)
154
+ print(" [OK] TCDScheduler loaded.")
155
 
156
  # 5. Load Adapters
157
  print("Loading Adapters...")
158
 
159
+ # 5a. Load and Fuse Style LoRA
160
  print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
161
  style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
162
  if not os.path.exists(style_lora_path):
 
170
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
171
  print(" [OK] Style LoRA fused.")
172
 
173
+ # 5b. Load IP-Adapter for InstantID
174
  ip_adapter_filename = "ip-adapter.bin"
175
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
176
  if not os.path.exists(ip_adapter_local_path):
 
181
  local_dir_use_symlinks=False
182
  )
183
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
184
+ print(" [OK] InstantID IP-Adapter loaded.")
185
 
186
+ # 6. Load Preprocessors
187
+ print("Loading Preprocessors...")
188
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
189
+
190
+ if edge_type in ["canny", "both"]:
191
+ self.canny_detector = CannyDetector()
192
+ print(" [OK] Canny detector loaded")
193
+
194
+ if edge_type in ["lineart", "both"]:
195
+ self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
196
+ print(" [OK] LineArt detector loaded")
197
 
198
  print("--- All models loaded successfully ---")
199
 
 
206
  faces = self.app.get(cv2_img)
207
  if len(faces) == 0:
208
  return None
209
+ faces = sorted(
210
+ faces,
211
+ key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]),
212
+ reverse=True
213
+ )
214
  return faces[0]
215
  except Exception as e:
216
  print(f"Face embedding extraction failed: {e}")
217
+ return None
218
+
219
+ def extract_depth(self, image):
220
+ """Extract depth map using LeReS detector"""
221
+ return self.leres_detector(image)
222
+
223
+ def extract_canny(self, image, low_threshold=100, high_threshold=200):
224
+ """Extract Canny edges"""
225
+ if self.canny_detector is None:
226
+ raise ValueError("Canny detector not loaded. Initialize with edge_type='canny' or 'both'")
227
+ return self.canny_detector(image, low_threshold=low_threshold, high_threshold=high_threshold)
228
+
229
+ def extract_lineart(self, image):
230
+ """Extract LineArt edges"""
231
+ if self.lineart_anime_detector is None:
232
+ raise ValueError("LineArt detector not loaded. Initialize with edge_type='lineart' or 'both'")
233
+ return self.lineart_anime_detector(image)