File size: 3,497 Bytes
4569902
 
 
 
 
 
 
 
 
 
564a1f1
 
 
4569902
564a1f1
 
 
4569902
564a1f1
 
 
 
4569902
 
564a1f1
4569902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

import gradio as gr
import base64
import requests
from io import BytesIO
from PIL import Image
import os

def call_predict_api(person, cloth):
    # Convert PIL images to base64 string
    src_buffer = BytesIO()
    person.save(src_buffer, format='PNG')
    src_buffer.seek(0)
    
    ref_buffer = BytesIO()
    cloth.save(ref_buffer, format='PNG')
    ref_buffer.seek(0)
    
    # Prepare files for upload
    files = {
        "src_image": ("src_image.png", src_buffer, "image/png"),
        "ref_image": ("ref_image.png", ref_buffer, "image/png")
    }
    headers = {"X-API-Key": os.environ["api_key"]}
    response = requests.post(os.environ["endpoint"], files=files, headers=headers)
    if response.status_code != 200:
        raise Exception(f"API Error: {response.text}")
    result = response.json()
    return Image.open(BytesIO(base64.b64decode(result["gen_image"])))

if __name__ == "__main__":
    title = "## Faster Try-On"
    description = "This is a Gradio interface for the 'Faster Try-on' project, focusing on the upper body. (We will release a version for the lower body after a few updates).  The application allows users to virtually try on various types of clothing such as shirts.  Experience a quick and intuitive way to visualize your fashion style."
    with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
        gr.Markdown(title)
        gr.Markdown(description)

        with gr.Row():
            with gr.Column():
                gr.Markdown("#### Person Image")
                person = gr.Image(
                    sources=["upload"],
                    type="pil",
                    label="Person Image",
                    width=512,
                    height=512,
                )
                gr.Examples(
                    inputs=person,
                    examples_per_page=5,
                    examples=["images/00019_00.jpg",
                              "images/00089_00.jpg",
                              "images/image_1.jpg"],
                )

            with gr.Column():
                gr.Markdown("#### Garment Image")
                garment = gr.Image(
                    sources=["upload"],
                    type="pil",
                    label="Garment Image",
                    width=512,
                    height=512,
                )
                gr.Examples(
                    inputs=garment,
                    examples_per_page=10,
                    examples=["images/00000_00.jpg",
                              "images/00044_00.jpg",
                              "images/00113_00.jpg",
                              "images/goods_474419_sub14_3x4.jpg",
                              "images/vngoods_41_481275002_3x4.jpg",
                              "images/vngoods_474419_sub7_3x4.jpg"],
                )

            with gr.Column():
                gr.Markdown("#### Generated Image")
                gen_image = gr.Image(
                    label="Generated Image",
                    width=512,
                    height=512,
                )
                with gr.Row():
                    vt_gen_button = gr.Button("Generate")

            # Update the button click to use the API endpoint
            vt_gen_button.click(
                fn=call_predict_api,
                inputs=[person, garment],
                outputs=[gen_image]
            )

        demo.launch(allowed_paths=["images"])