Nad54 commited on
Commit
0afc625
·
verified ·
1 Parent(s): a32b4dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -97
app.py CHANGED
@@ -16,77 +16,83 @@ from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
17
  from pipeline import InstantCharacterFluxPipeline
18
 
19
- # =========================
20
- # CONFIG
21
- # =========================
22
  MAX_SEED = np.iinfo(np.int32).max
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- dtype = torch.float16 # L4: FP16 OK
25
-
26
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
 
27
  def need_token_guard():
28
  if HF_TOKEN is None:
29
  raise gr.Error(
30
  "⚠️ Token manquant : ajoute un secret 'HF_TOKEN' (Settings → Repository secrets) "
31
- "avec accès à black-forest-labs/FLUX.1-dev."
32
  )
33
 
34
- # =========================
35
- # PATHS / WEIGHTS
36
- # =========================
37
  base_model = "black-forest-labs/FLUX.1-dev"
38
  image_encoder_path = "google/siglip-so400m-patch14-384"
39
- # 🔻 On supprime l'encodeur 2 (DINOv2-giant) pour sauver ~10 Go RAM
40
- image_encoder_2_path = None
41
  birefnet_path = "ZhengPeng7/BiRefNet"
42
 
43
- # Ton LoRA One Piece (local)
44
- onepiece_flux_lora_path = "./onepiece_flux_v2.safetensors"
45
- onepiece_flux_trigger = "onepiece style"
46
-
47
  def _dl(repo_id, filename, token=None):
48
  return hf_hub_download(repo_id=repo_id, filename=filename, token=token)
49
 
50
  need_token_guard()
51
- # Uniquement l'IP-Adapter (nécessaire à l'identité) — 5.6 Go
52
  ip_adapter_path = _dl("tencent/InstantCharacter", "instantcharacter_ip-adapter.bin", HF_TOKEN)
53
 
54
- # =========================
55
- # PIPELINE (GPU only, low RAM peak)
56
- # =========================
 
 
 
 
 
 
 
 
 
57
  pipe = InstantCharacterFluxPipeline.from_pretrained(
58
- base_model,
59
- torch_dtype=dtype,
60
- token=HF_TOKEN,
61
- low_cpu_mem_usage=True, # ↓ pic RAM à l'init
62
  )
63
- pipe.to(device)
64
 
65
  try:
66
  pipe.enable_xformers_memory_efficient_attention()
67
  except Exception:
68
  pass
 
 
 
 
 
 
 
69
 
70
  pipe.set_progress_bar_config(disable=True)
71
  if hasattr(pipe, "vae"):
72
- if hasattr(pipe.vae, "enable_slicing"): pipe.vae.enable_slicing()
73
- if hasattr(pipe.vae, "enable_tiling"): pipe.vae.enable_tiling()
 
 
74
 
75
- # 🔻 Init de l'adapter: 1 seul image encoder (SigLIP) + moins de tokens
76
- adapter_kwargs = dict(
77
  image_encoder_path=image_encoder_path,
78
- subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, nb_token=512), # 1024 -> 512 pour baisser mémoire
 
79
  )
80
- # N'ajoute image_encoder_2_path que s'il existe
81
- if image_encoder_2_path:
82
- adapter_kwargs["image_encoder_2_path"] = image_encoder_2_path
83
 
84
- pipe.init_adapter(**adapter_kwargs)
 
 
 
 
 
 
85
 
86
- # =========================
87
- # BiRefNet : lazy-load sur CPU
88
- # =========================
89
- birefnet = None
90
  birefnet_transform_image = transforms.Compose([
91
  transforms.Resize((1024, 1024)),
92
  transforms.ToTensor(),
@@ -94,13 +100,6 @@ birefnet_transform_image = transforms.Compose([
94
  ])
95
 
96
  def remove_bkg(subject_image):
97
- global birefnet
98
- if birefnet is None:
99
- birefnet = AutoModelForImageSegmentation.from_pretrained(
100
- birefnet_path, trust_remote_code=True, token=HF_TOKEN
101
- )
102
- birefnet.to("cpu").eval()
103
-
104
  def infer_matting(img_pil):
105
  imgs = birefnet_transform_image(img_pil).unsqueeze(0).to("cpu")
106
  with torch.no_grad():
@@ -122,8 +121,10 @@ def remove_bkg(subject_image):
122
  if H == W: return image
123
  pad = abs(H - W)
124
  pad1, pad2 = pad // 2, pad - pad // 2
125
- if H > W: pad_param = ((0,0),(pad1,pad2),(0,0))
126
- else: pad_param = ((pad1,pad2),(0,0),(0,0))
 
 
127
  return np.pad(image, pad_param, "constant", constant_values=pad_value)
128
 
129
  mask = infer_matting(subject_image)[..., 0]
@@ -137,15 +138,15 @@ def remove_bkg(subject_image):
137
  crop = pad_to_square(crop)
138
  return Image.fromarray(crop.astype(np.uint8))
139
 
140
- # =========================
141
- # UTILS
142
- # =========================
143
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
144
  return random.randint(0, MAX_SEED) if randomize_seed else seed
145
 
146
- # =========================
147
- # INFERENCE
148
- # =========================
149
  @spaces.GPU
150
  def create_image(
151
  input_image,
@@ -154,54 +155,55 @@ def create_image(
154
  guidance_scale,
155
  num_inference_steps,
156
  seed,
157
- use_onepiece_lora=True,
158
  lora_strength=0.85,
159
- width=768,
160
- height=768,
161
  ):
162
  if input_image is None:
163
  raise gr.Error("Merci d'uploader une image de visage.")
164
- if use_onepiece_lora and not os.path.exists(onepiece_flux_lora_path):
165
  raise gr.Error(f"Fichier LoRA manquant : {onepiece_flux_lora_path}")
166
 
167
  input_image = remove_bkg(input_image)
168
- generator = torch.Generator(device=device).manual_seed(int(seed))
169
-
170
- if use_onepiece_lora:
171
- images = pipe.with_style_lora(
172
- lora_file_path=onepiece_flux_lora_path,
173
- trigger=onepiece_flux_trigger,
174
- prompt=prompt,
175
- num_inference_steps=int(num_inference_steps),
176
- guidance_scale=float(guidance_scale),
177
- width=int(width),
178
- height=int(height),
179
- subject_image=input_image,
180
- subject_scale=float(scale),
181
- lora_scale=float(lora_strength),
182
- generator=generator,
183
- ).images
184
  else:
185
- images = pipe(
186
- prompt=prompt,
187
- num_inference_steps=int(num_inference_steps),
188
- guidance_scale=float(guidance_scale),
189
- width=int(width),
190
- height=int(height),
191
- subject_image=input_image,
192
- subject_scale=float(scale),
193
- generator=generator,
194
- ).images
195
-
 
 
 
 
 
 
 
 
 
196
  return images
197
 
198
- # =========================
199
- # UI
200
- # =========================
201
- title = "<h1 align='center'>InstantCharacter (FLUX.1-dev) + One Piece (FLUX LoRA) — single encoder</h1>"
202
  description = (
203
- "GPU-only (FP16), low_cpu_mem_usage=True, **sans DINOv2-giant** pour éviter la limite RAM 30 Go. "
204
- "Départ en 768×768, tu peux monter à 896→1024 si stable."
205
  )
206
 
207
  block = gr.Blocks(css="footer {visibility: hidden}").queue(concurrency_count=1, max_size=5, api_open=False)
@@ -216,17 +218,17 @@ with block:
216
  value="onepiece style, a pirate character standing on a ship deck, shonen manga, strong black line art, cel shading, expressive eyes, dynamic pose, clean linework"
217
  )
218
  scale = gr.Slider(0.0, 1.5, 1.0, 0.01, label="Scale (face strength)")
219
- use_onepiece_lora = gr.Checkbox(value=True, label="Use One Piece (FLUX LoRA)")
220
- lora_strength = gr.Slider(0.0, 1.5, 0.85, 0.05, label="LoRA strength")
221
-
 
 
 
222
  with gr.Accordion("Advanced Options", open=False):
223
  guidance_scale = gr.Slider(1.0, 7.0, 3.5, 0.1, label="Guidance scale")
224
  num_inference_steps = gr.Slider(5, 50, 28, 1, label="Inference steps")
225
  seed = gr.Slider(-MAX_SEED, MAX_SEED, 123456, 1, label="Seed")
226
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
227
- width = gr.Slider(640, 1152, 768, 32, label="Width")
228
- height = gr.Slider(640, 1152, 768, 32, label="Height")
229
-
230
  generate_button = gr.Button("Generate Image", variant="primary")
231
 
232
  with gr.Column():
@@ -240,7 +242,7 @@ with block:
240
  ).then(
241
  fn=create_image,
242
  inputs=[image_pil, prompt, scale, guidance_scale, num_inference_steps,
243
- seed, use_onepiece_lora, lora_strength, width, height],
244
  outputs=output_gallery,
245
  )
246
 
 
16
  from torchvision import transforms
17
  from pipeline import InstantCharacterFluxPipeline
18
 
19
+ # =====================================================
20
+ # CONFIG GÉNÉRALE
21
+ # =====================================================
22
  MAX_SEED = np.iinfo(np.int32).max
23
+ dtype = torch.float16 # parfait sur L4 (24 Go)
 
 
24
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
25
+
26
  def need_token_guard():
27
  if HF_TOKEN is None:
28
  raise gr.Error(
29
  "⚠️ Token manquant : ajoute un secret 'HF_TOKEN' (Settings → Repository secrets) "
30
+ "avec ton token Hugging Face ayant accès à black-forest-labs/FLUX.1-dev."
31
  )
32
 
33
+ # =====================================================
34
+ # TÉLÉCHARGEMENT DES PONTS ET MODÈLES
35
+ # =====================================================
36
  base_model = "black-forest-labs/FLUX.1-dev"
37
  image_encoder_path = "google/siglip-so400m-patch14-384"
38
+ image_encoder_2_path = "facebook/dinov2-giant"
 
39
  birefnet_path = "ZhengPeng7/BiRefNet"
40
 
 
 
 
 
41
  def _dl(repo_id, filename, token=None):
42
  return hf_hub_download(repo_id=repo_id, filename=filename, token=token)
43
 
44
  need_token_guard()
 
45
  ip_adapter_path = _dl("tencent/InstantCharacter", "instantcharacter_ip-adapter.bin", HF_TOKEN)
46
 
47
+ makoto_style_lora_path = _dl("InstantX/FLUX.1-dev-LoRA-Makoto-Shinkai",
48
+ "Makoto_Shinkai_style.safetensors", HF_TOKEN)
49
+ ghibli_style_lora_path = _dl("InstantX/FLUX.1-dev-LoRA-Ghibli",
50
+ "ghibli_style.safetensors", HF_TOKEN)
51
+
52
+ # >>> Ton LoRA One Piece (FLUX) <<<
53
+ onepiece_flux_lora_path = "./onepiece_flux_v2.safetensors"
54
+ onepiece_flux_trigger = "onepiece style"
55
+
56
+ # =====================================================
57
+ # INITIALISATION DU PIPELINE (optimisée VRAM)
58
+ # =====================================================
59
  pipe = InstantCharacterFluxPipeline.from_pretrained(
60
+ base_model, torch_dtype=dtype, token=HF_TOKEN
 
 
 
61
  )
 
62
 
63
  try:
64
  pipe.enable_xformers_memory_efficient_attention()
65
  except Exception:
66
  pass
67
+ try:
68
+ pipe.enable_model_cpu_offload() # offload auto GPU/CPU
69
+ except Exception:
70
+ try:
71
+ pipe.enable_sequential_cpu_offload()
72
+ except Exception:
73
+ pass
74
 
75
  pipe.set_progress_bar_config(disable=True)
76
  if hasattr(pipe, "vae"):
77
+ if hasattr(pipe.vae, "enable_slicing"):
78
+ pipe.vae.enable_slicing()
79
+ if hasattr(pipe.vae, "enable_tiling"):
80
+ pipe.vae.enable_tiling()
81
 
82
+ pipe.init_adapter(
 
83
  image_encoder_path=image_encoder_path,
84
+ image_encoder_2_path=image_encoder_2_path,
85
+ subject_ipadapter_cfg=dict(subject_ip_adapter_path=ip_adapter_path, nb_token=1024),
86
  )
 
 
 
87
 
88
+ # =====================================================
89
+ # MATTEUR (BiRefNet) – SUR CPU POUR ÉCONOMISER LA VRAM
90
+ # =====================================================
91
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
92
+ birefnet_path, trust_remote_code=True, token=HF_TOKEN
93
+ )
94
+ birefnet.to("cpu").eval()
95
 
 
 
 
 
96
  birefnet_transform_image = transforms.Compose([
97
  transforms.Resize((1024, 1024)),
98
  transforms.ToTensor(),
 
100
  ])
101
 
102
  def remove_bkg(subject_image):
 
 
 
 
 
 
 
103
  def infer_matting(img_pil):
104
  imgs = birefnet_transform_image(img_pil).unsqueeze(0).to("cpu")
105
  with torch.no_grad():
 
121
  if H == W: return image
122
  pad = abs(H - W)
123
  pad1, pad2 = pad // 2, pad - pad // 2
124
+ if H > W:
125
+ pad_param = ((0, 0), (pad1, pad2), (0, 0))
126
+ else:
127
+ pad_param = ((pad1, pad2), (0, 0), (0, 0))
128
  return np.pad(image, pad_param, "constant", constant_values=pad_value)
129
 
130
  mask = infer_matting(subject_image)[..., 0]
 
138
  crop = pad_to_square(crop)
139
  return Image.fromarray(crop.astype(np.uint8))
140
 
141
+ # =====================================================
142
+ # OUTILS
143
+ # =====================================================
144
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
145
  return random.randint(0, MAX_SEED) if randomize_seed else seed
146
 
147
+ # =====================================================
148
+ # GÉNÉRATION D'IMAGE
149
+ # =====================================================
150
  @spaces.GPU
151
  def create_image(
152
  input_image,
 
155
  guidance_scale,
156
  num_inference_steps,
157
  seed,
158
+ style_mode=None,
159
  lora_strength=0.85,
160
+ width=896,
161
+ height=896,
162
  ):
163
  if input_image is None:
164
  raise gr.Error("Merci d'uploader une image de visage.")
165
+ if style_mode == "One Piece (FLUX LoRA)" and not os.path.exists(onepiece_flux_lora_path):
166
  raise gr.Error(f"Fichier LoRA manquant : {onepiece_flux_lora_path}")
167
 
168
  input_image = remove_bkg(input_image)
169
+ generator = None # évite conflits avec offload auto
170
+
171
+ if style_mode == "Makoto Shinkai style":
172
+ lora_file_path, trigger = makoto_style_lora_path, "Makoto Shinkai style"
173
+ elif style_mode == "Ghibli style":
174
+ lora_file_path, trigger = ghibli_style_lora_path, "ghibli style"
175
+ elif style_mode == "One Piece (FLUX LoRA)":
176
+ lora_file_path, trigger = onepiece_flux_lora_path, onepiece_flux_trigger
 
 
 
 
 
 
 
 
177
  else:
178
+ lora_file_path, trigger = None, None
179
+
180
+ fn = pipe.with_style_lora if lora_file_path else pipe
181
+ kwargs = dict(
182
+ prompt=prompt,
183
+ num_inference_steps=int(num_inference_steps),
184
+ guidance_scale=float(guidance_scale),
185
+ width=int(width),
186
+ height=int(height),
187
+ subject_image=input_image,
188
+ subject_scale=float(scale),
189
+ generator=generator,
190
+ )
191
+ if lora_file_path:
192
+ kwargs.update(dict(
193
+ lora_file_path=lora_file_path,
194
+ trigger=trigger,
195
+ lora_scale=float(lora_strength),
196
+ ))
197
+ images = fn(**kwargs).images
198
  return images
199
 
200
+ # =====================================================
201
+ # INTERFACE GRADIO
202
+ # =====================================================
203
+ title = "<h1 align='center'>InstantCharacter (FLUX.1-dev) + One Piece (FLUX LoRA)</h1>"
204
  description = (
205
+ "<b>GPU :</b> Nvidia L4 24 Go configuration optimisée VRAM.<br>"
206
+ "Résolution par défaut : 896 × 896 (monte à 1024 si stable)."
207
  )
208
 
209
  block = gr.Blocks(css="footer {visibility: hidden}").queue(concurrency_count=1, max_size=5, api_open=False)
 
218
  value="onepiece style, a pirate character standing on a ship deck, shonen manga, strong black line art, cel shading, expressive eyes, dynamic pose, clean linework"
219
  )
220
  scale = gr.Slider(0.0, 1.5, 1.0, 0.01, label="Scale (face strength)")
221
+ style_mode = gr.Dropdown(
222
+ ["None", "Makoto Shinkai style", "Ghibli style", "One Piece (FLUX LoRA)"],
223
+ value="One Piece (FLUX LoRA)",
224
+ label="Style",
225
+ )
226
+ lora_strength = gr.Slider(0.0, 1.5, 0.85, 0.05, label="LoRA strength (One Piece)")
227
  with gr.Accordion("Advanced Options", open=False):
228
  guidance_scale = gr.Slider(1.0, 7.0, 3.5, 0.1, label="Guidance scale")
229
  num_inference_steps = gr.Slider(5, 50, 28, 1, label="Inference steps")
230
  seed = gr.Slider(-MAX_SEED, MAX_SEED, 123456, 1, label="Seed")
231
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
 
 
 
232
  generate_button = gr.Button("Generate Image", variant="primary")
233
 
234
  with gr.Column():
 
242
  ).then(
243
  fn=create_image,
244
  inputs=[image_pil, prompt, scale, guidance_scale, num_inference_steps,
245
+ seed, style_mode, lora_strength],
246
  outputs=output_gallery,
247
  )
248