Stylique commited on
Commit
2f949e9
·
verified ·
1 Parent(s): 5198051

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +265 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, List, Dict
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ from functools import lru_cache
9
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
10
+
11
+ import mediapipe as mp # MediaPipe is mandatory
12
+ HAS_MEDIAPIPE = True
13
+
14
+
15
+ def _ensure_rgb_uint8(image: np.ndarray) -> np.ndarray:
16
+ """Convert an input image array to RGB uint8 format.
17
+
18
+ Gradio provides images as numpy arrays in RGB order with dtype uint8 by default,
19
+ but we defensively normalize here in case inputs vary.
20
+ """
21
+ if image is None:
22
+ raise ValueError("No image provided")
23
+
24
+ if isinstance(image, Image.Image):
25
+ image = np.array(image.convert("RGB"))
26
+ elif image.dtype != np.uint8:
27
+ image = image.astype(np.uint8)
28
+
29
+ if image.ndim == 2:
30
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
31
+ elif image.shape[2] == 4:
32
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
33
+ return image
34
+
35
+
36
+ def _central_crop_bbox(width: int, height: int, frac: float = 0.6) -> Tuple[int, int, int, int]:
37
+ """Return a central crop bounding box (x1, y1, x2, y2) covering `frac` of width/height."""
38
+ frac = float(np.clip(frac, 0.2, 1.0))
39
+ crop_w = int(width * frac)
40
+ crop_h = int(height * frac)
41
+ x1 = (width - crop_w) // 2
42
+ y1 = (height - crop_h) // 2
43
+ x2 = x1 + crop_w
44
+ y2 = y1 + crop_h
45
+ return x1, y1, x2, y2
46
+
47
+
48
+ def _detect_face_bbox_mediapipe(image_rgb: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
49
+ """Detect a face bounding box using MediaPipe Face Detection and return (x1, y1, x2, y2).
50
+
51
+ Returns None if detection fails or mediapipe is unavailable.
52
+ """
53
+ if not HAS_MEDIAPIPE:
54
+ return None
55
+ height, width = image_rgb.shape[:2]
56
+ try:
57
+ with mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as detector:
58
+ results = detector.process(image_rgb)
59
+ detections = results.detections or []
60
+ if not detections:
61
+ return None
62
+ # Pick the largest bbox
63
+ def bbox_area(det):
64
+ bbox = det.location_data.relative_bounding_box
65
+ return max(0.0, bbox.width) * max(0.0, bbox.height)
66
+
67
+ best = max(detections, key=bbox_area)
68
+ rb = best.location_data.relative_bounding_box
69
+ x1 = int(np.clip(rb.xmin * width, 0, width - 1))
70
+ y1 = int(np.clip(rb.ymin * height, 0, height - 1))
71
+ x2 = int(np.clip((rb.xmin + rb.width) * width, 0, width))
72
+ y2 = int(np.clip((rb.ymin + rb.height) * height, 0, height))
73
+
74
+ # Expand a bit to include cheeks/forehead
75
+ pad_x = int(0.08 * width)
76
+ pad_y = int(0.12 * height)
77
+ x1 = int(np.clip(x1 - pad_x, 0, width - 1))
78
+ y1 = int(np.clip(y1 - pad_y, 0, height - 1))
79
+ x2 = int(np.clip(x2 + pad_x, 0, width))
80
+ y2 = int(np.clip(y2 + pad_y, 0, height))
81
+
82
+ if x2 - x1 < 10 or y2 - y1 < 10:
83
+ return None
84
+ return x1, y1, x2, y2
85
+ except Exception:
86
+ return None
87
+
88
+
89
+ def _binary_open_close(mask: np.ndarray, kernel_size: int = 5, iterations: int = 1) -> np.ndarray:
90
+ """Apply morphological open then close to clean the binary mask."""
91
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
92
+ opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=iterations)
93
+ closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel, iterations=iterations)
94
+ return closed
95
+
96
+
97
+ @lru_cache(maxsize=1)
98
+ def _load_face_parsing_model():
99
+ """Load face-parsing model and processor from the Hugging Face Hub (cached)."""
100
+ model_id = "jonathandinu/face-parsing"
101
+ processor = AutoImageProcessor.from_pretrained(model_id)
102
+ model = AutoModelForSemanticSegmentation.from_pretrained(model_id)
103
+ model.eval()
104
+ id2label: Dict[int, str] = model.config.id2label
105
+ label2id: Dict[str, int] = model.config.label2id
106
+ return processor, model, id2label, label2id
107
+
108
+
109
+ def _segment_face_labels(image_rgb: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]:
110
+ """Run face-parsing segmentation on an RGB crop. Returns (labels HxW int, id2label)."""
111
+ processor, model, id2label, _ = _load_face_parsing_model()
112
+ pil_img = Image.fromarray(image_rgb)
113
+ inputs = processor(images=pil_img, return_tensors="pt")
114
+ with torch.no_grad():
115
+ outputs = model(**inputs)
116
+ logits = outputs.logits # (1, num_labels, h', w')
117
+
118
+ # Upsample to original image size
119
+ upsampled = torch.nn.functional.interpolate(
120
+ logits,
121
+ size=pil_img.size[::-1], # (H, W)
122
+ mode="bilinear",
123
+ align_corners=False,
124
+ )
125
+ labels = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.int32)
126
+ return labels, id2label
127
+
128
+
129
+ def _skin_indices_from_id2label(id2label: Dict[int, str]) -> List[int]:
130
+ skin_indices: List[int] = []
131
+ for idx, name in id2label.items():
132
+ name_l = name.lower()
133
+ if "skin" in name_l:
134
+ skin_indices.append(int(idx))
135
+ # Fallback: some models may label general face region as 'face'
136
+ if not skin_indices:
137
+ for idx, name in id2label.items():
138
+ if "face" in name.lower():
139
+ skin_indices.append(int(idx))
140
+ return skin_indices
141
+
142
+
143
+ def _compute_skin_color_hex(image_rgb: np.ndarray, mask: np.ndarray) -> Tuple[str, np.ndarray]:
144
+ """Compute a robust representative skin color as a hex string and return also the RGB color.
145
+
146
+ Uses median across masked pixels to reduce influence of highlights/shadows.
147
+ """
148
+ if mask is None or mask.size == 0:
149
+ raise ValueError("Invalid mask for skin color computation")
150
+
151
+ # boolean mask for indexing
152
+ mask_bool = mask.astype(bool)
153
+ if not np.any(mask_bool):
154
+ raise ValueError("No skin pixels detected")
155
+
156
+ skin_pixels = image_rgb[mask_bool]
157
+
158
+ # Robust median to mitigate outliers
159
+ median_color = np.median(skin_pixels, axis=0)
160
+ median_color = np.clip(median_color, 0, 255).astype(np.uint8)
161
+
162
+ r, g, b = int(median_color[0]), int(median_color[1]), int(median_color[2])
163
+ hex_code = f"#{r:02X}{g:02X}{b:02X}"
164
+ return hex_code, median_color
165
+
166
+
167
+ def _solid_color_image(color_rgb: np.ndarray, size: Tuple[int, int] = (160, 160)) -> np.ndarray:
168
+ swatch = np.zeros((size[1], size[0], 3), dtype=np.uint8)
169
+ swatch[:, :] = color_rgb
170
+ return swatch
171
+
172
+
173
+ def detect_skin_tone(image: np.ndarray) -> Tuple[str, np.ndarray, np.ndarray]:
174
+ """Main pipeline: returns (hex_code, color_swatch_image, debug_mask_overlay).
175
+
176
+ - image: input image as numpy array (H, W, 3) RGB uint8
177
+ - center_focus: if True, prioritizes central crop region to avoid background/hands
178
+ """
179
+ rgb = _ensure_rgb_uint8(image)
180
+ height, width = rgb.shape[:2]
181
+
182
+ # Mandatory: detect face with MediaPipe
183
+ face_bbox = _detect_face_bbox_mediapipe(rgb)
184
+ if face_bbox is None:
185
+ raise ValueError("No face detected. Please upload an image with a clear frontal face.")
186
+ x1, y1, x2, y2 = face_bbox
187
+ central_rgb = rgb[y1:y2, x1:x2]
188
+
189
+ # Face parsing segmentation to get skin mask
190
+ labels, id2label = _segment_face_labels(central_rgb)
191
+ skin_indices = _skin_indices_from_id2label(id2label)
192
+ if not skin_indices:
193
+ raise ValueError("Face parsing model did not expose a skin class.")
194
+
195
+ skin_mask = np.isin(labels, np.array(skin_indices, dtype=np.int32)).astype(np.uint8) * 255
196
+
197
+ # Compute color from masked central region
198
+ hex_code, color_rgb = _compute_skin_color_hex(central_rgb, skin_mask)
199
+
200
+ # Prepare swatch and debug visualization
201
+ swatch = _solid_color_image(color_rgb)
202
+
203
+ # Place mask back into full image coordinates for visualization
204
+ full_mask = np.zeros((height, width), dtype=np.uint8)
205
+ full_mask[y1:y2, x1:x2] = skin_mask
206
+ color_mask = cv2.cvtColor(full_mask, cv2.COLOR_GRAY2RGB)
207
+ overlay = cv2.addWeighted(rgb, 0.8, color_mask, 0.2, 0)
208
+
209
+ return hex_code, swatch, overlay
210
+
211
+
212
+ def _hex_html(hex_code: str) -> str:
213
+ style = (
214
+ "display:flex;align-items:center;gap:12px;padding:8px 0;"
215
+ )
216
+ swatch_style = (
217
+ f"width:20px;height:20px;border-radius:4px;background:{hex_code};"
218
+ "border:1px solid #ccc;"
219
+ )
220
+ return (
221
+ f"<div style='{style}'>"
222
+ f"<div style='{swatch_style}'></div>"
223
+ f"<span style='font-family:monospace;font-size:16px'>{hex_code}</span>"
224
+ "</div>"
225
+ )
226
+
227
+
228
+ with gr.Blocks(title="Skin Tone Detector") as demo:
229
+ gr.Markdown(
230
+ """
231
+ ### Skin Tone Hex Detector
232
+ Upload a face image. The app estimates a representative skin tone and returns a HEX color.
233
+ """
234
+ )
235
+
236
+ with gr.Row():
237
+ with gr.Column():
238
+ input_image = gr.Image(
239
+ label="Upload face image",
240
+ type="numpy",
241
+ image_mode="RGB",
242
+ height=360,
243
+ )
244
+ run_btn = gr.Button("Detect Skin Tone", variant="primary")
245
+
246
+ with gr.Column():
247
+ hex_output = gr.HTML(label="HEX Color")
248
+ swatch_output = gr.Image(label="Color Swatch", type="numpy")
249
+ debug_output = gr.Image(label="Mask Overlay", type="numpy")
250
+ gr.Markdown("MediaPipe face detection and a face-parsing model are used to isolate skin pixels.")
251
+
252
+ def _run(image: Optional[np.ndarray]):
253
+ if image is None:
254
+ return _hex_html("#000000"), np.zeros((160, 160, 3), dtype=np.uint8), None
255
+ hex_code, swatch, debug = detect_skin_tone(image)
256
+ return _hex_html(hex_code), swatch, debug
257
+
258
+ run_btn.click(_run, inputs=[input_image], outputs=[hex_output, swatch_output, debug_output])
259
+ input_image.change(_run, inputs=[input_image], outputs=[hex_output, swatch_output, debug_output])
260
+
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch()
264
+
265
+
requirements.txt CHANGED
@@ -3,4 +3,6 @@ opencv-python-headless>=4.10.0.84
3
  numpy>=1.26.0
4
  Pillow>=10.3.0
5
  mediapipe>=0.10.14
 
 
6
 
 
3
  numpy>=1.26.0
4
  Pillow>=10.3.0
5
  mediapipe>=0.10.14
6
+ torch>=2.2.0
7
+ transformers>=4.42.0
8