File size: 10,310 Bytes
5886e53
03edb30
73f1086
bbc4eef
9c49035
03edb30
 
9c49035
03edb30
 
 
 
5886e53
03edb30
 
 
 
 
 
9c49035
bbc4eef
 
03edb30
 
 
 
 
 
9c49035
5886e53
 
 
03edb30
 
 
 
 
5886e53
03edb30
 
 
 
 
 
 
 
 
 
 
 
 
9c49035
03edb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbc4eef
03edb30
bbc4eef
03edb30
 
 
 
 
 
 
 
 
9c49035
03edb30
 
 
 
 
 
 
 
 
9c49035
 
 
 
 
03edb30
 
bbc4eef
 
9c49035
 
03edb30
9c49035
03edb30
9c49035
03edb30
 
 
 
9c49035
03edb30
 
 
 
9c49035
03edb30
 
 
 
9c49035
03edb30
 
9c49035
03edb30
 
 
9c49035
03edb30
 
 
 
 
 
 
9c49035
03edb30
 
 
 
 
69d3136
 
9c49035
03edb30
 
 
 
 
9c49035
03edb30
 
 
 
 
 
5886e53
03edb30
 
9c49035
03edb30
 
 
 
9c49035
03edb30
 
 
 
9c49035
 
03edb30
9c49035
 
 
03edb30
 
 
 
 
 
9c49035
03edb30
9c49035
 
5886e53
69d3136
03edb30
9c49035
 
03edb30
 
9c49035
03edb30
9c49035
03edb30
 
 
9c49035
03edb30
9c49035
 
03edb30
9c49035
 
 
 
03edb30
9c49035
 
 
03edb30
9c49035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03edb30
 
9c49035
03edb30
 
 
9c49035
5886e53
03edb30
5886e53
03edb30
5886e53
03edb30
9c49035
03edb30
 
 
 
9c49035
 
03edb30
9c49035
03edb30
 
 
9c49035
03edb30
 
 
69d3136
03edb30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import os, json, tempfile, re
import cv2, numpy as np, gradio as gr
from PIL import Image

# =============== Paths ===============
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CANDIDATES = [
    os.path.join(BASE_DIR, "hair"),
    os.path.join(BASE_DIR, "assets", "hairstyles"),
    os.path.join(BASE_DIR, "assets", "Hairstyles"),
    os.path.join(BASE_DIR, "hairstyles"),
]
HAIR_DIR = next((p for p in CANDIDATES if os.path.isdir(p)), None)
if HAIR_DIR is None:
    HAIR_DIR = os.path.join(BASE_DIR, "hair")
    os.makedirs(HAIR_DIR, exist_ok=True)

META_PATH = os.path.join(HAIR_DIR, "meta.json")  # optional per-style anchors

# =============== Dependencies ===============
try:
    import mediapipe as mp
except Exception as e:
    raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}")

mp_face_mesh = mp.solutions.face_mesh
LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}

# =============== Helpers ===============
def natural_key(s: str):
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]

def load_hairstyles():
    try:
        files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
    except FileNotFoundError:
        files = []
    files.sort(key=natural_key)
    return files

def load_meta():
    if os.path.exists(META_PATH):
        try:
            with open(META_PATH, "r") as f:
                m = json.load(f)
            return m if isinstance(m, dict) else {}
        except Exception:
            return {}
    return {}

def premultiply_alpha(bgra):
    """Eliminate gray/white halos on edges."""
    bgr = bgra[:, :, :3].astype(np.float32) / 255.0
    a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
    bgr_pm = (bgr * a * 255.0).astype(np.uint8)
    return np.dstack([bgr_pm, bgra[:, :, 3]])

def load_hair_png(name):
    path = os.path.join(HAIR_DIR, name)
    hair = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGRA
    if hair is None or hair.shape[2] != 4:
        raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
    return premultiply_alpha(hair)

def detect_face_keypoints(img_bgr):
    h, w = img_bgr.shape[:2]
    with mp_face_mesh.FaceMesh(
        static_image_mode=True, max_num_faces=1, refine_landmarks=True,
        min_detection_confidence=0.6
    ) as fm:
        res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
        if not res.multi_face_landmarks:
            return None
        lm = res.multi_face_landmarks[0].landmark
        def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
        return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])

def hair_reference_points(hair_bgra, filename, meta):
    h, w = hair_bgra.shape[:2]
    if filename in meta:
        pts = np.array(meta[filename], dtype=np.float32)
        if pts.shape == (3, 2):
            return pts
    # Defaults (OK for many styles). For pixel-perfect fit, add 3 points to meta.json.
    pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
    pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
    pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
    return np.stack([pL, pR, pM], axis=0)

def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
    H, W = base_bgr.shape[:2]
    hair_rgb = hair_bgra[:, :, :3]
    hair_a = hair_bgra[:, :, 3] / 255.0
    # borderMode CONSTANT avoids odd edge artifacts; value black (transparent)
    hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR,
                               borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    a = np.clip(a_warp * opacity, 0, 1)[..., None]
    out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
    return out

def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta):
    """No head-mask (prevents neck lines & cropping)."""
    if image is None:
        return None, "Upload a photo first."
    if not hairstyle:
        return np.array(image), "Pick a hairstyle."

    img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    kpts = detect_face_keypoints(img_bgr)
    if kpts is None:
        return image, "No face detected. Use a brighter, front-facing photo."

    hair = load_hair_png(hairstyle)
    hair_pts = hair_reference_points(hair, hairstyle, meta)

    # Target points = facial anchors + user nudges
    dst = kpts.copy()
    dst[:, 0] += dx
    dst[:, 1] += dy

    # Scale hair anchors around their centroid (no rotation for simplicity)
    center = hair_pts.mean(axis=0)
    s = max(0.5, scale_pct / 100.0)
    hair_pts_adj = (hair_pts - center) * s + center

    M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
    if M is None:
        return image, "Alignment failed for this image/style."

    out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
    out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
    return out_rgb, "OK"

def save_png_to_tmp(img, filename="output_tryon.png"):
    if img is None:
        raise gr.Error("No image to save. Click a hairstyle or 'Apply' first.")
    out_path = os.path.join(tempfile.gettempdir(), filename)
    if isinstance(img, np.ndarray):
        Image.fromarray(img).save(out_path)
    else:
        img.save(out_path)
    return out_path

# ---- white thumbnails with labels ----
def thumb_on_white(hair_bgra, max_h=220):
    h, w = hair_bgra.shape[:2]
    scale = min(1.0, max_h / h)
    hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
    h, w = hair_bgra.shape[:2]
    bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8)
    a = (hair_bgra[:, :, 3:4] / 255.0)
    comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
    return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)

def build_gallery_items(files):
    items = []
    for idx, fname in enumerate(files, start=1):
        try:
            img = load_hair_png(fname)
            items.append((thumb_on_white(img), f"{idx}. {fname}"))  # show number + filename
        except Exception:
            continue
    return items

# =============== UI ===============
def build_ui():
    META = load_meta()
    HAIR_FILES = load_hairstyles()

    with gr.Blocks(title="Salon Hairstyle Virtual Try-On (Simple)") as demo:
        gr.Markdown("Upload a photo, then **click a hairstyle** below. Use a few sliders if needed, then **Save result**.")

        selected_file = gr.State(None)  # currently selected hairstyle filename
        meta_state    = gr.State(META)
        files_state   = gr.State(HAIR_FILES)

        with gr.Tabs():
            with gr.Tab("πŸ“· Photo (Upload)"):
                with gr.Row():
                    in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
                    out_img = gr.Image(label="Preview", height=360)

                with gr.Row():
                    apply_btn = gr.Button("✨ Apply (optional)")
                    save_btn  = gr.Button("πŸ’Ύ Save result")
                    save_file = gr.File(label="Saved file", visible=False)

                with gr.Row():
                    refresh = gr.Button("πŸ”„ Refresh styles")

                count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.")
                gallery = gr.Gallery(
                    label="Hairstyles (click to apply)",
                    value=build_gallery_items(HAIR_FILES),
                    columns=6, rows=3, height=520,
                    allow_preview=False, object_fit="contain", show_label=True
                )

                with gr.Accordion("Fine-tune (simple)", open=True):
                    with gr.Row():
                        scale = gr.Slider(50, 200, 100, 1, label="Scale (temple distance %)")  # main size
                        opacity = gr.Slider(0.4, 1.0, 1.0, 0.05, label="Hair opacity")
                    with gr.Row():
                        dx    = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right (px)")
                        dy    = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down (px)")

                status = gr.Markdown("")

                # ----- actions -----
                def do_apply(im, hairfile, s, dxv, dyv, op, meta):
                    return apply_tryon(im, hairfile, s, dxv, dyv, op, meta)

                # 1) click a tile -> set selected file AND auto-apply
                def on_gallery_select(evt, files, im, s, dxv, dyv, op, meta):
                    idx = getattr(evt, "index", None)
                    if idx is None or not files:
                        return None, gr.update(), None
                    idx = max(0, min(idx, len(files)-1))
                    hairfile = files[idx]
                    out, msg = do_apply(im, hairfile, s, dxv, dyv, op, meta)
                    return hairfile, out, msg

                gallery.select(
                    on_gallery_select,
                    inputs=[files_state, in_img, scale, dx, dy, opacity, meta_state],
                    outputs=[selected_file, out_img, status]
                )

                # 2) Apply button (useful after slider tweaks)
                apply_btn.click(
                    fn=do_apply,
                    inputs=[in_img, selected_file, scale, dx, dy, opacity, meta_state],
                    outputs=[out_img, status]
                )

                # 3) Save
                def do_save(im):
                    path = save_png_to_tmp(im, "output_tryon.png")
                    return gr.File.update(value=path, visible=True)

                save_btn.click(fn=do_save, inputs=[out_img], outputs=[save_file])

                # 4) Refresh styles
                def do_refresh():
                    files = load_hairstyles()
                    items = build_gallery_items(files)
                    msg = f"Found {len(files)} hairstyles."
                    # Keep selection if name still exists
                    return items, files, msg

                refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, files_state, count_md])

        return demo

# export for Spaces
app = build_ui()
demo = app

if __name__ == "__main__":
    app.launch()