import os, json, tempfile, re import cv2, numpy as np, gradio as gr from PIL import Image # =============== Paths =============== BASE_DIR = os.path.dirname(os.path.abspath(__file__)) CANDIDATES = [ os.path.join(BASE_DIR, "hair"), os.path.join(BASE_DIR, "assets", "hairstyles"), os.path.join(BASE_DIR, "assets", "Hairstyles"), os.path.join(BASE_DIR, "hairstyles"), ] HAIR_DIR = next((p for p in CANDIDATES if os.path.isdir(p)), None) if HAIR_DIR is None: HAIR_DIR = os.path.join(BASE_DIR, "hair") os.makedirs(HAIR_DIR, exist_ok=True) META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors # =============== Dependencies =============== try: import mediapipe as mp except Exception as e: raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}") mp_face_mesh = mp.solutions.face_mesh LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10} # =============== Helpers =============== def natural_key(s: str): return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)] def load_hairstyles(): try: files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")] except FileNotFoundError: files = [] files.sort(key=natural_key) return files def load_meta(): if os.path.exists(META_PATH): try: with open(META_PATH, "r") as f: m = json.load(f) return m if isinstance(m, dict) else {} except Exception: return {} return {} def premultiply_alpha(bgra): """Eliminate gray/white halos on edges.""" bgr = bgra[:, :, :3].astype(np.float32) / 255.0 a = (bgra[:, :, 3:4].astype(np.float32) / 255.0) bgr_pm = (bgr * a * 255.0).astype(np.uint8) return np.dstack([bgr_pm, bgra[:, :, 3]]) def load_hair_png(name): path = os.path.join(HAIR_DIR, name) hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA if hair is None or hair.shape[2] != 4: raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)") return premultiply_alpha(hair) def detect_face_keypoints(img_bgr): h, w = img_bgr.shape[:2] with mp_face_mesh.FaceMesh( static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.6 ) as fm: res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) if not res.multi_face_landmarks: return None lm = res.multi_face_landmarks[0].landmark def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32) return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])]) def hair_reference_points(hair_bgra, filename, meta): h, w = hair_bgra.shape[:2] if filename in meta: pts = np.array(meta[filename], dtype=np.float32) if pts.shape == (3, 2): return pts # Defaults (OK for many styles). For pixel-perfect fit, add 3 points to meta.json. pL = np.array([0.30*w, 0.60*h], dtype=np.float32) pR = np.array([0.70*w, 0.60*h], dtype=np.float32) pM = np.array([0.50*w, 0.40*h], dtype=np.float32) return np.stack([pL, pR, pM], axis=0) def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0): H, W = base_bgr.shape[:2] hair_rgb = hair_bgra[:, :, :3] hair_a = hair_bgra[:, :, 3] / 255.0 # borderMode CONSTANT avoids odd edge artifacts; value black (transparent) hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0)) a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0) a = np.clip(a_warp * opacity, 0, 1)[..., None] out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8) return out def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta): """No head-mask (prevents neck lines & cropping).""" if image is None: return None, "Upload a photo first." if not hairstyle: return np.array(image), "Pick a hairstyle." img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) kpts = detect_face_keypoints(img_bgr) if kpts is None: return image, "No face detected. Use a brighter, front-facing photo." hair = load_hair_png(hairstyle) hair_pts = hair_reference_points(hair, hairstyle, meta) # Target points = facial anchors + user nudges dst = kpts.copy() dst[:, 0] += dx dst[:, 1] += dy # Scale hair anchors around their centroid (no rotation for simplicity) center = hair_pts.mean(axis=0) s = max(0.5, scale_pct / 100.0) hair_pts_adj = (hair_pts - center) * s + center M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS) if M is None: return image, "Alignment failed for this image/style." out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity) out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) return out_rgb, "OK" def save_png_to_tmp(img, filename="output_tryon.png"): if img is None: raise gr.Error("No image to save. Click a hairstyle or 'Apply' first.") out_path = os.path.join(tempfile.gettempdir(), filename) if isinstance(img, np.ndarray): Image.fromarray(img).save(out_path) else: img.save(out_path) return out_path # ---- white thumbnails with labels ---- def thumb_on_white(hair_bgra, max_h=220): h, w = hair_bgra.shape[:2] scale = min(1.0, max_h / h) hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR) h, w = hair_bgra.shape[:2] bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8) a = (hair_bgra[:, :, 3:4] / 255.0) comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8) return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB) def build_gallery_items(files): items = [] for idx, fname in enumerate(files, start=1): try: img = load_hair_png(fname) items.append((thumb_on_white(img), f"{idx}. {fname}")) # show number + filename except Exception: continue return items # =============== UI =============== def build_ui(): META = load_meta() HAIR_FILES = load_hairstyles() with gr.Blocks(title="Salon Hairstyle Virtual Try-On (Simple)") as demo: gr.Markdown("Upload a photo, then **click a hairstyle** below. Use a few sliders if needed, then **Save result**.") selected_file = gr.State(None) # currently selected hairstyle filename meta_state = gr.State(META) files_state = gr.State(HAIR_FILES) with gr.Tabs(): with gr.Tab("📷 Photo (Upload)"): with gr.Row(): in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"]) out_img = gr.Image(label="Preview", height=360) with gr.Row(): apply_btn = gr.Button("✨ Apply (optional)") save_btn = gr.Button("💾 Save result") save_file = gr.File(label="Saved file", visible=False) with gr.Row(): refresh = gr.Button("🔄 Refresh styles") count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.") gallery = gr.Gallery( label="Hairstyles (click to apply)", value=build_gallery_items(HAIR_FILES), columns=6, rows=3, height=520, allow_preview=False, object_fit="contain", show_label=True ) with gr.Accordion("Fine-tune (simple)", open=True): with gr.Row(): scale = gr.Slider(50, 200, 100, 1, label="Scale (temple distance %)") # main size opacity = gr.Slider(0.4, 1.0, 1.0, 0.05, label="Hair opacity") with gr.Row(): dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right (px)") dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down (px)") status = gr.Markdown("") # ----- actions ----- def do_apply(im, hairfile, s, dxv, dyv, op, meta): return apply_tryon(im, hairfile, s, dxv, dyv, op, meta) # 1) click a tile -> set selected file AND auto-apply def on_gallery_select(evt, files, im, s, dxv, dyv, op, meta): idx = getattr(evt, "index", None) if idx is None or not files: return None, gr.update(), None idx = max(0, min(idx, len(files)-1)) hairfile = files[idx] out, msg = do_apply(im, hairfile, s, dxv, dyv, op, meta) return hairfile, out, msg gallery.select( on_gallery_select, inputs=[files_state, in_img, scale, dx, dy, opacity, meta_state], outputs=[selected_file, out_img, status] ) # 2) Apply button (useful after slider tweaks) apply_btn.click( fn=do_apply, inputs=[in_img, selected_file, scale, dx, dy, opacity, meta_state], outputs=[out_img, status] ) # 3) Save def do_save(im): path = save_png_to_tmp(im, "output_tryon.png") return gr.File.update(value=path, visible=True) save_btn.click(fn=do_save, inputs=[out_img], outputs=[save_file]) # 4) Refresh styles def do_refresh(): files = load_hairstyles() items = build_gallery_items(files) msg = f"Found {len(files)} hairstyles." # Keep selection if name still exists return items, files, msg refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, files_state, count_md]) return demo # export for Spaces app = build_ui() demo = app if __name__ == "__main__": app.launch()