primerz commited on
Commit
65a7aea
·
verified ·
1 Parent(s): 050e6f9

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +15 -24
model.py CHANGED
@@ -15,8 +15,7 @@ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInst
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
-
19
- # --- MODIFIED: Removed ColorDetector ---
20
  from controlnet_aux import LeresDetector, LineartAnimeDetector
21
  # --- END MODIFIED ---
22
 
@@ -24,10 +23,8 @@ class ModelHandler:
24
  def __init__(self):
25
  self.pipeline = None
26
  self.app = None # InsightFace
27
- # --- MODIFIED: Removed color_detector ---
28
  self.leres_detector = None
29
  self.lineart_anime_detector = None
30
- # --- END MODIFIED ---
31
  self.face_analysis_loaded = False
32
 
33
  def load_face_analysis(self):
@@ -38,8 +35,6 @@ class ModelHandler:
38
  """
39
  print("Loading face analysis model...")
40
 
41
- # insightface expects models in '{root}/models/{name}'
42
- # Since our root='.' and name='antelopev2', the expected path is './models/antelopev2'
43
  model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
44
 
45
  if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
@@ -54,11 +49,10 @@ class ModelHandler:
54
  return False
55
 
56
  try:
57
- # Initialize with root='.' and name='antelopev2'
58
  self.app = FaceAnalysis(
59
  name=Config.ANTELOPEV2_NAME,
60
  root=Config.ANTELOPEV2_ROOT,
61
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
62
  )
63
  self.app.prepare(ctx_id=0, det_size=(640, 640))
64
  print(f" [OK] Face analysis model loaded successfully.")
@@ -73,33 +67,36 @@ class ModelHandler:
73
  self.face_analysis_loaded = self.load_face_analysis()
74
 
75
  # 2. Load ControlNets
76
- print("Loading ControlNets (InstantID, Zoe, LineArt, Color)...")
 
77
 
78
  # Load the InstantID ControlNet from the correct subfolder
79
  print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
80
  cn_instantid = ControlNetModel.from_pretrained(
81
- Config.INSTANTID_REPO, # "InstantX/InstantID"
82
- subfolder="ControlNetModel", # Correct casing
83
  torch_dtype=Config.DTYPE
84
  )
85
  print(" [OK] Loaded InstantID ControlNet.")
86
 
87
  # Load other ControlNets normally
88
- print("Loading Zoe, LineArt, and Color ControlNets...")
 
89
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
90
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
91
- cn_color = ControlNetModel.from_pretrained(Config.CN_COLOR_REPO, torch_dtype=Config.DTYPE) # This is still correct
 
92
 
93
  # --- Manually wrap the list of models in a MultiControlNetModel ---
94
  print("Wrapping ControlNets in MultiControlNetModel...")
95
- controlnet_list = [cn_instantid, cn_zoe, cn_lineart, cn_color] # This is still correct
 
96
  controlnet = MultiControlNetModel(controlnet_list)
97
  # --- End wrapping ---
98
 
99
  # 3. Load SDXL Pipeline
100
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
101
 
102
- # Manually download the checkpoint file first.
103
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
104
  if not os.path.exists(checkpoint_local_path):
105
  print(f"Downloading checkpoint to {checkpoint_local_path}...")
@@ -110,11 +107,10 @@ class ModelHandler:
110
  local_dir_use_symlinks=False
111
  )
112
 
113
- # Use the custom Img2Img pipeline class you provided, loading from the LOCAL FILE
114
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
115
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
116
- checkpoint_local_path, # Pass the local path
117
- controlnet=controlnet, # Pass the single, wrapped object
118
  torch_dtype=Config.DTYPE,
119
  use_safetensors=True
120
  )
@@ -135,7 +131,6 @@ class ModelHandler:
135
  # 5. Load Adapters (IP-Adapter & LoRA)
136
  print("Loading Adapters (IP-Adapter & LoRA)...")
137
 
138
- # Download the ip-adapter.bin file and pass its local path
139
  ip_adapter_filename = "ip-adapter.bin"
140
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
141
 
@@ -149,12 +144,11 @@ class ModelHandler:
149
  )
150
 
151
  print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
152
- self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path) # Pass local path
153
 
154
  print("Loading LoRA weights...")
155
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
156
 
157
- # --- NEW: Fuse LoRA at build time with fixed strength ---
158
  print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
159
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
160
  print(" [OK] LoRA fused.")
@@ -174,17 +168,14 @@ class ModelHandler:
174
  return None
175
 
176
  try:
177
- # Convert PIL to CV2
178
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
179
  faces = self.app.get(cv2_img)
180
 
181
  if len(faces) == 0:
182
  return None
183
 
184
- # Sort by size (width * height) to find the main character
185
  faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
186
 
187
- # Return the largest face
188
  return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
189
  except Exception as e:
190
  print(f"Face embedding extraction failed: {e}")
 
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
+ # --- MODIFIED: Removed ColorDetector import ---
 
19
  from controlnet_aux import LeresDetector, LineartAnimeDetector
20
  # --- END MODIFIED ---
21
 
 
23
  def __init__(self):
24
  self.pipeline = None
25
  self.app = None # InsightFace
 
26
  self.leres_detector = None
27
  self.lineart_anime_detector = None
 
28
  self.face_analysis_loaded = False
29
 
30
  def load_face_analysis(self):
 
35
  """
36
  print("Loading face analysis model...")
37
 
 
 
38
  model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
39
 
40
  if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
 
49
  return False
50
 
51
  try:
 
52
  self.app = FaceAnalysis(
53
  name=Config.ANTELOPEV2_NAME,
54
  root=Config.ANTELOPEV2_ROOT,
55
+ providers=['CPUExecutionProvider']
56
  )
57
  self.app.prepare(ctx_id=0, det_size=(640, 640))
58
  print(f" [OK] Face analysis model loaded successfully.")
 
67
  self.face_analysis_loaded = self.load_face_analysis()
68
 
69
  # 2. Load ControlNets
70
+ # --- MODIFIED: Updated print ---
71
+ print("Loading ControlNets (InstantID, Zoe, LineArt, Tile)...")
72
 
73
  # Load the InstantID ControlNet from the correct subfolder
74
  print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
75
  cn_instantid = ControlNetModel.from_pretrained(
76
+ Config.INSTANTID_REPO,
77
+ subfolder="ControlNetModel",
78
  torch_dtype=Config.DTYPE
79
  )
80
  print(" [OK] Loaded InstantID ControlNet.")
81
 
82
  # Load other ControlNets normally
83
+ # --- MODIFIED: Load Tile CN ---
84
+ print("Loading Zoe, LineArt, and Tile ControlNets...")
85
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
86
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
87
+ cn_tile = ControlNetModel.from_pretrained(Config.CN_TILE_REPO, torch_dtype=Config.DTYPE)
88
+ # --- END MODIFIED ---
89
 
90
  # --- Manually wrap the list of models in a MultiControlNetModel ---
91
  print("Wrapping ControlNets in MultiControlNetModel...")
92
+ # --- MODIFIED: Add Tile CN to list ---
93
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart, cn_tile]
94
  controlnet = MultiControlNetModel(controlnet_list)
95
  # --- End wrapping ---
96
 
97
  # 3. Load SDXL Pipeline
98
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
99
 
 
100
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
101
  if not os.path.exists(checkpoint_local_path):
102
  print(f"Downloading checkpoint to {checkpoint_local_path}...")
 
107
  local_dir_use_symlinks=False
108
  )
109
 
 
110
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
111
  self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
112
+ checkpoint_local_path,
113
+ controlnet=controlnet,
114
  torch_dtype=Config.DTYPE,
115
  use_safetensors=True
116
  )
 
131
  # 5. Load Adapters (IP-Adapter & LoRA)
132
  print("Loading Adapters (IP-Adapter & LoRA)...")
133
 
 
134
  ip_adapter_filename = "ip-adapter.bin"
135
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
136
 
 
144
  )
145
 
146
  print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
147
+ self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
148
 
149
  print("Loading LoRA weights...")
150
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
151
 
 
152
  print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
153
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
154
  print(" [OK] LoRA fused.")
 
168
  return None
169
 
170
  try:
 
171
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
172
  faces = self.app.get(cv2_img)
173
 
174
  if len(faces) == 0:
175
  return None
176
 
 
177
  faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
178
 
 
179
  return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
180
  except Exception as e:
181
  print(f"Face embedding extraction failed: {e}")