RojaKatta commited on
Commit
c78d7f0
·
verified ·
1 Parent(s): 4a9e43b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -268
app.py CHANGED
@@ -1,271 +1,59 @@
1
- import os, glob, math
2
- import numpy as np
3
- from PIL import Image
4
  import gradio as gr
5
-
6
- # Optional heavy deps (installed via requirements.txt)
7
- try:
8
- import cv2
9
- import mediapipe as mp
10
- MP_AVAILABLE = True
11
- except Exception:
12
- MP_AVAILABLE = False
13
-
14
- APP_TITLE = "Salon Hairstyle Virtual Try-On — Pro Demo"
15
- HAIR_DIR = "assets/hairstyles"
16
-
17
- # --------------------------
18
- # Asset loading / utilities
19
- # --------------------------
20
- def list_hairstyle_files():
21
- files = sorted(glob.glob(os.path.join(HAIR_DIR, "*.png")))
22
- labels = [os.path.basename(f).replace(".png","").replace("_"," ").title() for f in files]
23
- thumbs = []
24
- for f in files:
25
- thumbs.append(f) # Gradio Gallery can take path strings
26
- return labels, files, thumbs
27
-
28
- def ensure_rgba(im: Image.Image) -> Image.Image:
29
- return im if im.mode == "RGBA" else im.convert("RGBA")
30
-
31
- def overlay_rgba(base: Image.Image, overlay: Image.Image, x: int, y: int) -> Image.Image:
32
- base = ensure_rgba(base)
33
- overlay = ensure_rgba(overlay)
34
- bw, bh = base.size
35
- ow, oh = overlay.size
36
- if ow <= 0 or oh <= 0:
37
- return base
38
- x1, y1 = max(0, x), max(0, y)
39
- x2, y2 = min(bw, x + ow), min(bh, y + oh)
40
- if x1 >= x2 or y1 >= y2:
41
- return base
42
- crop = overlay.crop((x1 - x, y1 - y, x2 - x, y2 - y))
43
- region = base.crop((x1, y1, x2, y2))
44
- region = Image.alpha_composite(region, crop)
45
- base.paste(region, (x1, y1))
46
- return base
47
-
48
- # --------------------------
49
- # Face detection / anchoring
50
- # --------------------------
51
- def detect_face_keypoints(np_img):
52
- """
53
- Returns dict with keys:
54
- bbox: (x1,y1,x2,y2)
55
- temples: (lx,ly, rx,ry) approximate
56
- forehead_y: int (approx y of forehead landmark)
57
- Uses MediaPipe FaceMesh when available; otherwise returns center-based heuristic.
58
- """
59
- h, w = np_img.shape[:2]
60
-
61
- def fallback():
62
- cx, cy = w // 2, int(h * 0.42)
63
- bw, bh = int(w * 0.4), int(h * 0.45)
64
- x1, y1 = max(0, cx - bw // 2), max(0, cy - bh // 2)
65
- x2, y2 = min(w, x1 + bw), min(h, y1 + bh)
66
- return {
67
- "bbox": (x1, y1, x2, y2),
68
- "temples": (x1, (y1 + y2) // 2, x2, (y1 + y2) // 2),
69
- "forehead_y": max(0, int(y1 - 0.1 * (y2 - y1))),
70
- }
71
-
72
- if not MP_AVAILABLE:
73
- return fallback()
74
-
75
- mpfm = mp.solutions.face_mesh
76
- with mpfm.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=False) as fm:
77
- rgb = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) if np_img.shape[2] == 3 else np_img[..., :3]
78
- res = fm.process(rgb)
79
- if not res.multi_face_landmarks:
80
- return fallback()
81
-
82
- lm = res.multi_face_landmarks[0].landmark
83
- xs = np.array([p.x for p in lm]) * w
84
- ys = np.array([p.y for p in lm]) * h
85
- x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
86
-
87
- def safe_idx(i):
88
- i = int(i)
89
- i = max(0, min(len(lm) - 1, i))
90
- return int(xs[i]), int(ys[i])
91
-
92
- lx, ly = safe_idx(127) # approx left temple
93
- rx, ry = safe_idx(356) # approx right temple
94
- _, fy = safe_idx(10) # approx forehead
95
- y1 = max(0, int(y1 - 0.12 * (y2 - y1))) # expand up a bit
96
-
97
- return {
98
- "bbox": (x1, y1, x2, y2),
99
- "temples": (lx, ly, rx, ry),
100
- "forehead_y": fy,
101
- }
102
-
103
- def place_and_render(base_np, hair_path, scale, x_shift_pct, y_shift_pct, rotation_deg):
104
- if base_np is None:
105
- raise gr.Error("Please upload or capture a photo first.")
106
- base = Image.fromarray(base_np.astype("uint8")).convert("RGBA")
107
- key = detect_face_keypoints(np.array(base.convert("RGB")))
108
- x1, y1, x2, y2 = key["bbox"]
109
- lx, ly, rx, ry = key["temples"]
110
- forehead_y = key["forehead_y"]
111
-
112
- hair = Image.open(hair_path).convert("RGBA")
113
-
114
- # Derive scale from temple distance
115
- temple_dx = rx - lx
116
- temple_dy = ry - ly
117
- temple_dist = max(1, (temple_dx ** 2 + temple_dy ** 2) ** 0.5)
118
- target_w = max(1, int(temple_dist * 2.0 * scale)) # widen beyond temples
119
- ratio = target_w / hair.width
120
- target_h = max(1, int(hair.height * ratio))
121
- hair_resized = hair.resize((target_w, target_h), Image.LANCZOS)
122
-
123
- # Auto-rotate with temple slope + extra manual rotation
124
- auto_deg = math.degrees(math.atan2(temple_dy, temple_dx))
125
- rot_total = auto_deg + rotation_deg
126
- hair_resized = hair_resized.rotate(rot_total, expand=True)
127
-
128
- # Anchor: horizontally center at temple midpoint; vertically above forehead
129
- midx = int((lx + rx) / 2)
130
- anchor_x = int(midx - hair_resized.width / 2)
131
- anchor_y = int(forehead_y - hair_resized.height * 0.45)
132
-
133
- # Manual shifts (percent of image size)
134
- img_w, img_h = base.size
135
- anchor_x += int(x_shift_pct * img_w / 100.0)
136
- anchor_y += int(y_shift_pct * img_h / 100.0)
137
-
138
- out = overlay_rgba(base, hair_resized, anchor_x, anchor_y).convert("RGB")
139
- return out
140
-
141
- # --------------------------
142
- # Gradio callbacks
143
- # --------------------------
144
- def refresh_assets():
145
- labels, files, thumbs = list_hairstyle_files()
146
- if not files:
147
- return (
148
- gr.update(value=None, choices=[]), # dropdown
149
- gr.update(value=None), # gallery
150
- "No PNGs in assets/hairstyles. Add some and press Refresh.",
151
- )
152
- dd = gr.update(choices=labels, value=labels[0])
153
- gallery = gr.update(value=[[p, l] for p, l in zip(thumbs, labels)])
154
- return dd, gallery, f"Found {len(files)} hairstyles."
155
-
156
- def refresh_dropdown_only():
157
- labels, files, _ = list_hairstyle_files()
158
- if not files:
159
- return gr.update(value=None, choices=[])
160
- return gr.update(choices=labels, value=labels[0])
161
-
162
- def pick_from_gallery(gallery_select, current_dropdown):
163
- # gallery_select: (index, value)
164
- labels, _, _ = list_hairstyle_files()
165
- if not labels:
166
- raise gr.Error("No hairstyles available.")
167
- if gallery_select is None:
168
- return current_dropdown
169
- idx = gallery_select[0]
170
- idx = max(0, min(len(labels) - 1, idx))
171
- return labels[idx]
172
-
173
- def apply_from_dropdown(image, hairstyle_label, scale, xs, ys, rot):
174
- labels, files, _ = list_hairstyle_files()
175
- if not labels:
176
- raise gr.Error("No hairstyle assets found.")
177
- if hairstyle_label not in labels:
178
- raise gr.Error("Choose a hairstyle first.")
179
- path = files[labels.index(hairstyle_label)]
180
- return place_and_render(image, path, scale, xs, ys, rot)
181
-
182
- def apply_from_gallery(image, gallery_select, scale, xs, ys, rot):
183
- labels, files, _ = list_hairstyle_files()
184
- if not labels:
185
- raise gr.Error("No hairstyle assets found.")
186
- if gallery_select is None:
187
- raise gr.Error("Select a hairstyle from the gallery.")
188
- idx = gallery_select[0]
189
- idx = max(0, min(len(files) - 1, idx))
190
- return place_and_render(image, files[idx], scale, xs, ys, rot)
191
-
192
- def save_image(img_np):
193
- if img_np is None:
194
- raise gr.Error("Nothing to save. Generate a preview first.")
195
- out_path = "output_tryon.png"
196
- Image.fromarray(img_np).save(out_path)
197
- return out_path
198
-
199
- # --------------------------
200
- # UI
201
- # --------------------------
202
- with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo:
203
- gr.Markdown(f"# {APP_TITLE}")
204
- gr.Markdown(
205
- "Upload a photo or use webcam. Put **transparent PNG** hairstyles in `assets/hairstyles/`, then **Refresh**."
206
  )
207
 
208
- with gr.Tabs():
209
- # ------------------ Upload tab ------------------
210
- with gr.Tab("📷 Photo (Upload)"):
211
- with gr.Row():
212
- with gr.Column():
213
- img = gr.Image(label="Photo", sources=["upload"], type="numpy", height=420)
214
- hair_dd = gr.Dropdown(label="Hairstyle (Dropdown)", choices=[], interactive=True)
215
- refresh = gr.Button("🔄 Refresh")
216
- status = gr.Markdown("Add PNGs to `assets/hairstyles/` and press Refresh.")
217
- gallery = gr.Gallery(
218
- label="Hairstyles (click to choose)",
219
- columns=4,
220
- height=220,
221
- allow_preview=False,
222
- interactive=True,
223
- )
224
-
225
- with gr.Accordion("Fine-tune placement", open=False):
226
- scale = gr.Slider(0.6, 2.2, value=1.3, step=0.01, label="Scale (× temple distance)")
227
- xs = gr.Slider(-30, 30, value=0, step=0.5, label="Horizontal shift (% image width)")
228
- ys = gr.Slider(-30, 30, value=-2, step=0.5, label="Vertical shift (% image height)")
229
- rot = gr.Slider(-30, 30, value=0, step=0.5, label="Extra rotation (°)")
230
-
231
- run_dd = gr.Button("✨ Apply (Dropdown)")
232
- run_ga = gr.Button("✨ Apply (Gallery selection)")
233
-
234
- with gr.Column():
235
- out = gr.Image(label="Preview", height=480)
236
- save_btn = gr.Button("💾 Save result")
237
- file_out = gr.File(label="Download")
238
-
239
- # Wiring
240
- refresh.click(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
241
- demo.load(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
242
-
243
- gallery.select(fn=pick_from_gallery, inputs=[gallery, hair_dd], outputs=hair_dd)
244
- run_dd.click(fn=apply_from_dropdown, inputs=[img, hair_dd, scale, xs, ys, rot], outputs=out)
245
- run_ga.click(fn=apply_from_gallery, inputs=[img, gallery, scale, xs, ys, rot], outputs=out)
246
- save_btn.click(fn=save_image, inputs=out, outputs=file_out)
247
-
248
- # ------------------ Webcam tab ------------------
249
- with gr.Tab("🎥 Webcam (Live Beta)"):
250
- gr.Markdown("Live mode processes frames continuously. For CPU Spaces, keep webcam resolution modest.")
251
- cam = gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Webcam")
252
- hair_dd2 = gr.Dropdown(label="Hairstyle", choices=[], interactive=True)
253
- scale2 = gr.Slider(0.6, 2.2, value=1.25, step=0.01, label="Scale")
254
- xs2 = gr.Slider(-30, 30, value=0, step=0.5, label="X shift (%)")
255
- ys2 = gr.Slider(-30, 30, value=-2, step=0.5, label="Y shift (%)")
256
- rot2 = gr.Slider(-30, 30, value=0, step=0.5, label="Rotation (°)")
257
- out_live = gr.Image(label="Live Preview", interactive=False, height=420)
258
-
259
- def live_process(frame, label, s, x, y, r):
260
- labels, files, _ = list_hairstyle_files()
261
- if frame is None or not labels or label not in labels:
262
- return frame
263
- path = files[labels.index(label)]
264
- return place_and_render(frame, path, s, x, y, r)
265
-
266
- cam.stream(fn=live_process, inputs=[cam, hair_dd2, scale2, xs2, ys2, rot2], outputs=out_live)
267
- # Populate only the dropdown on load in this tab
268
- demo.load(fn=refresh_dropdown_only, inputs=None, outputs=hair_dd2)
269
-
270
- if __name__ == "__main__":
271
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+
5
+ def load_hairstyles():
6
+ folder = "hairstyles"
7
+ if not os.path.exists(folder):
8
+ return []
9
+ return [
10
+ Image.open(os.path.join(folder, f)).convert("RGBA")
11
+ for f in sorted(os.listdir(folder)) if f.endswith(".png")
12
+ ]
13
+
14
+ hairstyles = load_hairstyles()
15
+
16
+ def apply_hairstyle(user_img, style_index, x_offset, y_offset, scale):
17
+ if user_img is None or not hairstyles:
18
+ return None
19
+
20
+ user_img = user_img.convert("RGBA")
21
+ base_w, base_h = user_img.size
22
+
23
+ hairstyle = hairstyles[style_index]
24
+
25
+ # Resize the hairstyle based on scale
26
+ new_size = (int(base_w * scale), int(hairstyle.height * (base_w * scale / hairstyle.width)))
27
+ hairstyle = hairstyle.resize(new_size)
28
+
29
+ # Create a blank transparent image to position the hairstyle
30
+ composite = Image.new("RGBA", user_img.size)
31
+ paste_x = int((base_w - new_size[0]) / 2 + x_offset)
32
+ paste_y = int(y_offset)
33
+ composite.paste(hairstyle, (paste_x, paste_y), hairstyle)
34
+
35
+ # Overlay it
36
+ result = Image.alpha_composite(user_img, composite)
37
+ return result.convert("RGB")
38
+
39
+ with gr.Blocks() as demo:
40
+ gr.Markdown("## 💇 Salon Virtual Hairstyle Try-On (Adjustable)")
41
+
42
+ with gr.Row():
43
+ with gr.Column():
44
+ image_input = gr.Image(type="pil", label="📷 Upload an Image")
45
+ style_slider = gr.Slider(0, max(len(hairstyles)-1, 0), step=1, label="🎨 Select Hairstyle")
46
+ x_offset = gr.Slider(-200, 200, value=0, step=1, label="⬅️➡️ Move Left / Right")
47
+ y_offset = gr.Slider(-200, 200, value=0, step=1, label="⬆️⬇️ Move Up / Down")
48
+ scale = gr.Slider(0.3, 2.0, value=1.0, step=0.05, label="📏 Scale Hairstyle")
49
+ apply_btn = gr.Button("✨ Apply Hairstyle")
50
+ with gr.Column():
51
+ result_output = gr.Image(label="🔍 Result Preview")
52
+
53
+ apply_btn.click(
54
+ fn=apply_hairstyle,
55
+ inputs=[image_input, style_slider, x_offset, y_offset, scale],
56
+ outputs=result_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
 
59
+ demo.launch()