RojaKatta commited on
Commit
73f1086
·
verified ·
1 Parent(s): f142308

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -137
app.py CHANGED
@@ -1,187 +1,196 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import gradio as gr
5
- import mediapipe as mp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- HAIR_DIR = "assets/hairstyles"
8
-
9
- # ---------- MediaPipe FaceMesh setup ----------
10
  mp_face_mesh = mp.solutions.face_mesh
11
- FACE_LANDMARKS = {
12
- "left_eye_outer": 33, # standard mesh indices
13
- "right_eye_outer": 263,
14
- "mid_forehead": 10 # glabella/forehead area (approx)
15
- }
16
 
 
17
  def load_hairstyles():
18
- files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
 
 
 
 
19
  files.sort()
20
  return files
21
 
22
  HAIR_FILES = load_hairstyles()
23
 
24
- def detect_keypoints(image_bgr):
25
- """Returns 3 keypoints (x,y) in image coords (float)."""
26
- h, w = image_bgr.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
27
  with mp_face_mesh.FaceMesh(
28
- static_image_mode=True,
29
- max_num_faces=1,
30
- refine_landmarks=True,
31
  min_detection_confidence=0.6
32
- ) as face_mesh:
33
- img_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
34
- res = face_mesh.process(img_rgb)
35
  if not res.multi_face_landmarks:
36
  return None
37
  lm = res.multi_face_landmarks[0].landmark
38
- def to_xy(idx):
39
- return np.array([lm[idx].x * w, lm[idx].y * h], dtype=np.float32)
40
- p_left = to_xy(FACE_LANDMARKS["left_eye_outer"])
41
- p_right = to_xy(FACE_LANDMARKS["right_eye_outer"])
42
- p_mid = to_xy(FACE_LANDMARKS["mid_forehead"])
43
- return np.stack([p_left, p_right, p_mid], axis=0)
 
 
 
 
 
44
 
45
  def load_hair_png(name):
46
  path = os.path.join(HAIR_DIR, name)
47
  hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
48
  if hair is None or hair.shape[2] != 4:
49
- raise ValueError(f"Invalid hair asset: {name}")
50
  return hair
51
 
52
- def hair_reference_points(hair_bgra):
53
- """
54
- Define 3 reference anchor points on the hair image (in hair coords).
55
- Tune once per style so that these points roughly correspond to
56
- left eye outer, right eye outer, and mid-forehead anchors.
57
- For a simple default, place anchors across the lower edge of hair.
58
- """
59
  h, w = hair_bgra.shape[:2]
60
- # Default guesses; adjust per asset for better alignment.
 
 
 
 
61
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
62
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
63
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
64
  return np.stack([pL, pR, pM], axis=0)
65
 
66
- def warp_and_blend(base_bgr, hair_bgra, M, opacity=1.0):
67
- h_img, w_img = base_bgr.shape[:2]
68
- # Split hair into BGR + A
69
- hair_bgr = hair_bgra[:, :, :3]
70
- hair_a = hair_bgra[:, :, 3] / 255.0
71
-
72
- # Warp both color and alpha with same affine
73
- hair_warp = cv2.warpAffine(hair_bgr, M, (w_img, h_img), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
74
- alpha_warp = cv2.warpAffine(hair_a, M, (w_img, h_img), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
75
-
76
- alpha = np.clip(alpha_warp * opacity, 0, 1)[..., None] # HxWx1
77
- out = (alpha * hair_warp + (1 - alpha) * base_bgr).astype(np.uint8)
78
  return out
79
 
80
- def tryon(
81
- image,
82
- hairstyle,
83
- scale_pct=100,
84
- rot_deg=0,
85
- dx=0, dy=0,
86
- opacity=1.0,
87
- mode="Photo"
88
- ):
89
  if image is None:
90
- return None, "Please provide an image or turn on webcam."
 
 
91
 
92
- # Convert PIL->BGR
93
  img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
94
- keypts = detect_keypoints(img_bgr)
95
- if keypts is None:
96
- return image, "No face detected. Try a clearer, front-facing photo."
 
97
 
98
  hair = load_hair_png(hairstyle)
99
- hair_pts = hair_reference_points(hair)
100
 
101
- # Compute affine (hair -> face) using three points
102
- dst_pts = keypts.copy()
103
- # Apply manual user tweaks on destination:
104
- dst_pts[:, 0] += dx
105
- dst_pts[:, 1] += dy
106
 
107
- # Scale and rotate around hair center before solving affine
108
  center = hair_pts.mean(axis=0)
109
  theta = np.deg2rad(rot_deg)
110
  s = max(0.5, scale_pct / 100.0)
111
-
112
  R = np.array([[np.cos(theta), -np.sin(theta)],
113
  [np.sin(theta), np.cos(theta)]], dtype=np.float32)
114
  hair_pts_adj = (hair_pts - center) @ R.T * s + center
115
 
116
- # Solve affine transform
117
- M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst_pts, method=cv2.LMEDS)
118
-
119
  if M is None:
120
- return image, "Could not compute alignment. Try another image/hairstyle."
 
 
 
 
 
 
 
121
 
122
- out_bgr = warp_and_blend(img_bgr, hair, M, opacity=opacity)
123
- out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
124
  return out_rgb, "OK"
125
 
126
- def build_ui():
127
- with gr.Blocks(title="Salon Hairstyle Virtual Try-On") as demo:
128
- gr.Markdown("# Salon Hairstyle Virtual Try-On\nTry different hairstyles in real-time or on a photo.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  with gr.Tabs():
 
130
  with gr.Tab("Photo"):
131
  with gr.Row():
132
- in_img = gr.Image(type="pil", label="Upload photo", sources=["upload"], height=420)
133
- gallery = gr.Radio(choices=HAIR_FILES, value=HAIR_FILES[0] if HAIR_FILES else None, label="Hairstyles")
134
- with gr.Accordion("Alignment & Style Controls", open=True):
135
- with gr.Row():
136
- scale = gr.Slider(50, 200, value=100, step=1, label="Scale %")
137
- rot = gr.Slider(-30, 30, value=0, step=1, label="Rotate (deg)")
138
- with gr.Row():
139
- dx = gr.Slider(-200, 200, value=0, step=1, label="Horizontal Nudge (px)")
140
- dy = gr.Slider(-200, 200, value=0, step=1, label="Vertical Nudge (px)")
141
- opacity = gr.Slider(0.2, 1.0, value=1.0, step=0.05, label="Hair Opacity")
142
-
143
- out_img = gr.Image(label="Result", height=420)
144
- status = gr.Markdown()
145
-
146
- run_btn = gr.Button("Apply")
147
- save_btn = gr.Button("Save Result")
148
-
149
- def on_save(img):
150
- # Gradio will let users right-click or use built-in download;
151
- # Optionally, return img so it can be saved from gallery.
152
- return img
153
-
154
- run_btn.click(
155
- fn=lambda im, h, s, r, dxv, dyv, op: tryon(im, h, s, r, dxv, dyv, op, mode="Photo"),
156
- inputs=[in_img, gallery, scale, rot, dx, dy, opacity],
157
- outputs=[out_img, status]
158
- )
159
- save_btn.click(fn=on_save, inputs=[out_img], outputs=[out_img])
160
-
161
- with gr.Tab("Webcam (Live)"):
162
- cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Webcam")
163
- hair2 = gr.Radio(choices=HAIR_FILES, value=HAIR_FILES[0] if HAIR_FILES else None, label="Hairstyles")
164
- scale2 = gr.Slider(50, 200, value=100, step=1, label="Scale %")
165
- rot2 = gr.Slider(-25, 25, value=0, step=1, label="Rotate (deg)")
166
- dx2 = gr.Slider(-150, 150, value=0, step=1, label="Horizontal Nudge (px)")
167
- dy2 = gr.Slider(-150, 150, value=0, step=1, label="Vertical Nudge (px)")
168
- opacity2 = gr.Slider(0.2, 1.0, value=0.95, step=0.05, label="Hair Opacity")
169
-
170
- out2 = gr.Image(label="Live Result")
171
-
172
- def live_fn(im, h, s, r, dxv, dyv, op):
173
- res, _ = tryon(im, h, s, r, dxv, dyv, op, mode="Webcam")
174
- return res
175
-
176
- cam.stream(
177
- fn=live_fn,
178
- inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2],
179
- outputs=[out2],
180
- time_limit=0.0 # continuous
181
- )
182
- return demo
183
-
184
- demo = build_ui()
185
-
186
- if __name__ == "__main__":
187
- demo.launch()
 
1
+ import os, json, tempfile
2
+ import cv2, numpy as np, gradio as gr
3
+ from PIL import Image
4
+
5
+ # ---------------------- Paths (hair/ first) ----------------------
6
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ CANDIDATES = [
8
+ os.path.join(BASE_DIR, "hair"), # <- your current folder
9
+ os.path.join(BASE_DIR, "assets", "hairstyles"),
10
+ os.path.join(BASE_DIR, "assets", "Hairstyles"),
11
+ os.path.join(BASE_DIR, "hairstyles"),
12
+ ]
13
+ HAIR_DIR = None
14
+ for p in CANDIDATES:
15
+ if os.path.isdir(p):
16
+ HAIR_DIR = p
17
+ break
18
+ if HAIR_DIR is None: # create the canonical path if nothing exists yet
19
+ HAIR_DIR = os.path.join(BASE_DIR, "hair")
20
+ os.makedirs(HAIR_DIR, exist_ok=True)
21
+
22
+ META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
23
+
24
+ # ---------------------- Dependencies ----------------------
25
+ try:
26
+ import mediapipe as mp # FaceMesh + SelfieSeg
27
+ except Exception as e:
28
+ raise RuntimeError(
29
+ f"Failed to import mediapipe. Check requirements.txt pins. Details: {e}"
30
+ )
31
 
 
 
 
32
  mp_face_mesh = mp.solutions.face_mesh
33
+ mp_selfie_seg = mp.solutions.selfie_segmentation
34
+
35
+ LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
 
 
36
 
37
+ # ---------------------- Helpers ----------------------
38
  def load_hairstyles():
39
+ """Return sorted list of .png files in HAIR_DIR."""
40
+ try:
41
+ files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
42
+ except FileNotFoundError:
43
+ files = []
44
  files.sort()
45
  return files
46
 
47
  HAIR_FILES = load_hairstyles()
48
 
49
+ def load_meta():
50
+ if os.path.exists(META_PATH):
51
+ try:
52
+ with open(META_PATH, "r") as f:
53
+ m = json.load(f)
54
+ return m if isinstance(m, dict) else {}
55
+ except Exception:
56
+ return {}
57
+ return {}
58
+
59
+ META = load_meta()
60
+
61
+ def detect_face_keypoints(img_bgr):
62
+ """Return 3 keypoints (left eye outer, right eye outer, mid-forehead) or None."""
63
+ h, w = img_bgr.shape[:2]
64
  with mp_face_mesh.FaceMesh(
65
+ static_image_mode=True, max_num_faces=1, refine_landmarks=True,
 
 
66
  min_detection_confidence=0.6
67
+ ) as fm:
68
+ res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
 
69
  if not res.multi_face_landmarks:
70
  return None
71
  lm = res.multi_face_landmarks[0].landmark
72
+ def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
73
+ return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
74
+
75
+ def person_mask(img_bgr):
76
+ """Rough head isolation using selfie segmentation + feathering."""
77
+ with mp_selfie_seg.SelfieSegmentation(model_selection=1) as seg:
78
+ rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
79
+ m = seg.process(rgb).segmentation_mask
80
+ mask = (m > 0.5).astype(np.float32)
81
+ mask = cv2.GaussianBlur(mask, (35, 35), 0)
82
+ return mask
83
 
84
  def load_hair_png(name):
85
  path = os.path.join(HAIR_DIR, name)
86
  hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
87
  if hair is None or hair.shape[2] != 4:
88
+ raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
89
  return hair
90
 
91
+ def hair_reference_points(hair_bgra, filename):
92
+ """Three anchors on hair image; override via meta.json if present."""
 
 
 
 
 
93
  h, w = hair_bgra.shape[:2]
94
+ if filename in META:
95
+ pts = np.array(META[filename], dtype=np.float32)
96
+ if pts.shape == (3, 2):
97
+ return pts
98
+ # Defaults (works for many styles; refine via meta.json for perfection)
99
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
100
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
101
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
102
  return np.stack([pL, pR, pM], axis=0)
103
 
104
+ def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
105
+ H, W = base_bgr.shape[:2]
106
+ hair_rgb = hair_bgra[:, :, :3]
107
+ hair_a = hair_bgra[:, :, 3] / 255.0
108
+ hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
109
+ a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
110
+ a = np.clip(a_warp * opacity, 0, 1)[..., None]
111
+ out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
 
 
 
 
112
  return out
113
 
114
+ def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity):
 
 
 
 
 
 
 
 
115
  if image is None:
116
+ return None, "Upload a photo or enable webcam."
117
+ if not hairstyle:
118
+ return np.array(image), "Pick a hairstyle first."
119
 
 
120
  img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
121
+
122
+ kpts = detect_face_keypoints(img_bgr)
123
+ if kpts is None:
124
+ return image, "No face detected. Try a brighter, front-facing photo."
125
 
126
  hair = load_hair_png(hairstyle)
127
+ hair_pts = hair_reference_points(hair, hairstyle)
128
 
129
+ # User nudges on destination points
130
+ dst = kpts.copy()
131
+ dst[:, 0] += dx
132
+ dst[:, 1] += dy
 
133
 
134
+ # Scale + rotate hair anchors around their centroid
135
  center = hair_pts.mean(axis=0)
136
  theta = np.deg2rad(rot_deg)
137
  s = max(0.5, scale_pct / 100.0)
 
138
  R = np.array([[np.cos(theta), -np.sin(theta)],
139
  [np.sin(theta), np.cos(theta)]], dtype=np.float32)
140
  hair_pts_adj = (hair_pts - center) @ R.T * s + center
141
 
142
+ M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
 
 
143
  if M is None:
144
+ return image, "Could not compute alignment for this image/style."
145
+
146
+ out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
147
+
148
+ # Restrict to head region for cleaner look
149
+ head = person_mask(img_bgr)
150
+ head3 = head[..., None]
151
+ out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8)
152
 
153
+ out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
 
154
  return out_rgb, "OK"
155
 
156
+ def save_png(img):
157
+ if img is None:
158
+ return None
159
+ p = os.path.join(tempfile.gettempdir(), "tryon_result.png")
160
+ Image.fromarray(img).save(p)
161
+ return p
162
+
163
+ def hair_preview(hairstyle):
164
+ if not hairstyle:
165
+ return None
166
+ # Show the raw PNG on checkerboard background for visibility
167
+ hair = load_hair_png(hairstyle)
168
+ h, w = hair.shape[:2]
169
+ # Make checkerboard
170
+ tile = 16
171
+ bg = np.kron(
172
+ ((np.indices((h//tile+1, w//tile+1)).sum(axis=0) % 2) * 64 + 192).astype(np.uint8),
173
+ np.ones((tile, tile), np.uint8)
174
+ )[:h, :w]
175
+ bg_rgb = np.dstack([bg, bg, bg])
176
+ a = (hair[:, :, 3:4] / 255.0)
177
+ comp = (a * hair[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
178
+ comp = cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
179
+ return comp
180
+
181
+ # ---------------------- UI ----------------------
182
+ def ui():
183
+ with gr.Blocks(title="Virtual Try-On (FR1–FR8)", css="""
184
+ .gradio-container {max-width: 980px; margin: auto;}
185
+ @media (max-width: 768px){ .gradio-container {padding: 8px;} }
186
+ """) as demo:
187
+ gr.Markdown("## Salon Hairstyle Virtual Try-On\nUpload or use webcam, pick a style from **Select Hairstyle**, adjust, then download.")
188
+
189
+ if not HAIR_FILES:
190
+ gr.Markdown("⚠️ **No hairstyle PNGs found.** Upload files into **`hair/`** (or `assets/hairstyles/`) and reload this Space.")
191
+
192
  with gr.Tabs():
193
+ # ---------------- Photo Tab (FR1,3–7) ----------------
194
  with gr.Tab("Photo"):
195
  with gr.Row():
196
+ i