Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| import io | |
| import traceback | |
| # Globals for lazy loading (no global load at import time) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = None | |
| model = None | |
| def load_model(): | |
| global processor, model | |
| if model is None: | |
| print(f"Using device: {device} | CUDA available: {torch.cuda.is_available()}") | |
| print("Loading SegFormer face-parsing model...") | |
| try: | |
| processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") | |
| model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") | |
| model.to(device) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print("CRITICAL: Model loading failed!") | |
| traceback.print_exc() | |
| raise RuntimeError(f"Model loading failed: {str(e)}") | |
| return processor, model | |
| hair_class_id = 13 | |
| ear_class_ids = [8, 9] # l_ear=8, r_ear=9 | |
| skin_class_id = 1 | |
| nose_class_id = 2 | |
| def make_realistic_bald(image_bytes: bytes) -> bytes: | |
| # Load model only when needed | |
| processor, model = load_model() | |
| try: | |
| # Open image safely | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except UnidentifiedImageError: | |
| raise ValueError("Invalid image format or corrupt bytes") | |
| except Exception as e: | |
| raise ValueError(f"Image open failed: {str(e)}") | |
| orig_w, orig_h = image.size | |
| original_np = np.array(image) | |
| original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR) | |
| # Resize if large | |
| MAX_PROCESS_DIM = 2048 | |
| scale_factor = 1.0 | |
| working_np = original_np | |
| working_bgr = original_bgr | |
| working_h, working_w = orig_h, orig_w | |
| if max(orig_w, orig_h) > MAX_PROCESS_DIM: | |
| scale_factor = MAX_PROCESS_DIM / max(orig_w, orig_h) | |
| working_w = int(orig_w * scale_factor) | |
| working_h = int(orig_h * scale_factor) | |
| working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA) | |
| working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR) | |
| # Segmentation | |
| pil_working = Image.fromarray(working_np) | |
| inputs = processor(images=pil_working, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled_logits = torch.nn.functional.interpolate( | |
| logits, | |
| size=(working_h, working_w), | |
| mode="bilinear", | |
| align_corners=False # Fixed: capital F | |
| ) | |
| parsing = upsampled_logits.argmax(dim=1).squeeze(0).cpu().numpy() | |
| # Skin mask | |
| skin_mask = (parsing == skin_class_id).astype(np.uint8) | |
| # IMPROVED Forehead region | |
| forehead_fraction_top = 0.25 | |
| forehead_fraction_bottom = 0.38 | |
| forehead_fraction_left = 0.38 | |
| forehead_fraction_right = 0.62 | |
| h, w = parsing.shape | |
| forehead_y_start = max(0, int(h * forehead_fraction_top)) | |
| forehead_y_end = min(h, int(h * forehead_fraction_bottom)) | |
| forehead_x_start = max(0, int(w * forehead_fraction_left)) | |
| forehead_x_end = min(w, int(w * forehead_fraction_right)) | |
| forehead_region = original_np[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end] | |
| forehead_skin_mask = skin_mask[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end] | |
| mean_color_rgb = np.array([210, 185, 170]) # Lighter neutral fallback | |
| try: | |
| if forehead_region.size > 0 and np.sum(forehead_skin_mask) > 80: | |
| skin_pixels = forehead_region[forehead_skin_mask == 1] | |
| if len(skin_pixels) > 30: | |
| brightness = np.mean(skin_pixels.astype(float), axis=1) | |
| thresh = np.percentile(brightness, 70) | |
| bright_pixels = skin_pixels[brightness > thresh] | |
| if len(bright_pixels) > 20: | |
| mean_color_rgb = np.mean(bright_pixels, axis=0).astype(int) | |
| else: | |
| mean_color_rgb = np.mean(skin_pixels, axis=0).astype(int) | |
| else: | |
| mean_color_rgb = np.mean(forehead_region, axis=(0,1)).astype(int) | |
| else: | |
| # Fallback 1: Nose | |
| nose_mask = (parsing == nose_class_id).astype(np.uint8) | |
| nose_pixels = original_np[nose_mask == 1] | |
| if len(nose_pixels) > 50: | |
| mean_color_rgb = np.mean(nose_pixels, axis=0).astype(int) | |
| else: | |
| # Fallback 2: Full skin | |
| skin_pixels_full = original_np[skin_mask == 1] | |
| if len(skin_pixels_full) > 100: | |
| mean_color_rgb = np.mean(skin_pixels_full, axis=0).astype(int) | |
| except Exception as skin_err: | |
| print("Skin detection error (fallback used): " + str(skin_err)) | |
| # Make detected skin color 30% brighter | |
| mean_color_rgb = np.array(mean_color_rgb, dtype=float) | |
| brightness_factor = 1.30 | |
| mean_color_rgb = np.clip(mean_color_rgb * brightness_factor, 0, 255).astype(int) | |
| # Print adjusted color (optional debug) | |
| hex_color = '#%02x%02x%02x' % tuple(mean_color_rgb) | |
| print("Adjusted (30% brighter) skin color → RGB: " + str(mean_color_rgb.tolist()) + " | Hex: " + hex_color) | |
| # Hair and ears masks | |
| hair_mask = (parsing == hair_class_id).astype(np.uint8) | |
| ears_mask = np.zeros_like(hair_mask, dtype=np.uint8) | |
| for cls in ear_class_ids: | |
| ears_mask[parsing == cls] = 1 | |
| ears_protected = np.zeros_like(hair_mask, dtype=np.uint8) | |
| ear_y, ear_x = np.where(ears_mask > 0) | |
| left, right = 0, 0 | |
| if len(ear_y) > 0: | |
| ear_top_y = ear_y.min() | |
| ear_x_min = ear_x.min() | |
| ear_x_max = ear_x.max() | |
| ear_width = ear_x_max - ear_x_min + 1 | |
| kernel_protect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 9)) | |
| ears_protected = cv2.dilate(ears_mask, kernel_protect, iterations=1) | |
| if ear_top_y > 10: | |
| ears_protected[:ear_top_y - 8, :] = 0 | |
| x_margin = int(ear_width * 0.25) | |
| left = max(0, ear_x_min - x_margin) | |
| right = min(working_w, ear_x_max + x_margin) | |
| hair_mask_final = hair_mask.copy() | |
| hair_mask_final[ears_protected == 1] = 0 | |
| top_quarter = int(working_h * 0.25) | |
| if hair_mask[:top_quarter, :].sum() > 60: | |
| hair_mask_final[:top_quarter, :] = np.maximum( | |
| hair_mask_final[:top_quarter, :], hair_mask[:top_quarter, :] | |
| ) | |
| kernel_s = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) | |
| hair_mask_final = cv2.morphologyEx(hair_mask_final, cv2.MORPH_CLOSE, kernel_s, iterations=2) | |
| hair_mask_final = cv2.dilate(hair_mask_final, kernel_s, iterations=1) | |
| blurred = cv2.GaussianBlur(hair_mask_final.astype(np.float32), (9, 9), 3) | |
| hair_mask_final = (blurred > 0.28).astype(np.uint8) | |
| kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| hair_mask_final = cv2.dilate(hair_mask_final, kernel_edge, iterations=1) | |
| hair_pixels = np.sum(hair_mask_final) | |
| final_mask = hair_mask_final.copy() | |
| use_extended_mask = False # Fixed: capital F | |
| if hair_pixels > 380000: | |
| use_extended_mask = True | |
| big_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25)) | |
| extended = cv2.dilate(hair_mask_final, big_kernel, iterations=1) | |
| upper = np.zeros_like(hair_mask_final) | |
| upper_end = int(working_h * 0.48) | |
| upper[:upper_end, :] = 1 | |
| extended = np.logical_or(extended, upper).astype(np.uint8) | |
| extended[ears_protected == 1] = 0 | |
| extended = cv2.morphologyEx(extended, cv2.MORPH_CLOSE, kernel_s, iterations=1) | |
| extended[int(working_h * 0.75):, :] = 0 | |
| if use_extended_mask or hair_pixels > 420000: | |
| final_mask = extended | |
| if use_extended_mask or hair_pixels > 420000: | |
| radius = 18 | |
| inpaint_flag = cv2.INPAINT_TELEA | |
| elif hair_pixels > 220000: | |
| radius = 15 | |
| inpaint_flag = cv2.INPAINT_TELEA | |
| else: | |
| radius = 10 | |
| inpaint_flag = cv2.INPAINT_NS | |
| inpainted_bgr = cv2.inpaint(working_bgr, final_mask * 255, inpaintRadius=radius, flags=inpaint_flag) | |
| inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB) | |
| # Add realistic bald head skin texture | |
| pores_noise = np.random.normal(0, 12, (working_h, working_w, 3)).astype(np.float32) | |
| large_kernel = cv2.getGaussianKernel(61, 20) | |
| large_var = cv2.filter2D(pores_noise, -1, large_kernel) * 0.5 | |
| texture_noise = pores_noise * 0.7 + large_var | |
| texture_noise = np.clip(texture_noise, -25, 25) | |
| textured_area = inpainted_rgb.astype(np.float32) + texture_noise | |
| textured_area = np.clip(textured_area, 0, 255).astype(np.uint8) | |
| blend_factor = 0.75 | |
| blended_bald = (blend_factor * textured_area + (1 - blend_factor) * inpainted_rgb).astype(np.uint8) | |
| result_small = working_np.copy() | |
| result_small[final_mask == 1] = blended_bald[final_mask == 1] | |
| if len(ear_x) > 0: | |
| side_clean_left = max(0, left - 30) | |
| side_clean_right = min(working_w, right + 30) | |
| final_mask[:, side_clean_left:side_clean_right] = np.minimum( | |
| final_mask[:, side_clean_left:side_clean_right], | |
| 1 - ears_protected[:, side_clean_left:side_clean_right] | |
| ) | |
| if scale_factor < 1.0: | |
| result = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) | |
| else: | |
| result = result_small | |
| output_bytes = io.BytesIO() | |
| Image.fromarray(result).save(output_bytes, format="JPEG") | |
| output_bytes.seek(0) | |
| return output_bytes.read() | |
| except Exception as main_err: | |
| print("ERROR in make_realistic_bald:") | |
| traceback.print_exc() | |
| raise RuntimeError("Bald processing failed: " + str(main_err)) |