Space / bald_processor.py
Seniordev22's picture
Update bald_processor.py
a93bbf1 verified
# bald_processor_clean.py
import cv2
import torch
import numpy as np
from PIL import Image, UnidentifiedImageError
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
try:
logger.info("Loading SegFormer face-parsing model...")
processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
model.eval()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
raise RuntimeError("SegFormer model load failed!")
hair_class_id = 13
ear_class_ids = [7, 8]
def make_realistic_bald(input_image: Image.Image) -> Image.Image:
if input_image is None:
raise ValueError("No input image provided!")
try:
orig_w, orig_h = input_image.size
original_np = np.array(input_image)
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
MAX_DIM = 2048
scale_factor = 1.0
working_np = original_np.copy()
working_bgr = original_bgr.copy()
working_h, working_w = orig_h, orig_w
if max(orig_w, orig_h) > MAX_DIM:
scale_factor = MAX_DIM / max(orig_w, orig_h)
working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
# Segmentation
pil_working = Image.fromarray(working_np)
inputs = processor(images=pil_working, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(
logits, size=(working_h, working_w),
mode="bilinear", align_corners=False
)
parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
# Hair mask
hair_mask = (parsing == hair_class_id).astype(np.uint8)
ears_mask = np.zeros_like(hair_mask)
for cls in ear_class_ids:
ears_mask[parsing == cls] = 1
hair_mask[ears_mask==1] = 0
# Morphology to clean
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8)
hair_pixels = np.sum(hair_mask)
if hair_pixels < 50:
raise ValueError("NO_HAIR_DETECTED")
# Inpainting (no extra noise, no blur)
radius = 15 if hair_pixels > 220000 else 10
flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
# Upscale bald area if needed
if scale_factor < 1.0:
bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
mask_up = cv2.resize(hair_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
else:
bald_up = inpainted_rgb
mask_up = hair_mask
# Composite only bald area, rest untouched
result = original_np.copy()
result[mask_up==1] = bald_up[mask_up==1]
return Image.fromarray(result)
except UnidentifiedImageError:
raise ValueError("Invalid image format or corrupt image!")
except Exception as e:
logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
raise RuntimeError(f"Bald processing failed: {str(e)}")