Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,927 Bytes
2dddf31 6d5987b 2dddf31 f5bcb07 2dddf31 cc0ae1f 61380fb acd970e a92aea4 2dddf31 91e168e 2dddf31 4a72459 6d5987b 948869c 6d5987b 2dddf31 948869c 6d5987b e01d167 4089031 6d5987b 948869c 3ae9ca7 6d5987b 948869c 6d5987b 2dddf31 6d5987b 2dddf31 6d5987b f6d97e6 e01d167 064a79e aecb45b 65a7aea 2dddf31 aecb45b e01d167 4a72459 f6d97e6 4a72459 8333ca9 cc0ae1f 61380fb f6d97e6 61380fb cc0ae1f 76e0564 6d5987b acd970e 4089031 94327de 4089031 acd970e 65a7aea 2dddf31 6d5987b e01d167 4089031 e01d167 2dddf31 545e006 16dc50a 0117fa7 16dc50a 545e006 2d5b51b b96a00f 2d5b51b b96a00f 2d5b51b 545e006 94327de 545e006 94327de 65a7aea 94327de 2d5b51b 545e006 b96a00f f5bcb07 545e006 f5bcb07 545e006 f5bcb07 545e006 b96a00f 545e006 2dddf31 ff641c2 545e006 d8780fa 545e006 d8780fa 6d5987b 91e168e 4a72459 2dddf31 6d5987b e4dd0ff 6d5987b 2dddf31 6d5987b 76e0564 6d5987b e4dd0ff 0117fa7 6d5987b e4dd0ff 6d5987b 0a54687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import torch
import os
import cv2
import numpy as np
from config import Config
from diffusers import (
ControlNetModel,
TCDScheduler,
)
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
# Import the custom pipeline from your local file
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
from huggingface_hub import snapshot_download, hf_hub_download
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector, LineartAnimeDetector
class ModelHandler:
def __init__(self):
self.pipeline = None
self.app = None # InsightFace
self.leres_detector = None
self.lineart_anime_detector = None
self.face_analysis_loaded = False
def load_face_analysis(self):
"""
Load face analysis model.
Downloads from HF Hub to the path insightface expects.
"""
print("Loading face analysis model...")
model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
try:
snapshot_download(
repo_id=Config.ANTELOPEV2_REPO,
local_dir=model_path, # Download to the correct expected path
)
except Exception as e:
print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
return False
try:
self.app = FaceAnalysis(
name=Config.ANTELOPEV2_NAME,
root=Config.ANTELOPEV2_ROOT,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.app.prepare(ctx_id=0, det_size=(640, 640))
print(f" [OK] Face analysis model loaded successfully.")
return True
except Exception as e:
print(f" [WARNING] Face detection system failed to initialize: {e}")
return False
def load_models(self):
# 1. Load Face Analysis
self.face_analysis_loaded = self.load_face_analysis()
# 2. Load ControlNets
print("Loading ControlNets (InstantID, Zoe, LineArt)...")
# Load the InstantID ControlNet from the correct subfolder
print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
cn_instantid = ControlNetModel.from_pretrained(
Config.INSTANTID_REPO,
subfolder="ControlNetModel",
torch_dtype=Config.DTYPE
)
print(" [OK] Loaded InstantID ControlNet.")
# Load other ControlNets normally
print("Loading Zoe and LineArt ControlNets...")
cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
# --- Manually wrap the list of models in a MultiControlNetModel ---
print("Wrapping ControlNets in MultiControlNetModel...")
controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
controlnet = MultiControlNetModel(controlnet_list)
# --- End wrapping ---
# 3. Load SDXL Pipeline
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
if not os.path.exists(checkpoint_local_path):
print(f"Downloading checkpoint to {checkpoint_local_path}...")
hf_hub_download(
repo_id=Config.REPO_ID,
filename=Config.CHECKPOINT_FILENAME,
local_dir="./models",
local_dir_use_symlinks=False
)
print(f"Loading pipeline from local file: {checkpoint_local_path}")
self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
checkpoint_local_path,
controlnet=controlnet,
torch_dtype=Config.DTYPE,
use_safetensors=True
)
self.pipeline.to(Config.DEVICE)
# Enable xFormers
try:
self.pipeline.enable_xformers_memory_efficient_attention()
print(" [OK] xFormers memory efficient attention enabled.")
except Exception as e:
print(f" [WARNING] Failed to enable xFormers: {e}")
# 4. Set TCD Scheduler
print("Configuring TCDScheduler...")
# --- FIX: Set timestep_spacing="trailing" for proper distilled sampling ---
self.pipeline.scheduler = TCDScheduler.from_config(
self.pipeline.scheduler.config,
use_karras_sigmas=True,
timestep_spacing="trailing"
)
print(" [OK] TCDScheduler loaded (Karras + Trailing Spacing).")
# 5. Load Adapters (IP-Adapter, TCD-LoRA & Style LoRA)
print("Loading Adapters...")
# 5a. IP-Adapter
ip_adapter_filename = "ip-adapter.bin"
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
if not os.path.exists(ip_adapter_local_path):
print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
hf_hub_download(
repo_id=Config.INSTANTID_REPO,
filename=ip_adapter_filename,
local_dir="./models",
local_dir_use_symlinks=False
)
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
# 5b. Load TCD LoRA (Correct filename)
print("Loading TCD-SDXL-LoRA...")
tcd_lora_filename = "pytorch_lora_weights.safetensors"
tcd_lora_path = os.path.join("./models", tcd_lora_filename)
if not os.path.exists(tcd_lora_path):
hf_hub_download(
repo_id="h1t/TCD-SDXL-LoRA",
filename=tcd_lora_filename,
local_dir="./models",
local_dir_use_symlinks=False
)
self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename)
self.pipeline.fuse_lora(lora_scale=1.0)
print(" [OK] TCD LoRA fused.")
# 5c. Load Style LoRA
print("Loading Style LoRA weights...")
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
print(f"Fusing Style LoRA with scale {Config.LORA_STRENGTH}...")
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
print(" [OK] Style LoRA fused.")
# 6. Load Preprocessors
print("Loading Preprocessors (LeReS, LineArtAnime)...")
self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
print("--- All models loaded successfully ---")
def get_face_info(self, image):
"""Extracts the largest face, returns insightface result object."""
if not self.face_analysis_loaded:
return None
try:
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
faces = self.app.get(cv2_img)
if len(faces) == 0:
return None
# Sort by size (width * height) to find the main character
faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
# Return the largest face info
return faces[0]
except Exception as e:
print(f"Face embedding extraction failed: {e}")
return None |