make-me-bald / bald_processor.py
Seniordev22's picture
Update bald_processor.py
dac57bb verified
import os
import cv2
import torch
import numpy as np
from PIL import Image, UnidentifiedImageError
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import io
import traceback
# Globals for lazy loading (no global load at import time)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = None
model = None
def load_model():
global processor, model
if model is None:
print(f"Using device: {device} | CUDA available: {torch.cuda.is_available()}")
print("Loading SegFormer face-parsing model...")
try:
processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print("CRITICAL: Model loading failed!")
traceback.print_exc()
raise RuntimeError(f"Model loading failed: {str(e)}")
return processor, model
hair_class_id = 13
ear_class_ids = [8, 9] # l_ear=8, r_ear=9
skin_class_id = 1
nose_class_id = 2
def make_realistic_bald(image_bytes: bytes) -> bytes:
# Load model only when needed
processor, model = load_model()
try:
# Open image safely
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except UnidentifiedImageError:
raise ValueError("Invalid image format or corrupt bytes")
except Exception as e:
raise ValueError(f"Image open failed: {str(e)}")
orig_w, orig_h = image.size
original_np = np.array(image)
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
# Resize if large
MAX_PROCESS_DIM = 2048
scale_factor = 1.0
working_np = original_np
working_bgr = original_bgr
working_h, working_w = orig_h, orig_w
if max(orig_w, orig_h) > MAX_PROCESS_DIM:
scale_factor = MAX_PROCESS_DIM / max(orig_w, orig_h)
working_w = int(orig_w * scale_factor)
working_h = int(orig_h * scale_factor)
working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
# Segmentation
pil_working = Image.fromarray(working_np)
inputs = processor(images=pil_working, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=(working_h, working_w),
mode="bilinear",
align_corners=False # Fixed: capital F
)
parsing = upsampled_logits.argmax(dim=1).squeeze(0).cpu().numpy()
# Skin mask
skin_mask = (parsing == skin_class_id).astype(np.uint8)
# IMPROVED Forehead region
forehead_fraction_top = 0.25
forehead_fraction_bottom = 0.38
forehead_fraction_left = 0.38
forehead_fraction_right = 0.62
h, w = parsing.shape
forehead_y_start = max(0, int(h * forehead_fraction_top))
forehead_y_end = min(h, int(h * forehead_fraction_bottom))
forehead_x_start = max(0, int(w * forehead_fraction_left))
forehead_x_end = min(w, int(w * forehead_fraction_right))
forehead_region = original_np[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end]
forehead_skin_mask = skin_mask[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end]
mean_color_rgb = np.array([210, 185, 170]) # Lighter neutral fallback
try:
if forehead_region.size > 0 and np.sum(forehead_skin_mask) > 80:
skin_pixels = forehead_region[forehead_skin_mask == 1]
if len(skin_pixels) > 30:
brightness = np.mean(skin_pixels.astype(float), axis=1)
thresh = np.percentile(brightness, 70)
bright_pixels = skin_pixels[brightness > thresh]
if len(bright_pixels) > 20:
mean_color_rgb = np.mean(bright_pixels, axis=0).astype(int)
else:
mean_color_rgb = np.mean(skin_pixels, axis=0).astype(int)
else:
mean_color_rgb = np.mean(forehead_region, axis=(0,1)).astype(int)
else:
# Fallback 1: Nose
nose_mask = (parsing == nose_class_id).astype(np.uint8)
nose_pixels = original_np[nose_mask == 1]
if len(nose_pixels) > 50:
mean_color_rgb = np.mean(nose_pixels, axis=0).astype(int)
else:
# Fallback 2: Full skin
skin_pixels_full = original_np[skin_mask == 1]
if len(skin_pixels_full) > 100:
mean_color_rgb = np.mean(skin_pixels_full, axis=0).astype(int)
except Exception as skin_err:
print("Skin detection error (fallback used): " + str(skin_err))
# Make detected skin color 30% brighter
mean_color_rgb = np.array(mean_color_rgb, dtype=float)
brightness_factor = 1.30
mean_color_rgb = np.clip(mean_color_rgb * brightness_factor, 0, 255).astype(int)
# Print adjusted color (optional debug)
hex_color = '#%02x%02x%02x' % tuple(mean_color_rgb)
print("Adjusted (30% brighter) skin color → RGB: " + str(mean_color_rgb.tolist()) + " | Hex: " + hex_color)
# Hair and ears masks
hair_mask = (parsing == hair_class_id).astype(np.uint8)
ears_mask = np.zeros_like(hair_mask, dtype=np.uint8)
for cls in ear_class_ids:
ears_mask[parsing == cls] = 1
ears_protected = np.zeros_like(hair_mask, dtype=np.uint8)
ear_y, ear_x = np.where(ears_mask > 0)
left, right = 0, 0
if len(ear_y) > 0:
ear_top_y = ear_y.min()
ear_x_min = ear_x.min()
ear_x_max = ear_x.max()
ear_width = ear_x_max - ear_x_min + 1
kernel_protect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 9))
ears_protected = cv2.dilate(ears_mask, kernel_protect, iterations=1)
if ear_top_y > 10:
ears_protected[:ear_top_y - 8, :] = 0
x_margin = int(ear_width * 0.25)
left = max(0, ear_x_min - x_margin)
right = min(working_w, ear_x_max + x_margin)
hair_mask_final = hair_mask.copy()
hair_mask_final[ears_protected == 1] = 0
top_quarter = int(working_h * 0.25)
if hair_mask[:top_quarter, :].sum() > 60:
hair_mask_final[:top_quarter, :] = np.maximum(
hair_mask_final[:top_quarter, :], hair_mask[:top_quarter, :]
)
kernel_s = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
hair_mask_final = cv2.morphologyEx(hair_mask_final, cv2.MORPH_CLOSE, kernel_s, iterations=2)
hair_mask_final = cv2.dilate(hair_mask_final, kernel_s, iterations=1)
blurred = cv2.GaussianBlur(hair_mask_final.astype(np.float32), (9, 9), 3)
hair_mask_final = (blurred > 0.28).astype(np.uint8)
kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
hair_mask_final = cv2.dilate(hair_mask_final, kernel_edge, iterations=1)
hair_pixels = np.sum(hair_mask_final)
final_mask = hair_mask_final.copy()
use_extended_mask = False # Fixed: capital F
if hair_pixels > 380000:
use_extended_mask = True
big_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25))
extended = cv2.dilate(hair_mask_final, big_kernel, iterations=1)
upper = np.zeros_like(hair_mask_final)
upper_end = int(working_h * 0.48)
upper[:upper_end, :] = 1
extended = np.logical_or(extended, upper).astype(np.uint8)
extended[ears_protected == 1] = 0
extended = cv2.morphologyEx(extended, cv2.MORPH_CLOSE, kernel_s, iterations=1)
extended[int(working_h * 0.75):, :] = 0
if use_extended_mask or hair_pixels > 420000:
final_mask = extended
if use_extended_mask or hair_pixels > 420000:
radius = 18
inpaint_flag = cv2.INPAINT_TELEA
elif hair_pixels > 220000:
radius = 15
inpaint_flag = cv2.INPAINT_TELEA
else:
radius = 10
inpaint_flag = cv2.INPAINT_NS
inpainted_bgr = cv2.inpaint(working_bgr, final_mask * 255, inpaintRadius=radius, flags=inpaint_flag)
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
# Add realistic bald head skin texture
pores_noise = np.random.normal(0, 12, (working_h, working_w, 3)).astype(np.float32)
large_kernel = cv2.getGaussianKernel(61, 20)
large_var = cv2.filter2D(pores_noise, -1, large_kernel) * 0.5
texture_noise = pores_noise * 0.7 + large_var
texture_noise = np.clip(texture_noise, -25, 25)
textured_area = inpainted_rgb.astype(np.float32) + texture_noise
textured_area = np.clip(textured_area, 0, 255).astype(np.uint8)
blend_factor = 0.75
blended_bald = (blend_factor * textured_area + (1 - blend_factor) * inpainted_rgb).astype(np.uint8)
result_small = working_np.copy()
result_small[final_mask == 1] = blended_bald[final_mask == 1]
if len(ear_x) > 0:
side_clean_left = max(0, left - 30)
side_clean_right = min(working_w, right + 30)
final_mask[:, side_clean_left:side_clean_right] = np.minimum(
final_mask[:, side_clean_left:side_clean_right],
1 - ears_protected[:, side_clean_left:side_clean_right]
)
if scale_factor < 1.0:
result = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
else:
result = result_small
output_bytes = io.BytesIO()
Image.fromarray(result).save(output_bytes, format="JPEG")
output_bytes.seek(0)
return output_bytes.read()
except Exception as main_err:
print("ERROR in make_realistic_bald:")
traceback.print_exc()
raise RuntimeError("Bald processing failed: " + str(main_err))