Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| import requests | |
| from PIL import Image | |
| # ux format | |
| tryon_css=""" | |
| #col-garment { | |
| margin: 0 auto; | |
| max-width: 420px; | |
| } | |
| #garm_img { | |
| aspect-ratio: 3 / 4; | |
| width: 100%; | |
| max-height: 560px; | |
| object-fit: contain; | |
| } | |
| #col-person { | |
| margin: 0 auto; | |
| max-width: 420px; | |
| } | |
| #person_img { | |
| aspect-ratio: 3 / 4; | |
| width: 100%; | |
| max-height: 560px; | |
| object-fit: contain; | |
| } | |
| #col-result { | |
| margin: 0 auto; | |
| max-width: 420px; | |
| } | |
| #result_img { | |
| aspect-ratio: 3 / 4; | |
| width: 100%; | |
| max-height: 560px; | |
| object-fit: contain; | |
| } | |
| #col-examples { | |
| margin: 0 auto; | |
| max-width: 1000px; | |
| } | |
| #col-examples img { | |
| aspect-ratio: 3 / 4; | |
| object-fit: contain; | |
| } | |
| #button { | |
| background-color: #A47764; | |
| color: white; | |
| } | |
| """ | |
| # assets loading | |
| example_path = os.path.join(os.path.dirname(__file__), 'data') | |
| garm_list = os.listdir(os.path.join(example_path,"garment")) | |
| garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list] | |
| person_list = os.listdir(os.path.join(example_path,"person")) | |
| person_list_path = [os.path.join(example_path, "person", person) for person in person_list] | |
| garm_img_category_mapping = {os.path.basename(garm_file): os.path.basename(garm_file).split("_")[2].capitalize() for garm_file in garm_list_path} | |
| def load_header(header_file): | |
| with open(header_file, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| def preprocess_img(img_path, max_size=1024): | |
| if img_path is None: | |
| return None | |
| img = Image.open(img_path) | |
| if max(img.size) > max_size: | |
| img.thumbnail((max_size, max_size)) | |
| img.save(img_path) | |
| return img_path | |
| def update_category(selected_garm_file): | |
| selected_category = garm_img_category_mapping.get(os.path.basename(selected_garm_file), "Fullbody") | |
| return gr.update(value=selected_category) | |
| def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'): | |
| tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1" | |
| payload = {'garment_type': category, 'model_type': model_type, 'repaint_other_garment': 'false'} | |
| files = { | |
| 'image_garment_file': open(garm_file, 'rb'), | |
| 'image_model_file': open(person_file, 'rb'), | |
| } | |
| headers = { | |
| 'x-api-key': os.environ['API_KEY'] | |
| } | |
| try: | |
| response = requests.post(tryon_url, headers=headers, data=payload, files=files) | |
| if response.ok: | |
| data = response.json() | |
| return data['job_id'], data['status'] | |
| else: | |
| print(response.content) | |
| except Exception as e: | |
| print(f"call tryon api error: {e}") | |
| # if the API call fails, return pop up error | |
| raise gr.Error("Over heated, please try again later") | |
| def get_tryon_result(job_id): | |
| result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}" | |
| headers = { | |
| 'x-api-key': os.environ['API_KEY'] | |
| } | |
| try: | |
| response = requests.get(result_url, headers=headers) | |
| if response.ok: | |
| data = response.json() | |
| if data["status"] == "completed": | |
| image_url = data['output'][0]['image_url'] | |
| return image_url, data['status'] | |
| else: | |
| return None, data['status'] | |
| except Exception as e: | |
| print(f"get tryon result error: {e}") | |
| return None, None | |
| def run_turbo(person_img, garm_img, category="Top"): | |
| if person_img is None or garm_img is None: | |
| gr.Warning("input image is missing") | |
| return None, "No input image" | |
| info = "" # placeholder for now | |
| job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE']) | |
| time.sleep(8) # wait before fetching the result | |
| # check the status of the job | |
| max_retry = 40 # 40x1.5s = 60s timeout for sinlge job run | |
| while status not in ["completed", "failed"]: | |
| try: | |
| result_image_url, status = get_tryon_result(job_id) | |
| if result_image_url is not None: | |
| return result_image_url, info | |
| except: | |
| pass | |
| time.sleep(1.5) # Wait before retrying | |
| gr.Warning("Over heated, please try again later") | |
| return None, info | |
| with gr.Blocks(css=tryon_css) as Huhu_Turbo: | |
| gr.HTML(load_header("data/header.html")) | |
| with gr.Row(): | |
| with gr.Column(elem_id = "col-garment"): | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
| <div> | |
| Upload your garment image 🧥 | |
| </div> | |
| </div> | |
| """) | |
| with gr.Column(elem_id = "col-person"): | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
| <div> | |
| Select a model image 🧍 | |
| </div> | |
| </div> | |
| """) | |
| with gr.Column(elem_id = "col-result"): | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
| <div> | |
| “RUN” to get results 🪄 | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(elem_id = "col-garment"): | |
| garm_img = gr.Image(label="Garment image", sources='upload', type="filepath", elem_id="garm_img") | |
| category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top") | |
| garm_example = gr.Examples( | |
| inputs=garm_img, | |
| examples_per_page=10, | |
| examples=garm_list_path, | |
| cache_examples=False | |
| ) | |
| with gr.Column(elem_id = "col-person"): | |
| person_img = gr.Image(label="Person image", sources='upload', type="filepath", elem_id="person_img") | |
| person_example = gr.Examples( | |
| inputs=person_img, | |
| examples_per_page=10, | |
| examples=person_list_path | |
| ) | |
| with gr.Column(elem_id = "col-result"): | |
| result_img = gr.Image(label="Result", show_share_button=False, elem_id="result_img") | |
| with gr.Row(): | |
| result_info = gr.Text(label="Tryon inference runtime", visible=False) | |
| generate_button = gr.Button(value="RUN", elem_id="button") | |
| garm_example.load_input_event.then( | |
| fn=update_category, | |
| inputs=[garm_img], | |
| outputs=[category] | |
| ) | |
| garm_img.change(fn=preprocess_img, inputs=[garm_img], outputs=[garm_img]) | |
| person_img.change(fn=preprocess_img, inputs=[person_img], outputs=[person_img]) | |
| generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30) | |
| gr.HTML(load_header("data/note.html")) | |
| with gr.Column(elem_id = "col-examples"): | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> | |
| <div> </div> | |
| <br> | |
| <div> | |
| Huhu Try-on Turbo examples in pairs of garment and model images | |
| </div> | |
| </div> | |
| """) | |
| show_case = gr.Examples( | |
| examples=[ | |
| ["data/examples/person_example_1.png", "data/examples/garment_example_1.png", "Top", "data/examples/result_example_1.png"], | |
| ["data/examples/person_example_2.png", "data/examples/garment_example_2.png", "Top", "data/examples/result_example_2.png"], | |
| ["data/examples/person_example_3.png", "data/examples/garment_example_3.png", "Top", "data/examples/result_example_3.png"], | |
| ["data/examples/person_example_4.png", "data/examples/garment_example_4.png", "Fullbody", "data/examples/result_example_4.png"], | |
| ["data/examples/person_example_5.png", "data/examples/garment_example_5.png", "Top", "data/examples/result_example_5.png"], | |
| ], | |
| inputs=[person_img, garm_img, category, result_img], | |
| label=None | |
| ) | |
| Huhu_Turbo.queue(api_open=False).launch(show_api=False) | |