RojaKatta commited on
Commit
95477cb
Β·
verified Β·
1 Parent(s): 30cc0b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -304
app.py CHANGED
@@ -1,317 +1,256 @@
1
- import os, json, tempfile
2
- import cv2, numpy as np, gradio as gr
3
  from PIL import Image
 
4
 
5
- # ===================== Paths & Assets =====================
6
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
- CANDIDATES = [
8
- os.path.join(BASE_DIR, "hair"), # your folder
9
- os.path.join(BASE_DIR, "assets", "hairstyles"),
10
- os.path.join(BASE_DIR, "assets", "Hairstyles"),
11
- os.path.join(BASE_DIR, "hairstyles"),
12
- ]
13
- HAIR_DIR = None
14
- for p in CANDIDATES:
15
- if os.path.isdir(p):
16
- HAIR_DIR = p
17
- break
18
- if HAIR_DIR is None:
19
- HAIR_DIR = os.path.join(BASE_DIR, "hair")
20
- os.makedirs(HAIR_DIR, exist_ok=True)
21
-
22
- META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
23
-
24
- # ===================== Dependencies =====================
25
  try:
 
26
  import mediapipe as mp
27
- except Exception as e:
28
- raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}")
29
-
30
- mp_face_mesh = mp.solutions.face_mesh
31
- mp_selfie_seg = mp.solutions.selfie_segmentation
32
- LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
33
-
34
- # ===================== Utilities =====================
35
- def load_hairstyles():
36
- try:
37
- files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
38
- except FileNotFoundError:
39
- files = []
40
- files.sort()
41
- return files
42
-
43
- def load_meta():
44
- if os.path.exists(META_PATH):
45
- try:
46
- with open(META_PATH, "r") as f:
47
- m = json.load(f)
48
- return m if isinstance(m, dict) else {}
49
- except Exception:
50
- return {}
51
- return {}
52
-
53
- def premultiply_alpha(bgra):
54
- """Reduce gray/white halos on edges for nicer blending."""
55
- bgr = bgra[:, :, :3].astype(np.float32) / 255.0
56
- a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
57
- bgr_pm = (bgr * a * 255.0).astype(np.uint8)
58
- return np.dstack([bgr_pm, bgra[:, :, 3]])
59
-
60
- def load_hair_png(name):
61
- path = os.path.join(HAIR_DIR, name)
62
- hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
63
- if hair is None or hair.shape[2] != 4:
64
- raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
65
- # Improve visual quality
66
- hair = premultiply_alpha(hair)
67
- return hair
68
-
69
- def detect_face_keypoints(img_bgr):
70
- h, w = img_bgr.shape[:2]
71
- with mp_face_mesh.FaceMesh(
72
- static_image_mode=True, max_num_faces=1, refine_landmarks=True,
73
- min_detection_confidence=0.6
74
- ) as fm:
75
- res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if not res.multi_face_landmarks:
77
- return None
 
78
  lm = res.multi_face_landmarks[0].landmark
79
- def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
80
- return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
81
-
82
- def person_mask(img_bgr):
83
- with mp_selfie_seg.SelfieSegmentation(model_selection=1) as seg:
84
- rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
85
- m = seg.process(rgb).segmentation_mask
86
- mask = (m > 0.5).astype(np.float32)
87
- mask = cv2.GaussianBlur(mask, (35, 35), 0)
88
- return mask
89
-
90
- def hair_reference_points(hair_bgra, filename, meta):
91
- h, w = hair_bgra.shape[:2]
92
- if filename in meta:
93
- pts = np.array(meta[filename], dtype=np.float32)
94
- if pts.shape == (3, 2):
95
- return pts
96
- # Defaults (ok for many styles). For perfect fit, add 3 points per file to meta.json.
97
- pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
98
- pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
99
- pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
100
- return np.stack([pL, pR, pM], axis=0)
101
-
102
- def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
103
- H, W = base_bgr.shape[:2]
104
- hair_rgb = hair_bgra[:, :, :3]
105
- hair_a = hair_bgra[:, :, 3] / 255.0
106
- hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
107
- a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
108
- a = np.clip(a_warp * opacity, 0, 1)[..., None]
109
- out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  return out
111
 
112
- def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity, meta):
113
- if image is None:
114
- return None, "Upload a photo or enable webcam."
115
- if not hairstyle:
116
- return np.array(image), "Pick a hairstyle first."
117
-
118
- img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
119
-
120
- kpts = detect_face_keypoints(img_bgr)
121
- if kpts is None:
122
- return image, "No face detected. Try a brighter, front-facing photo."
123
-
124
- hair = load_hair_png(hairstyle)
125
- hair_pts = hair_reference_points(hair, hairstyle, meta)
126
-
127
- # Destination points (with user nudges)
128
- dst = kpts.copy()
129
- dst[:, 0] += dx
130
- dst[:, 1] += dy
131
-
132
- # Scale + rotate around hair anchor centroid
133
- center = hair_pts.mean(axis=0)
134
- theta = np.deg2rad(rot_deg)
135
- s = max(0.5, scale_pct / 100.0)
136
- R = np.array([[np.cos(theta), -np.sin(theta)],
137
- [np.sin(theta), np.cos(theta)]], dtype=np.float32)
138
- hair_pts_adj = (hair_pts - center) @ R.T * s + center
139
-
140
- M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
141
- if M is None:
142
- return image, "Could not compute alignment for this image/style."
143
-
144
- out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
145
-
146
- # Restrict to head region for cleaner look
147
- head = person_mask(img_bgr)
148
- head3 = head[..., None]
149
- out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8)
150
-
151
- out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
152
- return out_rgb, "OK"
153
-
154
- def save_png_to_tmp(img, filename="output_tryon.png"):
155
- """Create a file in /tmp and return the path (works reliably on Spaces)."""
156
- if img is None:
157
- raise gr.Error("No image to download. Click Apply first.")
158
- out_path = os.path.join(tempfile.gettempdir(), filename)
159
- if isinstance(img, np.ndarray):
160
- Image.fromarray(img).save(out_path)
161
- else:
162
- img.save(out_path)
163
  return out_path
164
 
165
- # ----- thumbnails on checkerboard for the gallery -----
166
- def thumb_on_checker(hair_bgra, max_h=220):
167
- h, w = hair_bgra.shape[:2]
168
- scale = min(1.0, max_h / h)
169
- hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
170
- h, w = hair_bgra.shape[:2]
171
- tile = 12
172
- bg = np.kron(((np.indices((h//tile+1, w//tile+1)).sum(axis=0) % 2) * 64 + 192).astype(np.uint8),
173
- np.ones((tile, tile), np.uint8))[:h, :w]
174
- bg_rgb = np.dstack([bg, bg, bg])
175
- a = (hair_bgra[:, :, 3:4] / 255.0)
176
- comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
177
- return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
178
-
179
- def build_gallery_items(files):
180
- items = []
181
- for fname in files:
182
- try:
183
- img = load_hair_png(fname)
184
- items.append((thumb_on_checker(img), fname)) # (image, caption)
185
- except Exception:
186
- continue
187
- return items
188
-
189
- # ===================== UI =====================
190
- def build_ui():
191
- META = load_meta()
192
- HAIR_FILES = load_hairstyles()
193
-
194
- with gr.Blocks(title="Salon Hairstyle Virtual Try-On β€” Pro Demo", css="""
195
- .gradio-container {max-width: 1100px; margin:auto;}
196
- @media (max-width: 768px){ .gradio-container {padding: 8px;} }
197
- """) as demo:
198
- gr.Markdown("Upload a photo or use webcam. Put transparent **PNGs** in the **`hair/`** folder, then click **Refresh**.")
199
-
200
- files_state = gr.State(HAIR_FILES) # keep filenames
201
- meta_state = gr.State(META)
202
-
203
- with gr.Tabs():
204
- # ---------------- Photo Tab ----------------
205
- with gr.Tab("πŸ“· Photo (Upload)"):
206
- with gr.Row():
207
- in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
208
- out_img = gr.Image(label="Preview", height=360)
209
- with gr.Row():
210
- hair_sel = gr.Dropdown(
211
- choices=HAIR_FILES,
212
- value=(HAIR_FILES[0] if HAIR_FILES else None),
213
- label="Selected hairstyle",
214
- interactive=True
215
  )
216
- apply_btn = gr.Button("✨ Apply (Align & Overlay)")
217
- # Reliable download: a button that fills a File link.
218
- download_btn = gr.Button("⬇️ Download")
219
- download_file = gr.File(label="download", visible=False)
220
- status = gr.Markdown()
221
 
222
- with gr.Row():
223
- refresh = gr.Button("πŸ”„ Refresh")
224
- gallery = gr.Gallery(
225
- label="Hairstyles (click to choose)",
226
- value=build_gallery_items(HAIR_FILES),
227
- columns=6, rows=2, height=320,
228
- allow_preview=False, object_fit="contain", show_label=True
229
- )
230
-
231
- with gr.Accordion("Fine-tune placement", open=True):
232
- with gr.Row():
233
- scale = gr.Slider(50, 200, 100, 1, label="Scale (β‰ˆ temple distance %)")
234
- rot = gr.Slider(-30, 30, 0, 1, label="Extra rotation (Β°)")
235
- with gr.Row():
236
- dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right shift (px)")
237
- dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down shift (px)")
238
- opacity = gr.Slider(0.2, 1.0, 1.0, 0.05, label="Hair opacity")
239
-
240
- # --- Callbacks ---
241
- def do_apply(im, hfile, s, r, dxv, dyv, op, meta):
242
- return apply_tryon(im, hfile, s, r, dxv, dyv, op, meta)
243
-
244
- apply_btn.click(
245
- fn=do_apply,
246
- inputs=[in_img, hair_sel, scale, rot, dx, dy, opacity, meta_state],
247
- outputs=[out_img, status]
248
- )
249
-
250
- def prepare_download(im):
251
- path = save_png_to_tmp(im, "output_tryon.png")
252
- return gr.File.update(value=path, visible=True)
253
-
254
- download_btn.click(fn=prepare_download, inputs=[out_img], outputs=[download_file])
255
-
256
- def do_refresh():
257
- files = load_hairstyles()
258
- items = build_gallery_items(files)
259
- return items, gr.update(choices=files, value=(files[0] if files else None)), files
260
-
261
- refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, hair_sel, files_state])
262
-
263
- # Clicking a tile sets the dropdown to that filename
264
- def on_gallery_select(evt, files):
265
- idx = getattr(evt, "index", None)
266
- if idx is None or not files:
267
- return gr.update()
268
- if idx >= len(files):
269
- idx = len(files) - 1
270
- return gr.update(value=files[idx])
271
-
272
- gallery.select(on_gallery_select, inputs=[files_state], outputs=[hair_sel])
273
-
274
- # ---------------- Webcam Tab ----------------
275
- with gr.Tab("πŸ“Ή Webcam (Live Beta)"):
276
- cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Enable camera")
277
- hair2 = gr.Dropdown(choices=HAIR_FILES, value=(HAIR_FILES[0] if HAIR_FILES else None), label="Selected hairstyle")
278
- with gr.Row():
279
- scale2 = gr.Slider(50, 200, 100, 1, label="Scale %")
280
- rot2 = gr.Slider(-25, 25, 0, 1, label="Rotate (Β°)")
281
- with gr.Row():
282
- dx2 = gr.Slider(-150, 150, 0, 1, label="Left ↔ Right (px)")
283
- dy2 = gr.Slider(-150, 150, 0, 1, label="Up ↕ Down (px)")
284
- opacity2 = gr.Slider(0.2, 1.0, 0.95, 0.05, label="Hair opacity")
285
- out2 = gr.Image(label="Live result", height=360)
286
- state_live = gr.State(None)
287
- snap = gr.Button("πŸ“Έ Snapshot")
288
- save_live_btn = gr.Button("⬇️ Download Snapshot")
289
- save_live_file = gr.File(label="snapshot", visible=False)
290
-
291
- def live(im, h, s, r, dxv, dyv, op, meta):
292
- res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta)
293
- return res, res
294
-
295
- cam.stream(
296
- fn=live,
297
- inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2, meta_state],
298
- outputs=[out2, state_live]
299
- )
300
-
301
- snap.click(lambda x: x, inputs=[state_live], outputs=[out2])
302
-
303
- def prepare_webcam_download(im):
304
- path = save_png_to_tmp(im, "tryon_webcam.png")
305
- return gr.File.update(value=path, visible=True)
306
-
307
- save_live_btn.click(fn=prepare_webcam_download, inputs=[state_live], outputs=[save_live_file])
308
-
309
- return demo
310
-
311
- # Export for Spaces autostart
312
- app = build_ui()
313
- demo = app
314
-
315
- # Local dev
316
  if __name__ == "__main__":
317
- app.launch()
 
1
+ import os, glob, math
2
+ import numpy as np
3
  from PIL import Image
4
+ import gradio as gr
5
 
6
+ # Your hairstyles live here:
7
+ HAIR_DIR = "hair"
8
+
9
+ # Optional heavy deps (installed via requirements.txt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
+ import cv2
12
  import mediapipe as mp
13
+ MP_AVAILABLE = True
14
+ except Exception:
15
+ MP_AVAILABLE = False
16
+
17
+ APP_TITLE = "Salon Hairstyle Virtual Try-On β€” Pro Demo"
18
+
19
+ # --------------------------
20
+ # Asset loading / utilities
21
+ # --------------------------
22
+ def list_hairstyle_files():
23
+ files = sorted(glob.glob(os.path.join(HAIR_DIR, "*.png")))
24
+ labels = [os.path.basename(f).replace(".png","").replace("_"," ").title() for f in files]
25
+ thumbs = [f for f in files] # Gradio Gallery accepts path strings
26
+ return labels, files, thumbs
27
+
28
+ def ensure_rgba(im: Image.Image) -> Image.Image:
29
+ return im if im.mode == "RGBA" else im.convert("RGBA")
30
+
31
+ def overlay_rgba(base: Image.Image, overlay: Image.Image, x: int, y: int) -> Image.Image:
32
+ base = ensure_rgba(base)
33
+ overlay = ensure_rgba(overlay)
34
+ bw, bh = base.size
35
+ ow, oh = overlay.size
36
+ if ow <= 0 or oh <= 0:
37
+ return base
38
+ x1, y1 = max(0, x), max(0, y)
39
+ x2, y2 = min(bw, x + ow), min(bh, y + oh)
40
+ if x1 >= x2 or y1 >= y2:
41
+ return base
42
+ crop = overlay.crop((x1 - x, y1 - y, x2 - x, y2 - y))
43
+ region = base.crop((x1, y1, x2, y2))
44
+ region = Image.alpha_composite(region, crop)
45
+ base.paste(region, (x1, y1))
46
+ return base
47
+
48
+ # --------------------------
49
+ # Face detection / anchoring
50
+ # --------------------------
51
+ def detect_face_keypoints(np_img):
52
+ """
53
+ Returns dict:
54
+ bbox: (x1,y1,x2,y2)
55
+ temples: (lx,ly, rx,ry)
56
+ forehead_y: int
57
+ """
58
+ h, w = np_img.shape[:2]
59
+
60
+ def fallback():
61
+ cx, cy = w // 2, int(h * 0.42)
62
+ bw, bh = int(w * 0.4), int(h * 0.45)
63
+ x1, y1 = max(0, cx - bw // 2), max(0, cy - bh // 2)
64
+ x2, y2 = min(w, x1 + bw), min(h, y1 + bh)
65
+ return {
66
+ "bbox": (x1, y1, x2, y2),
67
+ "temples": (x1, (y1 + y2) // 2, x2, (y1 + y2) // 2),
68
+ "forehead_y": max(0, int(y1 - 0.1 * (y2 - y1))),
69
+ }
70
+
71
+ if not MP_AVAILABLE:
72
+ return fallback()
73
+
74
+ mpfm = mp.solutions.face_mesh
75
+ with mpfm.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=False) as fm:
76
+ rgb = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) if np_img.shape[2] == 3 else np_img[..., :3]
77
+ res = fm.process(rgb)
78
  if not res.multi_face_landmarks:
79
+ return fallback()
80
+
81
  lm = res.multi_face_landmarks[0].landmark
82
+ xs = np.array([p.x for p in lm]) * w
83
+ ys = np.array([p.y for p in lm]) * h
84
+ x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
85
+
86
+ def safe_idx(i):
87
+ i = int(i)
88
+ i = max(0, min(len(lm) - 1, i))
89
+ return int(xs[i]), int(ys[i])
90
+
91
+ lx, ly = safe_idx(127) # approx left temple
92
+ rx, ry = safe_idx(356) # approx right temple
93
+ _, fy = safe_idx(10) # approx forehead
94
+ y1 = max(0, int(y1 - 0.12 * (y2 - y1))) # include hairline
95
+ return {
96
+ "bbox": (x1, y1, x2, y2),
97
+ "temples": (lx, ly, rx, ry),
98
+ "forehead_y": fy,
99
+ }
100
+
101
+ def place_and_render(base_np, hair_path, scale, lr_shift_pct, ud_shift_pct, rotation_deg):
102
+ if base_np is None:
103
+ raise gr.Error("Please upload or capture a photo first.")
104
+ base = Image.fromarray(base_np.astype("uint8")).convert("RGBA")
105
+ key = detect_face_keypoints(np.array(base.convert("RGB")))
106
+ x1, y1, x2, y2 = key["bbox"]
107
+ lx, ly, rx, ry = key["temples"]
108
+ forehead_y = key["forehead_y"]
109
+
110
+ hair = Image.open(hair_path).convert("RGBA")
111
+
112
+ # Scale from temple distance
113
+ temple_dx = rx - lx
114
+ temple_dy = ry - ly
115
+ temple_dist = max(1, (temple_dx ** 2 + temple_dy ** 2) ** 0.5)
116
+ target_w = max(1, int(temple_dist * 2.0 * scale))
117
+ ratio = target_w / hair.width
118
+ target_h = max(1, int(hair.height * ratio))
119
+ hair_resized = hair.resize((target_w, target_h), Image.LANCZOS)
120
+
121
+ # Auto-rotation + manual extra rotation
122
+ auto_deg = math.degrees(math.atan2(temple_dy, temple_dx))
123
+ rot_total = auto_deg + rotation_deg
124
+ hair_resized = hair_resized.rotate(rot_total, expand=True)
125
+
126
+ # Anchor at temple midpoint, above forehead
127
+ midx = int((lx + rx) / 2)
128
+ anchor_x = int(midx - hair_resized.width / 2)
129
+ anchor_y = int(forehead_y - hair_resized.height * 0.45)
130
+
131
+ # Apply shifts (percent of image size)
132
+ img_w, img_h = base.size
133
+ anchor_x += int(lr_shift_pct * img_w / 100.0) # +right, -left
134
+ anchor_y += int(ud_shift_pct * img_h / 100.0) # +down, -up
135
+
136
+ out = overlay_rgba(base, hair_resized, anchor_x, anchor_y).convert("RGB")
137
  return out
138
 
139
+ # --------------------------
140
+ # Gradio callbacks
141
+ # --------------------------
142
+ def refresh_assets():
143
+ labels, files, thumbs = list_hairstyle_files()
144
+ if not files:
145
+ return (
146
+ gr.update(value=None, choices=[]),
147
+ gr.update(value=None),
148
+ "No PNGs found in the 'hair/' folder.",
149
+ )
150
+ dd = gr.update(choices=labels, value=labels[0])
151
+ gallery = gr.update(value=[[p, l] for p, l in zip(thumbs, labels)])
152
+ return dd, gallery, f"Found {len(files)} hairstyles."
153
+
154
+ def refresh_dropdown_only():
155
+ labels, files, _ = list_hairstyle_files()
156
+ if not files:
157
+ return gr.update(value=None, choices=[])
158
+ return gr.update(choices=labels, value=labels[0])
159
+
160
+ def pick_from_gallery(gallery_select, current_dropdown):
161
+ # When user clicks a gallery item, we set the dropdown to match
162
+ labels, _, _ = list_hairstyle_files()
163
+ if not labels:
164
+ raise gr.Error("No hairstyles available.")
165
+ if gallery_select is None:
166
+ return current_dropdown
167
+ idx = gallery_select[0]
168
+ idx = max(0, min(len(labels) - 1, idx))
169
+ return labels[idx]
170
+
171
+ def apply_from_dropdown(image, hairstyle_label, scale, lr, ud, rot):
172
+ labels, files, _ = list_hairstyle_files()
173
+ if not labels:
174
+ raise gr.Error("No hairstyle assets found.")
175
+ if hairstyle_label not in labels:
176
+ raise gr.Error("Choose a hairstyle first.")
177
+ path = files[labels.index(hairstyle_label)]
178
+ return place_and_render(image, path, scale, lr, ud, rot)
179
+
180
+ def save_image(img_np):
181
+ if img_np is None:
182
+ raise gr.Error("Nothing to save. Generate a preview first.")
183
+ out_path = "output_tryon.png"
184
+ Image.fromarray(img_np).save(out_path)
 
 
 
 
 
185
  return out_path
186
 
187
+ # --------------------------
188
+ # UI
189
+ # --------------------------
190
+ with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo:
191
+ gr.Markdown(f"# {APP_TITLE}")
192
+ gr.Markdown("Upload a photo or use webcam. Put **transparent PNGs** in the hair/ folder, then click **Refresh**.")
193
+
194
+ with gr.Tabs():
195
+ # -------- Upload tab --------
196
+ with gr.Tab("πŸ“· Photo (Upload)"):
197
+ with gr.Row():
198
+ with gr.Column():
199
+ img = gr.Image(label="Photo", sources=["upload"], type="numpy", height=420)
200
+
201
+ hair_dd = gr.Dropdown(label="Hairstyle", choices=[], interactive=True)
202
+ refresh = gr.Button("πŸ”„ Refresh")
203
+ status = gr.Markdown("Drop PNGs into hair/ and press Refresh.")
204
+ gallery = gr.Gallery(
205
+ label="Hairstyles (click to choose)",
206
+ columns=4,
207
+ height=220,
208
+ allow_preview=False,
209
+ interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  )
 
 
 
 
 
211
 
212
+ with gr.Accordion("Fine-tune placement", open=False):
213
+ scale = gr.Slider(0.6, 2.2, value=1.3, step=0.01, label="Scale (Γ— temple distance)")
214
+ lr = gr.Slider(-30, 30, value=0, step=0.5, label="Left ↔ Right shift (%)")
215
+ ud = gr.Slider(-30, 30, value=0, step=0.5, label="Up ↕ Down shift (%)")
216
+ rot = gr.Slider(-30, 30, value=0, step=0.5, label="Extra rotation (Β°)")
217
+
218
+ run_dd = gr.Button("✨ Apply")
219
+
220
+ with gr.Column():
221
+ out = gr.Image(label="Preview", height=480)
222
+ save_btn = gr.Button("πŸ’Ύ Save result")
223
+ file_out = gr.File(label="Download")
224
+
225
+ # Wiring
226
+ refresh.click(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
227
+ demo.load(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
228
+
229
+ # Clicking gallery picks the dropdown value; one Apply button only
230
+ gallery.select(fn=pick_from_gallery, inputs=[gallery, hair_dd], outputs=hair_dd)
231
+ run_dd.click(fn=apply_from_dropdown, inputs=[img, hair_dd, scale, lr, ud, rot], outputs=out)
232
+ save_btn.click(fn=save_image, inputs=out, outputs=file_out)
233
+
234
+ # -------- Webcam tab --------
235
+ with gr.Tab("πŸŽ₯ Webcam (Live Beta)"):
236
+ gr.Markdown("Live mode processes frames continuously. Keep resolution small on CPU Spaces.")
237
+ cam = gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Webcam")
238
+ hair_dd2 = gr.Dropdown(label="Hairstyle", choices=[], interactive=True)
239
+ scale2 = gr.Slider(0.6, 2.2, value=1.25, step=0.01, label="Scale")
240
+ lr2 = gr.Slider(-30, 30, value=0, step=0.5, label="Left ↔ Right shift (%)")
241
+ ud2 = gr.Slider(-30, 30, value=0, step=0.5, label="Up ↕ Down shift (%)")
242
+ rot2 = gr.Slider(-30, 30, value=0, step=0.5, label="Rotation (Β°)")
243
+ out_live = gr.Image(label="Live Preview", interactive=False, height=420)
244
+
245
+ def live_process(frame, label, s, lrp, udp, r):
246
+ labels, files, _ = list_hairstyle_files()
247
+ if frame is None or not labels or label not in labels:
248
+ return frame
249
+ path = files[labels.index(label)]
250
+ return place_and_render(frame, path, s, lrp, udp, r)
251
+
252
+ cam.stream(fn=live_process, inputs=[cam, hair_dd2, scale2, lr2, ud2, rot2], outputs=out_live)
253
+ demo.load(fn=refresh_dropdown_only, inputs=None, outputs=hair_dd2)
254
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  if __name__ == "__main__":
256
+ demo.launch()