from typing import Tuple, Optional, List, Dict import cv2 import gradio as gr import numpy as np from PIL import Image import torch from functools import lru_cache from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation import mediapipe as mp # MediaPipe is mandatory import warnings warnings.filterwarnings('ignore') def _ensure_rgb_uint8(image: np.ndarray) -> np.ndarray: """Convert an input image array to RGB uint8 format.""" if image is None: raise ValueError("No image provided") if isinstance(image, Image.Image): image = np.array(image.convert("RGB")) elif image.dtype != np.uint8: image = image.astype(np.uint8) if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) return image def _preprocess_image(image: np.ndarray) -> np.ndarray: """Preprocess image to improve face detection.""" rgb = _ensure_rgb_uint8(image) # Resize if image is too large or too small h, w = rgb.shape[:2] # If too large, resize down max_dim = 1024 if max(h, w) > max_dim: scale = max_dim / max(h, w) new_w = int(w * scale) new_h = int(h * scale) rgb = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_AREA) # If too small, resize up min_dim = 200 if min(h, w) < min_dim: scale = min_dim / min(h, w) new_w = int(w * scale) new_h = int(h * scale) rgb = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC) # Apply contrast enhancement if image is dark gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY) if np.mean(gray) < 50: # Too dark lab = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) l = clahe.apply(l) lab = cv2.merge((l, a, b)) rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) return rgb def _central_crop_bbox(width: int, height: int, frac: float = 0.6) -> Tuple[int, int, int, int]: """Return a central crop bounding box (x1, y1, x2, y2) covering `frac` of width/height.""" frac = float(np.clip(frac, 0.2, 1.0)) crop_w = int(width * frac) crop_h = int(height * frac) x1 = (width - crop_w) // 2 y1 = (height - crop_h) // 2 x2 = x1 + crop_w y2 = y1 + crop_h return x1, y1, x2, y2 def _detect_face_bbox_mediapipe(image_rgb: np.ndarray) -> Optional[Tuple[int, int, int, int]]: """Detect a face bounding box using MediaPipe Face Detection and return (x1, y1, x2, y2).""" try: height, width = image_rgb.shape[:2] # Initialize MediaPipe Face Detection mp_face_detection = mp.solutions.face_detection face_detection = mp_face_detection.FaceDetection( model_selection=1, # 1 for front-facing, 2 for full-range min_detection_confidence=0.3 # Lower confidence for better detection ) # Convert to BGR for MediaPipe (MediaPipe expects BGR) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) results = face_detection.process(image_bgr) face_detection.close() if not results.detections: return None # Get all detections detections = [] for detection in results.detections: bbox = detection.location_data.relative_bounding_box confidence = detection.score[0] # Convert normalized coordinates to pixel coordinates x = int(bbox.xmin * width) y = int(bbox.ymin * height) w = int(bbox.width * width) h = int(bbox.height * height) # Ensure coordinates are within image bounds x = max(0, x) y = max(0, y) w = min(width - x, w) h = min(height - y, h) if w > 0 and h > 0: detections.append({ 'bbox': (x, y, w, h), 'confidence': confidence }) if not detections: return None # Sort by confidence and pick the best detections.sort(key=lambda d: d['confidence'], reverse=True) best = detections[0] x, y, w, h = best['bbox'] # Expand the bounding box to include more context expand_x = int(w * 0.15) expand_y = int(h * 0.20) x1 = max(0, x - expand_x) y1 = max(0, y - expand_y) x2 = min(width, x + w + expand_x) y2 = min(height, y + h + expand_y) # Ensure minimum size if (x2 - x1) < 50 or (y2 - y1) < 50: # If too small, use central crop instead return None return x1, y1, x2, y2 except Exception as e: print(f"MediaPipe error: {e}") return None def _detect_face_bbox_opencv(image_rgb: np.ndarray) -> Optional[Tuple[int, int, int, int]]: """Fallback face detection using OpenCV Haar cascades.""" try: gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) # Load pre-trained Haar cascade cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' face_cascade = cv2.CascadeClassifier(cascade_path) if face_cascade.empty(): print("Haar cascade not loaded properly") return None # Detect faces faces = face_cascade.detectMultiScale( gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30), flags=cv2.CASCADE_SCALE_IMAGE ) if len(faces) == 0: return None # Get the largest face faces = sorted(faces, key=lambda f: f[2] * f[3], reverse=True) x, y, w, h = faces[0] # Expand bounding box expand_x = int(w * 0.15) expand_y = int(h * 0.20) height, width = image_rgb.shape[:2] x1 = max(0, x - expand_x) y1 = max(0, y - expand_y) x2 = min(width, x + w + expand_x) y2 = min(height, y + h + expand_y) return x1, y1, x2, y2 except Exception as e: print(f"OpenCV face detection error: {e}") return None def _binary_open_close(mask: np.ndarray, kernel_size: int = 5, iterations: int = 1) -> np.ndarray: """Apply morphological open then close to clean the binary mask.""" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=iterations) closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel, iterations=iterations) return closed @lru_cache(maxsize=1) def _load_face_parsing_model(): """Load face-parsing model and processor from the Hugging Face Hub (cached).""" model_id = "jonathandinu/face-parsing" processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModelForSemanticSegmentation.from_pretrained(model_id) model.eval() id2label: Dict[int, str] = model.config.id2label label2id: Dict[str, int] = model.config.label2id return processor, model, id2label, label2id def _segment_face_labels(image_rgb: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]: """Run face-parsing segmentation on an RGB crop. Returns (labels HxW int, id2label).""" processor, model, id2label, _ = _load_face_parsing_model() pil_img = Image.fromarray(image_rgb) # Resize if too large for the model max_size = 512 if max(pil_img.size) > max_size: scale = max_size / max(pil_img.size) new_size = (int(pil_img.size[0] * scale), int(pil_img.size[1] * scale)) pil_img = pil_img.resize(new_size, Image.Resampling.LANCZOS) inputs = processor(images=pil_img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Upsample to original image size upsampled = torch.nn.functional.interpolate( logits, size=pil_img.size[::-1], # (H, W) mode="bilinear", align_corners=False, ) labels = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.int32) return labels, id2label def _skin_indices_from_id2label(id2label: Dict[int, str]) -> List[int]: skin_indices: List[int] = [] for idx, name in id2label.items(): name_l = name.lower() if "skin" in name_l: skin_indices.append(int(idx)) elif "face" in name_l and "skin" not in name_l and "hair" not in name_l: skin_indices.append(int(idx)) # Default fallback indices (common in face-parsing models) if not skin_indices: # Try common skin class indices common_skin_indices = [1, 13, 14, 15] # These vary by model for idx in common_skin_indices: if idx in id2label: skin_indices.append(idx) return skin_indices def _compute_skin_color_hex(image_rgb: np.ndarray, mask: np.ndarray) -> Tuple[str, np.ndarray]: """Compute a robust representative skin color as a hex string and return also the RGB color.""" if mask is None or mask.size == 0: raise ValueError("Invalid mask for skin color computation") # boolean mask for indexing mask_bool = mask.astype(bool) if not np.any(mask_bool): raise ValueError("No skin pixels detected") skin_pixels = image_rgb[mask_bool] # Use median for robustness median_color = np.median(skin_pixels, axis=0) median_color = np.clip(median_color, 0, 255).astype(np.uint8) # Also compute mean for comparison mean_color = np.mean(skin_pixels, axis=0) mean_color = np.clip(mean_color, 0, 255).astype(np.uint8) # Use median as primary, but fall back to mean if median seems off if np.std(median_color) > 100: # If median has high variance color_rgb = mean_color else: color_rgb = median_color r, g, b = int(color_rgb[0]), int(color_rgb[1]), int(color_rgb[2]) hex_code = f"#{r:02X}{g:02X}{b:02X}" return hex_code, color_rgb def _solid_color_image(color_rgb: np.ndarray, size: Tuple[int, int] = (160, 160)) -> np.ndarray: swatch = np.zeros((size[1], size[0], 3), dtype=np.uint8) swatch[:, :] = color_rgb return swatch def detect_skin_tone(image: np.ndarray) -> Tuple[str, np.ndarray, np.ndarray]: """Main pipeline: returns (hex_code, color_swatch_image, debug_mask_overlay).""" try: # Preprocess image rgb = _preprocess_image(image) height, width = rgb.shape[:2] # Create debug image debug_img = rgb.copy() # Try multiple face detection methods face_bbox = None detection_method = "" # Method 1: MediaPipe (primary) face_bbox = _detect_face_bbox_mediapipe(rgb) if face_bbox is not None: detection_method = "MediaPipe" # Method 2: OpenCV Haar Cascade (fallback) if face_bbox is None: face_bbox = _detect_face_bbox_opencv(rgb) if face_bbox is not None: detection_method = "OpenCV Haar" # Method 3: Central crop (last resort) if face_bbox is None: face_bbox = _central_crop_bbox(width, height, frac=0.5) detection_method = "Central Crop" print(f"Warning: Using central crop as fallback") x1, y1, x2, y2 = face_bbox # Ensure bbox is valid and not too small if x2 <= x1 or y2 <= y1: raise ValueError("Invalid bounding box coordinates") if (x2 - x1) < 20 or (y2 - y1) < 20: raise ValueError("Face region too small") # Crop face region face_crop = rgb[y1:y2, x1:x2] if face_crop.size == 0: raise ValueError("Empty face crop") # Draw detection box on debug image color = (0, 255, 0) if detection_method != "Central Crop" else (255, 0, 0) cv2.rectangle(debug_img, (x1, y1), (x2, y2), color, 2) cv2.putText(debug_img, detection_method, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) # Face parsing segmentation to get skin mask try: labels, id2label = _segment_face_labels(face_crop) skin_indices = _skin_indices_from_id2label(id2label) if not skin_indices: # Create a simple central mask as fallback h, w = face_crop.shape[:2] skin_mask = np.zeros((h, w), dtype=np.uint8) center_y, center_x = h // 2, w // 2 mask_size = min(h, w) // 3 cv2.ellipse(skin_mask, (center_x, center_y), (mask_size, mask_size // 2), 0, 0, 360, 255, -1) else: skin_mask = np.isin(labels, np.array(skin_indices, dtype=np.int32)).astype(np.uint8) * 255 # Clean up the mask skin_mask = _binary_open_close(skin_mask, kernel_size=3, iterations=1) except Exception as e: print(f"Face parsing error: {e}") # Create a simple elliptical mask h, w = face_crop.shape[:2] skin_mask = np.zeros((h, w), dtype=np.uint8) center_y, center_x = h // 2, w // 2 mask_size = min(h, w) // 3 cv2.ellipse(skin_mask, (center_x, center_y), (mask_size, mask_size // 2), 0, 0, 360, 255, -1) # Ensure we have some skin pixels if np.sum(skin_mask) == 0: # Use entire face crop as fallback skin_mask = np.ones((face_crop.shape[0], face_crop.shape[1]), dtype=np.uint8) * 255 # Compute skin color hex_code, color_rgb = _compute_skin_color_hex(face_crop, skin_mask) # Prepare swatch swatch = _solid_color_image(color_rgb) # Create mask overlay for debug full_mask = np.zeros((height, width), dtype=np.uint8) full_mask[y1:y2, x1:x2] = skin_mask # Create colored mask color_mask = np.zeros_like(rgb) color_mask[:, :, 0] = 0 # Red channel color_mask[:, :, 1] = 255 # Green channel for skin mask color_mask[:, :, 2] = 0 # Blue channel # Apply mask mask_3d = np.stack([full_mask] * 3, axis=2) / 255.0 overlay = (rgb * (1 - mask_3d) + color_mask * mask_3d).astype(np.uint8) # Add hex code to debug image cv2.putText(debug_img, hex_code, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) cv2.putText(debug_img, hex_code, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1) return hex_code, swatch, debug_img except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) # Return error state error_color = np.array([255, 0, 0], dtype=np.uint8) # Red for error error_hex = "#FF0000" error_swatch = _solid_color_image(error_color) # Create error debug image if 'rgb' in locals(): error_debug = rgb.copy() else: error_debug = np.zeros((300, 300, 3), dtype=np.uint8) error_debug[:] = [100, 100, 100] cv2.putText(error_debug, "ERROR", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3) cv2.putText(error_debug, error_msg[:30], (50, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) return error_hex, error_swatch, error_debug def _hex_html(hex_code: str) -> str: style = ( "display:flex;align-items:center;gap:12px;padding:8px 0;" ) swatch_style = ( f"width:24px;height:24px;border-radius:4px;background:{hex_code};" "border:2px solid #333;box-shadow:2px 2px 5px rgba(0,0,0,0.2);" ) return ( f"
{str(e)[:100]}...
Please try a different image.