Nad54 commited on
Commit
3a7fbf1
·
verified ·
1 Parent(s): 0102606

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -88
app.py CHANGED
@@ -1,6 +1,9 @@
1
- # app.py — InstantID (SDXL) résilient: essaie plusieurs chemins de pipeline dans InstantX/InstantID
 
 
2
  import os, traceback, importlib.util
3
 
 
4
  os.environ.setdefault("OMP_NUM_THREADS", "4")
5
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
6
 
@@ -9,95 +12,86 @@ from PIL import Image, ImageOps
9
  from huggingface_hub import hf_hub_download, HfHubHTTPError
10
  from diffusers.models import ControlNetModel
11
 
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- dtype = torch.float16 if device == "cuda" else torch.float32
14
 
15
- REPO = "InstantX/InstantID"
 
16
 
17
- # --- candidats de fichiers possibles dans le repo (les noms varient selon les versions)
18
  PIPE_CANDIDATES = [
19
- "pipeline_stable_diffusion_xl_instantid_full.py",
20
- "pipeline_stable_diffusion_xl_instantid.py",
21
- "pipelines/pipeline_stable_diffusion_xl_instantid_full.py",
22
- "pipelines/pipeline_stable_diffusion_xl_instantid.py",
23
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
  load_logs = []
 
 
 
26
 
27
- def _download_first_existing(repo_id: str, candidates: list[str], local_dir: str) -> str | None:
28
- for fname in candidates:
29
- try:
30
- p = hf_hub_download(repo_id=repo_id, filename=fname, local_dir=local_dir)
31
- load_logs.append(f"✅ Pipeline trouvée: {fname}")
32
- return p
33
- except HfHubHTTPError as e:
34
- load_logs.append(f"… {fname} introuvable ({e.__class__.__name__})")
35
- except Exception as e:
36
- load_logs.append(f"… {fname} erreur: {e}")
37
- return None
38
-
39
- # 1) Télécharger la pipeline InstantID (un des fichiers ci-dessus)
40
- PIPE_LOCAL = _download_first_existing(REPO, PIPE_CANDIDATES, "./instantid")
41
- if PIPE_LOCAL is None:
42
- # Abort propre avec aide utilisateur
43
- msg = (
44
- "Aucun fichier pipeline *.py trouvé dans InstantX/InstantID.\n"
45
- "Solutions:\n"
46
- " - Uploade manuellement dans /instantid/ un fichier pipeline nommé, par ex:\n"
47
- " pipeline_stable_diffusion_xl_instantid_full.py\n"
48
- " (copie-le depuis le repo InstantX/InstantID)\n"
49
- " - Ou change la variable PIPE_CANDIDATES pour matcher le nom exact dans le repo.\n"
50
- )
51
- raise RuntimeError(msg + "\n" + "\n".join(load_logs))
52
 
53
- # 2) Télécharger les poids IdentityNet (ControlNet) + ip-adapter.bin
54
- try:
55
- cn_cfg = hf_hub_download(repo_id=REPO, filename="ControlNetModel/config.json", local_dir="./checkpoints")
56
- cn_dir = os.path.dirname(cn_cfg) # ./checkpoints/ControlNetModel
57
- hf_hub_download(repo_id=REPO, filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
58
- ip_adapter_path = hf_hub_download(repo_id=REPO, filename="ip-adapter.bin", local_dir="./checkpoints")
59
- load_logs.append("✅ IdentityNet + ip-adapter téléchargés.")
60
- except Exception as e:
61
- raise RuntimeError(f"Echec téléchargement IdentityNet/ip-adapter: {e}")
62
-
63
- # 3) Import dynamique de la pipeline
64
- import importlib.util
65
- spec = importlib.util.spec_from_file_location("instantid_pipeline", PIPE_LOCAL)
66
- mod = importlib.util.module_from_spec(spec)
67
- spec.loader.exec_module(mod)
68
-
69
- # Ces symboles existent dans les implémentations InstantID SDXL
70
- StableDiffusionXLInstantIDPipeline = getattr(mod, "StableDiffusionXLInstantIDPipeline", None)
71
- draw_kps = getattr(mod, "draw_kps", None)
72
- if StableDiffusionXLInstantIDPipeline is None or draw_kps is None:
73
- raise RuntimeError("La pipeline importée ne contient pas StableDiffusionXLInstantIDPipeline/draw_kps.")
74
-
75
- # 4) Modèle de base SDXL — prends un SDXL stylé anime ou le base officiel
76
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
77
- # Astuce: si tu as un SDXL anime (ex: YamerMIX), mets-le ici pour rendu plus manga:
78
- # BASE_MODEL = "wangqixun/YamerMIX_v8"
79
-
80
- # 5) Charger IdentityNet + pipeline
81
- try:
82
  load_logs.append("Chargement ControlNet IdentityNet…")
83
- controlnet_identitynet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype)
84
 
85
  load_logs.append(f"Chargement pipeline InstantID (base={BASE_MODEL})…")
86
- pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
87
  BASE_MODEL,
88
  controlnet=[controlnet_identitynet],
89
- torch_dtype=dtype,
90
  safety_checker=None,
91
  feature_extractor=None,
92
- ).to(device)
93
 
94
- # ip-adapter d’InstantID
95
  if hasattr(pipe, "load_ip_adapter_instantid"):
96
- pipe.load_ip_adapter_instantid(ip_adapter_path)
97
  else:
98
- raise RuntimeError("La méthode pipe.load_ip_adapter_instantid est absente dans cette pipeline.")
99
 
100
- if device == "cuda":
101
  if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
102
  if hasattr(pipe, "unet"): pipe.unet.to("cuda")
103
 
@@ -107,36 +101,35 @@ except Exception:
107
  pipe = None
108
 
109
  if pipe is None:
110
- raise RuntimeError("Échec de chargement de la pipeline InstantID.\n" + "\n".join(load_logs))
111
 
112
- # 6) InsightFace pour landmarks
113
  from insightface.app import FaceAnalysis
114
  fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
115
  fa.prepare(ctx_id=0, det_size=(640, 640))
116
 
117
- def extract_face_info(pil_img: Image.Image):
118
  import numpy as np, cv2
119
  img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
120
  faces = fa.get(img_cv2)
121
  if not faces:
122
- raise ValueError("Aucun visage détecté. Utilise un portrait net (buste, visage bien centré).")
123
- face = faces[-1]
124
- kps_image = draw_kps(pil_img, face["kps"])
125
- return face, kps_image
126
 
127
- # 7) Inference
128
  def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
129
  try:
130
  if face_image is None:
131
  return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)
132
 
133
- gen = None if seed is None or int(seed) < 0 else torch.Generator(device=device).manual_seed(int(seed))
134
 
135
  face = ImageOps.exif_transpose(face_image).convert("RGB")
136
  ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
137
  face_sq = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
138
 
139
- face_info, face_kps = extract_face_info(face_sq)
140
 
141
  if hasattr(pipe, "set_ip_adapter_scale"):
142
  pipe.set_ip_adapter_scale(float(adapter_strength))
@@ -144,7 +137,7 @@ def generate(face_image, prompt, negative_prompt, identity_strength, adapter_str
144
  images = pipe(
145
  prompt=prompt.strip(),
146
  negative_prompt=(negative_prompt or "").strip(),
147
- image=face_kps,
148
  controlnet_conditioning_scale=float(identity_strength),
149
  num_inference_steps=int(steps),
150
  guidance_scale=float(cfg),
@@ -154,14 +147,13 @@ def generate(face_image, prompt, negative_prompt, identity_strength, adapter_str
154
  ).images
155
 
156
  return images[0], "", "\n".join(load_logs)
157
-
158
  except torch.cuda.OutOfMemoryError as oom:
159
- msg = "CUDA OOM: baisse résolution (ex: 640×768 → 576×704), steps 24–28, CFG 5–7."
160
  return None, f"{msg}\n{oom}", "\n".join(load_logs)
161
  except Exception:
162
  return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
163
 
164
- # 8) UI
165
  EX_PROMPT = (
166
  "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, straw hat, "
167
  "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
@@ -173,7 +165,7 @@ EX_NEG = (
173
  )
174
 
175
  with gr.Blocks(css="footer{display:none !important}") as demo:
176
- gr.Markdown("# 🏴‍☠️ One Piece — InstantID (SDXL)")
177
 
178
  with gr.Row():
179
  with gr.Column():
@@ -181,7 +173,7 @@ with gr.Blocks(css="footer{display:none !important}") as demo:
181
  prompt = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
182
  negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
183
 
184
- identity_strength = gr.Slider(0.2, 1.5, value=0.90, step=0.05, label="IdentityNet strength (fidélité)")
185
  adapter_strength = gr.Slider(0.2, 1.5, value=0.85, step=0.05, label="Adapter strength (détails)")
186
  steps = gr.Slider(10, 60, value=30, step=1, label="Steps")
187
  cfg = gr.Slider(0.1, 12.0, value=5.5, step=0.1, label="CFG")
@@ -208,4 +200,3 @@ with gr.Blocks(css="footer{display:none !important}") as demo:
208
  demo.queue()
209
  if __name__ == "__main__":
210
  demo.launch(ssr_mode=False, server_name="0.0.0.0", server_port=7860)
211
-
 
1
+ # app.py — InstantID SDXL (Option 1: téléchargements auto des poids depuis un repo Model)
2
+ # - Tu uploades localement SEULEMENT la pipeline .py (texte) dans ./instantid/
3
+ # - Les poids (safetensors/bin) sont téléchargés au runtime depuis InstantX/InstantID (repo Model)
4
  import os, traceback, importlib.util
5
 
6
+ # Eviter l'erreur libgomp
7
  os.environ.setdefault("OMP_NUM_THREADS", "4")
8
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
9
 
 
12
  from huggingface_hub import hf_hub_download, HfHubHTTPError
13
  from diffusers.models import ControlNetModel
14
 
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
17
 
18
+ # -------- Références Hub (repo MODEL public qui contient les poids) --------
19
+ ASSETS_REPO = "InstantX/InstantID" # tu peux le remplacer par ton propre repo MODEL si besoin
20
 
21
+ # -------- Chemins locaux attendus dans le Space --------
22
  PIPE_CANDIDATES = [
23
+ "./instantid/pipeline_stable_diffusion_xl_instantid.py",
24
+ "./instantid/pipeline_stable_diffusion_xl_instantid_full.py",
 
 
25
  ]
26
+ CHECKPOINTS_DIR = "./checkpoints"
27
+ CN_LOCAL_DIR = os.path.join(CHECKPOINTS_DIR, "ControlNetModel")
28
+ IP_ADAPTER_LOCAL = os.path.join(CHECKPOINTS_DIR, "ip-adapter.bin")
29
+
30
+ # -------- Utilitaires --------
31
+ def import_pipeline_or_fail():
32
+ pipeline_file = next((p for p in PIPE_CANDIDATES if os.path.exists(p)), None)
33
+ if pipeline_file is None:
34
+ raise RuntimeError(
35
+ "Pipeline InstantID introuvable.\n"
36
+ "➡️ Uploade l’un de ces fichiers (texte) dans ton Space :\n"
37
+ " - instantid/pipeline_stable_diffusion_xl_instantid.py\n"
38
+ " - instantid/pipeline_stable_diffusion_xl_instantid_full.py\n"
39
+ "Les poids seront téléchargés automatiquement."
40
+ )
41
+ spec = importlib.util.spec_from_file_location("instantid_pipeline", pipeline_file)
42
+ mod = importlib.util.module_from_spec(spec)
43
+ spec.loader.exec_module(mod)
44
+ SDXLInstantID = getattr(mod, "StableDiffusionXLInstantIDPipeline", None)
45
+ draw_kps = getattr(mod, "draw_kps", None)
46
+ if SDXLInstantID is None or draw_kps is None:
47
+ raise RuntimeError("Le fichier pipeline ne contient pas StableDiffusionXLInstantIDPipeline/draw_kps.")
48
+ return SDXLInstantID, draw_kps
49
+
50
+ def ensure_assets_or_download():
51
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
52
+ os.makedirs(CN_LOCAL_DIR, exist_ok=True)
53
+ # Télécharge/valide ControlNet IdentityNet
54
+ try:
55
+ if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "config.json")):
56
+ hf_hub_download(ASSETS_REPO, "ControlNetModel/config.json", local_dir=CHECKPOINTS_DIR)
57
+ if not os.path.isfile(os.path.join(CN_LOCAL_DIR, "diffusion_pytorch_model.safetensors")):
58
+ hf_hub_download(ASSETS_REPO, "ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINTS_DIR)
59
+ except HfHubHTTPError as e:
60
+ raise RuntimeError(f"Echec téléchargement IdentityNet depuis {ASSETS_REPO} : {e}")
61
+
62
+ # Télécharge/valide ip-adapter.bin
63
+ try:
64
+ if not os.path.isfile(IP_ADAPTER_LOCAL):
65
+ hf_hub_download(ASSETS_REPO, "ip-adapter.bin", local_dir=CHECKPOINTS_DIR)
66
+ except HfHubHTTPError as e:
67
+ raise RuntimeError(f"Echec téléchargement ip-adapter.bin depuis {ASSETS_REPO} : {e}")
68
 
69
+ # -------- Chargement pipeline --------
70
  load_logs = []
71
+ try:
72
+ SDXLInstantID, draw_kps = import_pipeline_or_fail()
73
+ ensure_assets_or_download()
74
 
75
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # remplaçable par un SDXL style anime si tu en as
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  load_logs.append("Chargement ControlNet IdentityNet…")
78
+ controlnet_identitynet = ControlNetModel.from_pretrained(CN_LOCAL_DIR, torch_dtype=DTYPE)
79
 
80
  load_logs.append(f"Chargement pipeline InstantID (base={BASE_MODEL})…")
81
+ pipe = SDXLInstantID.from_pretrained(
82
  BASE_MODEL,
83
  controlnet=[controlnet_identitynet],
84
+ torch_dtype=DTYPE,
85
  safety_checker=None,
86
  feature_extractor=None,
87
+ ).to(DEVICE)
88
 
 
89
  if hasattr(pipe, "load_ip_adapter_instantid"):
90
+ pipe.load_ip_adapter_instantid(IP_ADAPTER_LOCAL)
91
  else:
92
+ raise RuntimeError("La méthode load_ip_adapter_instantid est absente de cette pipeline.")
93
 
94
+ if DEVICE == "cuda":
95
  if hasattr(pipe, "image_proj_model"): pipe.image_proj_model.to("cuda")
96
  if hasattr(pipe, "unet"): pipe.unet.to("cuda")
97
 
 
101
  pipe = None
102
 
103
  if pipe is None:
104
+ raise RuntimeError("Échec chargement pipeline InstantID.\n" + "\n".join(load_logs))
105
 
106
+ # -------- InsightFace pour landmarks --------
107
  from insightface.app import FaceAnalysis
108
  fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
109
  fa.prepare(ctx_id=0, det_size=(640, 640))
110
 
111
+ def extract_kps_image(pil_img: Image.Image):
112
  import numpy as np, cv2
113
  img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
114
  faces = fa.get(img_cv2)
115
  if not faces:
116
+ raise ValueError("Aucun visage détecté. Utilise un portrait net (visage centré).")
117
+ face = faces[-1] # visage principal
118
+ return draw_kps(pil_img, face["kps"])
 
119
 
120
+ # -------- Inference --------
121
  def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
122
  try:
123
  if face_image is None:
124
  return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)
125
 
126
+ gen = None if seed is None or int(seed) < 0 else torch.Generator(device=DEVICE).manual_seed(int(seed))
127
 
128
  face = ImageOps.exif_transpose(face_image).convert("RGB")
129
  ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
130
  face_sq = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
131
 
132
+ kps_img = extract_kps_image(face_sq)
133
 
134
  if hasattr(pipe, "set_ip_adapter_scale"):
135
  pipe.set_ip_adapter_scale(float(adapter_strength))
 
137
  images = pipe(
138
  prompt=prompt.strip(),
139
  negative_prompt=(negative_prompt or "").strip(),
140
+ image=kps_img, # landmarks pour IdentityNet
141
  controlnet_conditioning_scale=float(identity_strength),
142
  num_inference_steps=int(steps),
143
  guidance_scale=float(cfg),
 
147
  ).images
148
 
149
  return images[0], "", "\n".join(load_logs)
 
150
  except torch.cuda.OutOfMemoryError as oom:
151
+ msg = "CUDA OOM: baisse la résolution (ex: 640×768 → 576×704), steps 24–28, CFG 5–6."
152
  return None, f"{msg}\n{oom}", "\n".join(load_logs)
153
  except Exception:
154
  return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
155
 
156
+ # -------- UI --------
157
  EX_PROMPT = (
158
  "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, straw hat, "
159
  "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
 
165
  )
166
 
167
  with gr.Blocks(css="footer{display:none !important}") as demo:
168
+ gr.Markdown("# 🏴‍☠️ One Piece — InstantID (SDXL) — Poids auto depuis HF")
169
 
170
  with gr.Row():
171
  with gr.Column():
 
173
  prompt = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
174
  negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
175
 
176
+ identity_strength = gr.Slider(0.2, 1.5, value=0.95, step=0.05, label="IdentityNet strength (fidélité)")
177
  adapter_strength = gr.Slider(0.2, 1.5, value=0.85, step=0.05, label="Adapter strength (détails)")
178
  steps = gr.Slider(10, 60, value=30, step=1, label="Steps")
179
  cfg = gr.Slider(0.1, 12.0, value=5.5, step=0.1, label="CFG")
 
200
  demo.queue()
201
  if __name__ == "__main__":
202
  demo.launch(ssr_mode=False, server_name="0.0.0.0", server_port=7860)