import os import cv2 import torch import numpy as np from PIL import Image, UnidentifiedImageError from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation import io import traceback # Globals for lazy loading (no global load at import time) device = "cuda" if torch.cuda.is_available() else "cpu" processor = None model = None def load_model(): global processor, model if model is None: print(f"Using device: {device} | CUDA available: {torch.cuda.is_available()}") print("Loading SegFormer face-parsing model...") try: processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") model.to(device) model.eval() print("Model loaded successfully!") except Exception as e: print("CRITICAL: Model loading failed!") traceback.print_exc() raise RuntimeError(f"Model loading failed: {str(e)}") return processor, model hair_class_id = 13 ear_class_ids = [8, 9] # l_ear=8, r_ear=9 skin_class_id = 1 nose_class_id = 2 def make_realistic_bald(image_bytes: bytes) -> bytes: # Load model only when needed processor, model = load_model() try: # Open image safely try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except UnidentifiedImageError: raise ValueError("Invalid image format or corrupt bytes") except Exception as e: raise ValueError(f"Image open failed: {str(e)}") orig_w, orig_h = image.size original_np = np.array(image) original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR) # Resize if large MAX_PROCESS_DIM = 2048 scale_factor = 1.0 working_np = original_np working_bgr = original_bgr working_h, working_w = orig_h, orig_w if max(orig_w, orig_h) > MAX_PROCESS_DIM: scale_factor = MAX_PROCESS_DIM / max(orig_w, orig_h) working_w = int(orig_w * scale_factor) working_h = int(orig_h * scale_factor) working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA) working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR) # Segmentation pil_working = Image.fromarray(working_np) inputs = processor(images=pil_working, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled_logits = torch.nn.functional.interpolate( logits, size=(working_h, working_w), mode="bilinear", align_corners=False # Fixed: capital F ) parsing = upsampled_logits.argmax(dim=1).squeeze(0).cpu().numpy() # Skin mask skin_mask = (parsing == skin_class_id).astype(np.uint8) # IMPROVED Forehead region forehead_fraction_top = 0.25 forehead_fraction_bottom = 0.38 forehead_fraction_left = 0.38 forehead_fraction_right = 0.62 h, w = parsing.shape forehead_y_start = max(0, int(h * forehead_fraction_top)) forehead_y_end = min(h, int(h * forehead_fraction_bottom)) forehead_x_start = max(0, int(w * forehead_fraction_left)) forehead_x_end = min(w, int(w * forehead_fraction_right)) forehead_region = original_np[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end] forehead_skin_mask = skin_mask[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end] mean_color_rgb = np.array([210, 185, 170]) # Lighter neutral fallback try: if forehead_region.size > 0 and np.sum(forehead_skin_mask) > 80: skin_pixels = forehead_region[forehead_skin_mask == 1] if len(skin_pixels) > 30: brightness = np.mean(skin_pixels.astype(float), axis=1) thresh = np.percentile(brightness, 70) bright_pixels = skin_pixels[brightness > thresh] if len(bright_pixels) > 20: mean_color_rgb = np.mean(bright_pixels, axis=0).astype(int) else: mean_color_rgb = np.mean(skin_pixels, axis=0).astype(int) else: mean_color_rgb = np.mean(forehead_region, axis=(0,1)).astype(int) else: # Fallback 1: Nose nose_mask = (parsing == nose_class_id).astype(np.uint8) nose_pixels = original_np[nose_mask == 1] if len(nose_pixels) > 50: mean_color_rgb = np.mean(nose_pixels, axis=0).astype(int) else: # Fallback 2: Full skin skin_pixels_full = original_np[skin_mask == 1] if len(skin_pixels_full) > 100: mean_color_rgb = np.mean(skin_pixels_full, axis=0).astype(int) except Exception as skin_err: print("Skin detection error (fallback used): " + str(skin_err)) # Make detected skin color 30% brighter mean_color_rgb = np.array(mean_color_rgb, dtype=float) brightness_factor = 1.30 mean_color_rgb = np.clip(mean_color_rgb * brightness_factor, 0, 255).astype(int) # Print adjusted color (optional debug) hex_color = '#%02x%02x%02x' % tuple(mean_color_rgb) print("Adjusted (30% brighter) skin color → RGB: " + str(mean_color_rgb.tolist()) + " | Hex: " + hex_color) # Hair and ears masks hair_mask = (parsing == hair_class_id).astype(np.uint8) ears_mask = np.zeros_like(hair_mask, dtype=np.uint8) for cls in ear_class_ids: ears_mask[parsing == cls] = 1 ears_protected = np.zeros_like(hair_mask, dtype=np.uint8) ear_y, ear_x = np.where(ears_mask > 0) left, right = 0, 0 if len(ear_y) > 0: ear_top_y = ear_y.min() ear_x_min = ear_x.min() ear_x_max = ear_x.max() ear_width = ear_x_max - ear_x_min + 1 kernel_protect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 9)) ears_protected = cv2.dilate(ears_mask, kernel_protect, iterations=1) if ear_top_y > 10: ears_protected[:ear_top_y - 8, :] = 0 x_margin = int(ear_width * 0.25) left = max(0, ear_x_min - x_margin) right = min(working_w, ear_x_max + x_margin) hair_mask_final = hair_mask.copy() hair_mask_final[ears_protected == 1] = 0 top_quarter = int(working_h * 0.25) if hair_mask[:top_quarter, :].sum() > 60: hair_mask_final[:top_quarter, :] = np.maximum( hair_mask_final[:top_quarter, :], hair_mask[:top_quarter, :] ) kernel_s = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) hair_mask_final = cv2.morphologyEx(hair_mask_final, cv2.MORPH_CLOSE, kernel_s, iterations=2) hair_mask_final = cv2.dilate(hair_mask_final, kernel_s, iterations=1) blurred = cv2.GaussianBlur(hair_mask_final.astype(np.float32), (9, 9), 3) hair_mask_final = (blurred > 0.28).astype(np.uint8) kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) hair_mask_final = cv2.dilate(hair_mask_final, kernel_edge, iterations=1) hair_pixels = np.sum(hair_mask_final) final_mask = hair_mask_final.copy() use_extended_mask = False # Fixed: capital F if hair_pixels > 380000: use_extended_mask = True big_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25)) extended = cv2.dilate(hair_mask_final, big_kernel, iterations=1) upper = np.zeros_like(hair_mask_final) upper_end = int(working_h * 0.48) upper[:upper_end, :] = 1 extended = np.logical_or(extended, upper).astype(np.uint8) extended[ears_protected == 1] = 0 extended = cv2.morphologyEx(extended, cv2.MORPH_CLOSE, kernel_s, iterations=1) extended[int(working_h * 0.75):, :] = 0 if use_extended_mask or hair_pixels > 420000: final_mask = extended if use_extended_mask or hair_pixels > 420000: radius = 18 inpaint_flag = cv2.INPAINT_TELEA elif hair_pixels > 220000: radius = 15 inpaint_flag = cv2.INPAINT_TELEA else: radius = 10 inpaint_flag = cv2.INPAINT_NS inpainted_bgr = cv2.inpaint(working_bgr, final_mask * 255, inpaintRadius=radius, flags=inpaint_flag) inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB) # Add realistic bald head skin texture pores_noise = np.random.normal(0, 12, (working_h, working_w, 3)).astype(np.float32) large_kernel = cv2.getGaussianKernel(61, 20) large_var = cv2.filter2D(pores_noise, -1, large_kernel) * 0.5 texture_noise = pores_noise * 0.7 + large_var texture_noise = np.clip(texture_noise, -25, 25) textured_area = inpainted_rgb.astype(np.float32) + texture_noise textured_area = np.clip(textured_area, 0, 255).astype(np.uint8) blend_factor = 0.75 blended_bald = (blend_factor * textured_area + (1 - blend_factor) * inpainted_rgb).astype(np.uint8) result_small = working_np.copy() result_small[final_mask == 1] = blended_bald[final_mask == 1] if len(ear_x) > 0: side_clean_left = max(0, left - 30) side_clean_right = min(working_w, right + 30) final_mask[:, side_clean_left:side_clean_right] = np.minimum( final_mask[:, side_clean_left:side_clean_right], 1 - ears_protected[:, side_clean_left:side_clean_right] ) if scale_factor < 1.0: result = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) else: result = result_small output_bytes = io.BytesIO() Image.fromarray(result).save(output_bytes, format="JPEG") output_bytes.seek(0) return output_bytes.read() except Exception as main_err: print("ERROR in make_realistic_bald:") traceback.print_exc() raise RuntimeError("Bald processing failed: " + str(main_err))