jeysshon commited on
Commit
cb225ac
·
unverified ·
0 Parent(s):

Add hair color changer with segmentation model

Browse files

Implement hair color changing functionality using semantic segmentation.

Files changed (1) hide show
  1. app.py +283 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import cv2
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+ import gradio as gr
9
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
10
+
11
+ # =========================
12
+ # CONFIG
13
+ # =========================
14
+ MODEL_ID = "jonathandinu/face-parsing"
15
+ HAIR_ID = 13 # clase "hair" en este modelo
16
+
17
+ # Presets de color (puedes editar)
18
+ COLOR_PRESETS = {
19
+ "Personalizado (picker)": None,
20
+ "Negro": "#121212",
21
+ "Castaño": "#4b2e1f",
22
+ "Rubio": "#d8c27a",
23
+ "Platinado": "#d9d9d9",
24
+ "Rojo": "#c1121f",
25
+ "Azul": "#0077b6",
26
+ "Verde": "#2a9d8f",
27
+ "Morado": "#7209b7",
28
+ "Rosa": "#ff4d8d",
29
+ }
30
+
31
+ def get_device():
32
+ return "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ DEVICE = get_device()
35
+
36
+ # Recomendado en Spaces para evitar timeouts raros al bajar modelos
37
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
38
+ os.environ.setdefault("HF_HUB_READ_TIMEOUT", "60")
39
+ os.environ.setdefault("HF_HUB_CONNECT_TIMEOUT", "30")
40
+
41
+ # Limita threads en CPU (opcional, mejora estabilidad)
42
+ try:
43
+ torch.set_num_threads(min(4, os.cpu_count() or 1))
44
+ except Exception:
45
+ pass
46
+
47
+ # =========================
48
+ # LOAD MODEL (una sola vez)
49
+ # =========================
50
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
51
+ model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(DEVICE)
52
+ model.eval()
53
+
54
+ # =========================
55
+ # UTIL: color parsing robusto
56
+ # =========================
57
+ def parse_color_to_rgb(color):
58
+ """
59
+ Acepta:
60
+ - "#RRGGBB"
61
+ - "#RRGGBBAA" (ignora AA)
62
+ - "#RGB"
63
+ - "rgb(r,g,b)" / "rgba(r,g,b,a)"
64
+ - (r,g,b) o [r,g,b]
65
+ - dict con {"hex": "..."} (por si acaso)
66
+ Devuelve (r,g,b) en 0..255
67
+ """
68
+ if color is None:
69
+ return (255, 0, 0)
70
+
71
+ if isinstance(color, dict):
72
+ color = color.get("hex") or color.get("value") or color.get("color")
73
+
74
+ if isinstance(color, (tuple, list)) and len(color) >= 3:
75
+ return (int(color[0]), int(color[1]), int(color[2]))
76
+
77
+ if not isinstance(color, str):
78
+ raise ValueError(f"Formato de color no soportado: {type(color)} -> {color}")
79
+
80
+ s = color.strip()
81
+
82
+ # rgb/rgba(...)
83
+ m = re.match(r"rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", s.lower())
84
+ if m:
85
+ r, g, b = map(int, m.groups())
86
+ r = max(0, min(255, r))
87
+ g = max(0, min(255, g))
88
+ b = max(0, min(255, b))
89
+ return (r, g, b)
90
+
91
+ # hex
92
+ if s.startswith("#"):
93
+ h = s[1:]
94
+ if len(h) == 3: # #RGB -> #RRGGBB
95
+ h = "".join([c * 2 for c in h])
96
+ if len(h) == 8: # #RRGGBBAA -> ignora AA
97
+ h = h[:6]
98
+ if len(h) != 6:
99
+ raise ValueError(f"HEX inválido: {s} (usa #RRGGBB)")
100
+ r = int(h[0:2], 16)
101
+ g = int(h[2:4], 16)
102
+ b = int(h[4:6], 16)
103
+ return (r, g, b)
104
+
105
+ raise ValueError(f"Color inválido: {color}")
106
+
107
+ # =========================
108
+ # IMAGE UTILS
109
+ # =========================
110
+ def resize_keep_aspect(pil: Image.Image, max_side: int) -> Image.Image:
111
+ w, h = pil.size
112
+ m = max(w, h)
113
+ if m <= max_side:
114
+ return pil
115
+ scale = max_side / float(m)
116
+ nw, nh = max(1, int(w * scale)), max(1, int(h * scale))
117
+ return pil.resize((nw, nh), Image.BILINEAR)
118
+
119
+ @torch.inference_mode()
120
+ def get_hair_mask(image: Image.Image, max_side: int = 640) -> Image.Image:
121
+ """
122
+ Devuelve una máscara L (0..255) del cabello, al tamaño original.
123
+ """
124
+ image = image.convert("RGB")
125
+ ow, oh = image.size
126
+
127
+ infer_img = resize_keep_aspect(image, max_side=max_side)
128
+
129
+ inputs = processor(images=infer_img, return_tensors="pt")
130
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
131
+
132
+ outputs = model(**inputs)
133
+ logits = outputs.logits # (B,C,h,w)
134
+
135
+ up = F.interpolate(
136
+ logits,
137
+ size=infer_img.size[::-1], # (H,W)
138
+ mode="bilinear",
139
+ align_corners=False,
140
+ )
141
+
142
+ labels = up.argmax(dim=1)[0] # (H,W)
143
+ hair = (labels == HAIR_ID).to(torch.uint8).cpu().numpy() * 255
144
+
145
+ mask = Image.fromarray(hair, mode="L")
146
+
147
+ if mask.size != (ow, oh):
148
+ mask = mask.resize((ow, oh), Image.NEAREST)
149
+
150
+ return mask
151
+
152
+ def refine_mask(mask: Image.Image, close_kernel: int = 9, feather: int = 9) -> Image.Image:
153
+ m = np.array(mask.convert("L"))
154
+
155
+ # binariza
156
+ _, mb = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
157
+
158
+ # close
159
+ k = max(3, int(close_kernel) | 1) # impar
160
+ kernel = np.ones((k, k), np.uint8)
161
+ mb = cv2.morphologyEx(mb, cv2.MORPH_CLOSE, kernel, iterations=1)
162
+
163
+ # feather (blur)
164
+ f = max(1, int(feather))
165
+ if f % 2 == 0:
166
+ f += 1
167
+ mb = cv2.GaussianBlur(mb, (f, f), 0)
168
+
169
+ return Image.fromarray(mb, mode="L")
170
+
171
+ def recolor_hair_lab(
172
+ image: Image.Image,
173
+ mask: Image.Image,
174
+ color_input,
175
+ strength: float = 0.85,
176
+ brighten: float = 0.0,
177
+ ) -> Image.Image:
178
+ """
179
+ Recolor en LAB para mantener sombras/luces.
180
+ strength: 0..1 confirmando cuánto entra el color
181
+ brighten: -0.3..0.3 (opcional, solo en cabello)
182
+ """
183
+ image_rgb = np.array(image.convert("RGB"))
184
+ mask_f = np.array(mask.convert("L")).astype(np.float32) / 255.0
185
+ alpha = np.clip(mask_f * float(strength), 0.0, 1.0)[..., None] # (H,W,1)
186
+
187
+ # RGB -> BGR -> LAB
188
+ bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
189
+ lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
190
+
191
+ # color objetivo -> LAB
192
+ r, g, b = parse_color_to_rgb(color_input)
193
+ target_bgr = np.array([[[b, g, r]]], dtype=np.uint8)
194
+ target_lab = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2LAB).astype(np.float32)[0, 0]
195
+
196
+ # Mezcla a/b hacia el objetivo
197
+ lab[:, :, 1] = lab[:, :, 1] * (1.0 - alpha[:, :, 0]) + target_lab[1] * alpha[:, :, 0]
198
+ lab[:, :, 2] = lab[:, :, 2] * (1.0 - alpha[:, :, 0]) + target_lab[2] * alpha[:, :, 0]
199
+
200
+ # Ajuste de brillo en cabello
201
+ if abs(brighten) > 1e-6:
202
+ lab[:, :, 0] = np.clip(lab[:, :, 0] + (brighten * 255.0) * alpha[:, :, 0], 0, 255)
203
+
204
+ lab_u8 = np.clip(lab, 0, 255).astype(np.uint8)
205
+ out_bgr = cv2.cvtColor(lab_u8, cv2.COLOR_LAB2BGR)
206
+ out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
207
+
208
+ return Image.fromarray(out_rgb)
209
+
210
+ # =========================
211
+ # GRADIO RUN
212
+ # =========================
213
+ def run(image, preset, picked_color, strength, brighten, max_side, close_kernel, feather):
214
+ try:
215
+ if image is None:
216
+ return None, None, "Sube una imagen primero."
217
+
218
+ # color final
219
+ preset_hex = COLOR_PRESETS.get(preset)
220
+ final_color = (picked_color or "#ff0000") if preset_hex is None else preset_hex
221
+
222
+ # máscara
223
+ raw_mask = get_hair_mask(image, max_side=int(max_side))
224
+ mask = refine_mask(raw_mask, close_kernel=int(close_kernel), feather=int(feather))
225
+
226
+ # si la máscara salió vacía
227
+ if np.mean(np.array(mask)) < 2.0:
228
+ return image, mask, "No detecté cabello en esta foto. Prueba otra (mejor luz/frente)."
229
+
230
+ # recolor
231
+ result = recolor_hair_lab(
232
+ image=image,
233
+ mask=mask,
234
+ color_input=final_color,
235
+ strength=float(strength),
236
+ brighten=float(brighten),
237
+ )
238
+
239
+ return result, mask, f"OK ✅ Color aplicado: {final_color}"
240
+
241
+ except Exception as e:
242
+ # devuelve el error visible en la app
243
+ return None, None, f"ERROR: {type(e).__name__}: {e}"
244
+
245
+ DESCRIPTION = """
246
+ Sube una foto y cambia el color del cabello.
247
+ - Segmentación de cabello (hair mask)
248
+ - Recolor en LAB para conservar sombras/luces
249
+ """
250
+
251
+ with gr.Blocks() as demo:
252
+ gr.Markdown("# 🎨 Cambiar color de cabello")
253
+ gr.Markdown(DESCRIPTION)
254
+
255
+ with gr.Row():
256
+ inp = gr.Image(label="Tu foto", type="pil")
257
+ out = gr.Image(label="Resultado", type="pil")
258
+
259
+ with gr.Accordion("Controles", open=True):
260
+ preset = gr.Dropdown(
261
+ label="Preset",
262
+ choices=list(COLOR_PRESETS.keys()),
263
+ value="Personalizado (picker)",
264
+ )
265
+ picked_color = gr.ColorPicker(label="Color personalizado", value="#ff0000")
266
+ strength = gr.Slider(0.0, 1.0, value=0.85, step=0.05, label="Intensidad")
267
+ brighten = gr.Slider(-0.3, 0.3, value=0.0, step=0.05, label="Brillo cabello (opcional)")
268
+ max_side = gr.Slider(384, 1024, value=640, step=64, label="Resolución segmentación")
269
+ close_kernel = gr.Slider(3, 21, value=9, step=2, label="Cerrar huecos (máscara)")
270
+ feather = gr.Slider(1, 31, value=9, step=2, label="Suavizado bordes (máscara)")
271
+
272
+ btn = gr.Button("Aplicar")
273
+ mask_out = gr.Image(label="Máscara (debug)", type="pil")
274
+ status = gr.Textbox(label="Estado", value="Listo.")
275
+
276
+ btn.click(
277
+ fn=run,
278
+ inputs=[inp, preset, picked_color, strength, brighten, max_side, close_kernel, feather],
279
+ outputs=[out, mask_out, status],
280
+ )
281
+
282
+ if __name__ == "__main__":
283
+ demo.launch()