Nad54 commited on
Commit
7829fea
·
verified ·
1 Parent(s): ae2593e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -16
app.py CHANGED
@@ -1,43 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import hf_hub_download
2
- import os
 
 
 
3
 
 
4
  ASSETS_REPO = "InstantX/InstantID"
5
  CHECKPOINTS_DIR = "./checkpoints"
6
  CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
7
  IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
8
 
 
 
 
9
  def safe_download(repo, filename, local_dir, min_bytes, label):
10
- """Télécharge avec reprise, puis vérifie la taille. Re-télécharge si besoin."""
11
  os.makedirs(local_dir, exist_ok=True)
12
- # Si fichier présent mais trop petit => le supprimer (corrompu)
13
  local_path = os.path.join(local_dir, filename)
 
14
  if os.path.exists(local_path) and os.path.getsize(local_path) < min_bytes:
15
- try:
16
- os.remove(local_path)
17
- except Exception:
18
- pass
19
- # Téléchargement (avec reprise)
20
  path = hf_hub_download(
21
  repo_id=repo,
22
  filename=filename,
23
  local_dir=local_dir,
24
  resume_download=True,
25
- force_download=not os.path.exists(os.path.join(local_dir, filename)),
26
  )
27
- # Vérification de taille
28
  size = os.path.getsize(path)
 
29
  if size < min_bytes:
30
- raise RuntimeError(
31
- f"Téléchargement incomplet de {label} ({size} bytes). "
32
- f"Relance le Space; si ça persiste, vérifie la connectivité HF."
33
- )
34
  return path
35
 
36
  def ensure_assets_or_download():
37
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
38
  os.makedirs(CN_LOCAL_DIR, exist_ok=True)
39
- # ControlNet IdentityNet (~2.5 Go)
40
  safe_download(ASSETS_REPO, "ControlNetModel/config.json", CHECKPOINTS_DIR, 1000, "IdentityNet config")
41
  safe_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", CHECKPOINTS_DIR, 100_000_000, "IdentityNet weights")
42
- # ip-adapter (~1.6 Go)
43
  safe_download(ASSETS_REPO, "ip-adapter.bin", CHECKPOINTS_DIR, 100_000_000, "ip-adapter")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — InstantID SDXL (version stable Hugging Face)
2
+ # --------------------------------------------
3
+ # ⚙️ Nécessite :
4
+ # - dossier ./instantid/ avec :
5
+ # pipeline_stable_diffusion_xl_instantid.py
6
+ # ip_adapter/ (avec les fichiers du repo GitHub)
7
+ # - requirements.txt avec numpy==1.26.4, onnxruntime==1.16.3
8
+ # --------------------------------------------
9
+
10
+ import os, sys, traceback, importlib.util
11
+
12
+ # --- Sécuriser l'environnement ---
13
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
14
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
15
+
16
+ # --- S'assurer que Python trouve ./instantid/ip_adapter ---
17
+ sys.path.insert(0, os.path.abspath("./instantid"))
18
+
19
+ import torch, gradio as gr
20
+ from PIL import Image, ImageOps, ImageDraw
21
  from huggingface_hub import hf_hub_download
22
+ from diffusers.models import ControlNetModel
23
+
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
26
 
27
+ # --- Répertoires / fichiers ---
28
  ASSETS_REPO = "InstantX/InstantID"
29
  CHECKPOINTS_DIR = "./checkpoints"
30
  CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
31
  IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
32
 
33
+ # ------------------------------------------------------------
34
+ # Téléchargement robuste (détecte les fichiers vides)
35
+ # ------------------------------------------------------------
36
  def safe_download(repo, filename, local_dir, min_bytes, label):
 
37
  os.makedirs(local_dir, exist_ok=True)
 
38
  local_path = os.path.join(local_dir, filename)
39
+ # Supprimer fichier incomplet
40
  if os.path.exists(local_path) and os.path.getsize(local_path) < min_bytes:
41
+ print(f"⚠️ {label} corrompu ({os.path.getsize(local_path)} bytes) → suppression")
42
+ os.remove(local_path)
43
+ # Télécharger / reprendre
 
 
44
  path = hf_hub_download(
45
  repo_id=repo,
46
  filename=filename,
47
  local_dir=local_dir,
48
  resume_download=True,
49
+ force_download=not os.path.exists(local_path),
50
  )
 
51
  size = os.path.getsize(path)
52
+ print(f"✅ {label} téléchargé ({size/1e6:.1f} MB)")
53
  if size < min_bytes:
54
+ raise RuntimeError(f"Téléchargement incomplet de {label}.")
 
 
 
55
  return path
56
 
57
  def ensure_assets_or_download():
58
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
59
  os.makedirs(CN_LOCAL_DIR, exist_ok=True)
 
60
  safe_download(ASSETS_REPO, "ControlNetModel/config.json", CHECKPOINTS_DIR, 1000, "IdentityNet config")
61
  safe_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", CHECKPOINTS_DIR, 100_000_000, "IdentityNet weights")
 
62
  safe_download(ASSETS_REPO, "ip-adapter.bin", CHECKPOINTS_DIR, 100_000_000, "ip-adapter")
63
+
64
+ # ------------------------------------------------------------
65
+ # Import dynamique de la pipeline InstantID
66
+ # ------------------------------------------------------------
67
+ def import_pipeline_or_fail():
68
+ pipeline_file = "./instantid/pipeline_stable_diffusion_xl_instantid.py"
69
+ if not os.path.exists(pipeline_file):
70
+ raise RuntimeError("❌ Fichier pipeline manquant dans ./instantid/")
71
+ spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
72
+ mod = importlib.util.module_from_spec(spec)
73
+ spec.loader.exec_module(mod)
74
+ # Chercher la classe InstantID
75
+ for name, obj in vars(mod).items():
76
+ if isinstance(obj, type) and "InstantID" in name and hasattr(obj, "from_pretrained"):
77
+ print(f"✅ Pipeline trouvée : {name}")
78
+ return obj
79
+ raise RuntimeError("❌ Aucune classe pipeline InstantID trouvée dans le fichier.")
80
+
81
+ # ------------------------------------------------------------
82
+ # draw_kps local (remplace la dépendance d'origine)
83
+ # ------------------------------------------------------------
84
+ def draw_kps_local(img_pil, kps):
85
+ w, h = img_pil.size
86
+ out = Image.new("RGB", (w, h), "white")
87
+ d = ImageDraw.Draw(out)
88
+ r = max(2, min(w, h)//100)
89
+ for (x, y) in kps:
90
+ d.ellipse((x - r, y - r, x + r, y + r), fill="black")
91
+ return out
92
+
93
+ # ------------------------------------------------------------
94
+ # Chargement du modèle complet
95
+ # ------------------------------------------------------------
96
+ load_logs = []
97
+ try:
98
+ SDXLInstantID = import_pipeline_or_fail()
99
+ ensure_assets_or_download()
100
+
101
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
102
+ load_logs.append(f"Chargement base: {BASE_MODEL}")
103
+
104
+ controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
105
+ pipe = SDXLInstantID.from_pretrained(
106
+ BASE_MODEL,
107
+ controlnet=[controlnet_identitynet],
108
+ torch_dtype=DTYPE,
109
+ safety_checker=None,
110
+ feature_extractor=None,
111
+ ).to(DEVICE)
112
+
113
+ pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)
114
+ if DEVICE == "cuda":
115
+ if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
116
+ if hasattr(pipe, "unet"): pipe.unet.to("cuda")
117
+
118
+ load_logs.append("✅ InstantID prêt.")
119
+ except Exception:
120
+ load_logs += ["❌ ERREUR au chargement:", traceback.format_exc()]
121
+ pipe = None
122
+
123
+ if pipe is None:
124
+ raise RuntimeError("Échec de chargement du pipeline.\n" + "\n".join(load_logs))
125
+
126
+ # ------------------------------------------------------------
127
+ # Détection visage (InsightFace)
128
+ # ------------------------------------------------------------
129
+ from insightface.app import FaceAnalysis
130
+ fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
131
+ fa.prepare(ctx_id=0, det_size=(640, 640))
132
+
133
+ def extract_kps_image(pil_img):
134
+ import numpy as np, cv2
135
+ img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
136
+ faces = fa.get(img_cv2)
137
+ if not faces:
138
+ raise ValueError("Aucun visage détecté.")
139
+ face = faces[-1]
140
+ return draw_kps_local(pil_img, face["kps"])
141
+
142
+ # ------------------------------------------------------------
143
+ # Inference
144
+ # ------------------------------------------------------------
145
+ def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
146
+ try:
147
+ if face_image is None:
148
+ return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)
149
+ gen = None if seed is None or int(seed) < 0 else torch.Generator(device=DEVICE).manual_seed(int(seed))
150
+ face = ImageOps.exif_transpose(face_image).convert("RGB")
151
+ ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
152
+ face_sq = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
153
+ kps_img = extract_kps_image(face_sq)
154
+ if hasattr(pipe, "set_ip_adapter_scale"):
155
+ pipe.set_ip_adapter_scale(float(adapter_strength))
156
+ images = pipe(
157
+ prompt=prompt.strip(),
158
+ negative_prompt=(negative_prompt or "").strip(),
159
+ image=kps_img,
160
+ controlnet_conditioning_scale=float(identity_strength),
161
+ num_inference_steps=int(steps),
162
+ guidance_scale=float(cfg),
163
+ width=int(width),
164
+ height=int(height),
165
+ generator=gen,
166
+ ).images
167
+ return images[0], "", "\n".join(load_logs)
168
+ except torch.cuda.OutOfMemoryError as oom:
169
+ return None, "CUDA OOM: baisse la résolution.", "\n".join(load_logs)
170
+ except Exception:
171
+ return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
172
+
173
+ # ------------------------------------------------------------
174
+ # Interface Gradio
175
+ # ------------------------------------------------------------
176
+ EX_PROMPT = (
177
+ "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, "
178
+ "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, "
179
+ "dynamic lighting, simple background, high detail"
180
+ )
181
+ EX_NEG = (
182
+ "low quality, worst quality, lowres, blurry, noisy, watermark, text, logo, jpeg artifacts, "
183
+ "bad anatomy, distorted eyes, deformed, multiple faces, nsfw"
184
+ )
185
+
186
+ with gr.Blocks(css="footer{display:none !important}") as demo:
187
+ gr.Markdown("# 🏴‍☠️ One Piece – InstantID SDXL")
188
+
189
+ with gr.Row():
190
+ with gr.Column():
191
+ face_image = gr.Image(type="pil", label="Photo visage", height=360)
192
+ prompt = gr.Textbox(label="Prompt", value=EX_PROMPT)
193
+ negative = gr.Textbox(label="Negative Prompt", value=EX_NEG)
194
+ identity_strength = gr.Slider(0.2, 1.5, 0.95, 0.05, label="Fidélité visage")
195
+ adapter_strength = gr.Slider(0.2, 1.5, 0.85, 0.05, label="Détails anime")
196
+ steps = gr.Slider(10, 60, 30, 1, label="Steps")
197
+ cfg = gr.Slider(0.1, 12.0, 5.5, 0.1, label="CFG")
198
+ width = gr.Dropdown(choices=[576, 640, 704, 768, 896], value=704, label="Largeur")
199
+ height = gr.Dropdown(choices=[704, 768, 896, 1024], value=896, label="Hauteur")
200
+ seed = gr.Number(value=-1, label="Seed (-1 aléatoire)")
201
+ btn = gr.Button("🎨 Générer", variant="primary")
202
+
203
+ with gr.Column():
204
+ out_image = gr.Image(label="Résultat", interactive=False)
205
+ err_box = gr.Textbox(label="Erreurs", visible=False)
206
+ log_box = gr.Textbox(label="Logs", value="\n".join(load_logs), lines=10)
207
+
208
+ def wrap(*args):
209
+ img, err, logs = generate(*args)
210
+ return img, gr.update(visible=bool(err), value=err), gr.update(value=logs)
211
+
212
+ btn.click(
213
+ wrap,
214
+ inputs=[face_image, prompt, negative, identity_strength, adapter_strength, steps, cfg, width, height, seed],
215
+ outputs=[out_image, err_box, log_box],
216
+ )
217
+
218
+ demo.queue()
219
+ if __name__ == "__main__":
220
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)