Nad54 commited on
Commit
42e4f65
·
verified ·
1 Parent(s): 633210e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -38
app.py CHANGED
@@ -1,78 +1,96 @@
1
- # app.py — InstantID SDXL (Option 1, robuste aux erreurs OMP + hub)
2
- # - Tu uploades SEULEMENT la pipeline .py (texte) dans ./instantid/
3
- # - Les poids sont téléchargés auto depuis InstantX/InstantID (repo Model)
4
-
5
  import os, traceback, importlib.util
6
 
7
- # --- Sécuriser OMP_NUM_THREADS (éviter libgomp error) ---
8
  val = os.environ.get("OMP_NUM_THREADS", "")
9
  try:
10
  if val == "" or int(val) <= 0:
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
  except Exception:
13
  os.environ["OMP_NUM_THREADS"] = "1"
14
-
15
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
16
 
17
  import torch, gradio as gr
18
- from PIL import Image, ImageOps
19
- from huggingface_hub import hf_hub_download # <- pas de HfHubHTTPError (compat large)
20
  from diffusers.models import ControlNetModel
21
 
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
24
 
25
- # -------- Repo Model qui contient les poids (public) --------
26
  ASSETS_REPO = "InstantX/InstantID"
27
 
28
- # -------- Chemins locaux attendus dans le Space --------
29
  PIPE_CANDIDATES = [
30
  "./instantid/pipeline_stable_diffusion_xl_instantid.py",
31
  "./instantid/pipeline_stable_diffusion_xl_instantid_full.py",
32
  ]
33
- CHECKPOINTS_DIR = "./checkpoints"
34
- CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
35
- IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
36
 
37
- # -------- Utilitaires --------
 
 
 
 
38
  def import_pipeline_or_fail():
39
  pipeline_file = next((p for p in PIPE_CANDIDATES if os.path.exists(p)), None)
40
  if pipeline_file is None:
41
  raise RuntimeError(
42
  "Pipeline InstantID introuvable.\n"
43
- "➡️ Uploade l’un de ces fichiers (TEXTE uniquement) dans ton Space :\n"
44
- " - instantid/pipeline_stable_diffusion_xl_instantid.py\n"
45
- " - instantid/pipeline_stable_diffusion_xl_instantid_full.py\n"
46
- "(Ne mets PAS de .safetensors dans le Space — ils seront téléchargés automatiquement)."
47
  )
48
  spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
49
  mod = importlib.util.module_from_spec(spec)
50
  spec.loader.exec_module(mod)
51
- SDXLInstantID = getattr(mod, "StableDiffusionXLInstantIDPipeline", None)
52
- draw_kps = getattr(mod, "draw_kps", None)
53
- if SDXLInstantID is None or draw_kps is None:
54
- raise RuntimeError("Le fichier pipeline ne contient pas StableDiffusionXLInstantIDPipeline/draw_kps.")
55
- return SDXLInstantID, draw_kps
56
 
 
 
 
 
 
 
 
 
 
57
  def ensure_assets_or_download():
58
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
59
  os.makedirs(CN_LOCAL_DIR, exist_ok=True)
60
- # ControlNet IdentityNet
61
  if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "config.json")):
62
  hf_hub_download(ASSETS_REPO, "ControlNetModel/config.json", local_dir=CHECKPOINTS_DIR)
63
  if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "diffusion_pytorch_model.safetensors")):
64
  hf_hub_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINTS_DIR)
65
- # ip-adapter
66
  if not os.path.isfile(IP_ADAPTER_LOCAL):
67
  hf_hub_download(ASSETS_REPO, "ip-adapter.bin", local_dir=CHECKPOINTS_DIR)
68
 
69
- # -------- Chargement pipeline --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  load_logs = []
71
  try:
72
- SDXLInstantID, draw_kps = import_pipeline_or_fail()
73
  ensure_assets_or_download()
74
 
75
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # tu peux remplacer par un SDXL plus “anime
76
 
77
  load_logs.append("Chargement ControlNet IdentityNet…")
78
  controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
@@ -86,11 +104,10 @@ try:
86
  feature_extractor=None,
87
  ).to(DEVICE)
88
 
89
- # ip-adapter d’InstantID
90
  if hasattr(pipe, "load_ip_adapter_instantid"):
91
  pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)
92
  else:
93
- raise RuntimeError("La méthode load_ip_adapter_instantid est absente de cette pipeline.")
94
 
95
  if DEVICE == "cuda":
96
  if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
@@ -104,21 +121,21 @@ except Exception:
104
  if pipe is None:
105
  raise RuntimeError("Échec chargement pipeline InstantID.\n" + "\n".join(load_logs))
106
 
107
- # -------- InsightFace (landmarks) --------
108
  from insightface.app import FaceAnalysis
109
  fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
110
  fa.prepare(ctx_id=0, det_size=(640, 640))
111
 
112
- def extract_kps_image(pil_img: Image.Image):
113
  import numpy as np, cv2
114
  img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
115
  faces = fa.get(img_cv2)
116
  if not faces:
117
- raise ValueError("Aucun visage détecté. Utilise un portrait net (visage centré).")
118
  face = faces[-1]
119
- return draw_kps(pil_img, face["kps"])
120
 
121
- # -------- Inference --------
122
  def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
123
  try:
124
  if face_image is None:
@@ -130,7 +147,7 @@ def generate(face_image, prompt, negative_prompt, identity_strength, adapter_str
130
  ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
131
  face_sq = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
132
 
133
- kps_img = extract_kps_image(face_sq)
134
 
135
  if hasattr(pipe, "set_ip_adapter_scale"):
136
  pipe.set_ip_adapter_scale(float(adapter_strength))
@@ -154,7 +171,7 @@ def generate(face_image, prompt, negative_prompt, identity_strength, adapter_str
154
  except Exception:
155
  return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
156
 
157
- # -------- UI --------
158
  EX_PROMPT = (
159
  "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, straw hat, "
160
  "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
@@ -166,7 +183,7 @@ EX_NEG = (
166
  )
167
 
168
  with gr.Blocks(css="footer{display:none !important}") as demo:
169
- gr.Markdown("# 🏴‍☠️ One Piece — InstantID (SDXL) — Poids auto depuis HF")
170
 
171
  with gr.Row():
172
  with gr.Column():
 
1
+ # app.py — InstantID SDXL (import local résilient + draw_kps local)
 
 
 
2
  import os, traceback, importlib.util
3
 
4
+ # OMP robuste
5
  val = os.environ.get("OMP_NUM_THREADS", "")
6
  try:
7
  if val == "" or int(val) <= 0:
8
  os.environ["OMP_NUM_THREADS"] = "1"
9
  except Exception:
10
  os.environ["OMP_NUM_THREADS"] = "1"
 
11
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
12
 
13
  import torch, gradio as gr
14
+ from PIL import Image, ImageOps, ImageDraw
15
+ from huggingface_hub import hf_hub_download
16
  from diffusers.models import ControlNetModel
17
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
20
 
21
+ # -------- Repo MODEL pour télécharger les poids --------
22
  ASSETS_REPO = "InstantX/InstantID"
23
 
24
+ # -------- Fichiers attendus localement (TEXTE uniquement) --------
25
  PIPE_CANDIDATES = [
26
  "./instantid/pipeline_stable_diffusion_xl_instantid.py",
27
  "./instantid/pipeline_stable_diffusion_xl_instantid_full.py",
28
  ]
 
 
 
29
 
30
+ CHECKPOINTS_DIR = "./checkpoints"
31
+ CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
32
+ IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
33
+
34
+ # ---------- import dynamique tolérant ----------
35
  def import_pipeline_or_fail():
36
  pipeline_file = next((p for p in PIPE_CANDIDATES if os.path.exists(p)), None)
37
  if pipeline_file is None:
38
  raise RuntimeError(
39
  "Pipeline InstantID introuvable.\n"
40
+ "➡️ Place un des fichiers (texte) dans ./instantid/ :\n"
41
+ " - pipeline_stable_diffusion_xl_instantid.py\n"
42
+ " - pipeline_stable_diffusion_xl_instantid_full.py\n"
 
43
  )
44
  spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
45
  mod = importlib.util.module_from_spec(spec)
46
  spec.loader.exec_module(mod)
 
 
 
 
 
47
 
48
+ # Cherche une classe dont le nom contient 'InstantID' et qui expose from_pretrained
49
+ for name, obj in vars(mod).items():
50
+ if isinstance(obj, type) and "InstantID" in name and hasattr(obj, "from_pretrained"):
51
+ return obj # classe pipeline
52
+ # Si rien trouvé, affiche les classes disponibles
53
+ avail = [n for n, o in vars(mod).items() if isinstance(o, type)]
54
+ raise RuntimeError("Aucune classe pipeline 'InstantID' trouvée. Classes dispo: " + ", ".join(avail))
55
+
56
+ # ---------- téléchargement des poids depuis le repo MODEL ----------
57
  def ensure_assets_or_download():
58
  os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
59
  os.makedirs(CN_LOCAL_DIR, exist_ok=True)
 
60
  if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "config.json")):
61
  hf_hub_download(ASSETS_REPO, "ControlNetModel/config.json", local_dir=CHECKPOINTS_DIR)
62
  if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "diffusion_pytorch_model.safetensors")):
63
  hf_hub_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINTS_DIR)
 
64
  if not os.path.isfile(IP_ADAPTER_LOCAL):
65
  hf_hub_download(ASSETS_REPO, "ip-adapter.bin", local_dir=CHECKPOINTS_DIR)
66
 
67
+ # ---------- draw_kps local (remplace la dépendance du fichier) ----------
68
+ def draw_kps_local(img_pil: Image.Image, kps):
69
+ """
70
+ kps: ndarray shape (5,2) insightface 'antelopev2' (yeux G/D, nez, bouche G/D).
71
+ On dessine de petits ronds noirs sur fond blanc (format attendu par IdentityNet).
72
+ """
73
+ w, h = img_pil.size
74
+ out = Image.new("RGB", (w, h), "white")
75
+ d = ImageDraw.Draw(out)
76
+ r = max(2, min(w, h) // 100) # rayon adaptatif
77
+ try:
78
+ import numpy as np
79
+ pts = kps if isinstance(kps, (list, tuple)) else np.array(kps)
80
+ except Exception:
81
+ pts = kps
82
+ for p in pts:
83
+ x, y = float(p[0]), float(p[1])
84
+ d.ellipse((x - r, y - r, x + r, y + r), fill="black")
85
+ return out
86
+
87
+ # ---------- Chargement pipeline ----------
88
  load_logs = []
89
  try:
90
+ SDXLInstantID = import_pipeline_or_fail()
91
  ensure_assets_or_download()
92
 
93
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # changeable vers un SDXL anime
94
 
95
  load_logs.append("Chargement ControlNet IdentityNet…")
96
  controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
 
104
  feature_extractor=None,
105
  ).to(DEVICE)
106
 
 
107
  if hasattr(pipe, "load_ip_adapter_instantid"):
108
  pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)
109
  else:
110
+ raise RuntimeError("Cette pipeline ne fournit pas load_ip_adapter_instantid().")
111
 
112
  if DEVICE == "cuda":
113
  if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
 
121
  if pipe is None:
122
  raise RuntimeError("Échec chargement pipeline InstantID.\n" + "\n".join(load_logs))
123
 
124
+ # ---------- InsightFace pour landmarks ----------
125
  from insightface.app import FaceAnalysis
126
  fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
127
  fa.prepare(ctx_id=0, det_size=(640, 640))
128
 
129
+ def kps_image_from_face(pil_img: Image.Image):
130
  import numpy as np, cv2
131
  img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
132
  faces = fa.get(img_cv2)
133
  if not faces:
134
+ raise ValueError("Aucun visage détecté. Utilise une photo portrait nette (visage centré).")
135
  face = faces[-1]
136
+ return draw_kps_local(pil_img, face["kps"])
137
 
138
+ # ---------- Inference ----------
139
  def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
140
  try:
141
  if face_image is None:
 
147
  ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
148
  face_sq = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
149
 
150
+ kps_img = kps_image_from_face(face_sq)
151
 
152
  if hasattr(pipe, "set_ip_adapter_scale"):
153
  pipe.set_ip_adapter_scale(float(adapter_strength))
 
171
  except Exception:
172
  return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
173
 
174
+ # ---------- UI ----------
175
  EX_PROMPT = (
176
  "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, straw hat, "
177
  "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
 
183
  )
184
 
185
  with gr.Blocks(css="footer{display:none !important}") as demo:
186
+ gr.Markdown("# 🏴‍☠️ One Piece — InstantID (SDXL) — Import local robuste")
187
 
188
  with gr.Row():
189
  with gr.Column():