|
|
import gradio as gr |
|
|
import os |
|
|
import time |
|
|
import requests |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
tryon_css=""" |
|
|
#col-garment { |
|
|
margin: 0 auto; |
|
|
max-width: 420px; |
|
|
} |
|
|
#garm_img { |
|
|
aspect-ratio: 3 / 4; |
|
|
width: 100%; |
|
|
max-height: 560px; |
|
|
object-fit: contain; |
|
|
} |
|
|
#col-person { |
|
|
margin: 0 auto; |
|
|
max-width: 420px; |
|
|
} |
|
|
#person_img { |
|
|
aspect-ratio: 3 / 4; |
|
|
width: 100%; |
|
|
max-height: 560px; |
|
|
object-fit: contain; |
|
|
} |
|
|
#col-result { |
|
|
margin: 0 auto; |
|
|
max-width: 420px; |
|
|
} |
|
|
#result_img { |
|
|
aspect-ratio: 3 / 4; |
|
|
width: 100%; |
|
|
max-height: 560px; |
|
|
object-fit: contain; |
|
|
} |
|
|
#col-examples { |
|
|
margin: 0 auto; |
|
|
max-width: 1000px; |
|
|
} |
|
|
#col-examples img { |
|
|
aspect-ratio: 3 / 4; |
|
|
object-fit: contain; |
|
|
} |
|
|
#button { |
|
|
background-color: #A47764; |
|
|
color: white; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
example_path = os.path.join(os.path.dirname(__file__), 'data') |
|
|
|
|
|
garm_list = os.listdir(os.path.join(example_path,"garment")) |
|
|
garm_list_path = [os.path.join(example_path, "garment", garm) for garm in garm_list] |
|
|
|
|
|
person_list = os.listdir(os.path.join(example_path,"person")) |
|
|
person_list_path = [os.path.join(example_path, "person", person) for person in person_list] |
|
|
|
|
|
garm_img_category_mapping = {os.path.basename(garm_file): os.path.basename(garm_file).split("_")[2].capitalize() for garm_file in garm_list_path} |
|
|
|
|
|
|
|
|
def load_header(header_file): |
|
|
with open(header_file, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
return content |
|
|
|
|
|
def preprocess_img(img_path, max_size=1024): |
|
|
if img_path is None: |
|
|
return None |
|
|
img = Image.open(img_path) |
|
|
if max(img.size) > max_size: |
|
|
img.thumbnail((max_size, max_size)) |
|
|
img.save(img_path) |
|
|
return img_path |
|
|
|
|
|
def update_category(selected_garm_file): |
|
|
selected_category = garm_img_category_mapping.get(os.path.basename(selected_garm_file), "Fullbody") |
|
|
return gr.update(value=selected_category) |
|
|
|
|
|
def call_tryon_api(person_file, garm_file, category, model_type='SD_V1'): |
|
|
tryon_url = os.environ['API_ENDPOINT'] + "/tryon/v1" |
|
|
payload = {'garment_type': category, 'model_type': model_type, 'repaint_other_garment': 'false'} |
|
|
files = { |
|
|
'image_garment_file': open(garm_file, 'rb'), |
|
|
'image_model_file': open(person_file, 'rb'), |
|
|
} |
|
|
headers = { |
|
|
'x-api-key': os.environ['API_KEY'] |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(tryon_url, headers=headers, data=payload, files=files) |
|
|
if response.ok: |
|
|
data = response.json() |
|
|
return data['job_id'], data['status'] |
|
|
else: |
|
|
print(response.content) |
|
|
except Exception as e: |
|
|
print(f"call tryon api error: {e}") |
|
|
|
|
|
|
|
|
raise gr.Error("Over heated, please try again later") |
|
|
|
|
|
def get_tryon_result(job_id): |
|
|
result_url = os.environ['API_ENDPOINT'] + "/requests/v1" + f"?job_id={job_id}" |
|
|
headers = { |
|
|
'x-api-key': os.environ['API_KEY'] |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.get(result_url, headers=headers) |
|
|
|
|
|
if response.ok: |
|
|
data = response.json() |
|
|
if data["status"] == "completed": |
|
|
image_url = data['output'][0]['image_url'] |
|
|
return image_url, data['status'] |
|
|
else: |
|
|
return None, data['status'] |
|
|
except Exception as e: |
|
|
print(f"get tryon result error: {e}") |
|
|
return None, None |
|
|
|
|
|
def run_turbo(person_img, garm_img, category="Top"): |
|
|
if person_img is None or garm_img is None: |
|
|
gr.Warning("input image is missing") |
|
|
return None, "No input image" |
|
|
|
|
|
info = "" |
|
|
|
|
|
job_id, status = call_tryon_api(person_img, garm_img, category, model_type= os.environ['MODEL_TYPE']) |
|
|
|
|
|
time.sleep(8) |
|
|
|
|
|
|
|
|
max_retry = 40 |
|
|
while status not in ["completed", "failed"]: |
|
|
try: |
|
|
result_image_url, status = get_tryon_result(job_id) |
|
|
if result_image_url is not None: |
|
|
return result_image_url, info |
|
|
except: |
|
|
pass |
|
|
time.sleep(1.5) |
|
|
|
|
|
gr.Warning("Over heated, please try again later") |
|
|
return None, info |
|
|
|
|
|
with gr.Blocks(css=tryon_css) as Huhu_Turbo: |
|
|
gr.HTML(load_header("data/header.html")) |
|
|
with gr.Row(): |
|
|
with gr.Column(elem_id = "col-garment"): |
|
|
gr.HTML(""" |
|
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> |
|
|
<div> |
|
|
Upload your garment image 🧥 |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
with gr.Column(elem_id = "col-person"): |
|
|
gr.HTML(""" |
|
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> |
|
|
<div> |
|
|
Select a model image 🧍 |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
with gr.Column(elem_id = "col-result"): |
|
|
gr.HTML(""" |
|
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> |
|
|
<div> |
|
|
“RUN” to get results 🪄 |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
with gr.Row(): |
|
|
with gr.Column(elem_id = "col-garment"): |
|
|
garm_img = gr.Image(label="Garment image", sources='upload', type="filepath", elem_id="garm_img") |
|
|
category = gr.Dropdown(label="Garment type", choices=['Top', 'Bottom', 'Fullbody'], value="Top") |
|
|
garm_example = gr.Examples( |
|
|
inputs=garm_img, |
|
|
examples_per_page=10, |
|
|
examples=garm_list_path, |
|
|
cache_examples=False |
|
|
) |
|
|
with gr.Column(elem_id = "col-person"): |
|
|
person_img = gr.Image(label="Person image", sources='upload', type="filepath", elem_id="person_img") |
|
|
person_example = gr.Examples( |
|
|
inputs=person_img, |
|
|
examples_per_page=10, |
|
|
examples=person_list_path |
|
|
) |
|
|
with gr.Column(elem_id = "col-result"): |
|
|
result_img = gr.Image(label="Result", show_share_button=False, elem_id="result_img") |
|
|
with gr.Row(): |
|
|
result_info = gr.Text(label="Tryon inference runtime", visible=False) |
|
|
generate_button = gr.Button(value="RUN", elem_id="button") |
|
|
|
|
|
garm_example.load_input_event.then( |
|
|
fn=update_category, |
|
|
inputs=[garm_img], |
|
|
outputs=[category] |
|
|
) |
|
|
|
|
|
garm_img.change(fn=preprocess_img, inputs=[garm_img], outputs=[garm_img]) |
|
|
person_img.change(fn=preprocess_img, inputs=[person_img], outputs=[person_img]) |
|
|
|
|
|
generate_button.click(fn=run_turbo, inputs=[person_img, garm_img, category], outputs=[result_img, result_info], api_name=False, concurrency_limit=30) |
|
|
|
|
|
with gr.Column(elem_id = "col-examples"): |
|
|
gr.HTML(""" |
|
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;"> |
|
|
<div> </div> |
|
|
<br> |
|
|
<div> |
|
|
Huhu Try-on Turbo examples in pairs of garment and model images |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
show_case = gr.Examples( |
|
|
examples=[ |
|
|
["data/examples/person_example_1.png", "data/examples/garment_example_1.png", "Top", "data/examples/result_example_1.png"], |
|
|
["data/examples/person_example_2.png", "data/examples/garment_example_2.png", "Top", "data/examples/result_example_2.png"], |
|
|
["data/examples/person_example_3.png", "data/examples/garment_example_3.png", "Top", "data/examples/result_example_3.png"], |
|
|
["data/examples/person_example_4.png", "data/examples/garment_example_4.png", "Fullbody", "data/examples/result_example_4.png"], |
|
|
["data/examples/person_example_5.png", "data/examples/garment_example_5.png", "Top", "data/examples/result_example_5.png"], |
|
|
], |
|
|
inputs=[person_img, garm_img, category, result_img], |
|
|
label=None |
|
|
) |
|
|
|
|
|
Huhu_Turbo.queue(api_open=False).launch(show_api=False) |
|
|
|