Nad54 commited on
Commit
a1aa9b4
·
verified ·
1 Parent(s): c6d931e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -36
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import sys, os
2
  sys.path.append("../")
3
 
4
- # ---- anti-fragmentation VRAM, à définir AVANT toute init CUDA ----
5
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
6
 
7
  import spaces
@@ -66,7 +66,6 @@ FEMALE_PROMPT = (
66
  pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
67
  pipe.to(device)
68
 
69
- # Offload/slicing/tiling pour réduire les pics VRAM
70
  try:
71
  if hasattr(pipe, "enable_sequential_cpu_offload"):
72
  pipe.enable_sequential_cpu_offload()
@@ -85,7 +84,6 @@ pipe.init_adapter(
85
  # --------------------------------------------
86
  # Background remover
87
  # --------------------------------------------
88
- # On charge BiRefNet sur CPU; on le montera sur GPU juste pour l'inférence puis retour CPU.
89
  birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, trust_remote_code=True)
90
  birefnet.to("cpu")
91
  birefnet.eval()
@@ -97,49 +95,41 @@ birefnet_transform = transforms.Compose([
97
 
98
  def remove_bkg(subject_image):
99
  def infer_matting(img_pil):
100
- # move temporairement sur GPU si dispo
101
  run_dev = device if torch.cuda.is_available() else "cpu"
102
- try:
103
- birefnet.to(run_dev)
104
- except Exception:
105
- run_dev = "cpu"
106
- birefnet.to("cpu")
107
-
108
  inp = birefnet_transform(img_pil).unsqueeze(0).to(run_dev)
109
  with torch.no_grad():
110
  preds = birefnet(inp)[-1].sigmoid().cpu()
111
  pred = preds[0].squeeze()
112
  mask = transforms.ToPILImage()(pred).resize(img_pil.size)
113
-
114
- # libère VRAM : retour CPU + vidage cache
115
- try:
116
- birefnet.to("cpu")
117
- except Exception:
118
- pass
119
  if torch.cuda.is_available():
120
  torch.cuda.empty_cache()
121
-
122
  return np.array(mask)[..., None]
123
 
124
- def pad_to_square(image, pad_value=255):
125
  H, W = image.shape[:2]
126
- if H == W:
127
- return image
128
- pad = abs(H - W)
129
- pad1, pad2 = pad // 2, pad - pad // 2
130
- pad_param = ((0, 0), (pad1, pad2), (0, 0)) if H > W else ((pad1, pad2), (0, 0), (0, 0))
131
- return np.pad(image, pad_param, "constant", constant_values=pad_value)
 
 
 
 
132
 
133
  mask = infer_matting(subject_image)[..., 0]
134
  subject_np = np.array(subject_image)
135
  mask = (mask > 128).astype(np.uint8) * 255
136
  sample_mask = np.stack([mask] * 3, axis=-1)
137
  obj = sample_mask / 255 * subject_np + (1 - sample_mask / 255) * 255
138
- cropped = pad_to_square(obj, 255)
139
- return Image.fromarray(cropped.astype(np.uint8))
140
 
141
  # --------------------------------------------
142
- # Simple gender detector (CLIP zero-shot)
143
  # --------------------------------------------
144
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
145
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@@ -150,8 +140,8 @@ def detect_gender(img_pil: Image.Image) -> str:
150
  texts = ["a portrait photo of a man", "a portrait photo of a woman"]
151
  inputs = clip_processor(text=texts, images=img_pil.convert("RGB"), return_tensors="pt", padding=True).to(device)
152
  outputs = clip_model(**inputs)
153
- logits_per_image = outputs.logits_per_image.squeeze(0)
154
- idx = int(torch.argmax(logits_per_image).item())
155
  return "male" if idx == 0 else "female"
156
 
157
  # --------------------------------------------
@@ -162,7 +152,6 @@ def randomize_seed(seed, randomize):
162
 
163
  @spaces.GPU
164
  def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps, seed, style_mode, negative_prompt=""):
165
- # purge VRAM avant d'attaquer
166
  if torch.cuda.is_available():
167
  torch.cuda.empty_cache()
168
 
@@ -183,7 +172,7 @@ def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps
183
  negative_prompt=negative_prompt,
184
  num_inference_steps=num_inference_steps,
185
  guidance_scale=guidance_scale,
186
- width=1024, height=1024, # si OOM persiste, passe à 896 ou 768
187
  subject_image=input_image,
188
  subject_scale=scale,
189
  generator=generator,
@@ -194,13 +183,12 @@ def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps
194
  else:
195
  result = pipe(**common_args)
196
 
197
- # purge VRAM après génération
198
  if torch.cuda.is_available():
199
  torch.cuda.empty_cache()
200
  return result.images
201
 
202
  # --------------------------------------------
203
- # UI definition (Gradio 5)
204
  # --------------------------------------------
205
  def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, negative_prompt, auto_prompt):
206
  if auto_prompt and image is not None:
@@ -212,15 +200,15 @@ def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, n
212
 
213
  title = "🎨 InstantCharacter + One Piece LoRA"
214
  description = (
215
- "Upload your photo, describe your scene, or tick **Auto One Piece Prompt** to auto-pick a gender-aware template. "
216
- "Choose **One Piece style** to apply the LoRA."
217
  )
218
 
219
  demo = gr.Interface(
220
  fn=generate_fn,
221
  inputs=[
222
  gr.Image(label="Source Image", type="pil"),
223
- gr.Textbox(label="Prompt", value=f", {ONEPIECE_TRIGGER}"),
224
  gr.Slider(0, 1.5, value=1.0, step=0.01, label="Scale"),
225
  gr.Dropdown(choices=[None, "Makoto Shinkai style", "Ghibli style", "One Piece style"],
226
  value="One Piece style", label="Style"),
 
1
  import sys, os
2
  sys.path.append("../")
3
 
4
+ # ---- anti-fragmentation VRAM ----
5
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
6
 
7
  import spaces
 
66
  pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
67
  pipe.to(device)
68
 
 
69
  try:
70
  if hasattr(pipe, "enable_sequential_cpu_offload"):
71
  pipe.enable_sequential_cpu_offload()
 
84
  # --------------------------------------------
85
  # Background remover
86
  # --------------------------------------------
 
87
  birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, trust_remote_code=True)
88
  birefnet.to("cpu")
89
  birefnet.eval()
 
95
 
96
  def remove_bkg(subject_image):
97
  def infer_matting(img_pil):
 
98
  run_dev = device if torch.cuda.is_available() else "cpu"
99
+ birefnet.to(run_dev)
 
 
 
 
 
100
  inp = birefnet_transform(img_pil).unsqueeze(0).to(run_dev)
101
  with torch.no_grad():
102
  preds = birefnet(inp)[-1].sigmoid().cpu()
103
  pred = preds[0].squeeze()
104
  mask = transforms.ToPILImage()(pred).resize(img_pil.size)
105
+ birefnet.to("cpu")
 
 
 
 
 
106
  if torch.cuda.is_available():
107
  torch.cuda.empty_cache()
 
108
  return np.array(mask)[..., None]
109
 
110
+ def pad_to_ratio(image, target_w=1024, target_h=768, pad_value=255):
111
  H, W = image.shape[:2]
112
+ aspect_target = target_w / target_h
113
+ aspect = W / H
114
+ if abs(aspect - aspect_target) < 1e-3:
115
+ # déjà bon ratio
116
+ resized = Image.fromarray(image.astype(np.uint8)).resize((target_w, target_h), Image.LANCZOS)
117
+ return np.array(resized)
118
+ # centrer et crop/pad selon le ratio
119
+ img = Image.fromarray(image.astype(np.uint8))
120
+ img = img.resize((target_w, target_h), Image.LANCZOS)
121
+ return np.array(img)
122
 
123
  mask = infer_matting(subject_image)[..., 0]
124
  subject_np = np.array(subject_image)
125
  mask = (mask > 128).astype(np.uint8) * 255
126
  sample_mask = np.stack([mask] * 3, axis=-1)
127
  obj = sample_mask / 255 * subject_np + (1 - sample_mask / 255) * 255
128
+ fixed = pad_to_ratio(obj, 1024, 768)
129
+ return Image.fromarray(fixed.astype(np.uint8))
130
 
131
  # --------------------------------------------
132
+ # Gender detector
133
  # --------------------------------------------
134
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
135
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
140
  texts = ["a portrait photo of a man", "a portrait photo of a woman"]
141
  inputs = clip_processor(text=texts, images=img_pil.convert("RGB"), return_tensors="pt", padding=True).to(device)
142
  outputs = clip_model(**inputs)
143
+ logits = outputs.logits_per_image.squeeze(0)
144
+ idx = int(torch.argmax(logits).item())
145
  return "male" if idx == 0 else "female"
146
 
147
  # --------------------------------------------
 
152
 
153
  @spaces.GPU
154
  def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps, seed, style_mode, negative_prompt=""):
 
155
  if torch.cuda.is_available():
156
  torch.cuda.empty_cache()
157
 
 
172
  negative_prompt=negative_prompt,
173
  num_inference_steps=num_inference_steps,
174
  guidance_scale=guidance_scale,
175
+ width=1024, height=768, # <<< résolution fixe
176
  subject_image=input_image,
177
  subject_scale=scale,
178
  generator=generator,
 
183
  else:
184
  result = pipe(**common_args)
185
 
 
186
  if torch.cuda.is_available():
187
  torch.cuda.empty_cache()
188
  return result.images
189
 
190
  # --------------------------------------------
191
+ # UI definition
192
  # --------------------------------------------
193
  def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, negative_prompt, auto_prompt):
194
  if auto_prompt and image is not None:
 
200
 
201
  title = "🎨 InstantCharacter + One Piece LoRA"
202
  description = (
203
+ "Upload your photo and generate yourself as a One Piece character (output always 1024×768). "
204
+ "Tick **Auto One Piece Prompt** for gender-aware templates."
205
  )
206
 
207
  demo = gr.Interface(
208
  fn=generate_fn,
209
  inputs=[
210
  gr.Image(label="Source Image", type="pil"),
211
+ gr.Textbox(label="Prompt", value=f"a character is riding a bike in snow, {ONEPIECE_TRIGGER}"),
212
  gr.Slider(0, 1.5, value=1.0, step=0.01, label="Scale"),
213
  gr.Dropdown(choices=[None, "Makoto Shinkai style", "Ghibli style", "One Piece style"],
214
  value="One Piece style", label="Style"),