Nad54 commited on
Commit
719b317
·
verified ·
1 Parent(s): cf0441d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -137
app.py CHANGED
@@ -1,181 +1,163 @@
1
- # app.py — InstantID (custom pipeline) + LoRA One Piece (Text-to-Image)
2
- # Compatible avec diffusers 0.29.x
3
- import os, traceback
4
 
5
- # Évite l'erreur libgomp: donne une valeur sûre si tu veux fixer OMP
6
  os.environ.setdefault("OMP_NUM_THREADS", "4")
7
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
8
 
9
  import torch, gradio as gr
10
  from PIL import Image, ImageOps
11
- from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
12
- from safetensors.torch import load_file
13
-
14
- # ============== Config ==============
15
- DEVICE = "cuda"
16
- DTYPE = torch.float16
17
-
18
- # Modèle de base: pipeline InstantID officielle (chargée comme "custom_pipeline")
19
- INSTANTID_REPO = "InstantX/InstantID"
20
- CUSTOM_PIPE = "instantid"
21
-
22
- # Ton LoRA de style One Piece (fichier local)
23
- LORA_PATH = "./wanostyle_2_offset.safetensors"
24
- LORA_NAME = "wanostyle" # nom d'adapter utilisé dans set_adapters
25
 
26
- # ============== Utils ===============
27
- def preflight():
28
- s = [f"torch: {torch.__version__}", f"cuda: {torch.cuda.is_available()}"]
29
- if torch.cuda.is_available():
30
- s += [f"gpu: {torch.cuda.get_device_name(0)}", f"cap: {torch.cuda.get_device_capability(0)}"]
31
- return "\n".join(s)
32
 
33
- def is_lora_file(path: str) -> bool:
34
- try:
35
- sd = load_file(path)
36
- return any("lora_down.weight" in k for k in sd.keys())
37
- except Exception:
38
- return False
39
 
40
- print("=== PREFLIGHT ===")
41
- print(preflight())
 
42
 
43
- # ============ Load pipeline =========
44
- pipe = None
45
  load_logs = []
46
  try:
47
- load_logs.append("Chargement InstantID (custom pipeline)…")
48
- pipe = DiffusionPipeline.from_pretrained(
49
- INSTANTID_REPO,
50
- custom_pipeline=CUSTOM_PIPE, # << clé: charge la pipeline InstantID
51
- torch_dtype=DTYPE,
52
- use_safetensors=True,
53
- safety_checker=None, # remets-le si Space public strict
54
- )
55
- # planificateur stable
56
- if hasattr(pipe, "scheduler"):
57
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
58
-
59
- # Optimisations VRAM usuelles
60
- if hasattr(pipe, "enable_attention_slicing"): pipe.enable_attention_slicing()
61
- if hasattr(pipe, "enable_vae_slicing"): pipe.enable_vae_slicing()
62
- if hasattr(pipe, "enable_vae_tiling"): pipe.enable_vae_tiling()
63
-
64
- # Charger ton LoRA de style One Piece
65
- if os.path.exists(LORA_PATH) and is_lora_file(LORA_PATH):
66
- pipe.load_lora_weights(LORA_PATH, adapter_name=LORA_NAME, use_safetensors=True)
67
- load_logs.append("✅ LoRA (One Piece) chargé.")
68
- else:
69
- load_logs.append("⚠️ LoRA introuvable ou non standard — vérifie le .safetensors.")
70
-
71
- pipe.to(DEVICE)
72
- load_logs.append("✅ Pipeline InstantID prête.")
73
-
74
  except Exception:
75
  load_logs += ["❌ ERREUR au chargement:", traceback.format_exc()]
 
76
  print("\n".join(load_logs))
77
 
78
  if pipe is None:
79
- raise RuntimeError("Échec de chargement du pipeline. Voir logs.")
80
-
81
- # ============== Inference ==============
82
- def generate(
83
- ref_face, # photo utilisateur (obligatoire)
84
- prompt,
85
- negative_prompt,
86
- id_strength=0.85, # force d'identité InstantID (0.7–0.95)
87
- lora_scale=1.05, # force du style One Piece
88
- cfg=7.0,
89
- steps=30,
90
- width=640,
91
- height=768,
92
- seed=-1,
93
- ):
94
- run_logs = []
 
 
 
 
95
  try:
96
- if ref_face is None:
97
- return None, "Merci d'ajouter ta photo (portrait/visage).", "\n".join(load_logs)
98
-
99
- run_logs.append(preflight())
100
-
101
- # Seed
102
- gen = None if seed is None or int(seed) < 0 else torch.Generator(DEVICE).manual_seed(int(seed))
103
-
104
- # Préparer visage (carré conseillé pour l'embedding)
105
- face = ImageOps.exif_transpose(ref_face).convert("RGB")
106
- min_side = min(face.size)
107
- x = (face.width - min_side)//2; y = (face.height - min_side)//2
108
- face_sq = face.crop((x, y, x+min_side, y+min_side)).resize((512, 512), Image.Resampling.LANCZOS)
109
-
110
- # Appliquer LoRA (intensité via set_adapters)
111
- ca_kwargs = None
112
- if LORA_NAME in getattr(pipe, "loaded_adapters", [LORA_NAME]):
113
- try:
114
- pipe.set_adapters([LORA_NAME], adapter_weights=[float(lora_scale)])
115
- ca_kwargs = {"scale": float(lora_scale)}
116
- run_logs.append(f"✅ LoRA actif (scale={float(lora_scale)})")
117
- except Exception as e:
118
- run_logs.append(f"⚠️ set_adapters erreur: {e}")
119
- else:
120
- run_logs.append("ℹ️ LoRA non chargé.")
121
-
122
- # Appel InstantID — la pipeline attend 'image' = référence visage
123
- # Génération T2I (image neuve)
124
- result = pipe(
125
  prompt=prompt.strip(),
126
  negative_prompt=(negative_prompt or "").strip(),
127
- image=face_sq, # <<< identité
128
- id_strength=float(id_strength), # <<< verrouillage identité
 
 
129
  width=int(width),
130
  height=int(height),
131
- guidance_scale=float(cfg),
132
- num_inference_steps=int(steps),
133
  generator=gen,
134
- cross_attention_kwargs=ca_kwargs, # LoRA intensity
135
- )
136
-
137
- return result.images[0], "", "\n".join(load_logs + run_logs)
138
 
 
139
  except torch.cuda.OutOfMemoryError as oom:
140
- msg = "CUDA OOM: baisse résolution (ex: 576×704), steps (24–28), CFG (6–7)."
141
- return None, f"{msg}\n{oom}", "\n".join(load_logs + run_logs)
142
- except Exception:
143
- return None, "Erreur:\n" + traceback.format_exc(), "\n".join(load_logs + run_logs)
144
 
145
- # ===================== UI =====================
146
  EX_PROMPT = (
147
- "one piece style, Eiichiro Oda style, anime portrait, upper body, pirate outfit, straw hat, "
148
- "cel shading, clean lineart, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
149
- "dynamic lighting, detailed face, simple background, high detail"
150
  )
151
  EX_NEG = (
152
- "worst quality, low quality, bad anatomy, lowres, blurry, noisy, text, watermark, logo, cropped, "
153
- "jpeg artifacts, 3d render, photorealistic, realistic skin, bad proportions, deformed face, distorted eyes, "
154
- "cross-eye, asymmetrical eyes, extra limbs, extra fingers, fused fingers, multiple faces, mutated, signature, username, nsfw"
155
  )
156
 
157
  with gr.Blocks(css="footer{display:none !important}") as demo:
158
- gr.Markdown("# 🏴‍☠️ One Piece — InstantID (custom) + LoRA")
159
 
160
  with gr.Row():
161
  with gr.Column():
162
- ref_face = gr.Image(type="pil", label="Photo visage (référence)", value=None)
163
  prompt = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
164
  negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
165
 
166
- id_strength = gr.Slider(0.5, 1.0, value=0.85, step=0.05, label="Force identité (InstantID)")
167
- lora_scale = gr.Slider(0.0, 1.5, value=1.05, step=0.05, label="Force du style One Piece")
168
- cfg = gr.Slider(1, 12, value=7.0, step=0.5, label="CFG (guidance)")
169
- steps = gr.Slider(10, 60, value=30, step=1, label="Steps")
170
- width = gr.Dropdown(choices=[512, 576, 640, 704, 768], value=640, label="Largeur")
171
- height = gr.Dropdown(choices=[640, 704, 768], value=768, label="Hauteur")
172
- seed = gr.Number(value=-1, label="Seed (-1 aléatoire)")
173
- btn = gr.Button("🎨 Générer", variant="primary")
174
 
175
  with gr.Column():
176
  out_image = gr.Image(label="Résultat", interactive=False)
177
  err_box = gr.Textbox(label="Erreurs", visible=False)
178
- log_box = gr.Textbox(label="Logs", value="\n".join(load_logs), lines=10)
179
 
180
  def wrap(*args):
181
  img, err, logs = generate(*args)
@@ -183,7 +165,7 @@ with gr.Blocks(css="footer{display:none !important}") as demo:
183
 
184
  btn.click(
185
  wrap,
186
- inputs=[ref_face, prompt, negative, id_strength, lora_scale, cfg, steps, width, height, seed],
187
  outputs=[out_image, err_box, log_box],
188
  )
189
 
 
1
+ # app.py — InstantID (SDXL) minimal + UI simple
2
+ # Télécharge la pipeline custom depuis le Space officiel et lance InstantID (IdentityNet uniquement).
3
+ import os, traceback, importlib.util
4
 
5
+ # Evite l'erreur libgomp
6
  os.environ.setdefault("OMP_NUM_THREADS", "4")
7
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
8
 
9
  import torch, gradio as gr
10
  from PIL import Image, ImageOps
11
+ from huggingface_hub import hf_hub_download
12
+ from diffusers.models import ControlNetModel
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.float16 if device == "cuda" else torch.float32
16
+
17
+ # --------- Téléchargements nécessaires depuis le Space officiel InstantID ---------
18
+ # 1) pipeline custom (fichier .py)
19
+ PIPE_FILENAME = "pipeline_stable_diffusion_xl_instantid_full.py"
20
+ local_pipeline_path = hf_hub_download(
21
+ repo_id="InstantX/InstantID",
22
+ filename=PIPE_FILENAME,
23
+ local_dir="./instantid"
24
+ )
25
 
26
+ # 2) Poids ControlNet IdentityNet + IP-Adapter
27
+ cn_dir = hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
28
+ cn_dir = os.path.dirname(cn_dir) # ./checkpoints/ControlNetModel
29
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
30
+ ip_adapter_path = hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
 
31
 
32
+ # --------- Import dynamique de la classe pipeline depuis le fichier téléchargé ---------
33
+ spec = importlib.util.spec_from_file_location("instantid_pipeline", local_pipeline_path)
34
+ mod = importlib.util.module_from_spec(spec)
35
+ spec.loader.exec_module(mod)
36
+ StableDiffusionXLInstantIDPipeline = mod.StableDiffusionXLInstantIDPipeline
37
+ draw_kps = mod.draw_kps # pour les landmarks
38
 
39
+ # --------- Base model SDXL (plus stylé que SDXL base pour anime) ---------
40
+ # Remplace par "stabilityai/stable-diffusion-xl-base-1.0" si tu préfères un rendu neutre
41
+ BASE_MODEL = "wangqixun/YamerMIX_v8"
42
 
43
+ # --------- Chargement pipeline + IdentityNet ---------
 
44
  load_logs = []
45
  try:
46
+ load_logs.append("Chargement ControlNet IdentityNet…")
47
+ controlnet_identitynet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype)
48
+
49
+ load_logs.append(f"Chargement pipeline InstantID (base={BASE_MODEL})…")
50
+ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
51
+ BASE_MODEL,
52
+ controlnet=[controlnet_identitynet], # seulement IdentityNet
53
+ torch_dtype=dtype,
54
+ safety_checker=None,
55
+ feature_extractor=None,
56
+ ).to(device)
57
+
58
+ # Charger l’adapter InstantID (ip-adapter.bin)
59
+ pipe.load_ip_adapter_instantid(ip_adapter_path)
60
+ if device == "cuda":
61
+ pipe.image_proj_model.to("cuda")
62
+ pipe.unet.to("cuda")
63
+
64
+ load_logs.append("✅ InstantID prêt.")
 
 
 
 
 
 
 
 
65
  except Exception:
66
  load_logs += ["❌ ERREUR au chargement:", traceback.format_exc()]
67
+ pipe = None
68
  print("\n".join(load_logs))
69
 
70
  if pipe is None:
71
+ raise RuntimeError("Échec de chargement du pipeline. Voir logs container.")
72
+
73
+ # --------- Face encoder (InsightFace) pour landmarks + embedding ---------
74
+ from insightface.app import FaceAnalysis
75
+ fa = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
76
+ fa.prepare(ctx_id=0, det_size=(640, 640))
77
+
78
+ def extract_face_info(pil_img: Image.Image):
79
+ import numpy as np, cv2
80
+ img_cv2 = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
81
+ faces = fa.get(img_cv2)
82
+ if not faces:
83
+ raise ValueError("Aucun visage détecté. Utilise une photo portrait nette.")
84
+ face = faces[-1] # plus grand visage
85
+ # kps = (5,2), bbox etc.
86
+ kps_image = draw_kps(pil_img, face["kps"])
87
+ return face, kps_image
88
+
89
+ # --------- Inference ---------
90
+ def generate(face_image, prompt, negative_prompt, identity_strength, adapter_strength, steps, cfg, width, height, seed):
91
  try:
92
+ if face_image is None:
93
+ return None, "Merci d'ajouter une photo visage.", "\n".join(load_logs)
94
+
95
+ gen = None if seed is None or int(seed) < 0 else torch.Generator(device=device).manual_seed(int(seed))
96
+
97
+ # Recadrage conseillé pour meilleure stabilité
98
+ face = ImageOps.exif_transpose(face_image).convert("RGB")
99
+ ms = min(face.size); x=(face.width-ms)//2; y=(face.height-ms)//2
100
+ face_cropped = face.crop((x, y, x+ms, y+ms)).resize((512,512), Image.Resampling.LANCZOS)
101
+
102
+ # Embedding + landmarks avec InsightFace
103
+ face_info, face_kps = extract_face_info(face_cropped)
104
+ # La pipeline s’occupe d’utiliser l’embedding interne; on lui passe:
105
+ # - image = landmarks (kps) pour IdentityNet
106
+ # - image_embeds = embedding visage calculé en interne
107
+ pipe.set_ip_adapter_scale(float(adapter_strength))
108
+
109
+ images = pipe(
 
 
 
 
 
 
 
 
 
 
 
110
  prompt=prompt.strip(),
111
  negative_prompt=(negative_prompt or "").strip(),
112
+ image=face_kps, # contrôle spatial (landmarks)
113
+ controlnet_conditioning_scale=float(identity_strength),
114
+ num_inference_steps=int(steps),
115
+ guidance_scale=float(cfg),
116
  width=int(width),
117
  height=int(height),
 
 
118
  generator=gen,
119
+ ).images
 
 
 
120
 
121
+ return images[0], "", "\n".join(load_logs)
122
  except torch.cuda.OutOfMemoryError as oom:
123
+ msg = "CUDA OOM: baisse la résolution (ex: 704×896 → 576×704), steps 24–28, CFG 5–7."
124
+ return None, f"{msg}\n{oom}", "\n".join(load_logs)
125
+ except Exception as e:
126
+ return None, "Erreur:\n"+traceback.format_exc(), "\n".join(load_logs)
127
 
128
+ # --------- UI ---------
129
  EX_PROMPT = (
130
+ "one piece style, Eiichiro Oda style, anime portrait, upper body, straw hat, pirate outfit, "
131
+ "clean lineart, cel shading, vibrant colors, expressive eyes, symmetrical face, looking at camera, "
132
+ "dynamic lighting, simple background, high detail"
133
  )
134
  EX_NEG = (
135
+ "low quality, worst quality, lowres, blurry, noisy, watermark, text, logo, jpeg artifacts, "
136
+ "bad anatomy, distorted eyes, cross-eye, asymmetrical eyes, deformed, multiple faces, nsfw"
 
137
  )
138
 
139
  with gr.Blocks(css="footer{display:none !important}") as demo:
140
+ gr.Markdown("# 🏴‍☠️ One Piece — InstantID (SDXL)")
141
 
142
  with gr.Row():
143
  with gr.Column():
144
+ face_image = gr.Image(type="pil", label="Photo visage", height=360)
145
  prompt = gr.Textbox(label="Prompt", value=EX_PROMPT, lines=3)
146
  negative = gr.Textbox(label="Negative Prompt", value=EX_NEG, lines=3)
147
 
148
+ identity_strength = gr.Slider(0.2, 1.5, value=0.85, step=0.05, label="IdentityNet strength (fidélité)")
149
+ adapter_strength = gr.Slider(0.2, 1.5, value=0.85, step=0.05, label="Adapter strength (détails)")
150
+ steps = gr.Slider(10, 60, value=30, step=1, label="Steps")
151
+ cfg = gr.Slider(0.1, 12.0, value=5.0, step=0.1, label="CFG")
152
+ width = gr.Dropdown(choices=[576, 640, 704, 768, 896], value=704, label="Largeur")
153
+ height = gr.Dropdown(choices=[704, 768, 896, 1024], value=896, label="Hauteur")
154
+ seed = gr.Number(value=-1, label="Seed (-1 aléatoire)")
155
+ btn = gr.Button("🎨 Générer", variant="primary")
156
 
157
  with gr.Column():
158
  out_image = gr.Image(label="Résultat", interactive=False)
159
  err_box = gr.Textbox(label="Erreurs", visible=False)
160
+ log_box = gr.Textbox(label="Logs", value="\n".join(load_logs), lines=8)
161
 
162
  def wrap(*args):
163
  img, err, logs = generate(*args)
 
165
 
166
  btn.click(
167
  wrap,
168
+ inputs=[face_image, prompt, negative, identity_strength, adapter_strength, steps, cfg, width, height, seed],
169
  outputs=[out_image, err_box, log_box],
170
  )
171