VirtualTryonFR1 / app.py
RojaKatta's picture
Update app.py
9c49035 verified
raw
history blame
10.3 kB
import os, json, tempfile, re
import cv2, numpy as np, gradio as gr
from PIL import Image
# =============== Paths ===============
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CANDIDATES = [
os.path.join(BASE_DIR, "hair"),
os.path.join(BASE_DIR, "assets", "hairstyles"),
os.path.join(BASE_DIR, "assets", "Hairstyles"),
os.path.join(BASE_DIR, "hairstyles"),
]
HAIR_DIR = next((p for p in CANDIDATES if os.path.isdir(p)), None)
if HAIR_DIR is None:
HAIR_DIR = os.path.join(BASE_DIR, "hair")
os.makedirs(HAIR_DIR, exist_ok=True)
META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
# =============== Dependencies ===============
try:
import mediapipe as mp
except Exception as e:
raise RuntimeError(f"Mediapipe import failed. Check requirements pins. Details: {e}")
mp_face_mesh = mp.solutions.face_mesh
LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
# =============== Helpers ===============
def natural_key(s: str):
return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]
def load_hairstyles():
try:
files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
except FileNotFoundError:
files = []
files.sort(key=natural_key)
return files
def load_meta():
if os.path.exists(META_PATH):
try:
with open(META_PATH, "r") as f:
m = json.load(f)
return m if isinstance(m, dict) else {}
except Exception:
return {}
return {}
def premultiply_alpha(bgra):
"""Eliminate gray/white halos on edges."""
bgr = bgra[:, :, :3].astype(np.float32) / 255.0
a = (bgra[:, :, 3:4].astype(np.float32) / 255.0)
bgr_pm = (bgr * a * 255.0).astype(np.uint8)
return np.dstack([bgr_pm, bgra[:, :, 3]])
def load_hair_png(name):
path = os.path.join(HAIR_DIR, name)
hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
if hair is None or hair.shape[2] != 4:
raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
return premultiply_alpha(hair)
def detect_face_keypoints(img_bgr):
h, w = img_bgr.shape[:2]
with mp_face_mesh.FaceMesh(
static_image_mode=True, max_num_faces=1, refine_landmarks=True,
min_detection_confidence=0.6
) as fm:
res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
if not res.multi_face_landmarks:
return None
lm = res.multi_face_landmarks[0].landmark
def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
def hair_reference_points(hair_bgra, filename, meta):
h, w = hair_bgra.shape[:2]
if filename in meta:
pts = np.array(meta[filename], dtype=np.float32)
if pts.shape == (3, 2):
return pts
# Defaults (OK for many styles). For pixel-perfect fit, add 3 points to meta.json.
pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
return np.stack([pL, pR, pM], axis=0)
def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
H, W = base_bgr.shape[:2]
hair_rgb = hair_bgra[:, :, :3]
hair_a = hair_bgra[:, :, 3] / 255.0
# borderMode CONSTANT avoids odd edge artifacts; value black (transparent)
hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderValue=0)
a = np.clip(a_warp * opacity, 0, 1)[..., None]
out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
return out
def apply_tryon(image, hairstyle, scale_pct, dx, dy, opacity, meta):
"""No head-mask (prevents neck lines & cropping)."""
if image is None:
return None, "Upload a photo first."
if not hairstyle:
return np.array(image), "Pick a hairstyle."
img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
kpts = detect_face_keypoints(img_bgr)
if kpts is None:
return image, "No face detected. Use a brighter, front-facing photo."
hair = load_hair_png(hairstyle)
hair_pts = hair_reference_points(hair, hairstyle, meta)
# Target points = facial anchors + user nudges
dst = kpts.copy()
dst[:, 0] += dx
dst[:, 1] += dy
# Scale hair anchors around their centroid (no rotation for simplicity)
center = hair_pts.mean(axis=0)
s = max(0.5, scale_pct / 100.0)
hair_pts_adj = (hair_pts - center) * s + center
M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
if M is None:
return image, "Alignment failed for this image/style."
out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
return out_rgb, "OK"
def save_png_to_tmp(img, filename="output_tryon.png"):
if img is None:
raise gr.Error("No image to save. Click a hairstyle or 'Apply' first.")
out_path = os.path.join(tempfile.gettempdir(), filename)
if isinstance(img, np.ndarray):
Image.fromarray(img).save(out_path)
else:
img.save(out_path)
return out_path
# ---- white thumbnails with labels ----
def thumb_on_white(hair_bgra, max_h=220):
h, w = hair_bgra.shape[:2]
scale = min(1.0, max_h / h)
hair_bgra = cv2.resize(hair_bgra, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
h, w = hair_bgra.shape[:2]
bg_rgb = np.full((h, w, 3), 255, dtype=np.uint8)
a = (hair_bgra[:, :, 3:4] / 255.0)
comp = (a * hair_bgra[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
return cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
def build_gallery_items(files):
items = []
for idx, fname in enumerate(files, start=1):
try:
img = load_hair_png(fname)
items.append((thumb_on_white(img), f"{idx}. {fname}")) # show number + filename
except Exception:
continue
return items
# =============== UI ===============
def build_ui():
META = load_meta()
HAIR_FILES = load_hairstyles()
with gr.Blocks(title="Salon Hairstyle Virtual Try-On (Simple)") as demo:
gr.Markdown("Upload a photo, then **click a hairstyle** below. Use a few sliders if needed, then **Save result**.")
selected_file = gr.State(None) # currently selected hairstyle filename
meta_state = gr.State(META)
files_state = gr.State(HAIR_FILES)
with gr.Tabs():
with gr.Tab("πŸ“· Photo (Upload)"):
with gr.Row():
in_img = gr.Image(label="Input photo (JPEG/PNG)", type="pil", height=360, sources=["upload"])
out_img = gr.Image(label="Preview", height=360)
with gr.Row():
apply_btn = gr.Button("✨ Apply (optional)")
save_btn = gr.Button("πŸ’Ύ Save result")
save_file = gr.File(label="Saved file", visible=False)
with gr.Row():
refresh = gr.Button("πŸ”„ Refresh styles")
count_md = gr.Markdown(f"Found {len(HAIR_FILES)} hairstyles.")
gallery = gr.Gallery(
label="Hairstyles (click to apply)",
value=build_gallery_items(HAIR_FILES),
columns=6, rows=3, height=520,
allow_preview=False, object_fit="contain", show_label=True
)
with gr.Accordion("Fine-tune (simple)", open=True):
with gr.Row():
scale = gr.Slider(50, 200, 100, 1, label="Scale (temple distance %)") # main size
opacity = gr.Slider(0.4, 1.0, 1.0, 0.05, label="Hair opacity")
with gr.Row():
dx = gr.Slider(-200, 200, 0, 1, label="Left ↔ Right (px)")
dy = gr.Slider(-200, 200, 0, 1, label="Up ↕ Down (px)")
status = gr.Markdown("")
# ----- actions -----
def do_apply(im, hairfile, s, dxv, dyv, op, meta):
return apply_tryon(im, hairfile, s, dxv, dyv, op, meta)
# 1) click a tile -> set selected file AND auto-apply
def on_gallery_select(evt, files, im, s, dxv, dyv, op, meta):
idx = getattr(evt, "index", None)
if idx is None or not files:
return None, gr.update(), None
idx = max(0, min(idx, len(files)-1))
hairfile = files[idx]
out, msg = do_apply(im, hairfile, s, dxv, dyv, op, meta)
return hairfile, out, msg
gallery.select(
on_gallery_select,
inputs=[files_state, in_img, scale, dx, dy, opacity, meta_state],
outputs=[selected_file, out_img, status]
)
# 2) Apply button (useful after slider tweaks)
apply_btn.click(
fn=do_apply,
inputs=[in_img, selected_file, scale, dx, dy, opacity, meta_state],
outputs=[out_img, status]
)
# 3) Save
def do_save(im):
path = save_png_to_tmp(im, "output_tryon.png")
return gr.File.update(value=path, visible=True)
save_btn.click(fn=do_save, inputs=[out_img], outputs=[save_file])
# 4) Refresh styles
def do_refresh():
files = load_hairstyles()
items = build_gallery_items(files)
msg = f"Found {len(files)} hairstyles."
# Keep selection if name still exists
return items, files, msg
refresh.click(fn=do_refresh, inputs=[], outputs=[gallery, files_state, count_md])
return demo
# export for Spaces
app = build_ui()
demo = app
if __name__ == "__main__":
app.launch()