Change / app.py
jeysshon's picture
Add hair color changer with segmentation model
cb225ac unverified
import os
import re
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
# =========================
# CONFIG
# =========================
MODEL_ID = "jonathandinu/face-parsing"
HAIR_ID = 13 # clase "hair" en este modelo
# Presets de color (puedes editar)
COLOR_PRESETS = {
"Personalizado (picker)": None,
"Negro": "#121212",
"Castaño": "#4b2e1f",
"Rubio": "#d8c27a",
"Platinado": "#d9d9d9",
"Rojo": "#c1121f",
"Azul": "#0077b6",
"Verde": "#2a9d8f",
"Morado": "#7209b7",
"Rosa": "#ff4d8d",
}
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = get_device()
# Recomendado en Spaces para evitar timeouts raros al bajar modelos
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
os.environ.setdefault("HF_HUB_READ_TIMEOUT", "60")
os.environ.setdefault("HF_HUB_CONNECT_TIMEOUT", "30")
# Limita threads en CPU (opcional, mejora estabilidad)
try:
torch.set_num_threads(min(4, os.cpu_count() or 1))
except Exception:
pass
# =========================
# LOAD MODEL (una sola vez)
# =========================
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(DEVICE)
model.eval()
# =========================
# UTIL: color parsing robusto
# =========================
def parse_color_to_rgb(color):
"""
Acepta:
- "#RRGGBB"
- "#RRGGBBAA" (ignora AA)
- "#RGB"
- "rgb(r,g,b)" / "rgba(r,g,b,a)"
- (r,g,b) o [r,g,b]
- dict con {"hex": "..."} (por si acaso)
Devuelve (r,g,b) en 0..255
"""
if color is None:
return (255, 0, 0)
if isinstance(color, dict):
color = color.get("hex") or color.get("value") or color.get("color")
if isinstance(color, (tuple, list)) and len(color) >= 3:
return (int(color[0]), int(color[1]), int(color[2]))
if not isinstance(color, str):
raise ValueError(f"Formato de color no soportado: {type(color)} -> {color}")
s = color.strip()
# rgb/rgba(...)
m = re.match(r"rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", s.lower())
if m:
r, g, b = map(int, m.groups())
r = max(0, min(255, r))
g = max(0, min(255, g))
b = max(0, min(255, b))
return (r, g, b)
# hex
if s.startswith("#"):
h = s[1:]
if len(h) == 3: # #RGB -> #RRGGBB
h = "".join([c * 2 for c in h])
if len(h) == 8: # #RRGGBBAA -> ignora AA
h = h[:6]
if len(h) != 6:
raise ValueError(f"HEX inválido: {s} (usa #RRGGBB)")
r = int(h[0:2], 16)
g = int(h[2:4], 16)
b = int(h[4:6], 16)
return (r, g, b)
raise ValueError(f"Color inválido: {color}")
# =========================
# IMAGE UTILS
# =========================
def resize_keep_aspect(pil: Image.Image, max_side: int) -> Image.Image:
w, h = pil.size
m = max(w, h)
if m <= max_side:
return pil
scale = max_side / float(m)
nw, nh = max(1, int(w * scale)), max(1, int(h * scale))
return pil.resize((nw, nh), Image.BILINEAR)
@torch.inference_mode()
def get_hair_mask(image: Image.Image, max_side: int = 640) -> Image.Image:
"""
Devuelve una máscara L (0..255) del cabello, al tamaño original.
"""
image = image.convert("RGB")
ow, oh = image.size
infer_img = resize_keep_aspect(image, max_side=max_side)
inputs = processor(images=infer_img, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
outputs = model(**inputs)
logits = outputs.logits # (B,C,h,w)
up = F.interpolate(
logits,
size=infer_img.size[::-1], # (H,W)
mode="bilinear",
align_corners=False,
)
labels = up.argmax(dim=1)[0] # (H,W)
hair = (labels == HAIR_ID).to(torch.uint8).cpu().numpy() * 255
mask = Image.fromarray(hair, mode="L")
if mask.size != (ow, oh):
mask = mask.resize((ow, oh), Image.NEAREST)
return mask
def refine_mask(mask: Image.Image, close_kernel: int = 9, feather: int = 9) -> Image.Image:
m = np.array(mask.convert("L"))
# binariza
_, mb = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
# close
k = max(3, int(close_kernel) | 1) # impar
kernel = np.ones((k, k), np.uint8)
mb = cv2.morphologyEx(mb, cv2.MORPH_CLOSE, kernel, iterations=1)
# feather (blur)
f = max(1, int(feather))
if f % 2 == 0:
f += 1
mb = cv2.GaussianBlur(mb, (f, f), 0)
return Image.fromarray(mb, mode="L")
def recolor_hair_lab(
image: Image.Image,
mask: Image.Image,
color_input,
strength: float = 0.85,
brighten: float = 0.0,
) -> Image.Image:
"""
Recolor en LAB para mantener sombras/luces.
strength: 0..1 confirmando cuánto entra el color
brighten: -0.3..0.3 (opcional, solo en cabello)
"""
image_rgb = np.array(image.convert("RGB"))
mask_f = np.array(mask.convert("L")).astype(np.float32) / 255.0
alpha = np.clip(mask_f * float(strength), 0.0, 1.0)[..., None] # (H,W,1)
# RGB -> BGR -> LAB
bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
# color objetivo -> LAB
r, g, b = parse_color_to_rgb(color_input)
target_bgr = np.array([[[b, g, r]]], dtype=np.uint8)
target_lab = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2LAB).astype(np.float32)[0, 0]
# Mezcla a/b hacia el objetivo
lab[:, :, 1] = lab[:, :, 1] * (1.0 - alpha[:, :, 0]) + target_lab[1] * alpha[:, :, 0]
lab[:, :, 2] = lab[:, :, 2] * (1.0 - alpha[:, :, 0]) + target_lab[2] * alpha[:, :, 0]
# Ajuste de brillo en cabello
if abs(brighten) > 1e-6:
lab[:, :, 0] = np.clip(lab[:, :, 0] + (brighten * 255.0) * alpha[:, :, 0], 0, 255)
lab_u8 = np.clip(lab, 0, 255).astype(np.uint8)
out_bgr = cv2.cvtColor(lab_u8, cv2.COLOR_LAB2BGR)
out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(out_rgb)
# =========================
# GRADIO RUN
# =========================
def run(image, preset, picked_color, strength, brighten, max_side, close_kernel, feather):
try:
if image is None:
return None, None, "Sube una imagen primero."
# color final
preset_hex = COLOR_PRESETS.get(preset)
final_color = (picked_color or "#ff0000") if preset_hex is None else preset_hex
# máscara
raw_mask = get_hair_mask(image, max_side=int(max_side))
mask = refine_mask(raw_mask, close_kernel=int(close_kernel), feather=int(feather))
# si la máscara salió vacía
if np.mean(np.array(mask)) < 2.0:
return image, mask, "No detecté cabello en esta foto. Prueba otra (mejor luz/frente)."
# recolor
result = recolor_hair_lab(
image=image,
mask=mask,
color_input=final_color,
strength=float(strength),
brighten=float(brighten),
)
return result, mask, f"OK ✅ Color aplicado: {final_color}"
except Exception as e:
# devuelve el error visible en la app
return None, None, f"ERROR: {type(e).__name__}: {e}"
DESCRIPTION = """
Sube una foto y cambia el color del cabello.
- Segmentación de cabello (hair mask)
- Recolor en LAB para conservar sombras/luces
"""
with gr.Blocks() as demo:
gr.Markdown("# 🎨 Cambiar color de cabello")
gr.Markdown(DESCRIPTION)
with gr.Row():
inp = gr.Image(label="Tu foto", type="pil")
out = gr.Image(label="Resultado", type="pil")
with gr.Accordion("Controles", open=True):
preset = gr.Dropdown(
label="Preset",
choices=list(COLOR_PRESETS.keys()),
value="Personalizado (picker)",
)
picked_color = gr.ColorPicker(label="Color personalizado", value="#ff0000")
strength = gr.Slider(0.0, 1.0, value=0.85, step=0.05, label="Intensidad")
brighten = gr.Slider(-0.3, 0.3, value=0.0, step=0.05, label="Brillo cabello (opcional)")
max_side = gr.Slider(384, 1024, value=640, step=64, label="Resolución segmentación")
close_kernel = gr.Slider(3, 21, value=9, step=2, label="Cerrar huecos (máscara)")
feather = gr.Slider(1, 31, value=9, step=2, label="Suavizado bordes (máscara)")
btn = gr.Button("Aplicar")
mask_out = gr.Image(label="Máscara (debug)", type="pil")
status = gr.Textbox(label="Estado", value="Listo.")
btn.click(
fn=run,
inputs=[inp, preset, picked_color, strength, brighten, max_side, close_kernel, feather],
outputs=[out, mask_out, status],
)
if __name__ == "__main__":
demo.launch()