RojaKatta commited on
Commit
30cc0b3
·
verified ·
1 Parent(s): 69d3136

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -50,11 +50,20 @@ def load_meta():
50
  return {}
51
  return {}
52
 
 
 
 
 
 
 
 
53
  def load_hair_png(name):
54
  path = os.path.join(HAIR_DIR, name)
55
  hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
56
  if hair is None or hair.shape[2] != 4:
57
  raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
 
 
58
  return hair
59
 
60
  def detect_face_keypoints(img_bgr):
@@ -84,7 +93,7 @@ def hair_reference_points(hair_bgra, filename, meta):
84
  pts = np.array(meta[filename], dtype=np.float32)
85
  if pts.shape == (3, 2):
86
  return pts
87
- # Defaults (ok for many styles). For perfect fit, put 3 points per file in meta.json.
88
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
89
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
90
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
@@ -143,7 +152,7 @@ def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity, meta):
143
  return out_rgb, "OK"
144
 
145
  def save_png_to_tmp(img, filename="output_tryon.png"):
146
- """DownloadButton needs a file path string. We save to /tmp and return that path."""
147
  if img is None:
148
  raise gr.Error("No image to download. Click Apply first.")
149
  out_path = os.path.join(tempfile.gettempdir(), filename)
@@ -205,7 +214,9 @@ def build_ui():
205
  interactive=True
206
  )
207
  apply_btn = gr.Button("✨ Apply (Align & Overlay)")
208
- download_btn = gr.DownloadButton("⬇️ Download") # NOTE: no file_name arg
 
 
209
  status = gr.Markdown()
210
 
211
  with gr.Row():
@@ -236,12 +247,11 @@ def build_ui():
236
  outputs=[out_img, status]
237
  )
238
 
239
- # Return a *file path* so the browser downloads with that name
240
- download_btn.click(
241
- fn=lambda im: save_png_to_tmp(im, "output_tryon.png"),
242
- inputs=[out_img],
243
- outputs=[download_btn]
244
- )
245
 
246
  def do_refresh():
247
  files = load_hairstyles()
@@ -275,7 +285,8 @@ def build_ui():
275
  out2 = gr.Image(label="Live result", height=360)
276
  state_live = gr.State(None)
277
  snap = gr.Button("📸 Snapshot")
278
- save_live = gr.DownloadButton("⬇️ Download Snapshot")
 
279
 
280
  def live(im, h, s, r, dxv, dyv, op, meta):
281
  res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta)
@@ -289,11 +300,11 @@ def build_ui():
289
 
290
  snap.click(lambda x: x, inputs=[state_live], outputs=[out2])
291
 
292
- save_live.click(
293
- fn=lambda im: save_png_to_tmp(im, "tryon_webcam.png"),
294
- inputs=[state_live],
295
- outputs=[save_live]
296
- )
297
 
298
  return demo
299
 
 
50
  return {}
51
  return {}
52
 
53
+ def premultiply_alpha(bgra):
54
+ """Reduce gray/white halos on edges for nicer blending."""
55
+ bgr = bgra[:, :, :3].astype(np.float32) / 255.0
56
+ a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
57
+ bgr_pm = (bgr * a * 255.0).astype(np.uint8)
58
+ return np.dstack([bgr_pm, bgra[:, :, 3]])
59
+
60
  def load_hair_png(name):
61
  path = os.path.join(HAIR_DIR, name)
62
  hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
63
  if hair is None or hair.shape[2] != 4:
64
  raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
65
+ # Improve visual quality
66
+ hair = premultiply_alpha(hair)
67
  return hair
68
 
69
  def detect_face_keypoints(img_bgr):
 
93
  pts = np.array(meta[filename], dtype=np.float32)
94
  if pts.shape == (3, 2):
95
  return pts
96
+ # Defaults (ok for many styles). For perfect fit, add 3 points per file to meta.json.
97
  pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
98
  pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
99
  pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
 
152
  return out_rgb, "OK"
153
 
154
  def save_png_to_tmp(img, filename="output_tryon.png"):
155
+ """Create a file in /tmp and return the path (works reliably on Spaces)."""
156
  if img is None:
157
  raise gr.Error("No image to download. Click Apply first.")
158
  out_path = os.path.join(tempfile.gettempdir(), filename)
 
214
  interactive=True
215
  )
216
  apply_btn = gr.Button("✨ Apply (Align & Overlay)")
217
+ # Reliable download: a button that fills a File link.
218
+ download_btn = gr.Button("⬇️ Download")
219
+ download_file = gr.File(label="download", visible=False)
220
  status = gr.Markdown()
221
 
222
  with gr.Row():
 
247
  outputs=[out_img, status]
248
  )
249
 
250
+ def prepare_download(im):
251
+ path = save_png_to_tmp(im, "output_tryon.png")
252
+ return gr.File.update(value=path, visible=True)
253
+
254
+ download_btn.click(fn=prepare_download, inputs=[out_img], outputs=[download_file])
 
255
 
256
  def do_refresh():
257
  files = load_hairstyles()
 
285
  out2 = gr.Image(label="Live result", height=360)
286
  state_live = gr.State(None)
287
  snap = gr.Button("📸 Snapshot")
288
+ save_live_btn = gr.Button("⬇️ Download Snapshot")
289
+ save_live_file = gr.File(label="snapshot", visible=False)
290
 
291
  def live(im, h, s, r, dxv, dyv, op, meta):
292
  res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta)
 
300
 
301
  snap.click(lambda x: x, inputs=[state_live], outputs=[out2])
302
 
303
+ def prepare_webcam_download(im):
304
+ path = save_png_to_tmp(im, "tryon_webcam.png")
305
+ return gr.File.update(value=path, visible=True)
306
+
307
+ save_live_btn.click(fn=prepare_webcam_download, inputs=[state_live], outputs=[save_live_file])
308
 
309
  return demo
310