File size: 3,347 Bytes
8c40219
edfb9ce
a466d6e
6e4a2a6
459eb54
6042875
 
8c40219
6e4a2a6
8c40219
 
 
9d9631e
 
 
8c40219
b793cb4
 
8c40219
b793cb4
8c40219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a100145
b793cb4
8c40219
b793cb4
9d9631e
8c40219
b793cb4
8c40219
 
6e4a2a6
 
8c40219
 
 
 
 
 
6e4a2a6
8c40219
6e4a2a6
 
a466d6e
6e4a2a6
a466d6e
c3b4694
 
 
 
 
 
 
 
a466d6e
 
 
 
6e4a2a6
a466d6e
d82b800
a466d6e
d82b800
a466d6e
 
 
 
 
 
 
 
 
 
 
 
 
51a8997
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import gradio as gr
import cv2
import numpy as np
from gradio_client import Client

# -----------------------------
# Load Hugging Face API key safely
# -----------------------------
API_KEY = os.environ.get("HF_TOKEN")  # read from environment or Secrets
hunyuan_client = None

try:
    hunyuan_client = Client("tencent/Hunyuan3D-2.1", hf_token=API_KEY)
except Exception as e:
    print(f"Hunyuan3D client not loaded: {e}")

# -----------------------------
# Local "try-on" function
# -----------------------------
def tryon_local(person_img, garment_img, seed, randomize_seed):
    if person_img is None or garment_img is None:
        return None, None, "Empty image"

    h_person, w_person = person_img.shape[:2]
    h_garment, w_garment = garment_img.shape[:2]

    scale = w_person / w_garment
    new_w = int(w_garment * scale)
    new_h = int(h_garment * scale)
    garment_resized = cv2.resize(garment_img, (new_w, new_h))

    overlay = person_img.copy()
    y_offset = 0
    x_offset = max(0, (w_person - new_w) // 2)

    if garment_resized.shape[2] == 4:
        alpha = garment_resized[:, :, 3] / 255.0
        for c in range(3):
            overlay[y_offset:y_offset+new_h, x_offset:x_offset+new_w, c] = \
                alpha * garment_resized[:, :, c] + (1 - alpha) * overlay[y_offset:y_offset+new_h, x_offset:x_offset+new_w, c]
    else:
        overlay[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = garment_resized

    info = "Success (local simulation)"
    return overlay, seed, info

# -----------------------------
# Try-on via Hunyuan3D
# -----------------------------
def tryon_to_3d(person_img, garment_img, seed, randomize_seed):
    overlay_img, seed_used, info = tryon_local(person_img, garment_img, seed, randomize_seed)

    if hunyuan_client is None:
        return overlay_img, "Hunyuan3D client not loaded. Please check your API key."

    try:
        # Convert image to bytes before sending to API
        _, buffer = cv2.imencode(".png", overlay_img)
        img_bytes = buffer.tobytes()
        # Call the API (most HF endpoints accept a single positional argument)
        result = hunyuan_client.predict(img_bytes)
        return result, "3D try-on completed!"
    except Exception as e:
        return overlay_img, f"Hunyuan3D API error: {e}"

# -----------------------------
# Gradio UI
# -----------------------------
css = """
#col-left, #col-mid, #col-right {
    margin: 0 auto;
    max-width: 430px;
}
#col-showcase {
    margin: 0 auto;
    max-width: 1100px;
}
#button { color: blue; }
"""

with gr.Blocks(css=css) as app:
    with gr.Row():
        with gr.Column(elem_id="col-left"):
            person_input = gr.Image(label="Person Image", type="numpy")
        with gr.Column(elem_id="col-mid"):
            garment_input = gr.Image(label="Garment Image", type="numpy")
        with gr.Column(elem_id="col-right"):
            output_img = gr.Image(label="3D Result")
            result_info = gr.Text(label="Info")
            seed = gr.Slider(0, 999999, value=0, step=1, label="Seed")
            randomize_seed = gr.Checkbox(label="Random seed", value=True)
            run_btn = gr.Button("Run")

    run_btn.click(
        fn=tryon_to_3d,
        inputs=[person_input, garment_input, seed, randomize_seed],
        outputs=[output_img, result_info]
    )

app.launch()