RojaKatta commited on
Commit
03edb30
Β·
verified Β·
1 Parent(s): 97c7a01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -243
app.py CHANGED
@@ -1,256 +1,315 @@
1
- import os, glob, math
2
- import numpy as np
3
  from PIL import Image
4
- import gradio as gr
5
 
6
- # Your hairstyles live here:
7
- HAIR_DIR = "hair"
8
-
9
- # Optional heavy deps (installed via requirements.txt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
- import cv2
12
  import mediapipe as mp
13
- MP_AVAILABLE = True
14
- except Exception:
15
- MP_AVAILABLE = False
16
-
17
- APP_TITLE = "Salon Hairstyle Virtual Try-On β€” Pro Demo"
18
-
19
- # --------------------------
20
- # Asset loading / utilities
21
- # --------------------------
22
- def list_hairstyle_files():
23
- files = sorted(glob.glob(os.path.join(HAIR_DIR, "*.png")))
24
- labels = [os.path.basename(f).replace(".png","").replace("_"," ").title() for f in files]
25
- thumbs = [f for f in files] # Gradio Gallery accepts path strings
26
- return labels, files, thumbs
27
-
28
- def ensure_rgba(im: Image.Image) -> Image.Image:
29
- return im if im.mode == "RGBA" else im.convert("RGBA")
30
-
31
- def overlay_rgba(base: Image.Image, overlay: Image.Image, x: int, y: int) -> Image.Image:
32
- base = ensure_rgba(base)
33
- overlay = ensure_rgba(overlay)
34
- bw, bh = base.size
35
- ow, oh = overlay.size
36
- if ow <= 0 or oh <= 0:
37
- return base
38
- x1, y1 = max(0, x), max(0, y)
39
- x2, y2 = min(bw, x + ow), min(bh, y + oh)
40
- if x1 >= x2 or y1 >= y2:
41
- return base
42
- crop = overlay.crop((x1 - x, y1 - y, x2 - x, y2 - y))
43
- region = base.crop((x1, y1, x2, y2))
44
- region = Image.alpha_composite(region, crop)
45
- base.paste(region, (x1, y1))
46
- return base
47
-
48
- # --------------------------
49
- # Face detection / anchoring
50
- # --------------------------
51
- def detect_face_keypoints(np_img):
52
- """
53
- Returns dict:
54
- bbox: (x1,y1,x2,y2)
55
- temples: (lx,ly, rx,ry)
56
- forehead_y: int
57
- """
58
- h, w = np_img.shape[:2]
59
-
60
- def fallback():
61
- cx, cy = w // 2, int(h * 0.42)
62
- bw, bh = int(w * 0.4), int(h * 0.45)
63
- x1, y1 = max(0, cx - bw // 2), max(0, cy - bh // 2)
64
- x2, y2 = min(w, x1 + bw), min(h, y1 + bh)
65
- return {
66
- "bbox": (x1, y1, x2, y2),
67
- "temples": (x1, (y1 + y2) // 2, x2, (y1 + y2) // 2),
68
- "forehead_y": max(0, int(y1 - 0.1 * (y2 - y1))),
69
- }
70
-
71
- if not MP_AVAILABLE:
72
- return fallback()
73
-
74
- mpfm = mp.solutions.face_mesh
75
- with mpfm.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=False) as fm:
76
- rgb = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) if np_img.shape[2] == 3 else np_img[..., :3]
77
- res = fm.process(rgb)
78
  if not res.multi_face_landmarks:
79
- return fallback()
80
-
81
  lm = res.multi_face_landmarks[0].landmark
82
- xs = np.array([p.x for p in lm]) * w
83
- ys = np.array([p.y for p in lm]) * h
84
- x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
85
-
86
- def safe_idx(i):
87
- i = int(i)
88
- i = max(0, min(len(lm) - 1, i))
89
- return int(xs[i]), int(ys[i])
90
-
91
- lx, ly = safe_idx(127) # approx left temple
92
- rx, ry = safe_idx(356) # approx right temple
93
- _, fy = safe_idx(10) # approx forehead
94
- y1 = max(0, int(y1 - 0.12 * (y2 - y1))) # include hairline
95
- return {
96
- "bbox": (x1, y1, x2, y2),
97
- "temples": (lx, ly, rx, ry),
98
- "forehead_y": fy,
99
- }
100
-
101
- def place_and_render(base_np, hair_path, scale, lr_shift_pct, ud_shift_pct, rotation_deg):
102
- if base_np is None:
103
- raise gr.Error("Please upload or capture a photo first.")
104
- base = Image.fromarray(base_np.astype("uint8")).convert("RGBA")
105
- key = detect_face_keypoints(np.array(base.convert("RGB")))
106
- x1, y1, x2, y2 = key["bbox"]
107
- lx, ly, rx, ry = key["temples"]
108
- forehead_y = key["forehead_y"]
109
-
110
- hair = Image.open(hair_path).convert("RGBA")
111
-
112
- # Scale from temple distance
113
- temple_dx = rx - lx
114
- temple_dy = ry - ly
115
- temple_dist = max(1, (temple_dx ** 2 + temple_dy ** 2) ** 0.5)
116
- target_w = max(1, int(temple_dist * 2.0 * scale))
117
- ratio = target_w / hair.width
118
- target_h = max(1, int(hair.height * ratio))
119
- hair_resized = hair.resize((target_w, target_h), Image.LANCZOS)
120
-
121
- # Auto-rotation + manual extra rotation
122
- auto_deg = math.degrees(math.atan2(temple_dy, temple_dx))
123
- rot_total = auto_deg + rotation_deg
124
- hair_resized = hair_resized.rotate(rot_total, expand=True)
125
-
126
- # Anchor at temple midpoint, above forehead
127
- midx = int((lx + rx) / 2)
128
- anchor_x = int(midx - hair_resized.width / 2)
129
- anchor_y = int(forehead_y - hair_resized.height * 0.45)
130
-
131
- # Apply shifts (percent of image size)
132
- img_w, img_h = base.size
133
- anchor_x += int(lr_shift_pct * img_w / 100.0) # +right, -left
134
- anchor_y += int(ud_shift_pct * img_h / 100.0) # +down, -up
135
-
136
- out = overlay_rgba(base, hair_resized, anchor_x, anchor_y).convert("RGB")
137
  return out
138
 
139
- # --------------------------
140
- # Gradio callbacks
141
- # --------------------------
142
- def refresh_assets():
143
- labels, files, thumbs = list_hairstyle_files()
144
- if not files:
145
- return (
146
- gr.update(value=None, choices=[]),
147
- gr.update(value=None),
148
- "No PNGs found in the 'hair/' folder.",
149
- )
150
- dd = gr.update(choices=labels, value=labels[0])
151
- gallery = gr.update(value=[[p, l] for p, l in zip(thumbs, labels)])
152
- return dd, gallery, f"Found {len(files)} hairstyles."
153
-
154
- def refresh_dropdown_only():
155
- labels, files, _ = list_hairstyle_files()
156
- if not files:
157
- return gr.update(value=None, choices=[])
158
- return gr.update(choices=labels, value=labels[0])
159
-
160
- def pick_from_gallery(gallery_select, current_dropdown):
161
- # When user clicks a gallery item, we set the dropdown to match
162
- labels, _, _ = list_hairstyle_files()
163
- if not labels:
164
- raise gr.Error("No hairstyles available.")
165
- if gallery_select is None:
166
- return current_dropdown
167
- idx = gallery_select[0]
168
- idx = max(0, min(len(labels) - 1, idx))
169
- return labels[idx]
170
-
171
- def apply_from_dropdown(image, hairstyle_label, scale, lr, ud, rot):
172
- labels, files, _ = list_hairstyle_files()
173
- if not labels:
174
- raise gr.Error("No hairstyle assets found.")
175
- if hairstyle_label not in labels:
176
- raise gr.Error("Choose a hairstyle first.")
177
- path = files[labels.index(hairstyle_label)]
178
- return place_and_render(image, path, scale, lr, ud, rot)
179
-
180
- def save_image(img_np):
181
- if img_np is None:
182
- raise gr.Error("Nothing to save. Generate a preview first.")
183
- out_path = "output_tryon.png"
184
- Image.fromarray(img_np).save(out_path)
 
 
 
 
 
185
  return out_path
186
 
187
- # --------------------------
188
- # UI
189
- # --------------------------
190
- with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo:
191
- gr.Markdown(f"# {APP_TITLE}")
192
- gr.Markdown("Upload a photo or use webcam. Put **transparent PNGs** in the hair/ folder, then click **Refresh**.")
193
-
194
- with gr.Tabs():
195
- # -------- Upload tab --------
196
- with gr.Tab("πŸ“· Photo (Upload)"):
197
- with gr.Row():
198
- with gr.Column():
199
- img = gr.Image(label="Photo", sources=["upload"], type="numpy", height=420)
200
-
201
- hair_dd = gr.Dropdown(label="Hairstyle", choices=[], interactive=True)
202
- refresh = gr.Button("πŸ”„ Refresh")
203
- status = gr.Markdown("Drop PNGs into hair/ and press Refresh.")
204
- gallery = gr.Gallery(
205
- label="Hairstyles (click to choose)",
206
- columns=4,
207
- height=220,
208
- allow_preview=False,
209
- interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  )
 
 
 
 
 
211
 
212
- with gr.Accordion("Fine-tune placement", open=False):
213
- scale = gr.Slider(0.6, 2.2, value=1.3, step=0.01, label="Scale (Γ— temple distance)")
214
- lr = gr.Slider(-30, 30, value=0, step=0.5, label="Left ↔ Right shift (%)")
215
- ud = gr.Slider(-30, 30, value=0, step=0.5, label="Up ↕ Down shift (%)")
216
- rot = gr.Slider(-30, 30, value=0, step=0.5, label="Extra rotation (Β°)")
217
-
218
- run_dd = gr.Button("✨ Apply")
219
-
220
- with gr.Column():
221
- out = gr.Image(label="Preview", height=480)
222
- save_btn = gr.Button("πŸ’Ύ Save result")
223
- file_out = gr.File(label="Download")
224
-
225
- # Wiring
226
- refresh.click(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
227
- demo.load(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
228
-
229
- # Clicking gallery picks the dropdown value; one Apply button only
230
- gallery.select(fn=pick_from_gallery, inputs=[gallery, hair_dd], outputs=hair_dd)
231
- run_dd.click(fn=apply_from_dropdown, inputs=[img, hair_dd, scale, lr, ud, rot], outputs=out)
232
- save_btn.click(fn=save_image, inputs=out, outputs=file_out)
233
-
234
- # -------- Webcam tab --------
235
- with gr.Tab("πŸŽ₯ Webcam (Live Beta)"):
236
- gr.Markdown("Live mode processes frames continuously. Keep resolution small on CPU Spaces.")
237
- cam = gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Webcam")
238
- hair_dd2 = gr.Dropdown(label="Hairstyle", choices=[], interactive=True)
239
- scale2 = gr.Slider(0.6, 2.2, value=1.25, step=0.01, label="Scale")
240
- lr2 = gr.Slider(-30, 30, value=0, step=0.5, label="Left ↔ Right shift (%)")
241
- ud2 = gr.Slider(-30, 30, value=0, step=0.5, label="Up ↕ Down shift (%)")
242
- rot2 = gr.Slider(-30, 30, value=0, step=0.5, label="Rotation (Β°)")
243
- out_live = gr.Image(label="Live Preview", interactive=False, height=420)
244
-
245
- def live_process(frame, label, s, lrp, udp, r):
246
- labels, files, _ = list_hairstyle_files()
247
- if frame is None or not labels or label not in labels:
248
- return frame
249
- path = files[labels.index(label)]
250
- return place_and_render(frame, path, s, lrp, udp, r)
251
-
252
- cam.stream(fn=live_process, inputs=[cam, hair_dd2, scale2, lr2, ud2, rot2], outputs=out_live)
253
- demo.load(fn=refresh_dropdown_only, inputs=None, outputs=hair_dd2)
254
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  if __name__ == "__main__":
256
- demo.launch()
 
1
+ import os, json, tempfile
2
+ import cv2, numpy as np, gradio as gr
3
  from PIL import Image
 
4
 
5
+ # ===================== Paths & Assets =====================
6
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ CANDIDATES = [
8
+ os.path.join(BASE_DIR, "hair"), # your folder
9
+ os.path.join(BASE_DIR, "assets", "hairstyles"),
10
+ os.path.join(BASE_DIR, "assets", "Hairstyles"),
11
+ os.path.join(BASE_DIR, "hairstyles"),
12
+ ]
13
+ HAIR_DIR = None
14
+ for p in CANDIDATES:
15
+ if os.path.isdir(p):
16
+ HAIR_DIR = p
17
+ break
18
+ if HAIR_DIR is None:
19
+ HAIR_DIR = os.path.join(BASE_DIR, "hair")
20
+ os.makedirs(HAIR_DIR, exist_ok=True)
21
+
22
+ META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
23
+
24
+ # ===================== Dependencies =====================
25
  try:
 
26
  import mediapipe as mp
27
+ except Exception as e:
28
+ raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}")
29
+
30
+ mp_face_mesh = mp.solutions.face_mesh
31
+ mp_selfie_seg = mp.solutions.selfie_segmentation
32
+ LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
33
+
34
+ # ===================== Utilities =====================
35
+ def load_hairstyles():
36
+ try:
37
+ files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
38
+ except FileNotFoundError:
39
+ files = []
40
+ files.sort()
41
+ return files
42
+
43
+ def load_meta():
44
+ if os.path.exists(META_PATH):
45
+ try:
46
+ with open(META_PATH, "r") as f:
47
+ m = json.load(f)
48
+ return m if isinstance(m, dict) else {}
49
+ except Exception:
50
+ return {}
51
+ return {}
52
+
53
+ def premultiply_alpha(bgra):
54
+ """Nicer blending (removes gray halos)."""
55
+ bgr = bgra[:, :, :3].astype(np.float32) / 255.0
56
+ a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
57
+ bgr_pm = (bgr * a * 255.0).astype(np.uint8)
58
+ return np.dstack([bgr_pm, bgra[:, :, 3]])
59
+
60
+ def load_hair_png(name):
61
+ path = os.path.join(HAIR_DIR, name)
62
+ hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
63
+ if hair is None or hair.shape[2] != 4:
64
+ raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
65
+ return premultiply_alpha(hair)
66
+
67
+ def detect_face_keypoints(img_bgr):
68
+ h, w = img_bgr.shape[:2]
69
+ with mp_face_mesh.FaceMesh(
70
+ static_image_mode=True, max_num_faces=1, refine_landmarks=True,
71
+ min_detection_confidence=0.6
72
+ ) as fm:
73
+ res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  if not res.multi_face_landmarks:
75
+ return None
 
76
  lm = res.multi_face_landmarks[0].landmark
77
+ def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
78
+ return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
79
+
80
+ def person_mask(img_bgr):
81
+ with mp_selfie_seg.SelfieSegmentation(model_selection=1) as seg:
82
+ rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
83
+ m = seg.process(rgb).segmentation_mask
84
+ mask = (m > 0.5).astype(np.float32)
85
+ mask = cv2.GaussianBlur(mask, (35, 35), 0)
86
+ return mask
87
+
88
+ def hair_reference_points(hair_bgra, filename, meta):
89
+ h, w = hair_bgra.shape[:2]
90
+ if filename in meta:
91
+ pts = np.array(meta[filename], dtype=np.float32)
92
+ if pts.shape == (3, 2):
93
+ return pts
94
+ # Defaults (OK for many styles). For perfect fit, add 3 points per file to meta.json.
95
+ pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
96
+ pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
97
+ pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
98
+ return np.stack([pL, pR, pM], axis=0)
99
+
100
+ def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
101
+ H, W = base_bgr.shape[:2]
102
+ hair_rgb = hair_bgra[:, :, :3]
103
+ hair_a = hair_bgra[:, :, 3] / 255.0
104
+ hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
105
+ a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
106
+ a = np.clip(a_warp * opacity, 0, 1)[..., None]
107
+ out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return out
109
 
110
+ def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity, meta):
111
+ if image is None:
112
+ return None, "Upload a photo or enable webcam."
113
+ if not hairstyle:
114
+ return np.array(image), "Pick a hairstyle first."
115
+
116
+ img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
117
+
118
+ kpts = detect_face_keypoints(img_bgr)
119
+ if kpts is None:
120
+ return image, "No face detected. Try a brighter, front-facing photo."
121
+
122
+ hair = load_hair_png(hairstyle)
123
+ hair_pts = hair_reference_points(hair, hairstyle, meta)
124
+
125
+ # Destination points (with user nudges)
126
+ dst = kpts.copy()
127
+ dst[:, 0] += dx
128
+ dst[:, 1] += dy
129
+
130
+ # Scale + rotate around hair anchor centroid
131
+ center = hair_pts.mean(axis=0)
132
+ theta = np.deg2rad(rot_deg)
133
+ s = max(0.5, scale_pct / 100.0)
134
+ R = np.array([[np.cos(theta), -np.sin(theta)],
135
+ [np.sin(theta), np.cos(theta)]], dtype=np.float32)
136
+ hair_pts_adj = (hair_pts - center) @ R.T * s + center
137
+
138
+ M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
139
+ if M is None:
140
+ return image, "Could not compute alignment for this image/style."
141
+
142
+ out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
143
+
144
+ # Restrict to head region for cleaner look
145
+ head = person_mask(img_bgr)
146
+ head3 = head[..., None]
147
+ out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8)
148
+
149
+ out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
150
+ return out_rgb, "OK"
151
+
152
+ def save_png_to_tmp(img, filename="output_tryon.png"):
153
+ """Create a file in /tmp and return the path."""
154
+ if img is None:
155
+ raise gr.Error("No image to download. Click Apply first.")
156
+ out_path = os.path.join(tempfile.gettempdir(), filename)
157
+ if isinstance(img, np.ndarray):
158
+ Image.fromarray(img).save(out_path)
159
+ else:
160
+ img.save(out_path)
161
  return out_path
162
 
163
+ # ----- WHITE background thumbnails (no checkerboard) -----
164
+ def thumb_on_white(hair_bgra, max_h=220):
165
+ h, w = hair_bgra.shape[:2]
166
+ scale = min(1.0, max_h / h)
167
+ hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
168
+ h, w = hair_bgra.shape[:2]
169
+ bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8) # white background
170
+ a = (hair_bgra[:, :, 3:4] / 255.0)
171
+ comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
172
+ return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
173
+
174
+ def build_gallery_items(files):
175
+ items = []
176
+ for fname in files:
177
+ try:
178
+ img = load_hair_png(fname)
179
+ items.append((thumb_on_white(img), fname)) # (image, caption)
180
+ except Exception:
181
+ continue
182
+ return items
183
+
184
+ # ===================== UI =====================
185
+ def build_ui():
186
+ META = load_meta()
187
+ HAIR_FILES = load_hairstyles()
188
+
189
+ with gr.Blocks(title="Salon Hairstyle Virtual Try-On", css="""
190
+ .gradio-container {max-width: 1200px; margin:auto;} /* wider so more tiles fit */
191
+ @media (max-width: 768px){ .gradio-container {padding: 8px;} }
192
+ """) as demo:
193
+ gr.Markdown("Upload a photo or use webcam. Put transparent **PNGs** in the **`hair/`** folder, then click **Refresh**.")
194
+
195
+ files_state = gr.State(HAIR_FILES) # keep filenames
196
+ meta_state = gr.State(META)
197
+
198
+ with gr.Tabs():
199
+ # ---------------- Photo Tab ----------------
200
+ with gr.Tab("πŸ“· Photo (Upload)"):
201
+ with gr.Row():
202
+ in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
203
+ out_img = gr.Image(label="Preview", height=360)
204
+ with gr.Row():
205
+ hair_sel = gr.Dropdown(
206
+ choices=HAIR_FILES,
207
+ value=(HAIR_FILES[0] if HAIR_FILES else None),
208
+ label="Selected hairstyle",
209
+ interactive=True
210
  )
211
+ apply_btn = gr.Button("✨ Apply (Align & Overlay)")
212
+ # One-click download + visible link fallback
213
+ download_btn = gr.DownloadButton("⬇️ Download")
214
+ download_file = gr.File(label="download (fallback)", visible=False)
215
+ status = gr.Markdown()
216
 
217
+ with gr.Row():
218
+ refresh = gr.Button("πŸ”„ Refresh")
219
+ count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.")
220
+ gallery = gr.Gallery(
221
+ label="Hairstyles (click to choose)",
222
+ value=build_gallery_items(HAIR_FILES),
223
+ columns=6, rows=2, height=340, # up to 12 visible at once
224
+ allow_preview=False, object_fit="contain", show_label=True
225
+ )
226
+
227
+ with gr.Accordion("Fine-tune placement", open=True):
228
+ with gr.Row():
229
+ scale = gr.Slider(50, 200, 100, 1, label="Scale (β‰ˆ temple distance %)")
230
+ rot = gr.Slider(-30, 30, 0, 1, label="Extra rotation (Β°)")
231
+ with gr.Row():
232
+ dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right shift (px)")
233
+ dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down shift (px)")
234
+ opacity = gr.Slider(0.2, 1.0, 1.0, 0.05, label="Hair opacity")
235
+
236
+ # --- Callbacks ---
237
+ def do_apply(im, hfile, s, r, dxv, dyv, op, meta):
238
+ return apply_tryon(im, hfile, s, r, dxv, dyv, op, meta)
239
+
240
+ apply_btn.click(
241
+ fn=do_apply,
242
+ inputs=[in_img, hair_sel, scale, rot, dx, dy, opacity, meta_state],
243
+ outputs=[out_img, status]
244
+ )
245
+
246
+ # Return path for one-click download AND show fallback file link
247
+ def prepare_download_dual(im):
248
+ path = save_png_to_tmp(im, "output_tryon.png")
249
+ return path, gr.File.update(value=path, visible=True)
250
+
251
+ download_btn.click(fn=prepare_download_dual, inputs=[out_img], outputs=[download_btn, download_file])
252
+
253
+ def do_refresh():
254
+ files = load_hairstyles()
255
+ items = build_gallery_items(files)
256
+ msg = f"Found {len(files)} hairstyles."
257
+ return items, gr.update(choices=files, value=(files[0] if files else None)), files, msg
258
+
259
+ refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, hair_sel, files_state, count_md])
260
+
261
+ # Clicking a tile sets the dropdown to that filename
262
+ def on_gallery_select(evt, files):
263
+ idx = getattr(evt, "index", None)
264
+ if idx is None or not files:
265
+ return gr.update()
266
+ if idx >= len(files):
267
+ idx = len(files) - 1
268
+ return gr.update(value=files[idx])
269
+
270
+ gallery.select(on_gallery_select, inputs=[files_state], outputs=[hair_sel])
271
+
272
+ # ---------------- Webcam Tab ----------------
273
+ with gr.Tab("πŸ“Ή Webcam (Live Beta)"):
274
+ cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Enable camera")
275
+ hair2 = gr.Dropdown(choices=HAIR_FILES, value=(HAIR_FILES[0] if HAIR_FILES else None), label="Selected hairstyle")
276
+ with gr.Row():
277
+ scale2 = gr.Slider(50, 200, 100, 1, label="Scale %")
278
+ rot2 = gr.Slider(-25, 25, 0, 1, label="Rotate (Β°)")
279
+ with gr.Row():
280
+ dx2 = gr.Slider(-150, 150, 0, 1, label="Left ↔ Right (px)")
281
+ dy2 = gr.Slider(-150, 150, 0, 1, label="Up ↕ Down (px)")
282
+ opacity2 = gr.Slider(0.2, 1.0, 0.95, 0.05, label="Hair opacity")
283
+ out2 = gr.Image(label="Live result", height=360)
284
+ state_live = gr.State(None)
285
+ snap = gr.Button("πŸ“Έ Snapshot")
286
+ save_live_btn = gr.DownloadButton("⬇️ Download Snapshot")
287
+ save_live_file = gr.File(label="snapshot (fallback)", visible=False)
288
+
289
+ def live(im, h, s, r, dxv, dyv, op, meta):
290
+ res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta)
291
+ return res, res
292
+
293
+ cam.stream(
294
+ fn=live,
295
+ inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2, meta_state],
296
+ outputs=[out2, state_live]
297
+ )
298
+
299
+ snap.click(lambda x: x, inputs=[state_live], outputs=[out2])
300
+
301
+ def prepare_webcam_download_dual(im):
302
+ path = save_png_to_tmp(im, "tryon_webcam.png")
303
+ return path, gr.File.update(value=path, visible=True)
304
+
305
+ save_live_btn.click(fn=prepare_webcam_download_dual, inputs=[state_live], outputs=[save_live_btn, save_live_file])
306
+
307
+ return demo
308
+
309
+ # Export for Spaces autostart
310
+ app = build_ui()
311
+ demo = app
312
+
313
+ # Local dev
314
  if __name__ == "__main__":
315
+ app.launch()