RojaKatta commited on
Commit
b792b3d
Β·
verified Β·
1 Parent(s): 072d76f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -130
app.py CHANGED
@@ -1,11 +1,11 @@
1
- import os, json, tempfile, re, traceback
2
  import cv2, numpy as np, gradio as gr
3
  from PIL import Image
4
 
5
- # -------- Paths --------
6
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
  CANDIDATES = [
8
- os.path.join(BASE_DIR, "hair"),
9
  os.path.join(BASE_DIR, "assets", "hairstyles"),
10
  os.path.join(BASE_DIR, "assets", "Hairstyles"),
11
  os.path.join(BASE_DIR, "hairstyles"),
@@ -17,18 +17,19 @@ if HAIR_DIR is None:
17
 
18
  META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
19
 
20
- # -------- Deps --------
21
  try:
22
  import mediapipe as mp
23
  except Exception as e:
24
  raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}")
25
 
26
  mp_face_mesh = mp.solutions.face_mesh
 
27
  LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
28
 
29
- # -------- Helpers --------
30
  def natural_key(s: str):
31
- """Sort 'photo1.png'...'photo11.png' in numeric order."""
32
  return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]
33
 
34
  def load_hairstyles():
@@ -50,7 +51,7 @@ def load_meta():
50
  return {}
51
 
52
  def premultiply_alpha(bgra):
53
- """Removes gray/white halos on edges."""
54
  bgr = bgra[:, :, :3].astype(np.float32) / 255.0
55
  a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
56
  bgr_pm = (bgr * a * 255.0).astype(np.uint8)
@@ -59,8 +60,8 @@ def premultiply_alpha(bgra):
59
  def load_hair_png(name):
60
  path = os.path.join(HAIR_DIR, name)
61
  hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
62
- if hair is None or hair.ndim != 3 or hair.shape[2] != 4:
63
- raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG in {HAIR_DIR})")
64
  return premultiply_alpha(hair)
65
 
66
  def detect_face_keypoints(img_bgr):
@@ -76,14 +77,25 @@ def detect_face_keypoints(img_bgr):
76
  def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
77
  return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def hair_reference_points(hair_bgra, filename, meta):
80
- """Return 3 anchor points (L-eye, R-eye, mid-forehead) in hair PNG space."""
81
  h, w = hair_bgra.shape[:2]
82
  if filename in meta:
83
  pts = np.array(meta[filename], dtype=np.float32)
84
  if pts.shape == (3, 2):
85
  return pts
86
- # Defaults (OK for many styles). For perfection, put 3 points per file in hair/meta.json.
87
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
88
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
89
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
@@ -92,56 +104,66 @@ def hair_reference_points(hair_bgra, filename, meta):
92
  def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
93
  H, W = base_bgr.shape[:2]
94
  hair_rgb = hair_bgra[:, :, :3]
95
- hair_a = hair_bgra[:, :, 3] / 255.0
96
- # CONSTANT borders avoid clipping/edge artifacts
97
- hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR,
98
- borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
99
- a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR,
100
- borderMode=cv2.BORDER_CONSTANT, borderValue=0)
101
  a = np.clip(a_warp * opacity, 0, 1)[..., None]
102
  out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
103
  return out
104
 
105
- def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta):
106
- """Simple engine: NO head-mask (prevents neck lines + lost hair)."""
107
- try:
108
- if image is None:
109
- return None, "Upload a photo first."
110
- if not hairstyle:
111
- return np.array(image), "Pick a hairstyle."
112
-
113
- img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
114
- kpts = detect_face_keypoints(img_bgr)
115
- if kpts is None:
116
- return image, "No face detected. Use a brighter, front-facing photo."
117
-
118
- hair = load_hair_png(hairstyle)
119
- hair_pts = hair_reference_points(hair, hairstyle, meta)
120
-
121
- # destination landmarks + nudges
122
- dst = kpts.copy()
123
- dst[:, 0] += dx
124
- dst[:, 1] += dy
125
-
126
- # scale (no rotation β†’ simpler UI)
127
- center = hair_pts.mean(axis=0)
128
- s = max(0.5, scale_pct / 100.0)
129
- hair_pts_adj = (hair_pts - center) * s + center
130
-
131
- M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
132
- if M is None:
133
- return image, "Alignment failed for this image/style."
134
-
135
- out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
136
- return cv2.cvtColor(out, cv2.COLOR_BGR2RGB), "OK"
137
- except Exception as e:
138
- print("apply_tryon error:", e)
139
- traceback.print_exc()
140
- return np.array(image) if image is not None else None, "Error applying style."
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def save_png_to_tmp(img, filename="output_tryon.png"):
 
143
  if img is None:
144
- raise gr.Error("No image to save. Click a hairstyle or 'Apply' first.")
145
  out_path = os.path.join(tempfile.gettempdir(), filename)
146
  if isinstance(img, np.ndarray):
147
  Image.fromarray(img).save(out_path)
@@ -149,13 +171,13 @@ def save_png_to_tmp(img, filename="output_tryon.png"):
149
  img.save(out_path)
150
  return out_path
151
 
152
- # White thumbnails with captions (number + filename)
153
  def thumb_on_white(hair_bgra, max_h=220):
154
  h, w = hair_bgra.shape[:2]
155
  scale = min(1.0, max_h / h)
156
  hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
157
  h, w = hair_bgra.shape[:2]
158
- bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8)
159
  a = (hair_bgra[:, :, 3:4] / 255.0)
160
  comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
161
  return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
@@ -165,127 +187,140 @@ def build_gallery_items(files):
165
  for idx, fname in enumerate(files, start=1):
166
  try:
167
  img = load_hair_png(fname)
168
- items.append((thumb_on_white(img), f"{idx}. {fname}"))
169
  except Exception:
170
  continue
171
  return items
172
 
173
- def event_to_filename(evt, files):
174
- """Robust across Gradio versions."""
175
- idx = getattr(evt, "index", None)
176
- if isinstance(idx, int) and 0 <= idx < len(files):
177
- return files[idx]
178
- val = getattr(evt, "value", None)
179
- cap = None
180
- if isinstance(val, str):
181
- cap = val
182
- elif isinstance(val, (tuple, list)) and len(val) >= 2 and isinstance(val[1], str):
183
- cap = val[1]
184
- elif isinstance(val, dict) and "caption" in val and isinstance(val["caption"], str):
185
- cap = val["caption"]
186
- if cap:
187
- m = re.search(r"\b(\d+)\.\s*(.+)$", cap.strip())
188
- if m:
189
- i = int(m.group(1)) - 1
190
- if 0 <= i < len(files):
191
- return files[i]
192
- for f in files:
193
- if f.lower() == cap.lower():
194
- return f
195
- return files[0] if files else None
196
-
197
- # -------- UI --------
198
  def build_ui():
199
  META = load_meta()
200
  HAIR_FILES = load_hairstyles()
201
 
202
- with gr.Blocks(title="Salon Hairstyle Virtual Try-On β€” Simple & Stable") as demo:
203
- gr.Markdown("Upload a photo, then **click a hairstyle**. Adjust size/position if needed. Use **Save result** to export.")
 
 
 
204
 
205
- selected_file = gr.State(None)
206
- meta_state = gr.State(META)
207
- files_state = gr.State(HAIR_FILES)
208
 
209
  with gr.Tabs():
 
210
  with gr.Tab("πŸ“· Photo (Upload)"):
211
  with gr.Row():
212
  in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
213
  out_img = gr.Image(label="Preview", height=360)
214
-
215
  with gr.Row():
216
- apply_btn = gr.Button("✨ Apply (optional)")
217
- save_btn = gr.Button("πŸ’Ύ Save result")
 
 
 
 
 
 
 
218
  save_file = gr.File(label="Saved file", visible=False)
 
219
 
220
  with gr.Row():
221
- refresh = gr.Button("πŸ”„ Refresh styles")
222
-
223
  count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.")
224
  gallery = gr.Gallery(
225
- label="Hairstyles (click to apply)",
226
  value=build_gallery_items(HAIR_FILES),
227
- columns=6, rows=3, height=520,
228
  allow_preview=False, object_fit="contain", show_label=True
229
  )
230
 
231
- with gr.Accordion("Fine-tune (simple)", open=True):
232
  with gr.Row():
233
- scale = gr.Slider(50, 200, 100, 1, label="Scale (temple distance %)")
234
- opacity = gr.Slider(0.4, 1.0, 1.0, 0.05, label="Hair opacity")
235
  with gr.Row():
236
- dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right (px)")
237
- dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down (px)")
238
-
239
- status = gr.Markdown("")
240
-
241
- def do_apply(im, hairfile, s, dxv, dyv, op, meta):
242
- return apply_tryon(im, hairfile, s, dxv, dyv, op, meta)
243
-
244
- # Click a tile β†’ apply (robust to Gradio variants)
245
- def on_gallery_select(evt, files, im, s, dxv, dyv, op, meta):
246
- try:
247
- hairfile = event_to_filename(evt, files)
248
- if hairfile is None:
249
- return None, im, "No styles found."
250
- out, msg = do_apply(im, hairfile, s, dxv, dyv, op, meta)
251
- return hairfile, out, msg
252
- except Exception as e:
253
- print("gallery.select error:", e)
254
- traceback.print_exc()
255
- return None, im, "Error applying style."
256
-
257
- gallery.select(
258
- on_gallery_select,
259
- inputs=[files_state, in_img, scale, dx, dy, opacity, meta_state],
260
- outputs=[selected_file, out_img, status]
261
- )
262
 
263
- # Apply after slider tweaks
264
  apply_btn.click(
265
  fn=do_apply,
266
- inputs=[in_img, selected_file, scale, dx, dy, opacity, meta_state],
267
  outputs=[out_img, status]
268
  )
269
 
270
- # Save to a real file (shows a link you can click to download)
271
  def do_save(im):
272
  path = save_png_to_tmp(im, "output_tryon.png")
273
  return gr.File.update(value=path, visible=True)
274
 
275
  save_btn.click(fn=do_save, inputs=[out_img], outputs=[save_file])
276
 
277
- # Refresh styles list
278
  def do_refresh():
279
  files = load_hairstyles()
280
  items = build_gallery_items(files)
281
  msg = f"Found {len(files)} hairstyles."
282
- return items, files, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, files_state, count_md])
285
 
286
  return demo
287
 
288
- # Spaces autostart
289
  app = build_ui()
290
  demo = app
291
 
 
1
+ import os, json, tempfile, re
2
  import cv2, numpy as np, gradio as gr
3
  from PIL import Image
4
 
5
+ # -------------------- Paths --------------------
6
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
  CANDIDATES = [
8
+ os.path.join(BASE_DIR, "hair"), # your folder
9
  os.path.join(BASE_DIR, "assets", "hairstyles"),
10
  os.path.join(BASE_DIR, "assets", "Hairstyles"),
11
  os.path.join(BASE_DIR, "hairstyles"),
 
17
 
18
  META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
19
 
20
+ # -------------------- Deps --------------------
21
  try:
22
  import mediapipe as mp
23
  except Exception as e:
24
  raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}")
25
 
26
  mp_face_mesh = mp.solutions.face_mesh
27
+ mp_selfie_seg = mp.solutions.selfie_segmentation # optional (off by default)
28
  LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
29
 
30
+ # -------------------- Helpers --------------------
31
  def natural_key(s: str):
32
+ # sorts photo1, photo2, ... photo10 in numeric order
33
  return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]
34
 
35
  def load_hairstyles():
 
51
  return {}
52
 
53
  def premultiply_alpha(bgra):
54
+ """Reduce gray/white halos on edges for nicer blending."""
55
  bgr = bgra[:, :, :3].astype(np.float32) / 255.0
56
  a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
57
  bgr_pm = (bgr * a * 255.0).astype(np.uint8)
 
60
  def load_hair_png(name):
61
  path = os.path.join(HAIR_DIR, name)
62
  hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
63
+ if hair is None or hair.shape[2] != 4:
64
+ raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
65
  return premultiply_alpha(hair)
66
 
67
  def detect_face_keypoints(img_bgr):
 
77
  def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
78
  return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
79
 
80
+ def person_mask(img_bgr, expand_px=20):
81
+ """Optional head mask (OFF by default). We expand+blur to avoid 'neck lines'."""
82
+ with mp_selfie_seg.SelfieSegmentation(model_selection=1) as seg:
83
+ rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
84
+ m = seg.process(rgb).segmentation_mask
85
+ mask = (m > 0.5).astype(np.uint8)
86
+ if expand_px > 0:
87
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*expand_px+1, 2*expand_px+1))
88
+ mask = cv2.dilate(mask, k, iterations=1)
89
+ mask = cv2.GaussianBlur(mask.astype(np.float32), (41, 41), 0)
90
+ return mask
91
+
92
  def hair_reference_points(hair_bgra, filename, meta):
 
93
  h, w = hair_bgra.shape[:2]
94
  if filename in meta:
95
  pts = np.array(meta[filename], dtype=np.float32)
96
  if pts.shape == (3, 2):
97
  return pts
98
+ # Defaults (ok for many styles). For perfect fit, add 3 points per file to meta.json.
99
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
100
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
101
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
 
104
  def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
105
  H, W = base_bgr.shape[:2]
106
  hair_rgb = hair_bgra[:, :, :3]
107
+ hair_a = hair_bgra[:, :, 3] / 255.0
108
+ hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
109
+ a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
 
 
 
110
  a = np.clip(a_warp * opacity, 0, 1)[..., None]
111
  out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
112
  return out
113
 
114
+ def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity, meta,
115
+ limit_head=False, expand_pct=3.0):
116
+ """
117
+ limit_head=False by default to avoid 'missing hair' and neck lines.
118
+ If True, we use an expanded soft head mask.
119
+ """
120
+ if image is None:
121
+ return None, "Upload a photo or enable webcam."
122
+ if not hairstyle:
123
+ return np.array(image), "Pick a hairstyle first."
124
+
125
+ img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
126
+
127
+ kpts = detect_face_keypoints(img_bgr)
128
+ if kpts is None:
129
+ return image, "No face detected. Try a brighter, front-facing photo."
130
+
131
+ hair = load_hair_png(hairstyle)
132
+ hair_pts = hair_reference_points(hair, hairstyle, meta)
133
+
134
+ # Destination points (with user nudges)
135
+ dst = kpts.copy()
136
+ dst[:, 0] += dx
137
+ dst[:, 1] += dy
138
+
139
+ # Scale + rotate around hair anchor centroid
140
+ center = hair_pts.mean(axis=0)
141
+ theta = np.deg2rad(rot_deg)
142
+ s = max(0.5, scale_pct / 100.0)
143
+ R = np.array([[np.cos(theta), -np.sin(theta)],
144
+ [np.sin(theta), np.cos(theta)]], dtype=np.float32)
145
+ hair_pts_adj = (hair_pts - center) @ R.T * s + center
146
+
147
+ M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
148
+ if M is None:
149
+ return image, "Could not compute alignment for this image/style."
150
+
151
+ out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
152
+
153
+ if limit_head:
154
+ H, W = img_bgr.shape[:2]
155
+ expand_px = max(8, int(min(H, W) * (expand_pct / 100.0))) # soft expansion
156
+ head = person_mask(img_bgr, expand_px=expand_px) # soft & expanded
157
+ head3 = head[..., None]
158
+ out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8)
159
+
160
+ out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
161
+ return out_rgb, "OK"
162
 
163
  def save_png_to_tmp(img, filename="output_tryon.png"):
164
+ """Create a file in /tmp and return the path (used by the Save button)."""
165
  if img is None:
166
+ raise gr.Error("No image to save. Click Apply first.")
167
  out_path = os.path.join(tempfile.gettempdir(), filename)
168
  if isinstance(img, np.ndarray):
169
  Image.fromarray(img).save(out_path)
 
171
  img.save(out_path)
172
  return out_path
173
 
174
+ # ---------- WHITE background thumbnails (shows filename number) ----------
175
  def thumb_on_white(hair_bgra, max_h=220):
176
  h, w = hair_bgra.shape[:2]
177
  scale = min(1.0, max_h / h)
178
  hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
179
  h, w = hair_bgra.shape[:2]
180
+ bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8) # white background
181
  a = (hair_bgra[:, :, 3:4] / 255.0)
182
  comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
183
  return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
 
187
  for idx, fname in enumerate(files, start=1):
188
  try:
189
  img = load_hair_png(fname)
190
+ items.append((thumb_on_white(img), f"{idx}. {fname}")) # caption shows number & filename
191
  except Exception:
192
  continue
193
  return items
194
 
195
+ # -------------------- UI --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def build_ui():
197
  META = load_meta()
198
  HAIR_FILES = load_hairstyles()
199
 
200
+ with gr.Blocks(title="Salon Hairstyle Virtual Try-On", css="""
201
+ .gradio-container {max-width: 1200px; margin:auto;}
202
+ @media (max-width: 768px){ .gradio-container {padding: 8px;} }
203
+ """) as demo:
204
+ gr.Markdown("Upload a photo or use webcam. Put transparent **PNGs** in **`hair/`**, then click **Refresh**.")
205
 
206
+ files_state = gr.State(HAIR_FILES) # filenames (natural order)
207
+ meta_state = gr.State(META)
 
208
 
209
  with gr.Tabs():
210
+ # -------- Photo Tab --------
211
  with gr.Tab("πŸ“· Photo (Upload)"):
212
  with gr.Row():
213
  in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
214
  out_img = gr.Image(label="Preview", height=360)
 
215
  with gr.Row():
216
+ hair_sel = gr.Dropdown(
217
+ choices=HAIR_FILES,
218
+ value=(HAIR_FILES[0] if HAIR_FILES else None),
219
+ label="Selected hairstyle",
220
+ interactive=True
221
+ )
222
+ apply_btn = gr.Button("✨ Apply (Align & Overlay)")
223
+ # SAVE (replaces Download)
224
+ save_btn = gr.Button("πŸ’Ύ Save result")
225
  save_file = gr.File(label="Saved file", visible=False)
226
+ status = gr.Markdown()
227
 
228
  with gr.Row():
229
+ refresh = gr.Button("πŸ”„ Refresh")
 
230
  count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.")
231
  gallery = gr.Gallery(
232
+ label="Hairstyles (click to choose)",
233
  value=build_gallery_items(HAIR_FILES),
234
+ columns=6, rows=3, height=520, # up to 18 tiles visible; all 11 will show
235
  allow_preview=False, object_fit="contain", show_label=True
236
  )
237
 
238
+ with gr.Accordion("Fine-tune placement", open=True):
239
  with gr.Row():
240
+ scale = gr.Slider(50, 200, 100, 1, label="Scale (β‰ˆ temple distance %)")
241
+ rot = gr.Slider(-30, 30, 0, 1, label="Extra rotation (Β°)")
242
  with gr.Row():
243
+ dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right shift (px)")
244
+ dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down shift (px)")
245
+ opacity = gr.Slider(0.2, 1.0, 1.0, 0.05, label="Hair opacity")
246
+ limit_head = gr.Checkbox(label="Limit overlay to head (avoid spill)", value=False)
247
+ expand = gr.Slider(0.0, 10.0, 3.0, 0.5, label="Head-mask expansion (%) β€” only if enabled")
248
+
249
+ # --- Callbacks ---
250
+ def do_apply(im, hfile, s, r, dxv, dyv, op, meta, lh, ex):
251
+ return apply_tryon(im, hfile, s, r, dxv, dyv, op, meta, limit_head=lh, expand_pct=ex)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
 
253
  apply_btn.click(
254
  fn=do_apply,
255
+ inputs=[in_img, hair_sel, scale, rot, dx, dy, opacity, meta_state, limit_head, expand],
256
  outputs=[out_img, status]
257
  )
258
 
 
259
  def do_save(im):
260
  path = save_png_to_tmp(im, "output_tryon.png")
261
  return gr.File.update(value=path, visible=True)
262
 
263
  save_btn.click(fn=do_save, inputs=[out_img], outputs=[save_file])
264
 
 
265
  def do_refresh():
266
  files = load_hairstyles()
267
  items = build_gallery_items(files)
268
  msg = f"Found {len(files)} hairstyles."
269
+ return items, gr.update(choices=files, value=(files[0] if files else None)), files, msg
270
+
271
+ refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, hair_sel, files_state, count_md])
272
+
273
+ # Gallery click -> set dropdown to that filename
274
+ def on_gallery_select(evt, files):
275
+ idx = getattr(evt, "index", None)
276
+ if idx is None or not files:
277
+ return gr.update()
278
+ # our captions start at 1., map index to filename directly
279
+ idx = max(0, min(idx, len(files)-1))
280
+ return gr.update(value=files[idx])
281
+
282
+ gallery.select(on_gallery_select, inputs=[files_state], outputs=[hair_sel])
283
+
284
+ # -------- Webcam Tab (unchanged except 'Save Snapshot') --------
285
+ with gr.Tab("πŸ“Ή Webcam (Live Beta)"):
286
+ cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Enable camera")
287
+ hair2 = gr.Dropdown(choices=HAIR_FILES, value=(HAIR_FILES[0] if HAIR_FILES else None), label="Selected hairstyle")
288
+ with gr.Row():
289
+ scale2 = gr.Slider(50, 200, 100, 1, label="Scale %")
290
+ rot2 = gr.Slider(-25, 25, 0, 1, label="Rotate (Β°)")
291
+ with gr.Row():
292
+ dx2 = gr.Slider(-150, 150, 0, 1, label="Left ↔ Right (px)")
293
+ dy2 = gr.Slider(-150, 150, 0, 1, label="Up ↕ Down (px)")
294
+ opacity2 = gr.Slider(0.2, 1.0, 0.95, 0.05, label="Hair opacity")
295
+ limit_head2 = gr.Checkbox(label="Limit overlay to head", value=False)
296
+ expand2 = gr.Slider(0.0, 10.0, 3.0, 0.5, label="Head-mask expansion (%)", visible=True)
297
+ out2 = gr.Image(label="Live result", height=360)
298
+ state_live = gr.State(None)
299
+ snap = gr.Button("πŸ“Έ Snapshot")
300
+ save_live_btn = gr.Button("πŸ’Ύ Save snapshot")
301
+ save_live_file = gr.File(label="snapshot", visible=False)
302
+
303
+ def live(im, h, s, r, dxv, dyv, op, meta, lh, ex):
304
+ res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta, limit_head=lh, expand_pct=ex)
305
+ return res, res
306
+
307
+ cam.stream(
308
+ fn=live,
309
+ inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2, meta_state, limit_head2, expand2],
310
+ outputs=[out2, state_live]
311
+ )
312
+
313
+ snap.click(lambda x: x, inputs=[state_live], outputs=[out2])
314
+
315
+ def save_snap(im):
316
+ path = save_png_to_tmp(im, "tryon_webcam.png")
317
+ return gr.File.update(value=path, visible=True)
318
 
319
+ save_live_btn.click(fn=save_snap, inputs=[state_live], outputs=[save_live_file])
320
 
321
  return demo
322
 
323
+ # Export for Spaces autostart
324
  app = build_ui()
325
  demo = app
326