venbab commited on
Commit
481de3e
Β·
verified Β·
1 Parent(s): b4d02b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -110
app.py CHANGED
@@ -26,9 +26,6 @@ from preprocess.openpose.run_openpose import OpenPose
26
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
27
  from torchvision.transforms.functional import to_pil_image
28
 
29
- # --- FastAPI for the /tryon REST route ---
30
- from fastapi import FastAPI, UploadFile, File, Response
31
-
32
  # ------------------------------------------------------------------------------------
33
  # Helpers
34
  # ------------------------------------------------------------------------------------
@@ -39,11 +36,10 @@ def pil_to_binary_mask(pil_image, threshold=0):
39
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
40
  for i in range(binary_mask.shape[0]):
41
  for j in range(binary_mask.shape[1]):
42
- if binary_mask[i, j] is True:
43
  mask[i, j] = 1
44
  mask = (mask * 255).astype(np.uint8)
45
- output_mask = Image.fromarray(mask)
46
- return output_mask
47
 
48
  # ------------------------------------------------------------------------------------
49
  # Load models / pipeline
@@ -100,12 +96,8 @@ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
100
  parsing_model = Parsing(0)
101
  openpose_model = OpenPose(0)
102
 
103
- UNet_Encoder.requires_grad_(False)
104
- image_encoder.requires_grad_(False)
105
- vae.requires_grad_(False)
106
- unet.requires_grad_(False)
107
- text_encoder_one.requires_grad_(False)
108
- text_encoder_two.requires_grad_(False)
109
 
110
  tensor_transfrom = transforms.Compose(
111
  [
@@ -130,7 +122,7 @@ pipe = TryonPipeline.from_pretrained(
130
  pipe.unet_encoder = UNet_Encoder
131
 
132
  # ------------------------------------------------------------------------------------
133
- # Core try-on function used by both Gradio UI and REST
134
  # ------------------------------------------------------------------------------------
135
  def _tryon_core(
136
  human_img: Image.Image,
@@ -167,15 +159,11 @@ def _tryon_core(
167
  if auto_mask:
168
  keypoints = openpose_model(human_img_used.resize((384, 512)))
169
  model_parse, _ = parsing_model(human_img_used.resize((384, 512)))
170
- mask, mask_gray = get_mask_location("hd", "upper_body", model_parse, keypoints)
171
  mask = mask.resize((768, 1024))
172
  else:
173
- # fallback: no-draw mask (full body) – rarely used in REST path
174
  mask = pil_to_binary_mask(Image.new("L", (768, 1024), 255))
175
 
176
- mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img_used)
177
- mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
178
-
179
  # DensePose
180
  human_img_arg = _apply_exif_orientation(human_img_used.resize((384, 512)))
181
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
@@ -195,60 +183,54 @@ def _tryon_core(
195
  pose_img = pose_img[:, :, ::-1]
196
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
197
 
198
- with torch.no_grad():
199
- with torch.cuda.amp.autocast():
200
- with torch.no_grad():
201
- prompt = "model is wearing " + (garment_des or "a garment")
202
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
203
- with torch.inference_mode():
204
- (
205
- prompt_embeds,
206
- negative_prompt_embeds,
207
- pooled_prompt_embeds,
208
- negative_pooled_prompt_embeds,
209
- ) = pipe.encode_prompt(
210
- prompt,
211
- num_images_per_prompt=1,
212
- do_classifier_free_guidance=True,
213
- negative_prompt=negative_prompt,
214
- )
215
-
216
- prompt_c = "a photo of " + (garment_des or "a garment")
217
- negative_prompt_c = negative_prompt
218
- if not isinstance(prompt_c, List):
219
- prompt_c = [prompt_c] * 1
220
- if not isinstance(negative_prompt_c, List):
221
- negative_prompt_c = [negative_prompt_c] * 1
222
- with torch.inference_mode():
223
- (prompt_embeds_c, _, _, _,) = pipe.encode_prompt(
224
- prompt_c,
225
- num_images_per_prompt=1,
226
- do_classifier_free_guidance=False,
227
- negative_prompt=negative_prompt_c,
228
- )
229
 
230
- pose_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
231
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
232
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
 
 
 
 
 
 
233
 
234
- images = pipe(
235
- prompt_embeds=prompt_embeds.to(device, torch.float16),
236
- negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
237
- pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
238
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
239
- num_inference_steps=int(denoise_steps),
240
- generator=generator,
241
- strength=1.0,
242
- pose_img=pose_tensor,
243
- text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
244
- cloth=garm_tensor,
245
- mask_image=mask,
246
- image=human_img_used,
247
- height=1024,
248
- width=768,
249
- ip_adapter_image=garm_img.resize((768, 1024)),
250
- guidance_scale=2.0,
251
- )[0]
 
 
 
 
252
 
253
  if auto_crop:
254
  out_img = images[0].resize(crop_size)
@@ -258,7 +240,7 @@ def _tryon_core(
258
  return images[0]
259
 
260
  # ------------------------------------------------------------------------------------
261
- # Gradio UI (original) – unchanged logic except we call the same core function
262
  # ------------------------------------------------------------------------------------
263
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
264
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
@@ -268,16 +250,12 @@ human_list_path = [os.path.join(example_path, "human", human) for human in human
268
 
269
  human_ex_list = []
270
  for ex_human in human_list_path:
271
- ex_dict = {}
272
- ex_dict["background"] = ex_human
273
- ex_dict["layers"] = None
274
- ex_dict["composite"] = None
275
  human_ex_list.append(ex_dict)
276
 
277
  @spaces.GPU
278
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
279
- # Keep compatibility with the existing Gradio workflow
280
- human_img = dict["background"].convert("RGB")
281
  out_img = _tryon_core(
282
  human_img=human_img,
283
  garm_img=garm_img,
@@ -287,12 +265,10 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
287
  denoise_steps=int(denoise_steps),
288
  seed=int(seed) if seed is not None else None,
289
  )
290
- # Also return the mask preview (approx) by recomputing lightweight gray
291
  mask_gray = pil_to_binary_mask(out_img.convert("L"))
292
  return out_img, mask_gray
293
 
294
- image_blocks = gr.Blocks().queue()
295
- with image_blocks as demo:
296
  gr.Markdown("## IDM-VTON πŸ‘•πŸ‘”πŸ‘š")
297
  gr.Markdown(
298
  "Virtual Try-on with your image and garment image. Check out the "
@@ -306,7 +282,7 @@ with image_blocks as demo:
306
  is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)", value=True)
307
  with gr.Row():
308
  is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing", value=False)
309
- _ = gr.Examples(inputs=imgs, examples_per_page=10, examples=human_ex_list)
310
 
311
  with gr.Column():
312
  garm_img = gr.Image(label="Garment", sources="upload", type="pil")
@@ -317,7 +293,7 @@ with image_blocks as demo:
317
  show_label=False,
318
  elem_id="prompt",
319
  )
320
- _ = gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path)
321
 
322
  with gr.Column():
323
  masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
@@ -335,32 +311,8 @@ with image_blocks as demo:
335
  fn=start_tryon,
336
  inputs=[imgs, garm_img, prompt, is_checked, is_checked_crop, denoise_steps, seed],
337
  outputs=[image_out, masked_img],
338
- api_name="tryon",
339
- )
340
-
341
- # ------------------------------------------------------------------------------------
342
- # FastAPI route and mount
343
- # ------------------------------------------------------------------------------------
344
- app = FastAPI()
345
-
346
- @app.post("/tryon")
347
- async def tryon(person: UploadFile = File(...), garment: UploadFile = File(...)):
348
- p_bytes = await person.read()
349
- g_bytes = await garment.read()
350
- human_img = Image.open(io.BytesIO(p_bytes)).convert("RGB")
351
- garment_img = Image.open(io.BytesIO(g_bytes)).convert("RGBA")
352
- out = _tryon_core(
353
- human_img=human_img,
354
- garm_img=garment_img,
355
- garment_des="", # optional: you can add a text box in Flutter later
356
- auto_mask=True,
357
- auto_crop=False,
358
- denoise_steps=30,
359
- seed=42,
360
  )
361
- buf = io.BytesIO()
362
- out.save(buf, format="JPEG", quality=92)
363
- return Response(content=buf.getvalue(), media_type="image/jpeg")
364
 
365
- # Mount Gradio UI on root path
366
- app = gr.mount_gradio_app(app, image_blocks, path="/")
 
26
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
27
  from torchvision.transforms.functional import to_pil_image
28
 
 
 
 
29
  # ------------------------------------------------------------------------------------
30
  # Helpers
31
  # ------------------------------------------------------------------------------------
 
36
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
37
  for i in range(binary_mask.shape[0]):
38
  for j in range(binary_mask.shape[1]):
39
+ if binary_mask[i, j]:
40
  mask[i, j] = 1
41
  mask = (mask * 255).astype(np.uint8)
42
+ return Image.fromarray(mask)
 
43
 
44
  # ------------------------------------------------------------------------------------
45
  # Load models / pipeline
 
96
  parsing_model = Parsing(0)
97
  openpose_model = OpenPose(0)
98
 
99
+ for m in (UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two):
100
+ m.requires_grad_(False)
 
 
 
 
101
 
102
  tensor_transfrom = transforms.Compose(
103
  [
 
122
  pipe.unet_encoder = UNet_Encoder
123
 
124
  # ------------------------------------------------------------------------------------
125
+ # Core try-on function
126
  # ------------------------------------------------------------------------------------
127
  def _tryon_core(
128
  human_img: Image.Image,
 
159
  if auto_mask:
160
  keypoints = openpose_model(human_img_used.resize((384, 512)))
161
  model_parse, _ = parsing_model(human_img_used.resize((384, 512)))
162
+ mask, _ = get_mask_location("hd", "upper_body", model_parse, keypoints)
163
  mask = mask.resize((768, 1024))
164
  else:
 
165
  mask = pil_to_binary_mask(Image.new("L", (768, 1024), 255))
166
 
 
 
 
167
  # DensePose
168
  human_img_arg = _apply_exif_orientation(human_img_used.resize((384, 512)))
169
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
 
183
  pose_img = pose_img[:, :, ::-1]
184
  pose_img = Image.fromarray(pose_img).resize((768, 1024))
185
 
186
+ # Run pipeline
187
+ with torch.no_grad(), torch.cuda.amp.autocast():
188
+ prompt = "model is wearing " + (garment_des or "a garment")
189
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
190
+ (
191
+ prompt_embeds,
192
+ negative_prompt_embeds,
193
+ pooled_prompt_embeds,
194
+ negative_pooled_prompt_embeds,
195
+ ) = pipe.encode_prompt(
196
+ prompt,
197
+ num_images_per_prompt=1,
198
+ do_classifier_free_guidance=True,
199
+ negative_prompt=negative_prompt,
200
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ prompt_c = "a photo of " + (garment_des or "a garment")
203
+ if not isinstance(prompt_c, List):
204
+ prompt_c = [prompt_c]
205
+ (prompt_embeds_c, _, _, _,) = pipe.encode_prompt(
206
+ prompt_c,
207
+ num_images_per_prompt=1,
208
+ do_classifier_free_guidance=False,
209
+ negative_prompt=negative_prompt,
210
+ )
211
 
212
+ pose_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
213
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
214
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
215
+
216
+ images = pipe(
217
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
218
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
219
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
220
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
221
+ num_inference_steps=int(denoise_steps),
222
+ generator=generator,
223
+ strength=1.0,
224
+ pose_img=pose_tensor,
225
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
226
+ cloth=garm_tensor,
227
+ mask_image=mask,
228
+ image=human_img_used,
229
+ height=1024,
230
+ width=768,
231
+ ip_adapter_image=garm_img.resize((768, 1024)),
232
+ guidance_scale=2.0,
233
+ )[0]
234
 
235
  if auto_crop:
236
  out_img = images[0].resize(crop_size)
 
240
  return images[0]
241
 
242
  # ------------------------------------------------------------------------------------
243
+ # Gradio UI (and HTTP function endpoint via /run/tryon)
244
  # ------------------------------------------------------------------------------------
245
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
246
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
 
250
 
251
  human_ex_list = []
252
  for ex_human in human_list_path:
253
+ ex_dict = {"background": ex_human, "layers": None, "composite": None}
 
 
 
254
  human_ex_list.append(ex_dict)
255
 
256
  @spaces.GPU
257
+ def start_tryon(dict_img, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
258
+ human_img = dict_img["background"].convert("RGB")
 
259
  out_img = _tryon_core(
260
  human_img=human_img,
261
  garm_img=garm_img,
 
265
  denoise_steps=int(denoise_steps),
266
  seed=int(seed) if seed is not None else None,
267
  )
 
268
  mask_gray = pil_to_binary_mask(out_img.convert("L"))
269
  return out_img, mask_gray
270
 
271
+ with gr.Blocks() as image_blocks:
 
272
  gr.Markdown("## IDM-VTON πŸ‘•πŸ‘”πŸ‘š")
273
  gr.Markdown(
274
  "Virtual Try-on with your image and garment image. Check out the "
 
282
  is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)", value=True)
283
  with gr.Row():
284
  is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing", value=False)
285
+ gr.Examples(inputs=imgs, examples_per_page=10, examples=human_ex_list)
286
 
287
  with gr.Column():
288
  garm_img = gr.Image(label="Garment", sources="upload", type="pil")
 
293
  show_label=False,
294
  elem_id="prompt",
295
  )
296
+ gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path)
297
 
298
  with gr.Column():
299
  masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
 
311
  fn=start_tryon,
312
  inputs=[imgs, garm_img, prompt, is_checked, is_checked_crop, denoise_steps, seed],
313
  outputs=[image_out, masked_img],
314
+ api_name="tryon", # <-- HTTP: POST /run/tryon
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  )
 
 
 
316
 
317
+ # IMPORTANT: expose a top-level `demo` for Gradio Spaces
318
+ demo = image_blocks