Nad54 commited on
Commit
5080fa8
·
verified ·
1 Parent(s): a1aa9b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -46
app.py CHANGED
@@ -1,9 +1,6 @@
1
  import sys, os
2
  sys.path.append("../")
3
 
4
- # ---- anti-fragmentation VRAM ----
5
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
6
-
7
  import spaces
8
  import torch
9
  import random
@@ -65,16 +62,6 @@ FEMALE_PROMPT = (
65
  # --------------------------------------------
66
  pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
67
  pipe.to(device)
68
-
69
- try:
70
- if hasattr(pipe, "enable_sequential_cpu_offload"):
71
- pipe.enable_sequential_cpu_offload()
72
- if hasattr(pipe, "vae"):
73
- pipe.vae.enable_slicing()
74
- pipe.vae.enable_tiling()
75
- except Exception:
76
- pass
77
-
78
  pipe.init_adapter(
79
  image_encoder_path=image_encoder_path,
80
  image_encoder_2_path=image_encoder_2_path,
@@ -85,7 +72,7 @@ pipe.init_adapter(
85
  # Background remover
86
  # --------------------------------------------
87
  birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, trust_remote_code=True)
88
- birefnet.to("cpu")
89
  birefnet.eval()
90
  birefnet_transform = transforms.Compose([
91
  transforms.Resize((1024, 1024)),
@@ -95,41 +82,32 @@ birefnet_transform = transforms.Compose([
95
 
96
  def remove_bkg(subject_image):
97
  def infer_matting(img_pil):
98
- run_dev = device if torch.cuda.is_available() else "cpu"
99
- birefnet.to(run_dev)
100
- inp = birefnet_transform(img_pil).unsqueeze(0).to(run_dev)
101
  with torch.no_grad():
102
  preds = birefnet(inp)[-1].sigmoid().cpu()
103
  pred = preds[0].squeeze()
104
  mask = transforms.ToPILImage()(pred).resize(img_pil.size)
105
- birefnet.to("cpu")
106
- if torch.cuda.is_available():
107
- torch.cuda.empty_cache()
108
  return np.array(mask)[..., None]
109
 
110
- def pad_to_ratio(image, target_w=1024, target_h=768, pad_value=255):
111
  H, W = image.shape[:2]
112
- aspect_target = target_w / target_h
113
- aspect = W / H
114
- if abs(aspect - aspect_target) < 1e-3:
115
- # déjà bon ratio
116
- resized = Image.fromarray(image.astype(np.uint8)).resize((target_w, target_h), Image.LANCZOS)
117
- return np.array(resized)
118
- # centrer et crop/pad selon le ratio
119
- img = Image.fromarray(image.astype(np.uint8))
120
- img = img.resize((target_w, target_h), Image.LANCZOS)
121
- return np.array(img)
122
 
123
  mask = infer_matting(subject_image)[..., 0]
124
  subject_np = np.array(subject_image)
125
  mask = (mask > 128).astype(np.uint8) * 255
126
  sample_mask = np.stack([mask] * 3, axis=-1)
127
  obj = sample_mask / 255 * subject_np + (1 - sample_mask / 255) * 255
128
- fixed = pad_to_ratio(obj, 1024, 768)
129
- return Image.fromarray(fixed.astype(np.uint8))
130
 
131
  # --------------------------------------------
132
- # Gender detector
133
  # --------------------------------------------
134
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
135
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@@ -140,8 +118,8 @@ def detect_gender(img_pil: Image.Image) -> str:
140
  texts = ["a portrait photo of a man", "a portrait photo of a woman"]
141
  inputs = clip_processor(text=texts, images=img_pil.convert("RGB"), return_tensors="pt", padding=True).to(device)
142
  outputs = clip_model(**inputs)
143
- logits = outputs.logits_per_image.squeeze(0)
144
- idx = int(torch.argmax(logits).item())
145
  return "male" if idx == 0 else "female"
146
 
147
  # --------------------------------------------
@@ -152,9 +130,6 @@ def randomize_seed(seed, randomize):
152
 
153
  @spaces.GPU
154
  def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps, seed, style_mode, negative_prompt=""):
155
- if torch.cuda.is_available():
156
- torch.cuda.empty_cache()
157
-
158
  input_image = remove_bkg(input_image)
159
 
160
  if style_mode == "Makoto Shinkai style":
@@ -172,7 +147,7 @@ def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps
172
  negative_prompt=negative_prompt,
173
  num_inference_steps=num_inference_steps,
174
  guidance_scale=guidance_scale,
175
- width=1024, height=768, # <<< résolution fixe
176
  subject_image=input_image,
177
  subject_scale=scale,
178
  generator=generator,
@@ -182,13 +157,10 @@ def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps
182
  result = pipe.with_style_lora(lora_file_path=lora_path, trigger=trigger, **common_args)
183
  else:
184
  result = pipe(**common_args)
185
-
186
- if torch.cuda.is_available():
187
- torch.cuda.empty_cache()
188
  return result.images
189
 
190
  # --------------------------------------------
191
- # UI definition
192
  # --------------------------------------------
193
  def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, negative_prompt, auto_prompt):
194
  if auto_prompt and image is not None:
@@ -200,8 +172,8 @@ def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, n
200
 
201
  title = "🎨 InstantCharacter + One Piece LoRA"
202
  description = (
203
- "Upload your photo and generate yourself as a One Piece character (output always 1024×768). "
204
- "Tick **Auto One Piece Prompt** for gender-aware templates."
205
  )
206
 
207
  demo = gr.Interface(
 
1
  import sys, os
2
  sys.path.append("../")
3
 
 
 
 
4
  import spaces
5
  import torch
6
  import random
 
62
  # --------------------------------------------
63
  pipe = InstantCharacterFluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
64
  pipe.to(device)
 
 
 
 
 
 
 
 
 
 
65
  pipe.init_adapter(
66
  image_encoder_path=image_encoder_path,
67
  image_encoder_2_path=image_encoder_2_path,
 
72
  # Background remover
73
  # --------------------------------------------
74
  birefnet = AutoModelForImageSegmentation.from_pretrained(birefnet_path, trust_remote_code=True)
75
+ birefnet.to(device)
76
  birefnet.eval()
77
  birefnet_transform = transforms.Compose([
78
  transforms.Resize((1024, 1024)),
 
82
 
83
  def remove_bkg(subject_image):
84
  def infer_matting(img_pil):
85
+ inp = birefnet_transform(img_pil).unsqueeze(0).to(device)
 
 
86
  with torch.no_grad():
87
  preds = birefnet(inp)[-1].sigmoid().cpu()
88
  pred = preds[0].squeeze()
89
  mask = transforms.ToPILImage()(pred).resize(img_pil.size)
 
 
 
90
  return np.array(mask)[..., None]
91
 
92
+ def pad_to_square(image, pad_value=255):
93
  H, W = image.shape[:2]
94
+ if H == W:
95
+ return image
96
+ pad = abs(H - W)
97
+ pad1, pad2 = pad // 2, pad - pad // 2
98
+ pad_param = ((0, 0), (pad1, pad2), (0, 0)) if H > W else ((pad1, pad2), (0, 0), (0, 0))
99
+ return np.pad(image, pad_param, "constant", constant_values=pad_value)
 
 
 
 
100
 
101
  mask = infer_matting(subject_image)[..., 0]
102
  subject_np = np.array(subject_image)
103
  mask = (mask > 128).astype(np.uint8) * 255
104
  sample_mask = np.stack([mask] * 3, axis=-1)
105
  obj = sample_mask / 255 * subject_np + (1 - sample_mask / 255) * 255
106
+ cropped = pad_to_square(obj, 255)
107
+ return Image.fromarray(cropped.astype(np.uint8))
108
 
109
  # --------------------------------------------
110
+ # Simple gender detector (CLIP zero-shot)
111
  # --------------------------------------------
112
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
113
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
118
  texts = ["a portrait photo of a man", "a portrait photo of a woman"]
119
  inputs = clip_processor(text=texts, images=img_pil.convert("RGB"), return_tensors="pt", padding=True).to(device)
120
  outputs = clip_model(**inputs)
121
+ logits_per_image = outputs.logits_per_image.squeeze(0)
122
+ idx = int(torch.argmax(logits_per_image).item())
123
  return "male" if idx == 0 else "female"
124
 
125
  # --------------------------------------------
 
130
 
131
  @spaces.GPU
132
  def create_image(input_image, prompt, scale, guidance_scale, num_inference_steps, seed, style_mode, negative_prompt=""):
 
 
 
133
  input_image = remove_bkg(input_image)
134
 
135
  if style_mode == "Makoto Shinkai style":
 
147
  negative_prompt=negative_prompt,
148
  num_inference_steps=num_inference_steps,
149
  guidance_scale=guidance_scale,
150
+ width=1024, height=780,
151
  subject_image=input_image,
152
  subject_scale=scale,
153
  generator=generator,
 
157
  result = pipe.with_style_lora(lora_file_path=lora_path, trigger=trigger, **common_args)
158
  else:
159
  result = pipe(**common_args)
 
 
 
160
  return result.images
161
 
162
  # --------------------------------------------
163
+ # UI definition (Gradio 5)
164
  # --------------------------------------------
165
  def generate_fn(image, prompt, scale, style, guidance, steps, seed, randomize, negative_prompt, auto_prompt):
166
  if auto_prompt and image is not None:
 
172
 
173
  title = "🎨 InstantCharacter + One Piece LoRA"
174
  description = (
175
+ "Upload your photo, describe your scene, or tick **Auto One Piece Prompt** to auto-pick a gender-aware template. "
176
+ "Choose **One Piece style** to apply the LoRA."
177
  )
178
 
179
  demo = gr.Interface(