File size: 13,060 Bytes
9cb49df
1d343f6
 
9cb49df
7829fea
 
 
1d343f6
7829fea
 
42e4f65
7829fea
9cb49df
 
7829fea
 
719b317
9cb49df
42e4f65
 
 
 
0da1dec
 
 
22a4eb5
9cb49df
 
22a4eb5
ae2593e
0da1dec
ae2593e
1d343f6
 
ae2593e
 
 
 
22a4eb5
ae2593e
7829fea
22a4eb5
ae2593e
 
 
1d343f6
fc9e0f9
ae2593e
3a7fbf1
 
 
 
1d343f6
ae2593e
9cb49df
 
7829fea
 
089483f
 
0da1dec
089483f
 
 
9cb49df
ad5043a
9cb49df
7829fea
 
 
 
 
 
 
ad5043a
 
7829fea
 
 
 
 
 
 
 
 
 
 
0da1dec
7829fea
 
 
 
 
 
 
fc9e0f9
7829fea
 
 
 
 
 
22a4eb5
 
 
 
 
 
 
 
b42cd10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5917b6
38fd38e
b42cd10
38fd38e
b42cd10
 
b5917b6
38fd38e
b42cd10
b5917b6
b42cd10
 
 
 
 
 
 
 
b5917b6
b42cd10
 
 
 
38fd38e
 
b42cd10
38fd38e
b42cd10
38fd38e
b42cd10
38fd38e
b5917b6
 
 
 
38fd38e
 
 
b5917b6
 
 
 
 
 
 
b42cd10
 
 
 
 
 
38fd38e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b42cd10
 
 
 
 
38fd38e
b42cd10
 
b5917b6
38fd38e
b42cd10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# app.py — InstantID SDXL (officiel) + IP-Adapter Style (optionnel, rendu 2D)

import os, sys
os.environ["OMP_NUM_THREADS"] = "4"
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
sys.path.insert(0, os.path.abspath("./instantid"))

import traceback, importlib.util
import torch, gradio as gr
from PIL import Image, ImageOps, ImageDraw
from huggingface_hub import hf_hub_download
from diffusers.models import ControlNetModel
from insightface.app import FaceAnalysis

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE == "cuda" else torch.float32

ASSETS_REPO = "InstantX/InstantID"
CHECKPOINTS_DIR  = "./checkpoints"
CN_LOCAL_DIR     = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")

IP_STYLE_REPO      = "h94/IP-Adapter"
IP_STYLE_SUBFOLDER = "sdxl_models"
IP_STYLE_WEIGHT    = "ip-adapter_sdxl.bin"

BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"

def safe_download(repo, filename, local_dir, min_bytes, label, subfolder=None):
    os.makedirs(local_dir, exist_ok=True)
    local_path = os.path.join(local_dir, os.path.basename(filename))
    if os.path.exists(local_path) and os.path.getsize(local_path) < min_bytes:
        try: os.remove(local_path)
        except Exception: pass
    path = hf_hub_download(
        repo_id=repo,
        filename=filename,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
        resume_download=True,
        force_download=not os.path.exists(local_path),
        subfolder=subfolder,
    )
    size = os.path.getsize(path)
    if size < min_bytes:
        raise RuntimeError(f"Téléchargement incomplet de {label} (taille: {size} bytes).")
    print(f"✅ {label} téléchargé ({size/1e6:.1f} MB)")
    return path

def ensure_assets_or_download():
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    os.makedirs(CN_LOCAL_DIR, exist_ok=True)
    safe_download(ASSETS_REPO, "ControlNetModel/config.json", CHECKPOINTS_DIR, 1_000, "IdentityNet config")
    safe_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", CHECKPOINTS_DIR, 100_000_000, "IdentityNet weights")
    safe_download(ASSETS_REPO, "ip-adapter.bin", CHECKPOINTS_DIR, 100_000_000, "IP-Adapter (InstantID)")
    safe_download(IP_STYLE_REPO, IP_STYLE_WEIGHT, CHECKPOINTS_DIR, 20_000_000, "IP-Adapter Style (SDXL)", subfolder=IP_STYLE_SUBFOLDER)

def import_pipeline_or_fail():
    candidates = [
        "./instantid/pipeline_stable_diffusion_xl_instantid_full.py",
        "./instantid/pipeline_stable_diffusion_xl_instantid.py",
    ]
    pipeline_file = next((p for p in candidates if os.path.exists(p)), None)
    if pipeline_file is None:
        raise RuntimeError("❌ Pipeline manquante. Place `pipeline_stable_diffusion_xl_instantid_full.py` dans ./instantid/")
    if os.path.getsize(pipeline_file) < 1024:
        raise RuntimeError("❌ Pipeline trop petite (vide ?). Utilise la version SDXL officielle.")
    spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
    mod  = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    for name, obj in vars(mod).items():
        if isinstance(obj, type) and "InstantID" in name and hasattr(obj, "from_pretrained"):
            print(f"✅ Pipeline trouvée : {name}")
            return obj
    avail = [n for n, o in vars(mod).items() if isinstance(o, type)]
    raise RuntimeError("❌ Aucune classe pipeline InstantID trouvée. Classes dispo: " + ", ".join(avail))

def draw_kps_local(img_pil, kps):
    w, h = img_pil.size
    out = Image.new("RGB", (w, h), "white")
    d = ImageDraw.Draw(out)
    r = max(2, min(w, h)//100)
    for (x, y) in kps:
        d.ellipse((x - r, y - r, x + r, y + r), fill="black")
    return out

load_logs = []
HAS_STYLE_ADAPTER = False
try:
    SDXLInstantID = import_pipeline_or_fail()
    ensure_assets_or_download()

    controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
    pipe = SDXLInstantID.from_pretrained(
        BASE_MODEL,
        controlnet=controlnet_identitynet,
        torch_dtype=DTYPE,
        safety_checker=None,
        feature_extractor=None,
    ).to(DEVICE)

    pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)

    try:
        pipe.load_ip_adapter(
            IP_STYLE_REPO,
            subfolder=IP_STYLE_SUBFOLDER,
            weight_name=IP_STYLE_WEIGHT,
            adapter_name="style",
        )
        load_logs.append("✅ IP-Adapter Style (SDXL) chargé (adapter_name='style').")
        HAS_STYLE_ADAPTER = True
    except Exception as e:
        load_logs.append(f"ℹ️ IP-Adapter Style non chargé: {e}")

    if DEVICE == "cuda":
        if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
        if hasattr(pipe, "unet"):             pipe.unet.to("cuda")

    load_logs.append("✅ InstantID prêt.")
except Exception:
    load_logs += ["❌ ERREUR au chargement:", traceback.format_exc()]
    pipe = None

if pipe is None:
    raise RuntimeError("Échec de chargement du pipeline.\n" + "\n".join(load_logs))

def load_face_analyser():
    errors = []
    for name in ("antelopev2", "buffalo_l"):
        try:
            fa = FaceAnalysis(name=name, root="./models", providers=["CPUExecutionProvider"])
            fa.prepare(ctx_id=0, det_size=(640, 640))
            print(f"✅ InsightFace chargé: {name}")
            return fa
        except Exception as e:
            errors.append(f"{name}: {e}")
            print(f"⚠️ InsightFace échec {name}{e}")
    raise RuntimeError("Echec chargement InsightFace. Détails: " + " | ".join(errors))

fa = load_face_analyser()

def extract_face_embed_and_kps(pil_img):
    import numpy as np, cv2
    img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
    faces = fa.get(img_cv2)
    if not faces:
        raise ValueError("Aucun visage détecté dans la photo.")
    face = faces[-1]
    emb_np = face["embedding"]
    if not isinstance(emb_np, np.ndarray):
        emb_np = np.asarray(emb_np, dtype="float32")
    if emb_np.ndim == 1:
        emb_np = emb_np[None, ...]  # (1, D)
    face_emb = torch.from_numpy(emb_np).to(device=DEVICE, dtype=DTYPE)  # ← Tensor [1,D] sur bon device/dtype
    kps_img  = draw_kps_local(pil_img, face["kps"])
    return face_emb, kps_img

def generate(face_image, style_image, prompt, negative_prompt,
             identity_strength, adapter_strength, style_strength,
             steps, cfg, width, height, seed):
    try:
        if face_image is None:
            return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)

        gen = None if seed is None or int(seed) < 0 else torch.Generator(device=DEVICE).manual_seed(int(seed))

        # visage → carré 512 pour détection stable
        from PIL import ImageOps
        face = ImageOps.exif_transpose(face_image).convert("RGB")
        ms = min(face.size); x = (face.width - ms) // 2; y = (face.height - ms) // 2
        face_sq = face.crop((x, y, x + ms, y + ms)).resize((512, 512), Image.Resampling.LANCZOS)

        # InsightFace : embedding (torch [1,D]) + landmarks
        face_emb, kps_img = extract_face_embed_and_kps(face_sq)  # face_emb: torch.Tensor [1,D] on DEVICE/DTYPE

        # IP-Adapter scales
        try:
            if HAS_STYLE_ADAPTER and style_image is not None:
                pipe.set_ip_adapter_scale({"instantid": float(adapter_strength), "style": float(style_strength)})
            else:
                pipe.set_ip_adapter_scale(float(adapter_strength))
        except Exception as e:
            print(f"ℹ️ set_ip_adapter_scale ignoré: {e}")

        # compat multi-ControlNet (même si on en a qu’un)
        cn = getattr(pipe, "controlnet", None)
        if isinstance(cn, (list, tuple)):
            n_cn = len(cn)
        else:
            try: n_cn = len(cn)
            except Exception: n_cn = 1

        image_arg = [kps_img] * n_cn if n_cn > 1 else ([kps_img] if isinstance(cn, (list, tuple)) else kps_img)
        scale_val = float(identity_strength)
        scale_arg = [scale_val] * n_cn if n_cn > 1 else ([scale_val] if isinstance(cn, (list, tuple)) else scale_val)

        # kwargs d’inférence (on met aussi ici pour compat)
        gen_kwargs = dict(
            prompt=(prompt or "").strip(),
            negative_prompt=(negative_prompt or "").strip(),
            image=image_arg,
            image_embeds=face_emb,                                # compat pipeline
            added_conditions={"image_embeds": face_emb},          # diffusers ≥ 0.30.x (si propagé)
            added_cond_kwargs={"image_embeds": face_emb},         # diffusers 0.29.x (si propagé)
            controlnet_conditioning_scale=scale_arg,
            num_inference_steps=int(steps),
            guidance_scale=float(cfg),
            width=int(width),
            height=int(height),
            generator=gen,
        )
        if HAS_STYLE_ADAPTER and style_image is not None:
            try:
                gen_kwargs["ip_adapter_image"] = ImageOps.exif_transpose(style_image).convert("RGB")
            except Exception as e:
                print(f"ℹ️ ip_adapter_image ignoré: {e}")

        # 🔧 MONKEY-PATCH: injecter image_embeds au niveau du UNet.forward
        orig_forward = pipe.unet.forward

        def forward_patch(*args, **kwargs):
            # on fusionne proprement pour n’écraser rien
            ac = kwargs.get("added_conditions")
            if ac is None:
                ac = {}
            else:
                ac = dict(ac)
            ac["image_embeds"] = face_emb
            kwargs["added_conditions"] = ac
            # compat pour 0.29.x
            kwargs["added_cond_kwargs"] = ac
            return orig_forward(*args, **kwargs)

        pipe.unet.forward = forward_patch

        try:
            images = pipe(**gen_kwargs).images
        finally:
            # toujours restaurer le forward d'origine
            pipe.unet.forward = orig_forward

        return images[0], "", "\n".join(load_logs)

    except torch.cuda.OutOfMemoryError:
        return None, "CUDA OOM: baisse la résolution ou les steps.", "\n".join(load_logs)
    except Exception:
        import traceback
        return None, "Erreur:\n" + traceback.format_exc(), "\n".join(load_logs)



EX_PROMPT = (
    "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, "
    "clean lineart, cel shading, vibrant colors, expressive eyes, dynamic composition, simple background"
)
EX_NEG = (
    "realistic, photo, photorealistic, skin pores, complex lighting, "
    "low quality, worst quality, lowres, blurry, noisy, watermark, text, logo, jpeg artifacts, "
    "bad anatomy, deformed, multiple faces, nsfw"
)

with gr.Blocks(css="footer{display:none !important}") as demo:
    gr.Markdown("# 🏴‍☠️ InstantID SDXL + IP-Adapter Style (2D) — visage → perso One Piece")
    with gr.Row():
        with gr.Column():
            face_image  = gr.Image(type="pil", label="Photo visage (obligatoire)", height=260)
            style_image = gr.Image(type="pil", label="Image de style (optionnel)", height=260)
            gr.Markdown("Astuce : poster/planche One Piece → rendu 2D renforcé via IP-Adapter Style.")
            prompt   = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
            negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
            with gr.Row():
                identity_strength = gr.Slider(0.2, 1.5, 0.95, 0.05, label="Fidélité visage (IdentityNet)")
                adapter_strength  = gr.Slider(0.1, 1.5, 0.85, 0.05, label="Détails anime (InstantID)")
            style_strength = gr.Slider(0.1, 1.5, 0.95, 0.05, label="Force style (IP-Adapter Style)")
            steps  = gr.Slider(10, 60, 30, 1, label="Steps")
            cfg    = gr.Slider(0.1, 12.0, 6.5, 0.1, label="CFG")
            width  = gr.Dropdown(choices=[576, 640, 704, 768, 896], value=704, label="Largeur")
            height = gr.Dropdown(choices=[704, 768, 896, 1024], value=896, label="Hauteur")
            seed   = gr.Number(value=-1, label="Seed (-1 aléatoire)")
            btn    = gr.Button("🎨 Générer", variant="primary")
        with gr.Column():
            out_image = gr.Image(label="Résultat", interactive=False)
            err_box   = gr.Textbox(label="Erreurs", visible=False)
            log_box   = gr.Textbox(label="Logs", value="\n".join(load_logs), lines=12)

    def wrap(*args):
        img, err, logs = generate(*args)
        return img, gr.update(visible=bool(err), value=err), gr.update(value=logs)

    btn.click(
        wrap,
        inputs=[face_image, style_image, prompt, negative,
                identity_strength, adapter_strength, style_strength,
                steps, cfg, width, height, seed],
        outputs=[out_image, err_box, log_box],
    )

demo.queue(api_open=False)
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)