File size: 13,060 Bytes
9cb49df 1d343f6 9cb49df 7829fea 1d343f6 7829fea 42e4f65 7829fea 9cb49df 7829fea 719b317 9cb49df 42e4f65 0da1dec 22a4eb5 9cb49df 22a4eb5 ae2593e 0da1dec ae2593e 1d343f6 ae2593e 22a4eb5 ae2593e 7829fea 22a4eb5 ae2593e 1d343f6 fc9e0f9 ae2593e 3a7fbf1 1d343f6 ae2593e 9cb49df 7829fea 089483f 0da1dec 089483f 9cb49df ad5043a 9cb49df 7829fea ad5043a 7829fea 0da1dec 7829fea fc9e0f9 7829fea 22a4eb5 b42cd10 b5917b6 38fd38e b42cd10 38fd38e b42cd10 b5917b6 38fd38e b42cd10 b5917b6 b42cd10 b5917b6 b42cd10 38fd38e b42cd10 38fd38e b42cd10 38fd38e b42cd10 38fd38e b5917b6 38fd38e b5917b6 b42cd10 38fd38e b42cd10 38fd38e b42cd10 b5917b6 38fd38e b42cd10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
# app.py — InstantID SDXL (officiel) + IP-Adapter Style (optionnel, rendu 2D)
import os, sys
os.environ["OMP_NUM_THREADS"] = "4"
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
sys.path.insert(0, os.path.abspath("./instantid"))
import traceback, importlib.util
import torch, gradio as gr
from PIL import Image, ImageOps, ImageDraw
from huggingface_hub import hf_hub_download
from diffusers.models import ControlNetModel
from insightface.app import FaceAnalysis
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
ASSETS_REPO = "InstantX/InstantID"
CHECKPOINTS_DIR = "./checkpoints"
CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
IP_STYLE_REPO = "h94/IP-Adapter"
IP_STYLE_SUBFOLDER = "sdxl_models"
IP_STYLE_WEIGHT = "ip-adapter_sdxl.bin"
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
def safe_download(repo, filename, local_dir, min_bytes, label, subfolder=None):
os.makedirs(local_dir, exist_ok=True)
local_path = os.path.join(local_dir, os.path.basename(filename))
if os.path.exists(local_path) and os.path.getsize(local_path) < min_bytes:
try: os.remove(local_path)
except Exception: pass
path = hf_hub_download(
repo_id=repo,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
resume_download=True,
force_download=not os.path.exists(local_path),
subfolder=subfolder,
)
size = os.path.getsize(path)
if size < min_bytes:
raise RuntimeError(f"Téléchargement incomplet de {label} (taille: {size} bytes).")
print(f"✅ {label} téléchargé ({size/1e6:.1f} MB)")
return path
def ensure_assets_or_download():
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(CN_LOCAL_DIR, exist_ok=True)
safe_download(ASSETS_REPO, "ControlNetModel/config.json", CHECKPOINTS_DIR, 1_000, "IdentityNet config")
safe_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", CHECKPOINTS_DIR, 100_000_000, "IdentityNet weights")
safe_download(ASSETS_REPO, "ip-adapter.bin", CHECKPOINTS_DIR, 100_000_000, "IP-Adapter (InstantID)")
safe_download(IP_STYLE_REPO, IP_STYLE_WEIGHT, CHECKPOINTS_DIR, 20_000_000, "IP-Adapter Style (SDXL)", subfolder=IP_STYLE_SUBFOLDER)
def import_pipeline_or_fail():
candidates = [
"./instantid/pipeline_stable_diffusion_xl_instantid_full.py",
"./instantid/pipeline_stable_diffusion_xl_instantid.py",
]
pipeline_file = next((p for p in candidates if os.path.exists(p)), None)
if pipeline_file is None:
raise RuntimeError("❌ Pipeline manquante. Place `pipeline_stable_diffusion_xl_instantid_full.py` dans ./instantid/")
if os.path.getsize(pipeline_file) < 1024:
raise RuntimeError("❌ Pipeline trop petite (vide ?). Utilise la version SDXL officielle.")
spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
for name, obj in vars(mod).items():
if isinstance(obj, type) and "InstantID" in name and hasattr(obj, "from_pretrained"):
print(f"✅ Pipeline trouvée : {name}")
return obj
avail = [n for n, o in vars(mod).items() if isinstance(o, type)]
raise RuntimeError("❌ Aucune classe pipeline InstantID trouvée. Classes dispo: " + ", ".join(avail))
def draw_kps_local(img_pil, kps):
w, h = img_pil.size
out = Image.new("RGB", (w, h), "white")
d = ImageDraw.Draw(out)
r = max(2, min(w, h)//100)
for (x, y) in kps:
d.ellipse((x - r, y - r, x + r, y + r), fill="black")
return out
load_logs = []
HAS_STYLE_ADAPTER = False
try:
SDXLInstantID = import_pipeline_or_fail()
ensure_assets_or_download()
controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
pipe = SDXLInstantID.from_pretrained(
BASE_MODEL,
controlnet=controlnet_identitynet,
torch_dtype=DTYPE,
safety_checker=None,
feature_extractor=None,
).to(DEVICE)
pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)
try:
pipe.load_ip_adapter(
IP_STYLE_REPO,
subfolder=IP_STYLE_SUBFOLDER,
weight_name=IP_STYLE_WEIGHT,
adapter_name="style",
)
load_logs.append("✅ IP-Adapter Style (SDXL) chargé (adapter_name='style').")
HAS_STYLE_ADAPTER = True
except Exception as e:
load_logs.append(f"ℹ️ IP-Adapter Style non chargé: {e}")
if DEVICE == "cuda":
if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
if hasattr(pipe, "unet"): pipe.unet.to("cuda")
load_logs.append("✅ InstantID prêt.")
except Exception:
load_logs += ["❌ ERREUR au chargement:", traceback.format_exc()]
pipe = None
if pipe is None:
raise RuntimeError("Échec de chargement du pipeline.\n" + "\n".join(load_logs))
def load_face_analyser():
errors = []
for name in ("antelopev2", "buffalo_l"):
try:
fa = FaceAnalysis(name=name, root="./models", providers=["CPUExecutionProvider"])
fa.prepare(ctx_id=0, det_size=(640, 640))
print(f"✅ InsightFace chargé: {name}")
return fa
except Exception as e:
errors.append(f"{name}: {e}")
print(f"⚠️ InsightFace échec {name} → {e}")
raise RuntimeError("Echec chargement InsightFace. Détails: " + " | ".join(errors))
fa = load_face_analyser()
def extract_face_embed_and_kps(pil_img):
import numpy as np, cv2
img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
faces = fa.get(img_cv2)
if not faces:
raise ValueError("Aucun visage détecté dans la photo.")
face = faces[-1]
emb_np = face["embedding"]
if not isinstance(emb_np, np.ndarray):
emb_np = np.asarray(emb_np, dtype="float32")
if emb_np.ndim == 1:
emb_np = emb_np[None, ...] # (1, D)
face_emb = torch.from_numpy(emb_np).to(device=DEVICE, dtype=DTYPE) # ← Tensor [1,D] sur bon device/dtype
kps_img = draw_kps_local(pil_img, face["kps"])
return face_emb, kps_img
def generate(face_image, style_image, prompt, negative_prompt,
identity_strength, adapter_strength, style_strength,
steps, cfg, width, height, seed):
try:
if face_image is None:
return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)
gen = None if seed is None or int(seed) < 0 else torch.Generator(device=DEVICE).manual_seed(int(seed))
# visage → carré 512 pour détection stable
from PIL import ImageOps
face = ImageOps.exif_transpose(face_image).convert("RGB")
ms = min(face.size); x = (face.width - ms) // 2; y = (face.height - ms) // 2
face_sq = face.crop((x, y, x + ms, y + ms)).resize((512, 512), Image.Resampling.LANCZOS)
# InsightFace : embedding (torch [1,D]) + landmarks
face_emb, kps_img = extract_face_embed_and_kps(face_sq) # face_emb: torch.Tensor [1,D] on DEVICE/DTYPE
# IP-Adapter scales
try:
if HAS_STYLE_ADAPTER and style_image is not None:
pipe.set_ip_adapter_scale({"instantid": float(adapter_strength), "style": float(style_strength)})
else:
pipe.set_ip_adapter_scale(float(adapter_strength))
except Exception as e:
print(f"ℹ️ set_ip_adapter_scale ignoré: {e}")
# compat multi-ControlNet (même si on en a qu’un)
cn = getattr(pipe, "controlnet", None)
if isinstance(cn, (list, tuple)):
n_cn = len(cn)
else:
try: n_cn = len(cn)
except Exception: n_cn = 1
image_arg = [kps_img] * n_cn if n_cn > 1 else ([kps_img] if isinstance(cn, (list, tuple)) else kps_img)
scale_val = float(identity_strength)
scale_arg = [scale_val] * n_cn if n_cn > 1 else ([scale_val] if isinstance(cn, (list, tuple)) else scale_val)
# kwargs d’inférence (on met aussi ici pour compat)
gen_kwargs = dict(
prompt=(prompt or "").strip(),
negative_prompt=(negative_prompt or "").strip(),
image=image_arg,
image_embeds=face_emb, # compat pipeline
added_conditions={"image_embeds": face_emb}, # diffusers ≥ 0.30.x (si propagé)
added_cond_kwargs={"image_embeds": face_emb}, # diffusers 0.29.x (si propagé)
controlnet_conditioning_scale=scale_arg,
num_inference_steps=int(steps),
guidance_scale=float(cfg),
width=int(width),
height=int(height),
generator=gen,
)
if HAS_STYLE_ADAPTER and style_image is not None:
try:
gen_kwargs["ip_adapter_image"] = ImageOps.exif_transpose(style_image).convert("RGB")
except Exception as e:
print(f"ℹ️ ip_adapter_image ignoré: {e}")
# 🔧 MONKEY-PATCH: injecter image_embeds au niveau du UNet.forward
orig_forward = pipe.unet.forward
def forward_patch(*args, **kwargs):
# on fusionne proprement pour n’écraser rien
ac = kwargs.get("added_conditions")
if ac is None:
ac = {}
else:
ac = dict(ac)
ac["image_embeds"] = face_emb
kwargs["added_conditions"] = ac
# compat pour 0.29.x
kwargs["added_cond_kwargs"] = ac
return orig_forward(*args, **kwargs)
pipe.unet.forward = forward_patch
try:
images = pipe(**gen_kwargs).images
finally:
# toujours restaurer le forward d'origine
pipe.unet.forward = orig_forward
return images[0], "", "\n".join(load_logs)
except torch.cuda.OutOfMemoryError:
return None, "CUDA OOM: baisse la résolution ou les steps.", "\n".join(load_logs)
except Exception:
import traceback
return None, "Erreur:\n" + traceback.format_exc(), "\n".join(load_logs)
EX_PROMPT = (
"one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, "
"clean lineart, cel shading, vibrant colors, expressive eyes, dynamic composition, simple background"
)
EX_NEG = (
"realistic, photo, photorealistic, skin pores, complex lighting, "
"low quality, worst quality, lowres, blurry, noisy, watermark, text, logo, jpeg artifacts, "
"bad anatomy, deformed, multiple faces, nsfw"
)
with gr.Blocks(css="footer{display:none !important}") as demo:
gr.Markdown("# 🏴☠️ InstantID SDXL + IP-Adapter Style (2D) — visage → perso One Piece")
with gr.Row():
with gr.Column():
face_image = gr.Image(type="pil", label="Photo visage (obligatoire)", height=260)
style_image = gr.Image(type="pil", label="Image de style (optionnel)", height=260)
gr.Markdown("Astuce : poster/planche One Piece → rendu 2D renforcé via IP-Adapter Style.")
prompt = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
with gr.Row():
identity_strength = gr.Slider(0.2, 1.5, 0.95, 0.05, label="Fidélité visage (IdentityNet)")
adapter_strength = gr.Slider(0.1, 1.5, 0.85, 0.05, label="Détails anime (InstantID)")
style_strength = gr.Slider(0.1, 1.5, 0.95, 0.05, label="Force style (IP-Adapter Style)")
steps = gr.Slider(10, 60, 30, 1, label="Steps")
cfg = gr.Slider(0.1, 12.0, 6.5, 0.1, label="CFG")
width = gr.Dropdown(choices=[576, 640, 704, 768, 896], value=704, label="Largeur")
height = gr.Dropdown(choices=[704, 768, 896, 1024], value=896, label="Hauteur")
seed = gr.Number(value=-1, label="Seed (-1 aléatoire)")
btn = gr.Button("🎨 Générer", variant="primary")
with gr.Column():
out_image = gr.Image(label="Résultat", interactive=False)
err_box = gr.Textbox(label="Erreurs", visible=False)
log_box = gr.Textbox(label="Logs", value="\n".join(load_logs), lines=12)
def wrap(*args):
img, err, logs = generate(*args)
return img, gr.update(visible=bool(err), value=err), gr.update(value=logs)
btn.click(
wrap,
inputs=[face_image, style_image, prompt, negative,
identity_strength, adapter_strength, style_strength,
steps, cfg, width, height, seed],
outputs=[out_image, err_box, log_box],
)
demo.queue(api_open=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
|