Faster-Try-On / app.py
Phuc Vo
update
564a1f1
import gradio as gr
import base64
import requests
from io import BytesIO
from PIL import Image
import os
def call_predict_api(person, cloth):
# Convert PIL images to base64 string
src_buffer = BytesIO()
person.save(src_buffer, format='PNG')
src_buffer.seek(0)
ref_buffer = BytesIO()
cloth.save(ref_buffer, format='PNG')
ref_buffer.seek(0)
# Prepare files for upload
files = {
"src_image": ("src_image.png", src_buffer, "image/png"),
"ref_image": ("ref_image.png", ref_buffer, "image/png")
}
headers = {"X-API-Key": os.environ["api_key"]}
response = requests.post(os.environ["endpoint"], files=files, headers=headers)
if response.status_code != 200:
raise Exception(f"API Error: {response.text}")
result = response.json()
return Image.open(BytesIO(base64.b64decode(result["gen_image"])))
if __name__ == "__main__":
title = "## Faster Try-On"
description = "This is a Gradio interface for the 'Faster Try-on' project, focusing on the upper body. (We will release a version for the lower body after a few updates). The application allows users to virtually try on various types of clothing such as shirts. Experience a quick and intuitive way to visualize your fashion style."
with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown("#### Person Image")
person = gr.Image(
sources=["upload"],
type="pil",
label="Person Image",
width=512,
height=512,
)
gr.Examples(
inputs=person,
examples_per_page=5,
examples=["images/00019_00.jpg",
"images/00089_00.jpg",
"images/image_1.jpg"],
)
with gr.Column():
gr.Markdown("#### Garment Image")
garment = gr.Image(
sources=["upload"],
type="pil",
label="Garment Image",
width=512,
height=512,
)
gr.Examples(
inputs=garment,
examples_per_page=10,
examples=["images/00000_00.jpg",
"images/00044_00.jpg",
"images/00113_00.jpg",
"images/goods_474419_sub14_3x4.jpg",
"images/vngoods_41_481275002_3x4.jpg",
"images/vngoods_474419_sub7_3x4.jpg"],
)
with gr.Column():
gr.Markdown("#### Generated Image")
gen_image = gr.Image(
label="Generated Image",
width=512,
height=512,
)
with gr.Row():
vt_gen_button = gr.Button("Generate")
# Update the button click to use the API endpoint
vt_gen_button.click(
fn=call_predict_api,
inputs=[person, garment],
outputs=[gen_image]
)
demo.launch(allowed_paths=["images"])