RojaKatta commited on
Commit
69d3136
Β·
verified Β·
1 Parent(s): b68a73e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -1
app.py CHANGED
@@ -108,4 +108,199 @@ def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity, meta):
108
 
109
  img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
110
 
111
- kpt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
110
 
111
+ kpts = detect_face_keypoints(img_bgr)
112
+ if kpts is None:
113
+ return image, "No face detected. Try a brighter, front-facing photo."
114
+
115
+ hair = load_hair_png(hairstyle)
116
+ hair_pts = hair_reference_points(hair, hairstyle, meta)
117
+
118
+ # Destination points (with user nudges)
119
+ dst = kpts.copy()
120
+ dst[:, 0] += dx
121
+ dst[:, 1] += dy
122
+
123
+ # Scale + rotate around hair anchor centroid
124
+ center = hair_pts.mean(axis=0)
125
+ theta = np.deg2rad(rot_deg)
126
+ s = max(0.5, scale_pct / 100.0)
127
+ R = np.array([[np.cos(theta), -np.sin(theta)],
128
+ [np.sin(theta), np.cos(theta)]], dtype=np.float32)
129
+ hair_pts_adj = (hair_pts - center) @ R.T * s + center
130
+
131
+ M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
132
+ if M is None:
133
+ return image, "Could not compute alignment for this image/style."
134
+
135
+ out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
136
+
137
+ # Restrict to head region for cleaner look
138
+ head = person_mask(img_bgr)
139
+ head3 = head[..., None]
140
+ out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8)
141
+
142
+ out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
143
+ return out_rgb, "OK"
144
+
145
+ def save_png_to_tmp(img, filename="output_tryon.png"):
146
+ """DownloadButton needs a file path string. We save to /tmp and return that path."""
147
+ if img is None:
148
+ raise gr.Error("No image to download. Click Apply first.")
149
+ out_path = os.path.join(tempfile.gettempdir(), filename)
150
+ if isinstance(img, np.ndarray):
151
+ Image.fromarray(img).save(out_path)
152
+ else:
153
+ img.save(out_path)
154
+ return out_path
155
+
156
+ # ----- thumbnails on checkerboard for the gallery -----
157
+ def thumb_on_checker(hair_bgra, max_h=220):
158
+ h, w = hair_bgra.shape[:2]
159
+ scale = min(1.0, max_h / h)
160
+ hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
161
+ h, w = hair_bgra.shape[:2]
162
+ tile = 12
163
+ bg = np.kron(((np.indices((h//tile+1, w//tile+1)).sum(axis=0) % 2) * 64 + 192).astype(np.uint8),
164
+ np.ones((tile, tile), np.uint8))[:h, :w]
165
+ bg_rgb = np.dstack([bg, bg, bg])
166
+ a = (hair_bgra[:, :, 3:4] / 255.0)
167
+ comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
168
+ return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
169
+
170
+ def build_gallery_items(files):
171
+ items = []
172
+ for fname in files:
173
+ try:
174
+ img = load_hair_png(fname)
175
+ items.append((thumb_on_checker(img), fname)) # (image, caption)
176
+ except Exception:
177
+ continue
178
+ return items
179
+
180
+ # ===================== UI =====================
181
+ def build_ui():
182
+ META = load_meta()
183
+ HAIR_FILES = load_hairstyles()
184
+
185
+ with gr.Blocks(title="Salon Hairstyle Virtual Try-On β€” Pro Demo", css="""
186
+ .gradio-container {max-width: 1100px; margin:auto;}
187
+ @media (max-width: 768px){ .gradio-container {padding: 8px;} }
188
+ """) as demo:
189
+ gr.Markdown("Upload a photo or use webcam. Put transparent **PNGs** in the **`hair/`** folder, then click **Refresh**.")
190
+
191
+ files_state = gr.State(HAIR_FILES) # keep filenames
192
+ meta_state = gr.State(META)
193
+
194
+ with gr.Tabs():
195
+ # ---------------- Photo Tab ----------------
196
+ with gr.Tab("πŸ“· Photo (Upload)"):
197
+ with gr.Row():
198
+ in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
199
+ out_img = gr.Image(label="Preview", height=360)
200
+ with gr.Row():
201
+ hair_sel = gr.Dropdown(
202
+ choices=HAIR_FILES,
203
+ value=(HAIR_FILES[0] if HAIR_FILES else None),
204
+ label="Selected hairstyle",
205
+ interactive=True
206
+ )
207
+ apply_btn = gr.Button("✨ Apply (Align & Overlay)")
208
+ download_btn = gr.DownloadButton("⬇️ Download") # NOTE: no file_name arg
209
+ status = gr.Markdown()
210
+
211
+ with gr.Row():
212
+ refresh = gr.Button("πŸ”„ Refresh")
213
+ gallery = gr.Gallery(
214
+ label="Hairstyles (click to choose)",
215
+ value=build_gallery_items(HAIR_FILES),
216
+ columns=6, rows=2, height=320,
217
+ allow_preview=False, object_fit="contain", show_label=True
218
+ )
219
+
220
+ with gr.Accordion("Fine-tune placement", open=True):
221
+ with gr.Row():
222
+ scale = gr.Slider(50, 200, 100, 1, label="Scale (β‰ˆ temple distance %)")
223
+ rot = gr.Slider(-30, 30, 0, 1, label="Extra rotation (Β°)")
224
+ with gr.Row():
225
+ dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right shift (px)")
226
+ dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down shift (px)")
227
+ opacity = gr.Slider(0.2, 1.0, 1.0, 0.05, label="Hair opacity")
228
+
229
+ # --- Callbacks ---
230
+ def do_apply(im, hfile, s, r, dxv, dyv, op, meta):
231
+ return apply_tryon(im, hfile, s, r, dxv, dyv, op, meta)
232
+
233
+ apply_btn.click(
234
+ fn=do_apply,
235
+ inputs=[in_img, hair_sel, scale, rot, dx, dy, opacity, meta_state],
236
+ outputs=[out_img, status]
237
+ )
238
+
239
+ # Return a *file path* so the browser downloads with that name
240
+ download_btn.click(
241
+ fn=lambda im: save_png_to_tmp(im, "output_tryon.png"),
242
+ inputs=[out_img],
243
+ outputs=[download_btn]
244
+ )
245
+
246
+ def do_refresh():
247
+ files = load_hairstyles()
248
+ items = build_gallery_items(files)
249
+ return items, gr.update(choices=files, value=(files[0] if files else None)), files
250
+
251
+ refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, hair_sel, files_state])
252
+
253
+ # Clicking a tile sets the dropdown to that filename
254
+ def on_gallery_select(evt, files):
255
+ idx = getattr(evt, "index", None)
256
+ if idx is None or not files:
257
+ return gr.update()
258
+ if idx >= len(files):
259
+ idx = len(files) - 1
260
+ return gr.update(value=files[idx])
261
+
262
+ gallery.select(on_gallery_select, inputs=[files_state], outputs=[hair_sel])
263
+
264
+ # ---------------- Webcam Tab ----------------
265
+ with gr.Tab("πŸ“Ή Webcam (Live Beta)"):
266
+ cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Enable camera")
267
+ hair2 = gr.Dropdown(choices=HAIR_FILES, value=(HAIR_FILES[0] if HAIR_FILES else None), label="Selected hairstyle")
268
+ with gr.Row():
269
+ scale2 = gr.Slider(50, 200, 100, 1, label="Scale %")
270
+ rot2 = gr.Slider(-25, 25, 0, 1, label="Rotate (Β°)")
271
+ with gr.Row():
272
+ dx2 = gr.Slider(-150, 150, 0, 1, label="Left ↔ Right (px)")
273
+ dy2 = gr.Slider(-150, 150, 0, 1, label="Up ↕ Down (px)")
274
+ opacity2 = gr.Slider(0.2, 1.0, 0.95, 0.05, label="Hair opacity")
275
+ out2 = gr.Image(label="Live result", height=360)
276
+ state_live = gr.State(None)
277
+ snap = gr.Button("πŸ“Έ Snapshot")
278
+ save_live = gr.DownloadButton("⬇️ Download Snapshot")
279
+
280
+ def live(im, h, s, r, dxv, dyv, op, meta):
281
+ res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta)
282
+ return res, res
283
+
284
+ cam.stream(
285
+ fn=live,
286
+ inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2, meta_state],
287
+ outputs=[out2, state_live]
288
+ )
289
+
290
+ snap.click(lambda x: x, inputs=[state_live], outputs=[out2])
291
+
292
+ save_live.click(
293
+ fn=lambda im: save_png_to_tmp(im, "tryon_webcam.png"),
294
+ inputs=[state_live],
295
+ outputs=[save_live]
296
+ )
297
+
298
+ return demo
299
+
300
+ # Export for Spaces autostart
301
+ app = build_ui()
302
+ demo = app
303
+
304
+ # Local dev
305
+ if __name__ == "__main__":
306
+ app.launch()