Commit ·
cb225ac
unverified ·
0
Parent(s):
Add hair color changer with segmentation model
Browse filesImplement hair color changing functionality using semantic segmentation.
app.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
|
| 10 |
+
|
| 11 |
+
# =========================
|
| 12 |
+
# CONFIG
|
| 13 |
+
# =========================
|
| 14 |
+
MODEL_ID = "jonathandinu/face-parsing"
|
| 15 |
+
HAIR_ID = 13 # clase "hair" en este modelo
|
| 16 |
+
|
| 17 |
+
# Presets de color (puedes editar)
|
| 18 |
+
COLOR_PRESETS = {
|
| 19 |
+
"Personalizado (picker)": None,
|
| 20 |
+
"Negro": "#121212",
|
| 21 |
+
"Castaño": "#4b2e1f",
|
| 22 |
+
"Rubio": "#d8c27a",
|
| 23 |
+
"Platinado": "#d9d9d9",
|
| 24 |
+
"Rojo": "#c1121f",
|
| 25 |
+
"Azul": "#0077b6",
|
| 26 |
+
"Verde": "#2a9d8f",
|
| 27 |
+
"Morado": "#7209b7",
|
| 28 |
+
"Rosa": "#ff4d8d",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def get_device():
|
| 32 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
|
| 34 |
+
DEVICE = get_device()
|
| 35 |
+
|
| 36 |
+
# Recomendado en Spaces para evitar timeouts raros al bajar modelos
|
| 37 |
+
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
|
| 38 |
+
os.environ.setdefault("HF_HUB_READ_TIMEOUT", "60")
|
| 39 |
+
os.environ.setdefault("HF_HUB_CONNECT_TIMEOUT", "30")
|
| 40 |
+
|
| 41 |
+
# Limita threads en CPU (opcional, mejora estabilidad)
|
| 42 |
+
try:
|
| 43 |
+
torch.set_num_threads(min(4, os.cpu_count() or 1))
|
| 44 |
+
except Exception:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# =========================
|
| 48 |
+
# LOAD MODEL (una sola vez)
|
| 49 |
+
# =========================
|
| 50 |
+
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
|
| 51 |
+
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(DEVICE)
|
| 52 |
+
model.eval()
|
| 53 |
+
|
| 54 |
+
# =========================
|
| 55 |
+
# UTIL: color parsing robusto
|
| 56 |
+
# =========================
|
| 57 |
+
def parse_color_to_rgb(color):
|
| 58 |
+
"""
|
| 59 |
+
Acepta:
|
| 60 |
+
- "#RRGGBB"
|
| 61 |
+
- "#RRGGBBAA" (ignora AA)
|
| 62 |
+
- "#RGB"
|
| 63 |
+
- "rgb(r,g,b)" / "rgba(r,g,b,a)"
|
| 64 |
+
- (r,g,b) o [r,g,b]
|
| 65 |
+
- dict con {"hex": "..."} (por si acaso)
|
| 66 |
+
Devuelve (r,g,b) en 0..255
|
| 67 |
+
"""
|
| 68 |
+
if color is None:
|
| 69 |
+
return (255, 0, 0)
|
| 70 |
+
|
| 71 |
+
if isinstance(color, dict):
|
| 72 |
+
color = color.get("hex") or color.get("value") or color.get("color")
|
| 73 |
+
|
| 74 |
+
if isinstance(color, (tuple, list)) and len(color) >= 3:
|
| 75 |
+
return (int(color[0]), int(color[1]), int(color[2]))
|
| 76 |
+
|
| 77 |
+
if not isinstance(color, str):
|
| 78 |
+
raise ValueError(f"Formato de color no soportado: {type(color)} -> {color}")
|
| 79 |
+
|
| 80 |
+
s = color.strip()
|
| 81 |
+
|
| 82 |
+
# rgb/rgba(...)
|
| 83 |
+
m = re.match(r"rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", s.lower())
|
| 84 |
+
if m:
|
| 85 |
+
r, g, b = map(int, m.groups())
|
| 86 |
+
r = max(0, min(255, r))
|
| 87 |
+
g = max(0, min(255, g))
|
| 88 |
+
b = max(0, min(255, b))
|
| 89 |
+
return (r, g, b)
|
| 90 |
+
|
| 91 |
+
# hex
|
| 92 |
+
if s.startswith("#"):
|
| 93 |
+
h = s[1:]
|
| 94 |
+
if len(h) == 3: # #RGB -> #RRGGBB
|
| 95 |
+
h = "".join([c * 2 for c in h])
|
| 96 |
+
if len(h) == 8: # #RRGGBBAA -> ignora AA
|
| 97 |
+
h = h[:6]
|
| 98 |
+
if len(h) != 6:
|
| 99 |
+
raise ValueError(f"HEX inválido: {s} (usa #RRGGBB)")
|
| 100 |
+
r = int(h[0:2], 16)
|
| 101 |
+
g = int(h[2:4], 16)
|
| 102 |
+
b = int(h[4:6], 16)
|
| 103 |
+
return (r, g, b)
|
| 104 |
+
|
| 105 |
+
raise ValueError(f"Color inválido: {color}")
|
| 106 |
+
|
| 107 |
+
# =========================
|
| 108 |
+
# IMAGE UTILS
|
| 109 |
+
# =========================
|
| 110 |
+
def resize_keep_aspect(pil: Image.Image, max_side: int) -> Image.Image:
|
| 111 |
+
w, h = pil.size
|
| 112 |
+
m = max(w, h)
|
| 113 |
+
if m <= max_side:
|
| 114 |
+
return pil
|
| 115 |
+
scale = max_side / float(m)
|
| 116 |
+
nw, nh = max(1, int(w * scale)), max(1, int(h * scale))
|
| 117 |
+
return pil.resize((nw, nh), Image.BILINEAR)
|
| 118 |
+
|
| 119 |
+
@torch.inference_mode()
|
| 120 |
+
def get_hair_mask(image: Image.Image, max_side: int = 640) -> Image.Image:
|
| 121 |
+
"""
|
| 122 |
+
Devuelve una máscara L (0..255) del cabello, al tamaño original.
|
| 123 |
+
"""
|
| 124 |
+
image = image.convert("RGB")
|
| 125 |
+
ow, oh = image.size
|
| 126 |
+
|
| 127 |
+
infer_img = resize_keep_aspect(image, max_side=max_side)
|
| 128 |
+
|
| 129 |
+
inputs = processor(images=infer_img, return_tensors="pt")
|
| 130 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 131 |
+
|
| 132 |
+
outputs = model(**inputs)
|
| 133 |
+
logits = outputs.logits # (B,C,h,w)
|
| 134 |
+
|
| 135 |
+
up = F.interpolate(
|
| 136 |
+
logits,
|
| 137 |
+
size=infer_img.size[::-1], # (H,W)
|
| 138 |
+
mode="bilinear",
|
| 139 |
+
align_corners=False,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
labels = up.argmax(dim=1)[0] # (H,W)
|
| 143 |
+
hair = (labels == HAIR_ID).to(torch.uint8).cpu().numpy() * 255
|
| 144 |
+
|
| 145 |
+
mask = Image.fromarray(hair, mode="L")
|
| 146 |
+
|
| 147 |
+
if mask.size != (ow, oh):
|
| 148 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
| 149 |
+
|
| 150 |
+
return mask
|
| 151 |
+
|
| 152 |
+
def refine_mask(mask: Image.Image, close_kernel: int = 9, feather: int = 9) -> Image.Image:
|
| 153 |
+
m = np.array(mask.convert("L"))
|
| 154 |
+
|
| 155 |
+
# binariza
|
| 156 |
+
_, mb = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
|
| 157 |
+
|
| 158 |
+
# close
|
| 159 |
+
k = max(3, int(close_kernel) | 1) # impar
|
| 160 |
+
kernel = np.ones((k, k), np.uint8)
|
| 161 |
+
mb = cv2.morphologyEx(mb, cv2.MORPH_CLOSE, kernel, iterations=1)
|
| 162 |
+
|
| 163 |
+
# feather (blur)
|
| 164 |
+
f = max(1, int(feather))
|
| 165 |
+
if f % 2 == 0:
|
| 166 |
+
f += 1
|
| 167 |
+
mb = cv2.GaussianBlur(mb, (f, f), 0)
|
| 168 |
+
|
| 169 |
+
return Image.fromarray(mb, mode="L")
|
| 170 |
+
|
| 171 |
+
def recolor_hair_lab(
|
| 172 |
+
image: Image.Image,
|
| 173 |
+
mask: Image.Image,
|
| 174 |
+
color_input,
|
| 175 |
+
strength: float = 0.85,
|
| 176 |
+
brighten: float = 0.0,
|
| 177 |
+
) -> Image.Image:
|
| 178 |
+
"""
|
| 179 |
+
Recolor en LAB para mantener sombras/luces.
|
| 180 |
+
strength: 0..1 confirmando cuánto entra el color
|
| 181 |
+
brighten: -0.3..0.3 (opcional, solo en cabello)
|
| 182 |
+
"""
|
| 183 |
+
image_rgb = np.array(image.convert("RGB"))
|
| 184 |
+
mask_f = np.array(mask.convert("L")).astype(np.float32) / 255.0
|
| 185 |
+
alpha = np.clip(mask_f * float(strength), 0.0, 1.0)[..., None] # (H,W,1)
|
| 186 |
+
|
| 187 |
+
# RGB -> BGR -> LAB
|
| 188 |
+
bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
| 189 |
+
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 190 |
+
|
| 191 |
+
# color objetivo -> LAB
|
| 192 |
+
r, g, b = parse_color_to_rgb(color_input)
|
| 193 |
+
target_bgr = np.array([[[b, g, r]]], dtype=np.uint8)
|
| 194 |
+
target_lab = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2LAB).astype(np.float32)[0, 0]
|
| 195 |
+
|
| 196 |
+
# Mezcla a/b hacia el objetivo
|
| 197 |
+
lab[:, :, 1] = lab[:, :, 1] * (1.0 - alpha[:, :, 0]) + target_lab[1] * alpha[:, :, 0]
|
| 198 |
+
lab[:, :, 2] = lab[:, :, 2] * (1.0 - alpha[:, :, 0]) + target_lab[2] * alpha[:, :, 0]
|
| 199 |
+
|
| 200 |
+
# Ajuste de brillo en cabello
|
| 201 |
+
if abs(brighten) > 1e-6:
|
| 202 |
+
lab[:, :, 0] = np.clip(lab[:, :, 0] + (brighten * 255.0) * alpha[:, :, 0], 0, 255)
|
| 203 |
+
|
| 204 |
+
lab_u8 = np.clip(lab, 0, 255).astype(np.uint8)
|
| 205 |
+
out_bgr = cv2.cvtColor(lab_u8, cv2.COLOR_LAB2BGR)
|
| 206 |
+
out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
|
| 207 |
+
|
| 208 |
+
return Image.fromarray(out_rgb)
|
| 209 |
+
|
| 210 |
+
# =========================
|
| 211 |
+
# GRADIO RUN
|
| 212 |
+
# =========================
|
| 213 |
+
def run(image, preset, picked_color, strength, brighten, max_side, close_kernel, feather):
|
| 214 |
+
try:
|
| 215 |
+
if image is None:
|
| 216 |
+
return None, None, "Sube una imagen primero."
|
| 217 |
+
|
| 218 |
+
# color final
|
| 219 |
+
preset_hex = COLOR_PRESETS.get(preset)
|
| 220 |
+
final_color = (picked_color or "#ff0000") if preset_hex is None else preset_hex
|
| 221 |
+
|
| 222 |
+
# máscara
|
| 223 |
+
raw_mask = get_hair_mask(image, max_side=int(max_side))
|
| 224 |
+
mask = refine_mask(raw_mask, close_kernel=int(close_kernel), feather=int(feather))
|
| 225 |
+
|
| 226 |
+
# si la máscara salió vacía
|
| 227 |
+
if np.mean(np.array(mask)) < 2.0:
|
| 228 |
+
return image, mask, "No detecté cabello en esta foto. Prueba otra (mejor luz/frente)."
|
| 229 |
+
|
| 230 |
+
# recolor
|
| 231 |
+
result = recolor_hair_lab(
|
| 232 |
+
image=image,
|
| 233 |
+
mask=mask,
|
| 234 |
+
color_input=final_color,
|
| 235 |
+
strength=float(strength),
|
| 236 |
+
brighten=float(brighten),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return result, mask, f"OK ✅ Color aplicado: {final_color}"
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
# devuelve el error visible en la app
|
| 243 |
+
return None, None, f"ERROR: {type(e).__name__}: {e}"
|
| 244 |
+
|
| 245 |
+
DESCRIPTION = """
|
| 246 |
+
Sube una foto y cambia el color del cabello.
|
| 247 |
+
- Segmentación de cabello (hair mask)
|
| 248 |
+
- Recolor en LAB para conservar sombras/luces
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
with gr.Blocks() as demo:
|
| 252 |
+
gr.Markdown("# 🎨 Cambiar color de cabello")
|
| 253 |
+
gr.Markdown(DESCRIPTION)
|
| 254 |
+
|
| 255 |
+
with gr.Row():
|
| 256 |
+
inp = gr.Image(label="Tu foto", type="pil")
|
| 257 |
+
out = gr.Image(label="Resultado", type="pil")
|
| 258 |
+
|
| 259 |
+
with gr.Accordion("Controles", open=True):
|
| 260 |
+
preset = gr.Dropdown(
|
| 261 |
+
label="Preset",
|
| 262 |
+
choices=list(COLOR_PRESETS.keys()),
|
| 263 |
+
value="Personalizado (picker)",
|
| 264 |
+
)
|
| 265 |
+
picked_color = gr.ColorPicker(label="Color personalizado", value="#ff0000")
|
| 266 |
+
strength = gr.Slider(0.0, 1.0, value=0.85, step=0.05, label="Intensidad")
|
| 267 |
+
brighten = gr.Slider(-0.3, 0.3, value=0.0, step=0.05, label="Brillo cabello (opcional)")
|
| 268 |
+
max_side = gr.Slider(384, 1024, value=640, step=64, label="Resolución segmentación")
|
| 269 |
+
close_kernel = gr.Slider(3, 21, value=9, step=2, label="Cerrar huecos (máscara)")
|
| 270 |
+
feather = gr.Slider(1, 31, value=9, step=2, label="Suavizado bordes (máscara)")
|
| 271 |
+
|
| 272 |
+
btn = gr.Button("Aplicar")
|
| 273 |
+
mask_out = gr.Image(label="Máscara (debug)", type="pil")
|
| 274 |
+
status = gr.Textbox(label="Estado", value="Listo.")
|
| 275 |
+
|
| 276 |
+
btn.click(
|
| 277 |
+
fn=run,
|
| 278 |
+
inputs=[inp, preset, picked_color, strength, brighten, max_side, close_kernel, feather],
|
| 279 |
+
outputs=[out, mask_out, status],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
demo.launch()
|