import gradio as gr import base64 import requests from io import BytesIO from PIL import Image import os def call_predict_api(person, cloth): # Convert PIL images to base64 string src_buffer = BytesIO() person.save(src_buffer, format='PNG') src_buffer.seek(0) ref_buffer = BytesIO() cloth.save(ref_buffer, format='PNG') ref_buffer.seek(0) # Prepare files for upload files = { "src_image": ("src_image.png", src_buffer, "image/png"), "ref_image": ("ref_image.png", ref_buffer, "image/png") } headers = {"X-API-Key": os.environ["api_key"]} response = requests.post(os.environ["endpoint"], files=files, headers=headers) if response.status_code != 200: raise Exception(f"API Error: {response.text}") result = response.json() return Image.open(BytesIO(base64.b64decode(result["gen_image"]))) if __name__ == "__main__": title = "## Faster Try-On" description = "This is a Gradio interface for the 'Faster Try-on' project, focusing on the upper body. (We will release a version for the lower body after a few updates). The application allows users to virtually try on various types of clothing such as shirts. Experience a quick and intuitive way to visualize your fashion style." with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): with gr.Column(): gr.Markdown("#### Person Image") person = gr.Image( sources=["upload"], type="pil", label="Person Image", width=512, height=512, ) gr.Examples( inputs=person, examples_per_page=5, examples=["images/00019_00.jpg", "images/00089_00.jpg", "images/image_1.jpg"], ) with gr.Column(): gr.Markdown("#### Garment Image") garment = gr.Image( sources=["upload"], type="pil", label="Garment Image", width=512, height=512, ) gr.Examples( inputs=garment, examples_per_page=10, examples=["images/00000_00.jpg", "images/00044_00.jpg", "images/00113_00.jpg", "images/goods_474419_sub14_3x4.jpg", "images/vngoods_41_481275002_3x4.jpg", "images/vngoods_474419_sub7_3x4.jpg"], ) with gr.Column(): gr.Markdown("#### Generated Image") gen_image = gr.Image( label="Generated Image", width=512, height=512, ) with gr.Row(): vt_gen_button = gr.Button("Generate") # Update the button click to use the API endpoint vt_gen_button.click( fn=call_predict_api, inputs=[person, garment], outputs=[gen_image] ) demo.launch(allowed_paths=["images"])