Spaces:
Running
Running
| # bald_processor_clean.py | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| try: | |
| logger.info("Loading SegFormer face-parsing model...") | |
| processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") | |
| model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") | |
| model.to(device) | |
| model.eval() | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}", exc_info=True) | |
| raise RuntimeError("SegFormer model load failed!") | |
| hair_class_id = 13 | |
| ear_class_ids = [7, 8] | |
| def make_realistic_bald(input_image: Image.Image) -> Image.Image: | |
| if input_image is None: | |
| raise ValueError("No input image provided!") | |
| try: | |
| orig_w, orig_h = input_image.size | |
| original_np = np.array(input_image) | |
| original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR) | |
| MAX_DIM = 2048 | |
| scale_factor = 1.0 | |
| working_np = original_np.copy() | |
| working_bgr = original_bgr.copy() | |
| working_h, working_w = orig_h, orig_w | |
| if max(orig_w, orig_h) > MAX_DIM: | |
| scale_factor = MAX_DIM / max(orig_w, orig_h) | |
| working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor) | |
| working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA) | |
| working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR) | |
| # Segmentation | |
| pil_working = Image.fromarray(working_np) | |
| inputs = processor(images=pil_working, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled = torch.nn.functional.interpolate( | |
| logits, size=(working_h, working_w), | |
| mode="bilinear", align_corners=False | |
| ) | |
| parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() | |
| # Hair mask | |
| hair_mask = (parsing == hair_class_id).astype(np.uint8) | |
| ears_mask = np.zeros_like(hair_mask) | |
| for cls in ear_class_ids: | |
| ears_mask[parsing == cls] = 1 | |
| hair_mask[ears_mask==1] = 0 | |
| # Morphology to clean | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13)) | |
| hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| hair_mask = cv2.dilate(hair_mask, kernel, iterations=1) | |
| hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8) | |
| hair_pixels = np.sum(hair_mask) | |
| if hair_pixels < 50: | |
| raise ValueError("NO_HAIR_DETECTED") | |
| # Inpainting (no extra noise, no blur) | |
| radius = 15 if hair_pixels > 220000 else 10 | |
| flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS | |
| inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag) | |
| inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB) | |
| # Upscale bald area if needed | |
| if scale_factor < 1.0: | |
| bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) | |
| mask_up = cv2.resize(hair_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
| else: | |
| bald_up = inpainted_rgb | |
| mask_up = hair_mask | |
| # Composite only bald area, rest untouched | |
| result = original_np.copy() | |
| result[mask_up==1] = bald_up[mask_up==1] | |
| return Image.fromarray(result) | |
| except UnidentifiedImageError: | |
| raise ValueError("Invalid image format or corrupt image!") | |
| except Exception as e: | |
| logger.error(f"Bald processing failed: {str(e)}", exc_info=True) | |
| raise RuntimeError(f"Bald processing failed: {str(e)}") | |