Spaces:
Running
Running
File size: 4,101 Bytes
0141af9 d7ab42c 2bb163a d7ab42c 5ab0443 d7ab42c e2eb2c5 d7ab42c e2eb2c5 5ab0443 d7ab42c 0141af9 d7ab42c 0141af9 e2eb2c5 d7ab42c e2eb2c5 0141af9 d7ab42c e2eb2c5 e8a4bf5 e2eb2c5 0141af9 d7ab42c 0141af9 e8a4bf5 e2eb2c5 0141af9 a93bbf1 e2eb2c5 0141af9 a93bbf1 e2eb2c5 5ab0443 e2eb2c5 0141af9 a93bbf1 0141af9 d7ab42c 0141af9 d7ab42c 0141af9 a93bbf1 d7ab42c 0141af9 e2eb2c5 0141af9 a93bbf1 211710a a93bbf1 e8a4bf5 211710a a93bbf1 e8a4bf5 0141af9 d7ab42c 0141af9 d7ab42c 5ab0443 0141af9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# bald_processor_clean.py
import cv2
import torch
import numpy as np
from PIL import Image, UnidentifiedImageError
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
try:
logger.info("Loading SegFormer face-parsing model...")
processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
model.eval()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
raise RuntimeError("SegFormer model load failed!")
hair_class_id = 13
ear_class_ids = [7, 8]
def make_realistic_bald(input_image: Image.Image) -> Image.Image:
if input_image is None:
raise ValueError("No input image provided!")
try:
orig_w, orig_h = input_image.size
original_np = np.array(input_image)
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
MAX_DIM = 2048
scale_factor = 1.0
working_np = original_np.copy()
working_bgr = original_bgr.copy()
working_h, working_w = orig_h, orig_w
if max(orig_w, orig_h) > MAX_DIM:
scale_factor = MAX_DIM / max(orig_w, orig_h)
working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
# Segmentation
pil_working = Image.fromarray(working_np)
inputs = processor(images=pil_working, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(
logits, size=(working_h, working_w),
mode="bilinear", align_corners=False
)
parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
# Hair mask
hair_mask = (parsing == hair_class_id).astype(np.uint8)
ears_mask = np.zeros_like(hair_mask)
for cls in ear_class_ids:
ears_mask[parsing == cls] = 1
hair_mask[ears_mask==1] = 0
# Morphology to clean
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8)
hair_pixels = np.sum(hair_mask)
if hair_pixels < 50:
raise ValueError("NO_HAIR_DETECTED")
# Inpainting (no extra noise, no blur)
radius = 15 if hair_pixels > 220000 else 10
flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
# Upscale bald area if needed
if scale_factor < 1.0:
bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
mask_up = cv2.resize(hair_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
else:
bald_up = inpainted_rgb
mask_up = hair_mask
# Composite only bald area, rest untouched
result = original_np.copy()
result[mask_up==1] = bald_up[mask_up==1]
return Image.fromarray(result)
except UnidentifiedImageError:
raise ValueError("Invalid image format or corrupt image!")
except Exception as e:
logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
raise RuntimeError(f"Bald processing failed: {str(e)}")
|