Nad54 commited on
Commit
ae2593e
·
verified ·
1 Parent(s): 91aae5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -209
app.py CHANGED
@@ -1,222 +1,43 @@
1
- # app.py — InstantID SDXL (import local résilient + draw_kps local)
2
- import os, sys, traceback, importlib.util
3
-
4
- # OMP robuste
5
- val = os.environ.get("OMP_NUM_THREADS", "")
6
- try:
7
- if val == "" or int(val) <= 0:
8
- os.environ["OMP_NUM_THREADS"] = "1"
9
- except Exception:
10
- os.environ["OMP_NUM_THREADS"] = "1"
11
- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
12
-
13
- sys.path.insert(0, os.path.abspath("./instantid"))
14
-
15
- import torch, gradio as gr
16
- from PIL import Image, ImageOps, ImageDraw
17
  from huggingface_hub import hf_hub_download
18
- from diffusers.models import ControlNetModel
19
 
20
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
22
-
23
- # -------- Repo MODEL pour télécharger les poids --------
24
  ASSETS_REPO = "InstantX/InstantID"
25
-
26
- # -------- Fichiers attendus localement (TEXTE uniquement) --------
27
- PIPE_CANDIDATES = [
28
- "./instantid/pipeline_stable_diffusion_xl_instantid.py",
29
- "./instantid/pipeline_stable_diffusion_xl_instantid_full.py",
30
- ]
31
-
32
  CHECKPOINTS_DIR = "./checkpoints"
33
  CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
34
  IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
35
 
36
- # ---------- import dynamique tolérant ----------
37
- def import_pipeline_or_fail():
38
- pipeline_file = next((p for p in PIPE_CANDIDATES if os.path.exists(p)), None)
39
- if pipeline_file is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  raise RuntimeError(
41
- "Pipeline InstantID introuvable.\n"
42
- "➡️ Place un des fichiers (texte) dans ./instantid/ :\n"
43
- " - pipeline_stable_diffusion_xl_instantid.py\n"
44
- " - pipeline_stable_diffusion_xl_instantid_full.py\n"
45
  )
46
- spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
47
- mod = importlib.util.module_from_spec(spec)
48
- spec.loader.exec_module(mod)
49
 
50
- # Cherche une classe dont le nom contient 'InstantID' et qui expose from_pretrained
51
- for name, obj in vars(mod).items():
52
- if isinstance(obj, type) and "InstantID" in name and hasattr(obj, "from_pretrained"):
53
- return obj # classe pipeline
54
- # Si rien trouvé, affiche les classes disponibles
55
- avail = [n for n, o in vars(mod).items() if isinstance(o, type)]
56
- raise RuntimeError("Aucune classe pipeline 'InstantID' trouvée. Classes dispo: " + ", ".join(avail))
57
-
58
- # ---------- téléchargement des poids depuis le repo MODEL ----------
59
  def ensure_assets_or_download():
60
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
61
  os.makedirs(CN_LOCAL_DIR, exist_ok=True)
62
- if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "config.json")):
63
- hf_hub_download(ASSETS_REPO, "ControlNetModel/config.json", local_dir=CHECKPOINTS_DIR)
64
- if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "diffusion_pytorch_model.safetensors")):
65
- hf_hub_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINTS_DIR)
66
- if not os.path.isfile(IP_ADAPTER_LOCAL):
67
- hf_hub_download(ASSETS_REPO, "ip-adapter.bin", local_dir=CHECKPOINTS_DIR)
68
-
69
- # ---------- draw_kps local (remplace la dépendance du fichier) ----------
70
- def draw_kps_local(img_pil: Image.Image, kps):
71
- """
72
- kps: ndarray shape (5,2) insightface 'antelopev2' (yeux G/D, nez, bouche G/D).
73
- On dessine de petits ronds noirs sur fond blanc (format attendu par IdentityNet).
74
- """
75
- w, h = img_pil.size
76
- out = Image.new("RGB", (w, h), "white")
77
- d = ImageDraw.Draw(out)
78
- r = max(2, min(w, h) // 100) # rayon adaptatif
79
- try:
80
- import numpy as np
81
- pts = kps if isinstance(kps, (list, tuple)) else np.array(kps)
82
- except Exception:
83
- pts = kps
84
- for p in pts:
85
- x, y = float(p[0]), float(p[1])
86
- d.ellipse((x - r, y - r, x + r, y + r), fill="black")
87
- return out
88
-
89
- # ---------- Chargement pipeline ----------
90
- load_logs = []
91
- try:
92
- SDXLInstantID = import_pipeline_or_fail()
93
- ensure_assets_or_download()
94
-
95
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # changeable vers un SDXL anime
96
-
97
- load_logs.append("Chargement ControlNet IdentityNet…")
98
- controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
99
-
100
- load_logs.append(f"Chargement pipeline InstantID (base={BASE_MODEL})…")
101
- pipe = SDXLInstantID.from_pretrained(
102
- BASE_MODEL,
103
- controlnet=[controlnet_identitynet],
104
- torch_dtype=DTYPE,
105
- safety_checker=None,
106
- feature_extractor=None,
107
- ).to(DEVICE)
108
-
109
- if hasattr(pipe, "load_ip_adapter_instantid"):
110
- pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)
111
- else:
112
- raise RuntimeError("Cette pipeline ne fournit pas load_ip_adapter_instantid().")
113
-
114
- if DEVICE == "cuda":
115
- if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
116
- if hasattr(pipe, "unet"): pipe.unet.to("cuda")
117
-
118
- load_logs.append("✅ InstantID prêt.")
119
- except Exception:
120
- load_logs += ["❌ ERREUR au chargement:", traceback.format_exc()]
121
- pipe = None
122
-
123
- if pipe is None:
124
- raise RuntimeError("Échec chargement pipeline InstantID.\n" + "\n".join(load_logs))
125
-
126
- # ---------- InsightFace pour landmarks ----------
127
- from insightface.app import FaceAnalysis
128
- fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
129
- fa.prepare(ctx_id=0, det_size=(640, 640))
130
-
131
- def kps_image_from_face(pil_img: Image.Image):
132
- import numpy as np, cv2
133
- img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
134
- faces = fa.get(img_cv2)
135
- if not faces:
136
- raise ValueError("Aucun visage détecté. Utilise une photo portrait nette (visage centré).")
137
- face = faces[-1]
138
- return draw_kps_local(pil_img, face["kps"])
139
-
140
- # ---------- Inference ----------
141
- def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
142
- try:
143
- if face_image is None:
144
- return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)
145
-
146
- gen = None if seed is None or int(seed) < 0 else torch.Generator(device=DEVICE).manual_seed(int(seed))
147
-
148
- face = ImageOps.exif_transpose(face_image).convert("RGB")
149
- ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
150
- face_sq = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
151
-
152
- kps_img = kps_image_from_face(face_sq)
153
-
154
- if hasattr(pipe, "set_ip_adapter_scale"):
155
- pipe.set_ip_adapter_scale(float(adapter_strength))
156
-
157
- images = pipe(
158
- prompt=prompt.strip(),
159
- negative_prompt=(negative_prompt or "").strip(),
160
- image=kps_img, # landmarks pour IdentityNet
161
- controlnet_conditioning_scale=float(identity_strength),
162
- num_inference_steps=int(steps),
163
- guidance_scale=float(cfg),
164
- width=int(width),
165
- height=int(height),
166
- generator=gen,
167
- ).images
168
-
169
- return images[0], "", "\n".join(load_logs)
170
- except torch.cuda.OutOfMemoryError as oom:
171
- msg = "CUDA OOM: baisse résolution (ex: 640×768 → 576×704), steps 24–28, CFG 5–6."
172
- return None, f"{msg}\n{oom}", "\n".join(load_logs)
173
- except Exception:
174
- return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
175
-
176
- # ---------- UI ----------
177
- EX_PROMPT = (
178
- "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, straw hat, "
179
- "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
180
- "dynamic lighting, simple background, high detail"
181
- )
182
- EX_NEG = (
183
- "low quality, worst quality, lowres, blurry, noisy, watermark, text, logo, jpeg artifacts, "
184
- "bad anatomy, distorted eyes, cross-eye, asymmetrical eyes, deformed, multiple faces, nsfw"
185
- )
186
-
187
- with gr.Blocks(css="footer{display:none !important}") as demo:
188
- gr.Markdown("# 🏴‍☠️ One Piece — InstantID (SDXL) — Import local robuste")
189
-
190
- with gr.Row():
191
- with gr.Column():
192
- face_image = gr.Image(type="pil", label="Photo visage", height=360)
193
- prompt = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
194
- negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
195
-
196
- identity_strength = gr.Slider(0.2, 1.5, value=0.95, step=0.05, label="IdentityNet strength (fidélité)")
197
- adapter_strength = gr.Slider(0.2, 1.5, value=0.85, step=0.05, label="Adapter strength (détails)")
198
- steps = gr.Slider(10, 60, value=30, step=1, label="Steps")
199
- cfg = gr.Slider(0.1, 12.0, value=5.5, step=0.1, label="CFG")
200
- width = gr.Dropdown(choices=[576, 640, 704, 768, 896], value=704, label="Largeur")
201
- height = gr.Dropdown(choices=[704, 768, 896, 1024], value=896, label="Hauteur")
202
- seed = gr.Number(value=-1, label="Seed (-1 aléatoire)")
203
- btn = gr.Button("🎨 Générer", variant="primary")
204
-
205
- with gr.Column():
206
- out_image = gr.Image(label="Résultat", interactive=False)
207
- err_box = gr.Textbox(label="Erreurs", visible=False)
208
- log_box = gr.Textbox(label="Logs", value="\n".join(load_logs), lines=10)
209
-
210
- def wrap(*args):
211
- img, err, logs = generate(*args)
212
- return img, gr.update(visible=bool(err), value=err), gr.update(value=logs)
213
-
214
- btn.click(
215
- wrap,
216
- inputs=[face_image, prompt, negative, identity_strength, adapter_strength, steps, cfg, width, height, seed],
217
- outputs=[out_image, err_box, log_box],
218
- )
219
-
220
- demo.queue()
221
- if __name__ == "__main__":
222
- demo.launch(ssr_mode=False, server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import hf_hub_download
2
+ import os
3
 
 
 
 
 
4
  ASSETS_REPO = "InstantX/InstantID"
 
 
 
 
 
 
 
5
  CHECKPOINTS_DIR = "./checkpoints"
6
  CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
7
  IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
8
 
9
+ def safe_download(repo, filename, local_dir, min_bytes, label):
10
+ """Télécharge avec reprise, puis vérifie la taille. Re-télécharge si besoin."""
11
+ os.makedirs(local_dir, exist_ok=True)
12
+ # Si fichier présent mais trop petit => le supprimer (corrompu)
13
+ local_path = os.path.join(local_dir, filename)
14
+ if os.path.exists(local_path) and os.path.getsize(local_path) < min_bytes:
15
+ try:
16
+ os.remove(local_path)
17
+ except Exception:
18
+ pass
19
+ # Téléchargement (avec reprise)
20
+ path = hf_hub_download(
21
+ repo_id=repo,
22
+ filename=filename,
23
+ local_dir=local_dir,
24
+ resume_download=True,
25
+ force_download=not os.path.exists(os.path.join(local_dir, filename)),
26
+ )
27
+ # Vérification de taille
28
+ size = os.path.getsize(path)
29
+ if size < min_bytes:
30
  raise RuntimeError(
31
+ f"Téléchargement incomplet de {label} ({size} bytes). "
32
+ f"Relance le Space; si ça persiste, vérifie la connectivité HF."
 
 
33
  )
34
+ return path
 
 
35
 
 
 
 
 
 
 
 
 
 
36
  def ensure_assets_or_download():
37
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
38
  os.makedirs(CN_LOCAL_DIR, exist_ok=True)
39
+ # ControlNet IdentityNet (~2.5 Go)
40
+ safe_download(ASSETS_REPO, "ControlNetModel/config.json", CHECKPOINTS_DIR, 1000, "IdentityNet config")
41
+ safe_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", CHECKPOINTS_DIR, 100_000_000, "IdentityNet weights")
42
+ # ip-adapter (~1.6 Go)
43
+ safe_download(ASSETS_REPO, "ip-adapter.bin", CHECKPOINTS_DIR, 100_000_000, "ip-adapter")