# bald_processor_clean.py import cv2 import torch import numpy as np from PIL import Image, UnidentifiedImageError from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") try: logger.info("Loading SegFormer face-parsing model...") processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") model.to(device) model.eval() logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}", exc_info=True) raise RuntimeError("SegFormer model load failed!") hair_class_id = 13 ear_class_ids = [7, 8] def make_realistic_bald(input_image: Image.Image) -> Image.Image: if input_image is None: raise ValueError("No input image provided!") try: orig_w, orig_h = input_image.size original_np = np.array(input_image) original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR) MAX_DIM = 2048 scale_factor = 1.0 working_np = original_np.copy() working_bgr = original_bgr.copy() working_h, working_w = orig_h, orig_w if max(orig_w, orig_h) > MAX_DIM: scale_factor = MAX_DIM / max(orig_w, orig_h) working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor) working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA) working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR) # Segmentation pil_working = Image.fromarray(working_np) inputs = processor(images=pil_working, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled = torch.nn.functional.interpolate( logits, size=(working_h, working_w), mode="bilinear", align_corners=False ) parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() # Hair mask hair_mask = (parsing == hair_class_id).astype(np.uint8) ears_mask = np.zeros_like(hair_mask) for cls in ear_class_ids: ears_mask[parsing == cls] = 1 hair_mask[ears_mask==1] = 0 # Morphology to clean kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13)) hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2) hair_mask = cv2.dilate(hair_mask, kernel, iterations=1) hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8) hair_pixels = np.sum(hair_mask) if hair_pixels < 50: raise ValueError("NO_HAIR_DETECTED") # Inpainting (no extra noise, no blur) radius = 15 if hair_pixels > 220000 else 10 flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag) inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB) # Upscale bald area if needed if scale_factor < 1.0: bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) mask_up = cv2.resize(hair_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) else: bald_up = inpainted_rgb mask_up = hair_mask # Composite only bald area, rest untouched result = original_np.copy() result[mask_up==1] = bald_up[mask_up==1] return Image.fromarray(result) except UnidentifiedImageError: raise ValueError("Invalid image format or corrupt image!") except Exception as e: logger.error(f"Bald processing failed: {str(e)}", exc_info=True) raise RuntimeError(f"Bald processing failed: {str(e)}")