Spaces:
Sleeping
Sleeping
| import os, json, tempfile, re | |
| import cv2, numpy as np, gradio as gr | |
| from PIL import Image | |
| # -------------------- Paths -------------------- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| CANDIDATES = [ | |
| os.path.join(BASE_DIR, "hair"), # your folder | |
| os.path.join(BASE_DIR, "assets", "hairstyles"), | |
| os.path.join(BASE_DIR, "assets", "Hairstyles"), | |
| os.path.join(BASE_DIR, "hairstyles"), | |
| ] | |
| HAIR_DIR = next((p for p in CANDIDATES if os.path.isdir(p)), None) | |
| if HAIR_DIR is None: | |
| HAIR_DIR = os.path.join(BASE_DIR, "hair") | |
| os.makedirs(HAIR_DIR, exist_ok=True) | |
| META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors | |
| # -------------------- Deps -------------------- | |
| try: | |
| import mediapipe as mp | |
| except Exception as e: | |
| raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}") | |
| mp_face_mesh = mp.solutions.face_mesh | |
| mp_selfie_seg = mp.solutions.selfie_segmentation # optional (off by default) | |
| LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10} | |
| # -------------------- Helpers -------------------- | |
| def natural_key(s: str): | |
| # sorts photo1, photo2, ... photo10 in numeric order | |
| return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)] | |
| def load_hairstyles(): | |
| try: | |
| files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")] | |
| except FileNotFoundError: | |
| files = [] | |
| files.sort(key=natural_key) | |
| return files | |
| def load_meta(): | |
| if os.path.exists(META_PATH): | |
| try: | |
| with open(META_PATH, "r") as f: | |
| m = json.load(f) | |
| return m if isinstance(m, dict) else {} | |
| except Exception: | |
| return {} | |
| return {} | |
| def premultiply_alpha(bgra): | |
| """Reduce gray/white halos on edges for nicer blending.""" | |
| bgr = bgra[:, :, :3].astype(np.float32) / 255.0 | |
| a = (bgra[:, :, 3:4].astype(np.float32) / 255.0) | |
| bgr_pm = (bgr * a * 255.0).astype(np.uint8) | |
| return np.dstack([bgr_pm, bgra[:, :, 3]]) | |
| def load_hair_png(name): | |
| path = os.path.join(HAIR_DIR, name) | |
| hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA | |
| if hair is None or hair.shape[2] != 4: | |
| raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)") | |
| return premultiply_alpha(hair) | |
| def detect_face_keypoints(img_bgr): | |
| h, w = img_bgr.shape[:2] | |
| with mp_face_mesh.FaceMesh( | |
| static_image_mode=True, max_num_faces=1, refine_landmarks=True, | |
| min_detection_confidence=0.6 | |
| ) as fm: | |
| res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) | |
| if not res.multi_face_landmarks: | |
| return None | |
| lm = res.multi_face_landmarks[0].landmark | |
| def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32) | |
| return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])]) | |
| def person_mask(img_bgr, expand_px=20): | |
| """Optional head mask (OFF by default). We expand+blur to avoid 'neck lines'.""" | |
| with mp_selfie_seg.SelfieSegmentation(model_selection=1) as seg: | |
| rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| m = seg.process(rgb).segmentation_mask | |
| mask = (m > 0.5).astype(np.uint8) | |
| if expand_px > 0: | |
| k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*expand_px+1, 2*expand_px+1)) | |
| mask = cv2.dilate(mask, k, iterations=1) | |
| mask = cv2.GaussianBlur(mask.astype(np.float32), (41, 41), 0) | |
| return mask | |
| def hair_reference_points(hair_bgra, filename, meta): | |
| h, w = hair_bgra.shape[:2] | |
| if filename in meta: | |
| pts = np.array(meta[filename], dtype=np.float32) | |
| if pts.shape == (3, 2): | |
| return pts | |
| # Defaults (ok for many styles). For perfect fit, add 3 points per file to meta.json. | |
| pL = np.array([0.30*w, 0.60*h], dtype=np.float32) | |
| pR = np.array([0.70*w, 0.60*h], dtype=np.float32) | |
| pM = np.array([0.50*w, 0.40*h], dtype=np.float32) | |
| return np.stack([pL, pR, pM], axis=0) | |
| def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0): | |
| H, W = base_bgr.shape[:2] | |
| hair_rgb = hair_bgra[:, :, :3] | |
| hair_a = hair_bgra[:, :, 3] / 255.0 | |
| hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT) | |
| a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT) | |
| a = np.clip(a_warp * opacity, 0, 1)[..., None] | |
| out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8) | |
| return out | |
| def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity, meta, | |
| limit_head=False, expand_pct=3.0): | |
| """ | |
| limit_head=False by default to avoid 'missing hair' and neck lines. | |
| If True, we use an expanded soft head mask. | |
| """ | |
| if image is None: | |
| return None, "Upload a photo or enable webcam." | |
| if not hairstyle: | |
| return np.array(image), "Pick a hairstyle first." | |
| img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| kpts = detect_face_keypoints(img_bgr) | |
| if kpts is None: | |
| return image, "No face detected. Try a brighter, front-facing photo." | |
| hair = load_hair_png(hairstyle) | |
| hair_pts = hair_reference_points(hair, hairstyle, meta) | |
| # Destination points (with user nudges) | |
| dst = kpts.copy() | |
| dst[:, 0] += dx | |
| dst[:, 1] += dy | |
| # Scale + rotate around hair anchor centroid | |
| center = hair_pts.mean(axis=0) | |
| theta = np.deg2rad(rot_deg) | |
| s = max(0.5, scale_pct / 100.0) | |
| R = np.array([[np.cos(theta), -np.sin(theta)], | |
| [np.sin(theta), np.cos(theta)]], dtype=np.float32) | |
| hair_pts_adj = (hair_pts - center) @ R.T * s + center | |
| M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS) | |
| if M is None: | |
| return image, "Could not compute alignment for this image/style." | |
| out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity) | |
| if limit_head: | |
| H, W = img_bgr.shape[:2] | |
| expand_px = max(8, int(min(H, W) * (expand_pct / 100.0))) # soft expansion | |
| head = person_mask(img_bgr, expand_px=expand_px) # soft & expanded | |
| head3 = head[..., None] | |
| out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8) | |
| out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) | |
| return out_rgb, "OK" | |
| def save_png_to_tmp(img, filename="output_tryon.png"): | |
| """Create a file in /tmp and return the path (used by the Save button).""" | |
| if img is None: | |
| raise gr.Error("No image to save. Click Apply first.") | |
| out_path = os.path.join(tempfile.gettempdir(), filename) | |
| if isinstance(img, np.ndarray): | |
| Image.fromarray(img).save(out_path) | |
| else: | |
| img.save(out_path) | |
| return out_path | |
| # ---------- WHITE background thumbnails (shows filename number) ---------- | |
| def thumb_on_white(hair_bgra, max_h=220): | |
| h, w = hair_bgra.shape[:2] | |
| scale = min(1.0, max_h / h) | |
| hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR) | |
| h, w = hair_bgra.shape[:2] | |
| bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8) # white background | |
| a = (hair_bgra[:, :, 3:4] / 255.0) | |
| comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8) | |
| return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB) | |
| def build_gallery_items(files): | |
| items = [] | |
| for idx, fname in enumerate(files, start=1): | |
| try: | |
| img = load_hair_png(fname) | |
| items.append((thumb_on_white(img), f"{idx}. {fname}")) # caption shows number & filename | |
| except Exception: | |
| continue | |
| return items | |
| # -------------------- UI -------------------- | |
| def build_ui(): | |
| META = load_meta() | |
| HAIR_FILES = load_hairstyles() | |
| with gr.Blocks(title="Salon Hairstyle Virtual Try-On", css=""" | |
| .gradio-container {max-width: 1200px; margin:auto;} | |
| @media (max-width: 768px){ .gradio-container {padding: 8px;} } | |
| """) as demo: | |
| gr.Markdown("Upload a photo or use webcam. Put transparent **PNGs** in **`hair/`**, then click **Refresh**.") | |
| files_state = gr.State(HAIR_FILES) # filenames (natural order) | |
| meta_state = gr.State(META) | |
| with gr.Tabs(): | |
| # -------- Photo Tab -------- | |
| with gr.Tab("π· Photo (Upload)"): | |
| with gr.Row(): | |
| in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"]) | |
| out_img = gr.Image(label="Preview", height=360) | |
| with gr.Row(): | |
| hair_sel = gr.Dropdown( | |
| choices=HAIR_FILES, | |
| value=(HAIR_FILES[0] if HAIR_FILES else None), | |
| label="Selected hairstyle", | |
| interactive=True | |
| ) | |
| apply_btn = gr.Button("β¨ Apply (Align & Overlay)") | |
| # SAVE (replaces Download) | |
| save_btn = gr.Button("πΎ Save result") | |
| save_file = gr.File(label="Saved file", visible=False) | |
| status = gr.Markdown() | |
| with gr.Row(): | |
| refresh = gr.Button("π Refresh") | |
| count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.") | |
| gallery = gr.Gallery( | |
| label="Hairstyles (click to choose)", | |
| value=build_gallery_items(HAIR_FILES), | |
| columns=6, rows=3, height=520, # up to 18 tiles visible; all 11 will show | |
| allow_preview=False, object_fit="contain", show_label=True | |
| ) | |
| with gr.Accordion("Fine-tune placement", open=True): | |
| with gr.Row(): | |
| scale = gr.Slider(50, 200, 100, 1, label="Scale (β temple distance %)") | |
| rot = gr.Slider(-30, 30, 0, 1, label="Extra rotation (Β°)") | |
| with gr.Row(): | |
| dx = gr.Slider(-200, 200, 0, 1, label="Left β Right shift (px)") | |
| dy = gr.Slider(-200, 200, 0, 1, label="Up β Down shift (px)") | |
| opacity = gr.Slider(0.2, 1.0, 1.0, 0.05, label="Hair opacity") | |
| limit_head = gr.Checkbox(label="Limit overlay to head (avoid spill)", value=False) | |
| expand = gr.Slider(0.0, 10.0, 3.0, 0.5, label="Head-mask expansion (%) β only if enabled") | |
| # --- Callbacks --- | |
| def do_apply(im, hfile, s, r, dxv, dyv, op, meta, lh, ex): | |
| return apply_tryon(im, hfile, s, r, dxv, dyv, op, meta, limit_head=lh, expand_pct=ex) | |
| apply_btn.click( | |
| fn=do_apply, | |
| inputs=[in_img, hair_sel, scale, rot, dx, dy, opacity, meta_state, limit_head, expand], | |
| outputs=[out_img, status] | |
| ) | |
| def do_save(im): | |
| path = save_png_to_tmp(im, "output_tryon.png") | |
| return gr.File.update(value=path, visible=True) | |
| save_btn.click(fn=do_save, inputs=[out_img], outputs=[save_file]) | |
| def do_refresh(): | |
| files = load_hairstyles() | |
| items = build_gallery_items(files) | |
| msg = f"Found {len(files)} hairstyles." | |
| return items, gr.update(choices=files, value=(files[0] if files else None)), files, msg | |
| refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, hair_sel, files_state, count_md]) | |
| # Gallery click -> set dropdown to that filename | |
| def on_gallery_select(evt, files): | |
| idx = getattr(evt, "index", None) | |
| if idx is None or not files: | |
| return gr.update() | |
| # our captions start at 1., map index to filename directly | |
| idx = max(0, min(idx, len(files)-1)) | |
| return gr.update(value=files[idx]) | |
| gallery.select(on_gallery_select, inputs=[files_state], outputs=[hair_sel]) | |
| # -------- Webcam Tab (unchanged except 'Save Snapshot') -------- | |
| with gr.Tab("πΉ Webcam (Live Beta)"): | |
| cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Enable camera") | |
| hair2 = gr.Dropdown(choices=HAIR_FILES, value=(HAIR_FILES[0] if HAIR_FILES else None), label="Selected hairstyle") | |
| with gr.Row(): | |
| scale2 = gr.Slider(50, 200, 100, 1, label="Scale %") | |
| rot2 = gr.Slider(-25, 25, 0, 1, label="Rotate (Β°)") | |
| with gr.Row(): | |
| dx2 = gr.Slider(-150, 150, 0, 1, label="Left β Right (px)") | |
| dy2 = gr.Slider(-150, 150, 0, 1, label="Up β Down (px)") | |
| opacity2 = gr.Slider(0.2, 1.0, 0.95, 0.05, label="Hair opacity") | |
| limit_head2 = gr.Checkbox(label="Limit overlay to head", value=False) | |
| expand2 = gr.Slider(0.0, 10.0, 3.0, 0.5, label="Head-mask expansion (%)", visible=True) | |
| out2 = gr.Image(label="Live result", height=360) | |
| state_live = gr.State(None) | |
| snap = gr.Button("πΈ Snapshot") | |
| save_live_btn = gr.Button("πΎ Save snapshot") | |
| save_live_file = gr.File(label="snapshot", visible=False) | |
| def live(im, h, s, r, dxv, dyv, op, meta, lh, ex): | |
| res, _ = apply_tryon(im, h, s, r, dxv, dyv, op, meta, limit_head=lh, expand_pct=ex) | |
| return res, res | |
| cam.stream( | |
| fn=live, | |
| inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2, meta_state, limit_head2, expand2], | |
| outputs=[out2, state_live] | |
| ) | |
| snap.click(lambda x: x, inputs=[state_live], outputs=[out2]) | |
| def save_snap(im): | |
| path = save_png_to_tmp(im, "tryon_webcam.png") | |
| return gr.File.update(value=path, visible=True) | |
| save_live_btn.click(fn=save_snap, inputs=[state_live], outputs=[save_live_file]) | |
| return demo | |
| # Export for Spaces autostart | |
| app = build_ui() | |
| demo = app | |
| if __name__ == "__main__": | |
| app.launch() | |