RojaKatta commited on
Commit
5b6030c
Β·
verified Β·
1 Parent(s): d68f8af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -56
app.py CHANGED
@@ -1,59 +1,271 @@
1
- import gradio as gr
 
2
  from PIL import Image
3
- import os
4
-
5
- def load_hairstyles():
6
- folder = "hairstyles"
7
- if not os.path.exists(folder):
8
- return []
9
- return [
10
- Image.open(os.path.join(folder, f)).convert("RGBA")
11
- for f in sorted(os.listdir(folder)) if f.endswith(".png")
12
- ]
13
-
14
- hairstyles = load_hairstyles()
15
-
16
- def apply_hairstyle(user_img, style_index, x_offset, y_offset, scale):
17
- if user_img is None or not hairstyles:
18
- return None
19
-
20
- user_img = user_img.convert("RGBA")
21
- base_w, base_h = user_img.size
22
-
23
- hairstyle = hairstyles[style_index]
24
-
25
- # Resize the hairstyle based on scale
26
- new_size = (int(base_w * scale), int(hairstyle.height * (base_w * scale / hairstyle.width)))
27
- hairstyle = hairstyle.resize(new_size)
28
-
29
- # Create a blank transparent image to position the hairstyle
30
- composite = Image.new("RGBA", user_img.size)
31
- paste_x = int((base_w - new_size[0]) / 2 + x_offset)
32
- paste_y = int(y_offset)
33
- composite.paste(hairstyle, (paste_x, paste_y), hairstyle)
34
-
35
- # Overlay it
36
- result = Image.alpha_composite(user_img, composite)
37
- return result.convert("RGB")
38
-
39
- with gr.Blocks() as demo:
40
- gr.Markdown("## πŸ’‡ Salon Virtual Hairstyle Try-On (Adjustable)")
41
-
42
- with gr.Row():
43
- with gr.Column():
44
- image_input = gr.Image(type="pil", label="πŸ“· Upload an Image")
45
- style_slider = gr.Slider(0, max(len(hairstyles)-1, 0), step=1, label="🎨 Select Hairstyle")
46
- x_offset = gr.Slider(-200, 200, value=0, step=1, label="β¬…οΈβž‘οΈ Move Left / Right")
47
- y_offset = gr.Slider(-200, 200, value=0, step=1, label="⬆️⬇️ Move Up / Down")
48
- scale = gr.Slider(0.3, 2.0, value=1.0, step=0.05, label="πŸ“ Scale Hairstyle")
49
- apply_btn = gr.Button("✨ Apply Hairstyle")
50
- with gr.Column():
51
- result_output = gr.Image(label="πŸ” Result Preview")
52
-
53
- apply_btn.click(
54
- fn=apply_hairstyle,
55
- inputs=[image_input, style_slider, x_offset, y_offset, scale],
56
- outputs=result_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
 
59
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, math
2
+ import numpy as np
3
  from PIL import Image
4
+ import gradio as gr
5
+
6
+ # Optional heavy deps (installed via requirements.txt)
7
+ try:
8
+ import cv2
9
+ import mediapipe as mp
10
+ MP_AVAILABLE = True
11
+ except Exception:
12
+ MP_AVAILABLE = False
13
+
14
+ APP_TITLE = "Salon Hairstyle Virtual Try-On β€” Pro Demo"
15
+ HAIR_DIR = "assets/hairstyles"
16
+
17
+ # --------------------------
18
+ # Asset loading / utilities
19
+ # --------------------------
20
+ def list_hairstyle_files():
21
+ files = sorted(glob.glob(os.path.join(HAIR_DIR, "*.png")))
22
+ labels = [os.path.basename(f).replace(".png","").replace("_"," ").title() for f in files]
23
+ thumbs = []
24
+ for f in files:
25
+ thumbs.append(f) # Gradio Gallery can take path strings
26
+ return labels, files, thumbs
27
+
28
+ def ensure_rgba(im: Image.Image) -> Image.Image:
29
+ return im if im.mode == "RGBA" else im.convert("RGBA")
30
+
31
+ def overlay_rgba(base: Image.Image, overlay: Image.Image, x: int, y: int) -> Image.Image:
32
+ base = ensure_rgba(base)
33
+ overlay = ensure_rgba(overlay)
34
+ bw, bh = base.size
35
+ ow, oh = overlay.size
36
+ if ow <= 0 or oh <= 0:
37
+ return base
38
+ x1, y1 = max(0, x), max(0, y)
39
+ x2, y2 = min(bw, x + ow), min(bh, y + oh)
40
+ if x1 >= x2 or y1 >= y2:
41
+ return base
42
+ crop = overlay.crop((x1 - x, y1 - y, x2 - x, y2 - y))
43
+ region = base.crop((x1, y1, x2, y2))
44
+ region = Image.alpha_composite(region, crop)
45
+ base.paste(region, (x1, y1))
46
+ return base
47
+
48
+ # --------------------------
49
+ # Face detection / anchoring
50
+ # --------------------------
51
+ def detect_face_keypoints(np_img):
52
+ """
53
+ Returns dict with keys:
54
+ bbox: (x1,y1,x2,y2)
55
+ temples: (lx,ly, rx,ry) approximate
56
+ forehead_y: int (approx y of forehead landmark)
57
+ Uses MediaPipe FaceMesh when available; otherwise returns center-based heuristic.
58
+ """
59
+ h, w = np_img.shape[:2]
60
+
61
+ def fallback():
62
+ cx, cy = w // 2, int(h * 0.42)
63
+ bw, bh = int(w * 0.4), int(h * 0.45)
64
+ x1, y1 = max(0, cx - bw // 2), max(0, cy - bh // 2)
65
+ x2, y2 = min(w, x1 + bw), min(h, y1 + bh)
66
+ return {
67
+ "bbox": (x1, y1, x2, y2),
68
+ "temples": (x1, (y1 + y2) // 2, x2, (y1 + y2) // 2),
69
+ "forehead_y": max(0, int(y1 - 0.1 * (y2 - y1))),
70
+ }
71
+
72
+ if not MP_AVAILABLE:
73
+ return fallback()
74
+
75
+ mpfm = mp.solutions.face_mesh
76
+ with mpfm.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=False) as fm:
77
+ rgb = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) if np_img.shape[2] == 3 else np_img[..., :3]
78
+ res = fm.process(rgb)
79
+ if not res.multi_face_landmarks:
80
+ return fallback()
81
+
82
+ lm = res.multi_face_landmarks[0].landmark
83
+ xs = np.array([p.x for p in lm]) * w
84
+ ys = np.array([p.y for p in lm]) * h
85
+ x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
86
+
87
+ def safe_idx(i):
88
+ i = int(i)
89
+ i = max(0, min(len(lm) - 1, i))
90
+ return int(xs[i]), int(ys[i])
91
+
92
+ lx, ly = safe_idx(127) # approx left temple
93
+ rx, ry = safe_idx(356) # approx right temple
94
+ _, fy = safe_idx(10) # approx forehead
95
+ y1 = max(0, int(y1 - 0.12 * (y2 - y1))) # expand up a bit
96
+
97
+ return {
98
+ "bbox": (x1, y1, x2, y2),
99
+ "temples": (lx, ly, rx, ry),
100
+ "forehead_y": fy,
101
+ }
102
+
103
+ def place_and_render(base_np, hair_path, scale, x_shift_pct, y_shift_pct, rotation_deg):
104
+ if base_np is None:
105
+ raise gr.Error("Please upload or capture a photo first.")
106
+ base = Image.fromarray(base_np.astype("uint8")).convert("RGBA")
107
+ key = detect_face_keypoints(np.array(base.convert("RGB")))
108
+ x1, y1, x2, y2 = key["bbox"]
109
+ lx, ly, rx, ry = key["temples"]
110
+ forehead_y = key["forehead_y"]
111
+
112
+ hair = Image.open(hair_path).convert("RGBA")
113
+
114
+ # Derive scale from temple distance
115
+ temple_dx = rx - lx
116
+ temple_dy = ry - ly
117
+ temple_dist = max(1, (temple_dx ** 2 + temple_dy ** 2) ** 0.5)
118
+ target_w = max(1, int(temple_dist * 2.0 * scale)) # widen beyond temples
119
+ ratio = target_w / hair.width
120
+ target_h = max(1, int(hair.height * ratio))
121
+ hair_resized = hair.resize((target_w, target_h), Image.LANCZOS)
122
+
123
+ # Auto-rotate with temple slope + extra manual rotation
124
+ auto_deg = math.degrees(math.atan2(temple_dy, temple_dx))
125
+ rot_total = auto_deg + rotation_deg
126
+ hair_resized = hair_resized.rotate(rot_total, expand=True)
127
+
128
+ # Anchor: horizontally center at temple midpoint; vertically above forehead
129
+ midx = int((lx + rx) / 2)
130
+ anchor_x = int(midx - hair_resized.width / 2)
131
+ anchor_y = int(forehead_y - hair_resized.height * 0.45)
132
+
133
+ # Manual shifts (percent of image size)
134
+ img_w, img_h = base.size
135
+ anchor_x += int(x_shift_pct * img_w / 100.0)
136
+ anchor_y += int(y_shift_pct * img_h / 100.0)
137
+
138
+ out = overlay_rgba(base, hair_resized, anchor_x, anchor_y).convert("RGB")
139
+ return out
140
+
141
+ # --------------------------
142
+ # Gradio callbacks
143
+ # --------------------------
144
+ def refresh_assets():
145
+ labels, files, thumbs = list_hairstyle_files()
146
+ if not files:
147
+ return (
148
+ gr.update(value=None, choices=[]), # dropdown
149
+ gr.update(value=None), # gallery
150
+ "No PNGs in assets/hairstyles. Add some and press Refresh.",
151
+ )
152
+ dd = gr.update(choices=labels, value=labels[0])
153
+ gallery = gr.update(value=[[p, l] for p, l in zip(thumbs, labels)])
154
+ return dd, gallery, f"Found {len(files)} hairstyles."
155
+
156
+ def refresh_dropdown_only():
157
+ labels, files, _ = list_hairstyle_files()
158
+ if not files:
159
+ return gr.update(value=None, choices=[])
160
+ return gr.update(choices=labels, value=labels[0])
161
+
162
+ def pick_from_gallery(gallery_select, current_dropdown):
163
+ # gallery_select: (index, value)
164
+ labels, _, _ = list_hairstyle_files()
165
+ if not labels:
166
+ raise gr.Error("No hairstyles available.")
167
+ if gallery_select is None:
168
+ return current_dropdown
169
+ idx = gallery_select[0]
170
+ idx = max(0, min(len(labels) - 1, idx))
171
+ return labels[idx]
172
+
173
+ def apply_from_dropdown(image, hairstyle_label, scale, xs, ys, rot):
174
+ labels, files, _ = list_hairstyle_files()
175
+ if not labels:
176
+ raise gr.Error("No hairstyle assets found.")
177
+ if hairstyle_label not in labels:
178
+ raise gr.Error("Choose a hairstyle first.")
179
+ path = files[labels.index(hairstyle_label)]
180
+ return place_and_render(image, path, scale, xs, ys, rot)
181
+
182
+ def apply_from_gallery(image, gallery_select, scale, xs, ys, rot):
183
+ labels, files, _ = list_hairstyle_files()
184
+ if not labels:
185
+ raise gr.Error("No hairstyle assets found.")
186
+ if gallery_select is None:
187
+ raise gr.Error("Select a hairstyle from the gallery.")
188
+ idx = gallery_select[0]
189
+ idx = max(0, min(len(files) - 1, idx))
190
+ return place_and_render(image, files[idx], scale, xs, ys, rot)
191
+
192
+ def save_image(img_np):
193
+ if img_np is None:
194
+ raise gr.Error("Nothing to save. Generate a preview first.")
195
+ out_path = "output_tryon.png"
196
+ Image.fromarray(img_np).save(out_path)
197
+ return out_path
198
+
199
+ # --------------------------
200
+ # UI
201
+ # --------------------------
202
+ with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo:
203
+ gr.Markdown(f"# {APP_TITLE}")
204
+ gr.Markdown(
205
+ "Upload a photo or use webcam. Put **transparent PNG** hairstyles in `assets/hairstyles/`, then **Refresh**."
206
  )
207
 
208
+ with gr.Tabs():
209
+ # ------------------ Upload tab ------------------
210
+ with gr.Tab("πŸ“· Photo (Upload)"):
211
+ with gr.Row():
212
+ with gr.Column():
213
+ img = gr.Image(label="Photo", sources=["upload"], type="numpy", height=420)
214
+ hair_dd = gr.Dropdown(label="Hairstyle (Dropdown)", choices=[], interactive=True)
215
+ refresh = gr.Button("πŸ”„ Refresh")
216
+ status = gr.Markdown("Add PNGs to `assets/hairstyles/` and press Refresh.")
217
+ gallery = gr.Gallery(
218
+ label="Hairstyles (click to choose)",
219
+ columns=4,
220
+ height=220,
221
+ allow_preview=False,
222
+ interactive=True,
223
+ )
224
+
225
+ with gr.Accordion("Fine-tune placement", open=False):
226
+ scale = gr.Slider(0.6, 2.2, value=1.3, step=0.01, label="Scale (Γ— temple distance)")
227
+ xs = gr.Slider(-30, 30, value=0, step=0.5, label="Horizontal shift (% image width)")
228
+ ys = gr.Slider(-30, 30, value=-2, step=0.5, label="Vertical shift (% image height)")
229
+ rot = gr.Slider(-30, 30, value=0, step=0.5, label="Extra rotation (Β°)")
230
+
231
+ run_dd = gr.Button("✨ Apply (Dropdown)")
232
+ run_ga = gr.Button("✨ Apply (Gallery selection)")
233
+
234
+ with gr.Column():
235
+ out = gr.Image(label="Preview", height=480)
236
+ save_btn = gr.Button("πŸ’Ύ Save result")
237
+ file_out = gr.File(label="Download")
238
+
239
+ # Wiring
240
+ refresh.click(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
241
+ demo.load(fn=refresh_assets, inputs=None, outputs=[hair_dd, gallery, status])
242
+
243
+ gallery.select(fn=pick_from_gallery, inputs=[gallery, hair_dd], outputs=hair_dd)
244
+ run_dd.click(fn=apply_from_dropdown, inputs=[img, hair_dd, scale, xs, ys, rot], outputs=out)
245
+ run_ga.click(fn=apply_from_gallery, inputs=[img, gallery, scale, xs, ys, rot], outputs=out)
246
+ save_btn.click(fn=save_image, inputs=out, outputs=file_out)
247
+
248
+ # ------------------ Webcam tab ------------------
249
+ with gr.Tab("πŸŽ₯ Webcam (Live Beta)"):
250
+ gr.Markdown("Live mode processes frames continuously. For CPU Spaces, keep webcam resolution modest.")
251
+ cam = gr.Image(sources=["webcam"], streaming=True, type="numpy", label="Webcam")
252
+ hair_dd2 = gr.Dropdown(label="Hairstyle", choices=[], interactive=True)
253
+ scale2 = gr.Slider(0.6, 2.2, value=1.25, step=0.01, label="Scale")
254
+ xs2 = gr.Slider(-30, 30, value=0, step=0.5, label="X shift (%)")
255
+ ys2 = gr.Slider(-30, 30, value=-2, step=0.5, label="Y shift (%)")
256
+ rot2 = gr.Slider(-30, 30, value=0, step=0.5, label="Rotation (Β°)")
257
+ out_live = gr.Image(label="Live Preview", interactive=False, height=420)
258
+
259
+ def live_process(frame, label, s, x, y, r):
260
+ labels, files, _ = list_hairstyle_files()
261
+ if frame is None or not labels or label not in labels:
262
+ return frame
263
+ path = files[labels.index(label)]
264
+ return place_and_render(frame, path, s, x, y, r)
265
+
266
+ cam.stream(fn=live_process, inputs=[cam, hair_dd2, scale2, xs2, ys2, rot2], outputs=out_live)
267
+ # Populate only the dropdown on load in this tab
268
+ demo.load(fn=refresh_dropdown_only, inputs=None, outputs=hair_dd2)
269
+
270
+ if __name__ == "__main__":
271
+ demo.launch()