venbab commited on
Commit
5422abd
·
verified ·
1 Parent(s): 7095797

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -123
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spaces
2
  import gradio as gr
3
  from PIL import Image
@@ -13,10 +14,7 @@ from transformers import (
13
  )
14
  from diffusers import DDPMScheduler, AutoencoderKL
15
  from typing import List
16
-
17
- import torch
18
- import os
19
- import numpy as np
20
  from utils_mask import get_mask_location
21
  from torchvision import transforms
22
  import apply_net
@@ -25,52 +23,47 @@ from preprocess.openpose.run_openpose import OpenPose
25
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
26
  from torchvision.transforms.functional import to_pil_image
27
 
28
- # ---------------- Helpers ----------------
 
 
 
 
29
  def pil_to_binary_mask(pil_image, threshold=0):
30
  np_image = np.array(pil_image)
31
  grayscale_image = Image.fromarray(np_image).convert("L")
32
  binary_mask = np.array(grayscale_image) > threshold
33
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
34
- for i in range(binary_mask.shape[0]):
35
- for j in range(binary_mask.shape[1]):
36
- if binary_mask[i, j]:
37
- mask[i, j] = 1
38
- mask = (mask * 255).astype(np.uint8)
39
- return Image.fromarray(mask)
40
-
41
- # ---------------- Load models / pipeline ----------------
 
 
 
 
 
 
 
 
 
 
42
  base_path = "yisol/IDM-VTON"
43
  example_path = os.path.join(os.path.dirname(__file__), "example")
44
 
45
- unet = UNet2DConditionModel.from_pretrained(
46
- base_path, subfolder="unet", torch_dtype=torch.float16
47
- )
48
- unet.requires_grad_(False)
49
-
50
- tokenizer_one = AutoTokenizer.from_pretrained(
51
- base_path, subfolder="tokenizer", revision=None, use_fast=False
52
- )
53
- tokenizer_two = AutoTokenizer.from_pretrained(
54
- base_path, subfolder="tokenizer_2", revision=None, use_fast=False
55
- )
56
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
57
 
58
- text_encoder_one = CLIPTextModel.from_pretrained(
59
- base_path, subfolder="text_encoder", torch_dtype=torch.float16
60
- )
61
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
62
- base_path, subfolder="text_encoder_2", torch_dtype=torch.float16
63
- )
64
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
65
- base_path, subfolder="image_encoder", torch_dtype=torch.float16
66
- )
67
- vae = AutoencoderKL.from_pretrained(
68
- base_path, subfolder="vae", torch_dtype=torch.float16
69
- )
70
-
71
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
72
- base_path, subfolder="unet_encoder", torch_dtype=torch.float16
73
- )
74
 
75
  parsing_model = Parsing(0)
76
  openpose_model = OpenPose(0)
@@ -78,9 +71,7 @@ openpose_model = OpenPose(0)
78
  for m in (UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two):
79
  m.requires_grad_(False)
80
 
81
- tensor_transfrom = transforms.Compose(
82
- [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
83
- )
84
 
85
  pipe = TryonPipeline.from_pretrained(
86
  base_path,
@@ -97,54 +88,32 @@ pipe = TryonPipeline.from_pretrained(
97
  )
98
  pipe.unet_encoder = UNet_Encoder
99
 
100
- progress = gr.Progress()
101
-
102
- # ---------------- Inference ----------------
103
- @spaces.GPU
104
- def infer(person, garment, denoise_steps, seed):
105
- print(f"[infer] steps={denoise_steps}, seed={seed}", flush=True)
106
- progress(0, desc="Starting")
107
  device = "cuda"
108
-
109
  openpose_model.preprocessor.body_estimation.model.to(device)
110
  pipe.to(device)
111
  pipe.unet_encoder.to(device)
112
 
113
- personRGB = person.convert("RGB")
114
  crop_size = personRGB.size
115
  human_img = personRGB.resize((768, 1024))
116
- garm_img = garment.convert("RGB").resize((768, 1024))
117
 
118
- progress(0.1, desc="Mask generating")
119
  keypoints = openpose_model(human_img.resize((384, 512)))
120
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
121
- mask, _mask_gray = get_mask_location("hd", "upper_body", model_parse, keypoints)
122
  mask = mask.resize((768, 1024))
123
 
124
- progress(0.3, desc="DensePose processing")
125
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
126
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
127
- args = apply_net.create_argument_parser().parse_args(
128
- (
129
- "show",
130
- "./configs/densepose_rcnn_R_50_FPN_s1x.yaml",
131
- "./ckpt/densepose/model_final_162be9.pkl",
132
- "dp_segm",
133
- "-v",
134
- "--opts",
135
- "MODEL.DEVICE",
136
- "cuda",
137
- )
138
- )
139
- pose_img = args.func(args, human_img_arg)
140
- pose_img = Image.fromarray(pose_img[:, :, ::-1]).resize((768, 1024))
141
-
142
- progress(0.5, desc="Image generating")
143
-
144
- def callback(pipe_, step, timestep, callback_kwargs):
145
- progress_value = 0.5 + ((step + 1.0) / int(denoise_steps)) * 0.5
146
- progress(progress_value, desc=f"Image generating, {step + 1}/{int(denoise_steps)} steps")
147
- return callback_kwargs
148
 
149
  with torch.no_grad(), torch.cuda.amp.autocast():
150
  prompt = "model is wearing clothing"
@@ -162,22 +131,19 @@ def infer(person, garment, denoise_steps, seed):
162
  )
163
 
164
  prompt_c = "a photo of clothing"
165
- if not isinstance(prompt_c, List):
166
- prompt_c = [prompt_c]
167
- if not isinstance(negative_prompt, List):
168
- negative_prompt_c = [negative_prompt]
169
- else:
170
- negative_prompt_c = negative_prompt
171
- (prompt_embeds_c, _, _, _,) = pipe.encode_prompt(
172
  prompt_c,
173
  num_images_per_prompt=1,
174
  do_classifier_free_guidance=False,
175
- negative_prompt=negative_prompt_c,
176
  )
177
 
178
  pose_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
179
  garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
180
- generator = torch.Generator(device).manual_seed(int(seed)) if seed is not None else None
181
 
182
  images = pipe(
183
  prompt_embeds=prompt_embeds.to(device, torch.float16),
@@ -196,63 +162,67 @@ def infer(person, garment, denoise_steps, seed):
196
  width=768,
197
  ip_adapter_image=garm_img.resize((768, 1024)),
198
  guidance_scale=2.0,
199
- callback_on_step_end=callback,
200
  )[0]
201
 
202
  out_img = images[0].resize(crop_size)
203
- progress(1, desc="Complete")
204
  return out_img
205
 
206
- # ---------------- UI (no queue) ----------------
 
 
 
 
 
 
 
 
 
207
  title = "## AI Clothes Changer"
208
  description = "Step into the world of AI clothes swap and unlock style possibilities."
209
 
210
- person_list = os.listdir(os.path.join(example_path, "human"))
211
- person_images = [os.path.join(example_path, "human", p) for p in person_list]
212
-
213
- garment_list = os.listdir(os.path.join(example_path, "cloth"))
214
- garment_images = [os.path.join(example_path, "cloth", g) for g in garment_list]
215
 
216
- with gr.Blocks() as demo: # ← NO .queue()
217
  gr.Markdown(title)
218
  gr.Markdown(description)
219
  with gr.Row():
220
  with gr.Column():
221
  gr.Markdown("#### Person Image")
222
- person_image = gr.Image(
223
- sources=["upload"], type="pil", label="Person Image",
224
- width=512, height=512, show_download_button=False, show_share_button=False
225
- )
226
  gr.Examples(inputs=person_image, examples_per_page=20, examples=person_images)
227
-
228
  with gr.Column():
229
  gr.Markdown("#### Garment Image")
230
- garment_image = gr.Image(
231
- sources=["upload"], type="pil", label="Garment Image",
232
- width=512, height=512, show_download_button=False, show_share_button=False
233
- )
234
  gr.Examples(inputs=garment_image, examples_per_page=20, examples=garment_images)
235
-
236
  with gr.Column():
237
  gr.Markdown("#### Generated Image")
238
- gen_image = gr.Image(label="Generated Image", width=512, height=512,
239
- show_download_button=True, show_share_button=False)
240
-
241
- with gr.Row():
242
- gen_button = gr.Button("Generate")
243
-
244
  with gr.Accordion("Advanced Options", open=False):
245
  denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
246
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
247
-
248
- gen_button.click(
249
- fn=infer,
250
- inputs=[person_image, garment_image, denoise_steps, seed],
251
- outputs=[gen_image],
252
- api_name="predict", # provides /run/predict
253
- queue=False # accept direct POSTs (no queue)
254
- )
255
-
256
- # For local dev only. On Spaces, Gradio auto-launches.
257
- if __name__ == "__main__":
258
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import spaces
3
  import gradio as gr
4
  from PIL import Image
 
14
  )
15
  from diffusers import DDPMScheduler, AutoencoderKL
16
  from typing import List
17
+ import torch, os, io, base64, json, numpy as np
 
 
 
18
  from utils_mask import get_mask_location
19
  from torchvision import transforms
20
  import apply_net
 
23
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
24
  from torchvision.transforms.functional import to_pil_image
25
 
26
+ # FastAPI REST
27
+ from fastapi import FastAPI, Response
28
+ from pydantic import BaseModel
29
+
30
+ # -------------------- helpers --------------------
31
  def pil_to_binary_mask(pil_image, threshold=0):
32
  np_image = np.array(pil_image)
33
  grayscale_image = Image.fromarray(np_image).convert("L")
34
  binary_mask = np.array(grayscale_image) > threshold
35
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
36
+ mask[binary_mask] = 1
37
+ return Image.fromarray((mask * 255).astype(np.uint8))
38
+
39
+ def _b64_to_pil(data_uri_or_b64: str) -> Image.Image:
40
+ # Accept both data: URI and raw base64
41
+ if data_uri_or_b64.startswith("data:"):
42
+ comma = data_uri_or_b64.find(",")
43
+ b64 = data_uri_or_b64[comma + 1:]
44
+ else:
45
+ b64 = data_uri_or_b64
46
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
47
+
48
+ def _pil_to_b64_jpeg(img: Image.Image) -> str:
49
+ buf = io.BytesIO()
50
+ img.save(buf, format="JPEG", quality=92)
51
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
52
+
53
+ # -------------------- load models --------------------
54
  base_path = "yisol/IDM-VTON"
55
  example_path = os.path.join(os.path.dirname(__file__), "example")
56
 
57
+ unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16)
58
+ tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False)
59
+ tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False)
 
 
 
 
 
 
 
 
60
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
61
 
62
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16)
63
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16)
64
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16)
65
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
66
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  parsing_model = Parsing(0)
69
  openpose_model = OpenPose(0)
 
71
  for m in (UNet_Encoder, image_encoder, vae, unet, text_encoder_one, text_encoder_two):
72
  m.requires_grad_(False)
73
 
74
+ tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
 
 
75
 
76
  pipe = TryonPipeline.from_pretrained(
77
  base_path,
 
88
  )
89
  pipe.unet_encoder = UNet_Encoder
90
 
91
+ # -------------------- core inference --------------------
92
+ def _infer_core(person_img: Image.Image, garment_img: Image.Image, denoise_steps: int, seed: int) -> Image.Image:
 
 
 
 
 
93
  device = "cuda"
 
94
  openpose_model.preprocessor.body_estimation.model.to(device)
95
  pipe.to(device)
96
  pipe.unet_encoder.to(device)
97
 
98
+ personRGB = person_img.convert("RGB")
99
  crop_size = personRGB.size
100
  human_img = personRGB.resize((768, 1024))
101
+ garm_img = garment_img.convert("RGB").resize((768, 1024))
102
 
 
103
  keypoints = openpose_model(human_img.resize((384, 512)))
104
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
105
+ mask, _ = get_mask_location("hd", "upper_body", model_parse, keypoints)
106
  mask = mask.resize((768, 1024))
107
 
 
108
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
109
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
110
+ args = apply_net.create_argument_parser().parse_args((
111
+ "show", "./configs/densepose_rcnn_R_50_FPN_s1x.yaml",
112
+ "./ckpt/densepose/model_final_162be9.pkl", "dp_segm", "-v",
113
+ "--opts", "MODEL.DEVICE", "cuda"
114
+ ))
115
+ pose_img = args.func(args, human_img_arg)[:, :, ::-1]
116
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  with torch.no_grad(), torch.cuda.amp.autocast():
119
  prompt = "model is wearing clothing"
 
131
  )
132
 
133
  prompt_c = "a photo of clothing"
134
+ if not isinstance(prompt_c, List): prompt_c = [prompt_c]
135
+ (
136
+ prompt_embeds_c, _, _, _
137
+ ) = pipe.encode_prompt(
 
 
 
138
  prompt_c,
139
  num_images_per_prompt=1,
140
  do_classifier_free_guidance=False,
141
+ negative_prompt=[negative_prompt],
142
  )
143
 
144
  pose_tensor = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
145
  garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
146
+ generator = torch.Generator(device).manual_seed(int(seed))
147
 
148
  images = pipe(
149
  prompt_embeds=prompt_embeds.to(device, torch.float16),
 
162
  width=768,
163
  ip_adapter_image=garm_img.resize((768, 1024)),
164
  guidance_scale=2.0,
 
165
  )[0]
166
 
167
  out_img = images[0].resize(crop_size)
 
168
  return out_img
169
 
170
+ # -------------------- Gradio UI --------------------
171
+ progress = gr.Progress()
172
+
173
+ @spaces.GPU
174
+ def infer(person, garment, denoise_steps, seed):
175
+ progress(0.05, desc="Starting")
176
+ out = _infer_core(person, garment, int(denoise_steps), int(seed))
177
+ progress(1.0, desc="Done")
178
+ return out
179
+
180
  title = "## AI Clothes Changer"
181
  description = "Step into the world of AI clothes swap and unlock style possibilities."
182
 
183
+ person_images = [os.path.join(example_path, "human", f) for f in os.listdir(os.path.join(example_path, "human"))]
184
+ garment_images = [os.path.join(example_path, "cloth", f) for f in os.listdir(os.path.join(example_path, "cloth"))]
 
 
 
185
 
186
+ with gr.Blocks().queue() as demo:
187
  gr.Markdown(title)
188
  gr.Markdown(description)
189
  with gr.Row():
190
  with gr.Column():
191
  gr.Markdown("#### Person Image")
192
+ person_image = gr.Image(sources=["upload"], type="pil", label="Person Image", width=512, height=512,
193
+ show_download_button=False, show_share_button=False)
 
 
194
  gr.Examples(inputs=person_image, examples_per_page=20, examples=person_images)
 
195
  with gr.Column():
196
  gr.Markdown("#### Garment Image")
197
+ garment_image = gr.Image(sources=["upload"], type="pil", label="Garment Image", width=512, height=512,
198
+ show_download_button=False, show_share_button=False)
 
 
199
  gr.Examples(inputs=garment_image, examples_per_page=20, examples=garment_images)
 
200
  with gr.Column():
201
  gr.Markdown("#### Generated Image")
202
+ gen_image = gr.Image(label="Generated Image", width=512, height=512, show_download_button=True, show_share_button=False)
203
+ with gr.Row(): gen_button = gr.Button("Generate")
 
 
 
 
204
  with gr.Accordion("Advanced Options", open=False):
205
  denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
206
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
207
+ gen_button.click(fn=infer, inputs=[person_image, garment_image, denoise_steps, seed], outputs=[gen_image], api_name="predict")
208
+
209
+ # -------------------- FastAPI REST (JSON, base64) --------------------
210
+ class TryOnPayload(BaseModel):
211
+ person_b64: str # data URI or raw base64
212
+ garment_b64: str # data URI or raw base64
213
+ denoise_steps: int = 30
214
+ seed: int = 42
215
+
216
+ fastapi_app = FastAPI()
217
+
218
+ @fastapi_app.post("/tryon")
219
+ def tryon_endpoint(payload: TryOnPayload):
220
+ person = _b64_to_pil(payload.person_b64)
221
+ garment = _b64_to_pil(payload.garment_b64)
222
+ out_img = _infer_core(person, garment, payload.denoise_steps, payload.seed)
223
+ b64 = _pil_to_b64_jpeg(out_img)
224
+ # return a data URI so clients can use it directly if they want
225
+ return {"image_data_uri": f"data:image/jpeg;base64,{b64}", "base64": b64}
226
+
227
+ # Mount Gradio at root, REST at same server
228
+ app = gr.mount_gradio_app(fastapi_app, demo, path="/")