Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
|
@@ -15,8 +15,7 @@ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInst
|
|
| 15 |
|
| 16 |
from huggingface_hub import snapshot_download, hf_hub_download
|
| 17 |
from insightface.app import FaceAnalysis
|
| 18 |
-
|
| 19 |
-
# --- MODIFIED: Removed ColorDetector ---
|
| 20 |
from controlnet_aux import LeresDetector, LineartAnimeDetector
|
| 21 |
# --- END MODIFIED ---
|
| 22 |
|
|
@@ -24,10 +23,8 @@ class ModelHandler:
|
|
| 24 |
def __init__(self):
|
| 25 |
self.pipeline = None
|
| 26 |
self.app = None # InsightFace
|
| 27 |
-
# --- MODIFIED: Removed color_detector ---
|
| 28 |
self.leres_detector = None
|
| 29 |
self.lineart_anime_detector = None
|
| 30 |
-
# --- END MODIFIED ---
|
| 31 |
self.face_analysis_loaded = False
|
| 32 |
|
| 33 |
def load_face_analysis(self):
|
|
@@ -38,8 +35,6 @@ class ModelHandler:
|
|
| 38 |
"""
|
| 39 |
print("Loading face analysis model...")
|
| 40 |
|
| 41 |
-
# insightface expects models in '{root}/models/{name}'
|
| 42 |
-
# Since our root='.' and name='antelopev2', the expected path is './models/antelopev2'
|
| 43 |
model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
|
| 44 |
|
| 45 |
if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
|
|
@@ -54,11 +49,10 @@ class ModelHandler:
|
|
| 54 |
return False
|
| 55 |
|
| 56 |
try:
|
| 57 |
-
# Initialize with root='.' and name='antelopev2'
|
| 58 |
self.app = FaceAnalysis(
|
| 59 |
name=Config.ANTELOPEV2_NAME,
|
| 60 |
root=Config.ANTELOPEV2_ROOT,
|
| 61 |
-
providers=['
|
| 62 |
)
|
| 63 |
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
| 64 |
print(f" [OK] Face analysis model loaded successfully.")
|
|
@@ -73,33 +67,36 @@ class ModelHandler:
|
|
| 73 |
self.face_analysis_loaded = self.load_face_analysis()
|
| 74 |
|
| 75 |
# 2. Load ControlNets
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
# Load the InstantID ControlNet from the correct subfolder
|
| 79 |
print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
|
| 80 |
cn_instantid = ControlNetModel.from_pretrained(
|
| 81 |
-
Config.INSTANTID_REPO,
|
| 82 |
-
subfolder="ControlNetModel",
|
| 83 |
torch_dtype=Config.DTYPE
|
| 84 |
)
|
| 85 |
print(" [OK] Loaded InstantID ControlNet.")
|
| 86 |
|
| 87 |
# Load other ControlNets normally
|
| 88 |
-
|
|
|
|
| 89 |
cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
|
| 90 |
cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
|
| 91 |
-
|
|
|
|
| 92 |
|
| 93 |
# --- Manually wrap the list of models in a MultiControlNetModel ---
|
| 94 |
print("Wrapping ControlNets in MultiControlNetModel...")
|
| 95 |
-
|
|
|
|
| 96 |
controlnet = MultiControlNetModel(controlnet_list)
|
| 97 |
# --- End wrapping ---
|
| 98 |
|
| 99 |
# 3. Load SDXL Pipeline
|
| 100 |
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
|
| 101 |
|
| 102 |
-
# Manually download the checkpoint file first.
|
| 103 |
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
|
| 104 |
if not os.path.exists(checkpoint_local_path):
|
| 105 |
print(f"Downloading checkpoint to {checkpoint_local_path}...")
|
|
@@ -110,11 +107,10 @@ class ModelHandler:
|
|
| 110 |
local_dir_use_symlinks=False
|
| 111 |
)
|
| 112 |
|
| 113 |
-
# Use the custom Img2Img pipeline class you provided, loading from the LOCAL FILE
|
| 114 |
print(f"Loading pipeline from local file: {checkpoint_local_path}")
|
| 115 |
self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
|
| 116 |
-
checkpoint_local_path,
|
| 117 |
-
controlnet=controlnet,
|
| 118 |
torch_dtype=Config.DTYPE,
|
| 119 |
use_safetensors=True
|
| 120 |
)
|
|
@@ -135,7 +131,6 @@ class ModelHandler:
|
|
| 135 |
# 5. Load Adapters (IP-Adapter & LoRA)
|
| 136 |
print("Loading Adapters (IP-Adapter & LoRA)...")
|
| 137 |
|
| 138 |
-
# Download the ip-adapter.bin file and pass its local path
|
| 139 |
ip_adapter_filename = "ip-adapter.bin"
|
| 140 |
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
|
| 141 |
|
|
@@ -149,12 +144,11 @@ class ModelHandler:
|
|
| 149 |
)
|
| 150 |
|
| 151 |
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
|
| 152 |
-
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
|
| 153 |
|
| 154 |
print("Loading LoRA weights...")
|
| 155 |
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
|
| 156 |
|
| 157 |
-
# --- NEW: Fuse LoRA at build time with fixed strength ---
|
| 158 |
print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
|
| 159 |
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
|
| 160 |
print(" [OK] LoRA fused.")
|
|
@@ -174,17 +168,14 @@ class ModelHandler:
|
|
| 174 |
return None
|
| 175 |
|
| 176 |
try:
|
| 177 |
-
# Convert PIL to CV2
|
| 178 |
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 179 |
faces = self.app.get(cv2_img)
|
| 180 |
|
| 181 |
if len(faces) == 0:
|
| 182 |
return None
|
| 183 |
|
| 184 |
-
# Sort by size (width * height) to find the main character
|
| 185 |
faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
|
| 186 |
|
| 187 |
-
# Return the largest face
|
| 188 |
return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
|
| 189 |
except Exception as e:
|
| 190 |
print(f"Face embedding extraction failed: {e}")
|
|
|
|
| 15 |
|
| 16 |
from huggingface_hub import snapshot_download, hf_hub_download
|
| 17 |
from insightface.app import FaceAnalysis
|
| 18 |
+
# --- MODIFIED: Removed ColorDetector import ---
|
|
|
|
| 19 |
from controlnet_aux import LeresDetector, LineartAnimeDetector
|
| 20 |
# --- END MODIFIED ---
|
| 21 |
|
|
|
|
| 23 |
def __init__(self):
|
| 24 |
self.pipeline = None
|
| 25 |
self.app = None # InsightFace
|
|
|
|
| 26 |
self.leres_detector = None
|
| 27 |
self.lineart_anime_detector = None
|
|
|
|
| 28 |
self.face_analysis_loaded = False
|
| 29 |
|
| 30 |
def load_face_analysis(self):
|
|
|
|
| 35 |
"""
|
| 36 |
print("Loading face analysis model...")
|
| 37 |
|
|
|
|
|
|
|
| 38 |
model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
|
| 39 |
|
| 40 |
if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
|
|
|
|
| 49 |
return False
|
| 50 |
|
| 51 |
try:
|
|
|
|
| 52 |
self.app = FaceAnalysis(
|
| 53 |
name=Config.ANTELOPEV2_NAME,
|
| 54 |
root=Config.ANTELOPEV2_ROOT,
|
| 55 |
+
providers=['CPUExecutionProvider']
|
| 56 |
)
|
| 57 |
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
| 58 |
print(f" [OK] Face analysis model loaded successfully.")
|
|
|
|
| 67 |
self.face_analysis_loaded = self.load_face_analysis()
|
| 68 |
|
| 69 |
# 2. Load ControlNets
|
| 70 |
+
# --- MODIFIED: Updated print ---
|
| 71 |
+
print("Loading ControlNets (InstantID, Zoe, LineArt, Tile)...")
|
| 72 |
|
| 73 |
# Load the InstantID ControlNet from the correct subfolder
|
| 74 |
print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
|
| 75 |
cn_instantid = ControlNetModel.from_pretrained(
|
| 76 |
+
Config.INSTANTID_REPO,
|
| 77 |
+
subfolder="ControlNetModel",
|
| 78 |
torch_dtype=Config.DTYPE
|
| 79 |
)
|
| 80 |
print(" [OK] Loaded InstantID ControlNet.")
|
| 81 |
|
| 82 |
# Load other ControlNets normally
|
| 83 |
+
# --- MODIFIED: Load Tile CN ---
|
| 84 |
+
print("Loading Zoe, LineArt, and Tile ControlNets...")
|
| 85 |
cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
|
| 86 |
cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
|
| 87 |
+
cn_tile = ControlNetModel.from_pretrained(Config.CN_TILE_REPO, torch_dtype=Config.DTYPE)
|
| 88 |
+
# --- END MODIFIED ---
|
| 89 |
|
| 90 |
# --- Manually wrap the list of models in a MultiControlNetModel ---
|
| 91 |
print("Wrapping ControlNets in MultiControlNetModel...")
|
| 92 |
+
# --- MODIFIED: Add Tile CN to list ---
|
| 93 |
+
controlnet_list = [cn_instantid, cn_zoe, cn_lineart, cn_tile]
|
| 94 |
controlnet = MultiControlNetModel(controlnet_list)
|
| 95 |
# --- End wrapping ---
|
| 96 |
|
| 97 |
# 3. Load SDXL Pipeline
|
| 98 |
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
|
| 99 |
|
|
|
|
| 100 |
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
|
| 101 |
if not os.path.exists(checkpoint_local_path):
|
| 102 |
print(f"Downloading checkpoint to {checkpoint_local_path}...")
|
|
|
|
| 107 |
local_dir_use_symlinks=False
|
| 108 |
)
|
| 109 |
|
|
|
|
| 110 |
print(f"Loading pipeline from local file: {checkpoint_local_path}")
|
| 111 |
self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
|
| 112 |
+
checkpoint_local_path,
|
| 113 |
+
controlnet=controlnet,
|
| 114 |
torch_dtype=Config.DTYPE,
|
| 115 |
use_safetensors=True
|
| 116 |
)
|
|
|
|
| 131 |
# 5. Load Adapters (IP-Adapter & LoRA)
|
| 132 |
print("Loading Adapters (IP-Adapter & LoRA)...")
|
| 133 |
|
|
|
|
| 134 |
ip_adapter_filename = "ip-adapter.bin"
|
| 135 |
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
|
| 136 |
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
|
| 147 |
+
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
|
| 148 |
|
| 149 |
print("Loading LoRA weights...")
|
| 150 |
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
|
| 151 |
|
|
|
|
| 152 |
print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
|
| 153 |
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
|
| 154 |
print(" [OK] LoRA fused.")
|
|
|
|
| 168 |
return None
|
| 169 |
|
| 170 |
try:
|
|
|
|
| 171 |
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 172 |
faces = self.app.get(cv2_img)
|
| 173 |
|
| 174 |
if len(faces) == 0:
|
| 175 |
return None
|
| 176 |
|
|
|
|
| 177 |
faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
|
| 178 |
|
|
|
|
| 179 |
return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
|
| 180 |
except Exception as e:
|
| 181 |
print(f"Face embedding extraction failed: {e}")
|