PawanratRung commited on
Commit
0683053
Β·
verified Β·
1 Parent(s): cdb5a72

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -0
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from huggingface_hub import snapshot_download
4
+ from leffa.transform import LeffaTransform
5
+ from leffa.model import LeffaModel
6
+ from leffa.inference import LeffaInference
7
+ from leffa_utils.garment_agnostic_mask_predictor import AutoMasker
8
+ from leffa_utils.densepose_predictor import DensePosePredictor
9
+ from leffa_utils.utils import resize_and_center, list_dir, get_agnostic_mask_hd, get_agnostic_mask_dc
10
+ from preprocess.humanparsing.run_parsing import Parsing
11
+ from preprocess.openpose.run_openpose import OpenPose
12
+
13
+ import gradio as gr
14
+
15
+ # Download checkpoints
16
+ snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
17
+
18
+
19
+ class LeffaPredictor(object):
20
+ def __init__(self):
21
+ self.mask_predictor = AutoMasker(
22
+ densepose_path="./ckpts/densepose",
23
+ schp_path="./ckpts/schp",
24
+ )
25
+
26
+ self.densepose_predictor = DensePosePredictor(
27
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
28
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
29
+ )
30
+
31
+ self.parsing = Parsing(
32
+ atr_path="./ckpts/humanparsing/parsing_atr.onnx",
33
+ lip_path="./ckpts/humanparsing/parsing_lip.onnx",
34
+ )
35
+
36
+ self.openpose = OpenPose(
37
+ body_model_path="./ckpts/openpose/body_pose_model.pth",
38
+ )
39
+
40
+ vt_model_hd = LeffaModel(
41
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
42
+ pretrained_model="./ckpts/virtual_tryon.pth",
43
+ dtype="float16",
44
+ )
45
+ self.vt_inference_hd = LeffaInference(model=vt_model_hd)
46
+
47
+ vt_model_dc = LeffaModel(
48
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
49
+ pretrained_model="./ckpts/virtual_tryon_dc.pth",
50
+ dtype="float16",
51
+ )
52
+ self.vt_inference_dc = LeffaInference(model=vt_model_dc)
53
+
54
+ pt_model = LeffaModel(
55
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
56
+ pretrained_model="./ckpts/pose_transfer.pth",
57
+ dtype="float16",
58
+ )
59
+ self.pt_inference = LeffaInference(model=pt_model)
60
+
61
+ def leffa_predict(
62
+ self,
63
+ src_image_path,
64
+ ref_image_path,
65
+ control_type,
66
+ ref_acceleration=False,
67
+ step=50,
68
+ scale=2.5,
69
+ seed=42,
70
+ vt_model_type="viton_hd",
71
+ vt_garment_type="upper_body",
72
+ vt_repaint=False
73
+ ):
74
+ assert control_type in [
75
+ "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
76
+ src_image = Image.open(src_image_path)
77
+ ref_image = Image.open(ref_image_path)
78
+ src_image = resize_and_center(src_image, 768, 1024)
79
+ ref_image = resize_and_center(ref_image, 768, 1024)
80
+
81
+ src_image_array = np.array(src_image)
82
+
83
+ # Mask
84
+ if control_type == "virtual_tryon":
85
+ src_image = src_image.convert("RGB")
86
+ model_parse, _ = self.parsing(src_image.resize((384, 512)))
87
+ keypoints = self.openpose(src_image.resize((384, 512)))
88
+ if vt_model_type == "viton_hd":
89
+ mask = get_agnostic_mask_hd(
90
+ model_parse, keypoints, vt_garment_type)
91
+ elif vt_model_type == "dress_code":
92
+ mask = get_agnostic_mask_dc(
93
+ model_parse, keypoints, vt_garment_type)
94
+ mask = mask.resize((768, 1024))
95
+ # garment_type_hd = "upper" if vt_garment_type in [
96
+ # "upper_body", "dresses"] else "lower"
97
+ # mask = self.mask_predictor(src_image, garment_type_hd)["mask"]
98
+ elif control_type == "pose_transfer":
99
+ mask = Image.fromarray(np.ones_like(src_image_array) * 255)
100
+
101
+ # DensePose
102
+ if control_type == "virtual_tryon":
103
+ if vt_model_type == "viton_hd":
104
+ src_image_seg_array = self.densepose_predictor.predict_seg(
105
+ src_image_array)[:, :, ::-1]
106
+ src_image_seg = Image.fromarray(src_image_seg_array)
107
+ densepose = src_image_seg
108
+ elif vt_model_type == "dress_code":
109
+ src_image_iuv_array = self.densepose_predictor.predict_iuv(
110
+ src_image_array)
111
+ src_image_seg_array = src_image_iuv_array[:, :, 0:1]
112
+ src_image_seg_array = np.concatenate(
113
+ [src_image_seg_array] * 3, axis=-1)
114
+ src_image_seg = Image.fromarray(src_image_seg_array)
115
+ densepose = src_image_seg
116
+ elif control_type == "pose_transfer":
117
+ src_image_iuv_array = self.densepose_predictor.predict_iuv(
118
+ src_image_array)[:, :, ::-1]
119
+ src_image_iuv = Image.fromarray(src_image_iuv_array)
120
+ densepose = src_image_iuv
121
+
122
+ # Leffa
123
+ transform = LeffaTransform()
124
+
125
+ data = {
126
+ "src_image": [src_image],
127
+ "ref_image": [ref_image],
128
+ "mask": [mask],
129
+ "densepose": [densepose],
130
+ }
131
+ data = transform(data)
132
+ if control_type == "virtual_tryon":
133
+ if vt_model_type == "viton_hd":
134
+ inference = self.vt_inference_hd
135
+ elif vt_model_type == "dress_code":
136
+ inference = self.vt_inference_dc
137
+ elif control_type == "pose_transfer":
138
+ inference = self.pt_inference
139
+ output = inference(
140
+ data,
141
+ ref_acceleration=ref_acceleration,
142
+ num_inference_steps=step,
143
+ guidance_scale=scale,
144
+ seed=seed,
145
+ repaint=vt_repaint,)
146
+ gen_image = output["generated_image"][0]
147
+ # gen_image.save("gen_image.png")
148
+ return np.array(gen_image), np.array(mask), np.array(densepose)
149
+
150
+ def leffa_predict_vt(self, src_image_path, ref_image_path, ref_acceleration, step, scale, seed, vt_model_type, vt_garment_type, vt_repaint):
151
+ return self.leffa_predict(src_image_path, ref_image_path, "virtual_tryon", ref_acceleration, step, scale, seed, vt_model_type, vt_garment_type, vt_repaint)
152
+
153
+ def leffa_predict_pt(self, src_image_path, ref_image_path, ref_acceleration, step, scale, seed):
154
+ return self.leffa_predict(src_image_path, ref_image_path, "pose_transfer", ref_acceleration, step, scale, seed)
155
+
156
+
157
+ if __name__ == "__main__":
158
+
159
+ leffa_predictor = LeffaPredictor()
160
+ example_dir = "./ckpts/examples"
161
+ person1_images = list_dir(f"{example_dir}/person1")
162
+ person2_images = list_dir(f"{example_dir}/person2")
163
+ garment_images = list_dir(f"{example_dir}/garment")
164
+
165
+ title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
166
+ link = """[πŸ“š Paper](https://arxiv.org/abs/2412.08486) - [πŸ€– Code](https://github.com/franciszzj/Leffa) - [πŸ”₯ Demo](https://huggingface.co/spaces/franciszzj/Leffa) - [πŸ€— Model](https://huggingface.co/franciszzj/Leffa)
167
+
168
+ Star ⭐ us if you like it!
169
+ """
170
+ news = """## News
171
+ - 09/Jan/2025. Inference defaults to float16, generating an image in 6 seconds (on A100).
172
+ More news can be found in the [GitHub repository](https://github.com/franciszzj/Leffa).
173
+ """
174
+ description = "Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer)."
175
+ note = "Note: The models used in the demo are trained solely on academic datasets. Virtual try-on uses VITON-HD/DressCode, and pose transfer uses DeepFashion."
176
+
177
+ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
178
+ gr.Markdown(title)
179
+ gr.Markdown(link)
180
+ gr.Markdown(news)
181
+ gr.Markdown(description)
182
+
183
+ with gr.Tab("Control Appearance (Virtual Try-on)"):
184
+ with gr.Row():
185
+ with gr.Column():
186
+ gr.Markdown("#### Person Image")
187
+ vt_src_image = gr.Image(
188
+ sources=["upload"],
189
+ type="filepath",
190
+ label="Person Image",
191
+ width=512,
192
+ height=512,
193
+ )
194
+
195
+ gr.Examples(
196
+ inputs=vt_src_image,
197
+ examples_per_page=10,
198
+ examples=person1_images,
199
+ )
200
+
201
+ with gr.Column():
202
+ gr.Markdown("#### Garment Image")
203
+ vt_ref_image = gr.Image(
204
+ sources=["upload"],
205
+ type="filepath",
206
+ label="Garment Image",
207
+ width=512,
208
+ height=512,
209
+ )
210
+
211
+ gr.Examples(
212
+ inputs=vt_ref_image,
213
+ examples_per_page=10,
214
+ examples=garment_images,
215
+ )
216
+
217
+ with gr.Column():
218
+ gr.Markdown("#### Generated Image")
219
+ vt_gen_image = gr.Image(
220
+ label="Generated Image",
221
+ width=512,
222
+ height=512,
223
+ )
224
+
225
+ with gr.Row():
226
+ vt_gen_button = gr.Button("Generate")
227
+
228
+ with gr.Accordion("Advanced Options", open=False):
229
+ vt_model_type = gr.Radio(
230
+ label="Model Type",
231
+ choices=[("VITON-HD (Recommended)", "viton_hd"),
232
+ ("DressCode (Experimental)", "dress_code")],
233
+ value="viton_hd",
234
+ )
235
+
236
+ vt_garment_type = gr.Radio(
237
+ label="Garment Type",
238
+ choices=[("Upper", "upper_body"),
239
+ ("Lower", "lower_body"),
240
+ ("Dress", "dresses")],
241
+ value="upper_body",
242
+ )
243
+
244
+ vt_ref_acceleration = gr.Radio(
245
+ label="Accelerate Reference UNet (may slightly reduce performance)",
246
+ choices=[("True", True), ("False", False)],
247
+ value=False,
248
+ )
249
+
250
+ vt_repaint = gr.Radio(
251
+ label="Repaint Mode",
252
+ choices=[("True", True), ("False", False)],
253
+ value=False,
254
+ )
255
+
256
+ vt_step = gr.Number(
257
+ label="Inference Steps", minimum=30, maximum=100, step=1, value=30)
258
+
259
+ vt_scale = gr.Number(
260
+ label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
261
+
262
+ vt_seed = gr.Number(
263
+ label="Random Seed", minimum=-1, maximum=2147483647, step=1, value=42)
264
+
265
+ with gr.Accordion("Debug", open=False):
266
+ vt_mask = gr.Image(
267
+ label="Generated Mask",
268
+ width=256,
269
+ height=256,
270
+ )
271
+
272
+ vt_densepose = gr.Image(
273
+ label="Generated DensePose",
274
+ width=256,
275
+ height=256,
276
+ )
277
+
278
+ vt_gen_button.click(fn=leffa_predictor.leffa_predict_vt, inputs=[
279
+ vt_src_image, vt_ref_image, vt_ref_acceleration, vt_step, vt_scale, vt_seed, vt_model_type, vt_garment_type, vt_repaint], outputs=[vt_gen_image, vt_mask, vt_densepose])
280
+
281
+ with gr.Tab("Control Pose (Pose Transfer)"):
282
+ with gr.Row():
283
+ with gr.Column():
284
+ gr.Markdown("#### Person Image")
285
+ pt_ref_image = gr.Image(
286
+ sources=["upload"],
287
+ type="filepath",
288
+ label="Person Image",
289
+ width=512,
290
+ height=512,
291
+ )
292
+
293
+ gr.Examples(
294
+ inputs=pt_ref_image,
295
+ examples_per_page=10,
296
+ examples=person1_images,
297
+ )
298
+
299
+ with gr.Column():
300
+ gr.Markdown("#### Target Pose Person Image")
301
+ pt_src_image = gr.Image(
302
+ sources=["upload"],
303
+ type="filepath",
304
+ label="Target Pose Person Image",
305
+ width=512,
306
+ height=512,
307
+ )
308
+
309
+ gr.Examples(
310
+ inputs=pt_src_image,
311
+ examples_per_page=10,
312
+ examples=person2_images,
313
+ )
314
+
315
+ with gr.Column():
316
+ gr.Markdown("#### Generated Image")
317
+ pt_gen_image = gr.Image(
318
+ label="Generated Image",
319
+ width=512,
320
+ height=512,
321
+ )
322
+
323
+ with gr.Row():
324
+ pose_transfer_gen_button = gr.Button("Generate")
325
+
326
+ with gr.Accordion("Advanced Options", open=False):
327
+ pt_ref_acceleration = gr.Radio(
328
+ label="Accelerate Reference UNet",
329
+ choices=[("True", True), ("False", False)],
330
+ value=False,
331
+ )
332
+
333
+ pt_step = gr.Number(
334
+ label="Inference Steps", minimum=30, maximum=100, step=1, value=30)
335
+
336
+ pt_scale = gr.Number(
337
+ label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
338
+
339
+ pt_seed = gr.Number(
340
+ label="Random Seed", minimum=-1, maximum=2147483647, step=1, value=42)
341
+
342
+ with gr.Accordion("Debug", open=False):
343
+ pt_mask = gr.Image(
344
+ label="Generated Mask",
345
+ width=256,
346
+ height=256,
347
+ )
348
+
349
+ pt_densepose = gr.Image(
350
+ label="Generated DensePose",
351
+ width=256,
352
+ height=256,
353
+ )
354
+
355
+ pose_transfer_gen_button.click(fn=leffa_predictor.leffa_predict_pt, inputs=[
356
+ pt_src_image, pt_ref_image, pt_ref_acceleration, pt_step, pt_scale, pt_seed], outputs=[pt_gen_image, pt_mask, pt_densepose])
357
+
358
+ gr.Markdown(note)
359
+
360
+ demo.launch(share=True, server_port=7860,
361
+ allowed_paths=["./ckpts/examples"])
362
+