Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import os | |
| import math | |
| import uuid | |
| import json | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| from PIL import Image, ImageOps | |
| import pandas as pd | |
| import numpy as np | |
| import mediapipe as mp | |
| # ---------------------------- | |
| # Globals & configuration | |
| # ---------------------------- | |
| DATASET_PATH = os.getenv("HAIRSTYLE_DATASET", "data/enhanced_full_hairstyle_dataset.csv") | |
| HAIRSTYLE_FOLDER = os.getenv("HAIRSTYLE_FOLDER", "hairstyles") | |
| RESULTS_DIR = os.getenv("RESULTS_DIR", "generated_results") | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| # Tune these if your images tend to sit too high/low by default | |
| DEFAULT_VERT_OFFSET_PCT = -0.25 # relative to style_forehead_height | |
| DEFAULT_HORIZ_OFFSET_PX = 0 | |
| # MediaPipe indices used | |
| LM_LEFT_EYE_OUTER = 33 | |
| LM_RIGHT_EYE_OUTER = 263 | |
| LM_FOREHEAD_TOP = 10 | |
| LM_FOREHEAD_LEFT = 103 | |
| LM_FOREHEAD_RIGHT = 332 | |
| # Initialize MediaPipe FaceMesh once (safer with concurrency=1 in Gradio queue) | |
| mp_face_mesh = mp.solutions.face_mesh | |
| FACE_MESH = mp_face_mesh.FaceMesh( | |
| static_image_mode=True, | |
| max_num_faces=1, | |
| refine_landmarks=True, | |
| min_detection_confidence=0.5 | |
| ) | |
| class Style: | |
| name: str | |
| gender: str | |
| img_path: str | |
| img_rgba: Optional[Image.Image] | |
| style_forehead_w: int | |
| style_forehead_h: int | |
| def _safe_read_dataset(path: str) -> pd.DataFrame: | |
| if not os.path.exists(path): | |
| # Create an empty frame with expected columns to avoid crashes | |
| cols = ["name", "gender", "forehead_width_px", "forehead_height_px", "image_file"] | |
| return pd.DataFrame(columns=cols) | |
| df = pd.read_csv(path) | |
| # Normalize columns and fill NaNs | |
| for col in ["name", "gender", "image_file"]: | |
| if col not in df.columns: | |
| df[col] = "" | |
| df[col] = df[col].fillna("") | |
| for col in ["forehead_width_px", "forehead_height_px"]: | |
| if col not in df.columns: | |
| df[col] = 0 | |
| df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int) | |
| return df | |
| def _load_styles(df: pd.DataFrame) -> List[Style]: | |
| styles: List[Style] = [] | |
| if not os.path.exists(HAIRSTYLE_FOLDER): | |
| return styles | |
| for _, row in df.iterrows(): | |
| img_file = row.get("image_file", "").strip() | |
| if not img_file: | |
| continue | |
| path = os.path.join(HAIRSTYLE_FOLDER, img_file) | |
| if not os.path.exists(path): | |
| continue | |
| try: | |
| img = Image.open(path).convert("RGBA") | |
| except Exception: | |
| img = None | |
| styles.append( | |
| Style( | |
| name=str(row.get("name", "Style")).strip() or "Style", | |
| gender=str(row.get("gender", "All")).strip() or "All", | |
| img_path=path, | |
| img_rgba=img, | |
| style_forehead_w=int(row.get("forehead_width_px", 0) or 0), | |
| style_forehead_h=int(row.get("forehead_height_px", 0) or 0), | |
| ) | |
| ) | |
| return styles | |
| def _to_rgb(image: Image.Image) -> Image.Image: | |
| return image.convert("RGB") if image.mode != "RGB" else image | |
| def get_face_landmarks(img_rgb: Image.Image): | |
| """Return MediaPipe face landmarks for a PIL RGB image or None.""" | |
| np_img = np.array(img_rgb) | |
| results = FACE_MESH.process(np_img) | |
| if results.multi_face_landmarks: | |
| return results.multi_face_landmarks[0] | |
| return None | |
| def _rotation_angle_rad(landmarks, w: int, h: int) -> float: | |
| """Estimate roll angle using outer eye corners.""" | |
| left = landmarks.landmark[LM_LEFT_EYE_OUTER] | |
| right = landmarks.landmark[LM_RIGHT_EYE_OUTER] | |
| x1, y1 = left.x * w, left.y * h | |
| x2, y2 = right.x * w, right.y * h | |
| # angle of the line from left to right; positive means head tilted CCW | |
| angle = math.atan2(y2 - y1, x2 - x1) | |
| return angle | |
| def _compute_forehead_metrics(landmarks, w: int, h: int) -> Tuple[int, Tuple[int, int]]: | |
| left = landmarks.landmark[LM_FOREHEAD_LEFT] | |
| right = landmarks.landmark[LM_FOREHEAD_RIGHT] | |
| top = landmarks.landmark[LM_FOREHEAD_TOP] | |
| forehead_width_px = int(abs((right.x - left.x) * w)) | |
| top_x = int(top.x * w) | |
| top_y = int(top.y * h) | |
| return forehead_width_px, (top_x, top_y) | |
| def _paste_rgba(base: Image.Image, overlay: Image.Image, pos: Tuple[int, int]) -> Image.Image: | |
| canvas = base.copy().convert("RGBA") | |
| tmp = Image.new("RGBA", canvas.size, (0, 0, 0, 0)) | |
| x, y = pos | |
| tmp.paste(overlay, (x, y), overlay) | |
| return Image.alpha_composite(canvas, tmp) | |
| def apply_hairstyle_impl( | |
| upload_img: Optional[Image.Image], | |
| webcam_img: Optional[Image.Image], | |
| input_source: str, | |
| style_index: Optional[int], | |
| scale_tweak: float, | |
| vert_offset: int, | |
| horiz_offset: int, | |
| opacity: float, | |
| ) -> Tuple[Optional[Image.Image], str]: | |
| user_img = upload_img if input_source == "Upload" else webcam_img | |
| if user_img is None: | |
| return None, "❌ No image from selected source." | |
| if style_index is None or style_index < 0 or style_index >= len(STYLES): | |
| return _to_rgb(user_img), "ℹ️ Select a hairstyle from the gallery." | |
| style = STYLES[style_index] | |
| if style.img_rgba is None: | |
| return _to_rgb(user_img), f"⚠️ Could not load image for: {style.name}" | |
| try: | |
| img_rgb = _to_rgb(user_img) | |
| w, h = img_rgb.size | |
| lms = get_face_landmarks(img_rgb) | |
| if not lms: | |
| return img_rgb, "⚠️ No face detected. Showing original image. Try a clearer, front‑facing photo." | |
| # Compute rotation and size | |
| angle_rad = _rotation_angle_rad(lms, w, h) | |
| forehead_w_px, (top_x, top_y) = _compute_forehead_metrics(lms, w, h) | |
| style_fw = max(style.style_forehead_w, 1) | |
| style_fh = max(style.style_forehead_h, 1) | |
| scale_ratio = (forehead_w_px / style_fw) * float(scale_tweak) | |
| new_w = max(int(style.img_rgba.width * scale_ratio), 1) | |
| new_h = max(int(style.img_rgba.height * scale_ratio), 1) | |
| # Rotate hair to match head roll | |
| hair = style.img_rgba.resize((new_w, new_h), resample=Image.LANCZOS) | |
| angle_deg = math.degrees(angle_rad) | |
| hair = hair.rotate(angle=-angle_deg, expand=True, resample=Image.BICUBIC) | |
| # Compute placement | |
| attach_y = top_y - int(style_fh * scale_ratio) | |
| attach_y += int(DEFAULT_VERT_OFFSET_PCT * style_fh * scale_ratio) | |
| attach_y += int(vert_offset) | |
| attach_x = top_x - hair.width // 2 + int(horiz_offset) + int(DEFAULT_HORIZ_OFFSET_PX) | |
| # Clamp within canvas (x can be <0 to allow partial paste, but we clamp y >= 0) | |
| attach_y = max(0, attach_y) | |
| # Optional opacity tweak | |
| if 0 <= opacity < 1: | |
| a = hair.split()[-1] | |
| a = ImageOps.autocontrast(a) | |
| a = a.point(lambda px: int(px * opacity)) | |
| hair = Image.merge("RGBA", (*hair.split()[:3], a)) | |
| composed = _paste_rgba(img_rgb, hair, (attach_x, attach_y)).convert("RGB") | |
| return composed, "✅ Success! Tip: fine‑tune scale/offsets if needed." | |
| except Exception as e: | |
| return _to_rgb(user_img), f"❌ Error: {str(e)}" | |
| # ---------------------------- | |
| # Load data once | |
| # ---------------------------- | |
| DATASET_DF = _safe_read_dataset(DATASET_PATH) | |
| STYLES: List[Style] = _load_styles(DATASET_DF) | |
| # Precompute gallery data (image + caption) | |
| GALLERY_ITEMS: List[Tuple[Image.Image, str]] = [] | |
| for s in STYLES: | |
| if s.img_rgba is not None: | |
| thumb = s.img_rgba.copy() | |
| GALLERY_ITEMS.append((thumb, s.name)) | |
| # ---------------------------- | |
| # Gradio helpers | |
| # ---------------------------- | |
| def update_gallery(gender: str): | |
| if gender == "All": | |
| indices = list(range(len(STYLES))) | |
| else: | |
| indices = [i for i, s in enumerate(STYLES) if s.gender.lower() == gender.lower()] | |
| filtered = [] | |
| for i in indices: | |
| s = STYLES[i] | |
| if s.img_rgba is not None: | |
| filtered.append((s.img_rgba, s.name)) | |
| return filtered, indices | |
| def select_hairstyle(evt: gr.SelectData, filtered_inds: List[int]): | |
| if filtered_inds and 0 <= evt.index < len(filtered_inds): | |
| return int(filtered_inds[evt.index]) | |
| return None | |
| def update_source(source: str): | |
| return gr.update(visible=source == "Upload"), gr.update(visible=source == "Webcam") | |
| def on_apply(upload_img, webcam_img, input_source, selected_index, scale_tweak, vert_offset, horiz_offset, opacity): | |
| img, msg = apply_hairstyle_impl( | |
| upload_img, webcam_img, input_source, selected_index, scale_tweak, vert_offset, horiz_offset, opacity | |
| ) | |
| return img, msg | |
| def on_random(filtered_indices: List[int]): | |
| if not filtered_indices: | |
| return None, "ℹ️ No styles available for current filter." | |
| import random | |
| return int(random.choice(filtered_indices)), "🎲 Random style selected!" | |
| def on_save(result_img: Optional[Image.Image]): | |
| if result_img is None: | |
| return None, "⚠️ Generate a preview first." | |
| file_path = os.path.join(RESULTS_DIR, f"hairstyle_{uuid.uuid4().hex}.png") | |
| result_img.save(file_path, format="PNG") | |
| return file_path, "💾 Saved! Use the button below to download." | |
| # ---------------------------- | |
| # UI | |
| # ---------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".small-hint{font-size:12px;opacity:.8}") as demo: | |
| gr.Markdown("## 💇 Virtual Hairstyle Try‑On") | |
| gr.Markdown( | |
| "Upload a front‑facing photo or use your webcam. Click a hairstyle to select it, then fine‑tune using the controls." | |
| ) | |
| status = gr.Textbox(label="Status", interactive=False) | |
| filtered_indices = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_source = gr.Radio(["Upload", "Webcam"], value="Upload", label="Input Source") | |
| upload_col = gr.Column(visible=True) | |
| with upload_col: | |
| upload_img = gr.Image(sources=["upload"], type="pil", label="📷 Upload Your Photo (front‑facing)") | |
| webcam_col = gr.Column(visible=False) | |
| with webcam_col: | |
| webcam_img = gr.Image(sources=["webcam"], type="pil", label="📹 Live Webcam", streaming=True) | |
| gender_filter = gr.Dropdown(choices=["All", "Male", "Female"], value="All", label="🎭 Filter by Gender") | |
| hairstyle_gallery = gr.Gallery( | |
| label="🎨 Available Hairstyles (click to select)", columns=4, height=380, object_fit="contain" | |
| ) | |
| selected_index = gr.Number(value=None, visible=False) | |
| selected_label = gr.Markdown("*No style selected*", elem_classes=["small-hint"]) | |
| random_btn = gr.Button("🎲 Random Style") | |
| with gr.Column(scale=2): | |
| result_output = gr.Image(label="🔍 Preview Result", height=520) | |
| with gr.Row(): | |
| scale_tweak = gr.Slider(0.7, 1.4, value=1.0, step=0.01, label="Scale tweak") | |
| opacity = gr.Slider(0.6, 1.0, value=1.0, step=0.01, label="Opacity") | |
| with gr.Row(): | |
| vert_offset = gr.Slider(-150, 150, value=0, step=1, label="Vertical offset (px)") | |
| horiz_offset = gr.Slider(-150, 150, value=0, step=1, label="Horizontal offset (px)") | |
| with gr.Row(): | |
| apply_btn = gr.Button("✨ Apply Hairstyle", variant="primary") | |
| save_btn = gr.Button("💾 Save Preview") | |
| dl = gr.DownloadButton("⬇️ Download PNG", file_name="hairstyle_result.png") | |
| # Visibility switching | |
| input_source.change(update_source, inputs=input_source, outputs=[upload_col, webcam_col]) | |
| # Gallery filtering / selection | |
| def _update_label(i): | |
| if i is None or not isinstance(i, (int, float)): | |
| return "*No style selected*" | |
| idx = int(i) | |
| if 0 <= idx < len(STYLES): | |
| return f"**Selected:** {STYLES[idx].name}" | |
| return "*No style selected*" | |
| gender_filter.change(update_gallery, inputs=gender_filter, outputs=[hairstyle_gallery, filtered_indices]) | |
| hairstyle_gallery.select(select_hairstyle, inputs=filtered_indices, outputs=selected_index) | |
| selected_index.change(_update_label, inputs=selected_index, outputs=selected_label) | |
| random_btn.click(on_random, inputs=filtered_indices, outputs=[selected_index, status]) | |
| # Apply + live preview | |
| apply_inputs = [upload_img, webcam_img, input_source, selected_index, scale_tweak, vert_offset, horiz_offset, opacity] | |
| apply_btn.click(on_apply, inputs=apply_inputs, outputs=[result_output, status]) | |
| # Live webcam auto-apply (gives a smooth preview). Keep concurrency=1 for FaceMesh safety. | |
| webcam_img.change(on_apply, inputs=apply_inputs, outputs=[result_output, status], every=0.6) | |
| # Save & download | |
| def _save_and_link(img): | |
| path, msg = on_save(img) | |
| # Update download component with the new file | |
| return msg, gr.update(value=path) | |
| save_btn.click(_save_and_link, inputs=[result_output], outputs=[status, dl]) | |
| # Initial gallery | |
| demo.load(update_gallery, inputs=gender_filter, outputs=[hairstyle_gallery, filtered_indices]) | |
| # Limit concurrency to avoid MediaPipe thread issues, enable queue for responsiveness | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=1) | |
| demo.launch() | |