primerz commited on
Commit
acd970e
·
verified ·
1 Parent(s): e4dec03

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +15 -15
model.py CHANGED
@@ -5,11 +5,15 @@ import numpy as np
5
  from config import Config
6
 
7
  from diffusers import (
8
- StableDiffusionXLControlNetPipeline,
9
  ControlNetModel,
10
  LCMScheduler,
11
- StableDiffusionXLInstantIDPipeline # <-- ADD THIS IMPORT
12
  )
 
 
 
 
 
13
  from huggingface_hub import snapshot_download
14
  from insightface.app import FaceAnalysis
15
  from controlnet_aux import ZoeDetector, LineartDetector
@@ -29,10 +33,8 @@ class ModelHandler:
29
  Forces CPU to avoid ZeroGPU initialization errors.
30
  """
31
  print("Loading face analysis model...")
32
- # Path will now be './antelopev2'
33
  model_root_path = os.path.join(Config.ANTELOPEV2_ROOT, Config.ANTELOPEV2_NAME)
34
 
35
- # 1. Download from HF Hub
36
  if not os.path.exists(os.path.join(model_root_path, "scrfd_10g_bnkps.onnx")):
37
  print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO}...")
38
  try:
@@ -45,14 +47,12 @@ class ModelHandler:
45
  print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
46
  return False
47
 
48
- # 2. Initialize InsightFace on CPU
49
  try:
50
  self.app = FaceAnalysis(
51
  name=Config.ANTELOPEV2_NAME, # 'antelopev2'
52
  root=Config.ANTELOPEV2_ROOT, # '.'
53
  providers=['CPUExecutionProvider']
54
  )
55
- # This will now correctly look in './antelopev2'
56
  self.app.prepare(ctx_id=0, det_size=(640, 640))
57
  print(f" [OK] Face analysis model loaded successfully.")
58
  return True
@@ -68,9 +68,7 @@ class ModelHandler:
68
  # 2. Load ControlNets
69
  print("Loading ControlNets (InstantID, Zoe, LineArt)...")
70
 
71
- # --- START FIX for config.json error ---
72
- # We must load the InstantID ControlNet by loading its pipeline
73
- # and "stealing" the controlnet component.
74
  print("Loading InstantID pipeline to extract ControlNet...")
75
  temp_pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
76
  Config.INSTANTID_REPO,
@@ -79,7 +77,6 @@ class ModelHandler:
79
  cn_instantid = temp_pipe.controlnet
80
  del temp_pipe # Free memory
81
  print(" [OK] Extracted InstantID ControlNet.")
82
- # --- END FIX ---
83
 
84
  # Load other ControlNets normally
85
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
@@ -87,14 +84,17 @@ class ModelHandler:
87
 
88
  # 3. Load SDXL Pipeline
89
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
90
- self.pipeline = StableDiffusionXLControlNetPipeline.from_single_file(
 
 
 
 
91
  Config.REPO_ID,
92
  filename=Config.CHECKPOINT_FILENAME,
93
- controlnet=[cn_instantid, cn_zoe, cn_lineart], # Pass the *objects*
94
  torch_dtype=Config.DTYPE,
95
  use_safetensors=True
96
  )
97
- # Move pipeline (and all its components) to the device
98
  self.pipeline.to(Config.DEVICE)
99
 
100
  # 4. Set Scheduler
@@ -102,7 +102,7 @@ class ModelHandler:
102
 
103
  # 5. Load Adapters (IP-Adapter & LoRA)
104
  print("Loading Adapters (IP-Adapter & LoRA)...")
105
- # We still load the IP-Adapter separately
106
  self.pipeline.load_ip_adapter_instantid(Config.INSTANTID_REPO)
107
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
108
  self.pipeline.fuse_lora(lora_scale=1.0)
@@ -121,7 +121,7 @@ class ModelHandler:
121
 
122
  try:
123
  # Convert PIL to CV2
124
- cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # <-- Fixed typo
125
  faces = self.app.get(cv2_img)
126
 
127
  if len(faces) == 0:
 
5
  from config import Config
6
 
7
  from diffusers import (
 
8
  ControlNetModel,
9
  LCMScheduler,
10
+ StableDiffusionXLInstantIDPipeline # To "steal" the ControlNet
11
  )
12
+ # --- START FIX ---
13
+ # Import the custom pipeline from your local file
14
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
15
+ # --- END FIX ---
16
+
17
  from huggingface_hub import snapshot_download
18
  from insightface.app import FaceAnalysis
19
  from controlnet_aux import ZoeDetector, LineartDetector
 
33
  Forces CPU to avoid ZeroGPU initialization errors.
34
  """
35
  print("Loading face analysis model...")
 
36
  model_root_path = os.path.join(Config.ANTELOPEV2_ROOT, Config.ANTELOPEV2_NAME)
37
 
 
38
  if not os.path.exists(os.path.join(model_root_path, "scrfd_10g_bnkps.onnx")):
39
  print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO}...")
40
  try:
 
47
  print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
48
  return False
49
 
 
50
  try:
51
  self.app = FaceAnalysis(
52
  name=Config.ANTELOPEV2_NAME, # 'antelopev2'
53
  root=Config.ANTELOPEV2_ROOT, # '.'
54
  providers=['CPUExecutionProvider']
55
  )
 
56
  self.app.prepare(ctx_id=0, det_size=(640, 640))
57
  print(f" [OK] Face analysis model loaded successfully.")
58
  return True
 
68
  # 2. Load ControlNets
69
  print("Loading ControlNets (InstantID, Zoe, LineArt)...")
70
 
71
+ # Load InstantID ControlNet by "stealing" it from the base pipeline
 
 
72
  print("Loading InstantID pipeline to extract ControlNet...")
73
  temp_pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
74
  Config.INSTANTID_REPO,
 
77
  cn_instantid = temp_pipe.controlnet
78
  del temp_pipe # Free memory
79
  print(" [OK] Extracted InstantID ControlNet.")
 
80
 
81
  # Load other ControlNets normally
82
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
 
84
 
85
  # 3. Load SDXL Pipeline
86
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
87
+
88
+ # --- START FIX ---
89
+ # Use the custom Img2Img pipeline class you provided
90
+ self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
91
+ # --- END FIX ---
92
  Config.REPO_ID,
93
  filename=Config.CHECKPOINT_FILENAME,
94
+ controlnet=[cn_instantid, cn_zoe, cn_lineart], # Pass the list of all ControlNets
95
  torch_dtype=Config.DTYPE,
96
  use_safetensors=True
97
  )
 
98
  self.pipeline.to(Config.DEVICE)
99
 
100
  # 4. Set Scheduler
 
102
 
103
  # 5. Load Adapters (IP-Adapter & LoRA)
104
  print("Loading Adapters (IP-Adapter & LoRA)...")
105
+ # The custom pipeline has this method
106
  self.pipeline.load_ip_adapter_instantid(Config.INSTANTID_REPO)
107
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
108
  self.pipeline.fuse_lora(lora_scale=1.0)
 
121
 
122
  try:
123
  # Convert PIL to CV2
124
+ cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
125
  faces = self.app.get(cv2_img)
126
 
127
  if len(faces) == 0: