RojaKatta commited on
Commit
d86fafa
·
verified ·
1 Parent(s): 80bcb01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -21
app.py CHANGED
@@ -2,7 +2,7 @@ import os, json, tempfile, re
2
  import cv2, numpy as np, gradio as gr
3
  from PIL import Image
4
 
5
- # =============== Paths ===============
6
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
  CANDIDATES = [
8
  os.path.join(BASE_DIR, "hair"),
@@ -17,7 +17,7 @@ if HAIR_DIR is None:
17
 
18
  META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
19
 
20
- # =============== Dependencies ===============
21
  try:
22
  import mediapipe as mp
23
  except Exception as e:
@@ -26,8 +26,9 @@ except Exception as e:
26
  mp_face_mesh = mp.solutions.face_mesh
27
  LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
28
 
29
- # =============== Helpers ===============
30
  def natural_key(s: str):
 
31
  return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]
32
 
33
  def load_hairstyles():
@@ -76,12 +77,13 @@ def detect_face_keypoints(img_bgr):
76
  return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
77
 
78
  def hair_reference_points(hair_bgra, filename, meta):
 
79
  h, w = hair_bgra.shape[:2]
80
  if filename in meta:
81
  pts = np.array(meta[filename], dtype=np.float32)
82
  if pts.shape == (3, 2):
83
  return pts
84
- # Defaults (OK for many styles). For pixel-perfect fit, add 3 points to meta.json.
85
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
86
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
87
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
@@ -91,7 +93,7 @@ def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
91
  H, W = base_bgr.shape[:2]
92
  hair_rgb = hair_bgra[:, :, :3]
93
  hair_a = hair_bgra[:, :, 3] / 255.0
94
- # borderMode CONSTANT avoids odd edge artifacts; value black (transparent)
95
  hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR,
96
  borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
97
  a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR,
@@ -101,7 +103,7 @@ def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
101
  return out
102
 
103
  def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta):
104
- """No head-mask (prevents neck lines & cropping)."""
105
  if image is None:
106
  return None, "Upload a photo first."
107
  if not hairstyle:
@@ -120,7 +122,7 @@ def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta):
120
  dst[:, 0] += dx
121
  dst[:, 1] += dy
122
 
123
- # Scale hair anchors around their centroid (no rotation for simplicity)
124
  center = hair_pts.mean(axis=0)
125
  s = max(0.5, scale_pct / 100.0)
126
  hair_pts_adj = (hair_pts - center) * s + center
@@ -143,7 +145,7 @@ def save_png_to_tmp(img, filename="output_tryon.png"):
143
  img.save(out_path)
144
  return out_path
145
 
146
- # ---- white thumbnails with labels ----
147
  def thumb_on_white(hair_bgra, max_h=220):
148
  h, w = hair_bgra.shape[:2]
149
  scale = min(1.0, max_h / h)
@@ -159,20 +161,48 @@ def build_gallery_items(files):
159
  for idx, fname in enumerate(files, start=1):
160
  try:
161
  img = load_hair_png(fname)
162
- items.append((thumb_on_white(img), f"{idx}. {fname}")) # show number + filename
163
  except Exception:
164
  continue
165
  return items
166
 
167
- # =============== UI ===============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def build_ui():
169
  META = load_meta()
170
  HAIR_FILES = load_hairstyles()
171
 
172
  with gr.Blocks(title="Salon Hairstyle Virtual Try-On (Simple)") as demo:
173
- gr.Markdown("Upload a photo, then **click a hairstyle** below. Use a few sliders if needed, then **Save result**.")
174
 
175
- selected_file = gr.State(None) # currently selected hairstyle filename
176
  meta_state = gr.State(META)
177
  files_state = gr.State(HAIR_FILES)
178
 
@@ -200,7 +230,7 @@ def build_ui():
200
 
201
  with gr.Accordion("Fine-tune (simple)", open=True):
202
  with gr.Row():
203
- scale = gr.Slider(50, 200, 100, 1, label="Scale (temple distance %)") # main size
204
  opacity = gr.Slider(0.4, 1.0, 1.0, 0.05, label="Hair opacity")
205
  with gr.Row():
206
  dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right (px)")
@@ -212,13 +242,11 @@ def build_ui():
212
  def do_apply(im, hairfile, s, dxv, dyv, op, meta):
213
  return apply_tryon(im, hairfile, s, dxv, dyv, op, meta)
214
 
215
- # 1) click a tile -> set selected file AND auto-apply
216
  def on_gallery_select(evt, files, im, s, dxv, dyv, op, meta):
217
- idx = getattr(evt, "index", None)
218
- if idx is None or not files:
219
- return None, gr.update(), None
220
- idx = max(0, min(idx, len(files)-1))
221
- hairfile = files[idx]
222
  out, msg = do_apply(im, hairfile, s, dxv, dyv, op, meta)
223
  return hairfile, out, msg
224
 
@@ -228,7 +256,7 @@ def build_ui():
228
  outputs=[selected_file, out_img, status]
229
  )
230
 
231
- # 2) Apply button (useful after slider tweaks)
232
  apply_btn.click(
233
  fn=do_apply,
234
  inputs=[in_img, selected_file, scale, dx, dy, opacity, meta_state],
@@ -247,7 +275,6 @@ def build_ui():
247
  files = load_hairstyles()
248
  items = build_gallery_items(files)
249
  msg = f"Found {len(files)} hairstyles."
250
- # Keep selection if name still exists
251
  return items, files, msg
252
 
253
  refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, files_state, count_md])
 
2
  import cv2, numpy as np, gradio as gr
3
  from PIL import Image
4
 
5
+ # ================== Paths ==================
6
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
  CANDIDATES = [
8
  os.path.join(BASE_DIR, "hair"),
 
17
 
18
  META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
19
 
20
+ # ================== Dependencies ==================
21
  try:
22
  import mediapipe as mp
23
  except Exception as e:
 
26
  mp_face_mesh = mp.solutions.face_mesh
27
  LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
28
 
29
+ # ================== Helpers ==================
30
  def natural_key(s: str):
31
+ """Sort 'photo1.png'...'photo11.png' in numeric order."""
32
  return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]
33
 
34
  def load_hairstyles():
 
77
  return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
78
 
79
  def hair_reference_points(hair_bgra, filename, meta):
80
+ """Return 3 reference points (L-eye, R-eye, mid-forehead) in hair PNG space."""
81
  h, w = hair_bgra.shape[:2]
82
  if filename in meta:
83
  pts = np.array(meta[filename], dtype=np.float32)
84
  if pts.shape == (3, 2):
85
  return pts
86
+ # Generic default; add per-style anchors in hair/meta.json for perfection.
87
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
88
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
89
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
 
93
  H, W = base_bgr.shape[:2]
94
  hair_rgb = hair_bgra[:, :, :3]
95
  hair_a = hair_bgra[:, :, 3] / 255.0
96
+ # CONSTANT border avoids edge artifacts; value=(0,0,0) is transparent with alpha=0
97
  hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR,
98
  borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
99
  a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR,
 
103
  return out
104
 
105
  def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta):
106
+ """Simplified engine: no head mask (prevents neck lines/cropping)."""
107
  if image is None:
108
  return None, "Upload a photo first."
109
  if not hairstyle:
 
122
  dst[:, 0] += dx
123
  dst[:, 1] += dy
124
 
125
+ # Scale hair anchors around their centroid (no rotation by default)
126
  center = hair_pts.mean(axis=0)
127
  s = max(0.5, scale_pct / 100.0)
128
  hair_pts_adj = (hair_pts - center) * s + center
 
145
  img.save(out_path)
146
  return out_path
147
 
148
+ # ---------- white thumbnails with captions ----------
149
  def thumb_on_white(hair_bgra, max_h=220):
150
  h, w = hair_bgra.shape[:2]
151
  scale = min(1.0, max_h / h)
 
161
  for idx, fname in enumerate(files, start=1):
162
  try:
163
  img = load_hair_png(fname)
164
+ items.append((thumb_on_white(img), f"{idx}. {fname}"))
165
  except Exception:
166
  continue
167
  return items
168
 
169
+ # Robustly extract a filename from a Gallery select event
170
+ def _event_to_filename(evt, files):
171
+ """
172
+ Works across Gradio versions:
173
+ - Prefer evt.index (int)
174
+ - Else parse evt.value (like '3. photo3.png')
175
+ - Else fall back to first file
176
+ """
177
+ # 1) index path
178
+ idx = getattr(evt, "index", None)
179
+ if isinstance(idx, int) and 0 <= idx < len(files):
180
+ return files[idx]
181
+ # 2) value path
182
+ val = getattr(evt, "value", None)
183
+ if isinstance(val, (str, tuple, list)):
184
+ cap = val if isinstance(val, str) else (val[1] if len(val) > 1 else "")
185
+ if isinstance(cap, str):
186
+ m = re.search(r"\b(\d+)\.\s*(.+)$", cap.strip())
187
+ if m:
188
+ i = int(m.group(1)) - 1
189
+ if 0 <= i < len(files):
190
+ return files[i]
191
+ # try raw filename
192
+ if cap in files:
193
+ return cap
194
+ # 3) fallback
195
+ return files[0] if files else None
196
+
197
+ # ================== UI ==================
198
  def build_ui():
199
  META = load_meta()
200
  HAIR_FILES = load_hairstyles()
201
 
202
  with gr.Blocks(title="Salon Hairstyle Virtual Try-On (Simple)") as demo:
203
+ gr.Markdown("Upload a photo, then **click a hairstyle**. Adjust a little if needed, then **Save result**.")
204
 
205
+ selected_file = gr.State(None) # current hairstyle filename
206
  meta_state = gr.State(META)
207
  files_state = gr.State(HAIR_FILES)
208
 
 
230
 
231
  with gr.Accordion("Fine-tune (simple)", open=True):
232
  with gr.Row():
233
+ scale = gr.Slider(50, 200, 100, 1, label="Scale (temple distance %)")
234
  opacity = gr.Slider(0.4, 1.0, 1.0, 0.05, label="Hair opacity")
235
  with gr.Row():
236
  dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right (px)")
 
242
  def do_apply(im, hairfile, s, dxv, dyv, op, meta):
243
  return apply_tryon(im, hairfile, s, dxv, dyv, op, meta)
244
 
245
+ # 1) gallery click robust filename auto-apply
246
  def on_gallery_select(evt, files, im, s, dxv, dyv, op, meta):
247
+ hairfile = _event_to_filename(evt, files)
248
+ if hairfile is None:
249
+ return None, im, "No styles found."
 
 
250
  out, msg = do_apply(im, hairfile, s, dxv, dyv, op, meta)
251
  return hairfile, out, msg
252
 
 
256
  outputs=[selected_file, out_img, status]
257
  )
258
 
259
+ # 2) Apply (use after slider tweaks)
260
  apply_btn.click(
261
  fn=do_apply,
262
  inputs=[in_img, selected_file, scale, dx, dy, opacity, meta_state],
 
275
  files = load_hairstyles()
276
  items = build_gallery_items(files)
277
  msg = f"Found {len(files)} hairstyles."
 
278
  return items, files, msg
279
 
280
  refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, files_state, count_md])