File size: 4,101 Bytes
0141af9
d7ab42c
2bb163a
d7ab42c
 
 
 
 
5ab0443
 
d7ab42c
e2eb2c5
d7ab42c
 
 
 
e2eb2c5
 
 
 
5ab0443
d7ab42c
 
 
 
 
0141af9
d7ab42c
 
 
 
0141af9
e2eb2c5
d7ab42c
 
e2eb2c5
0141af9
d7ab42c
e2eb2c5
e8a4bf5
 
e2eb2c5
0141af9
d7ab42c
 
0141af9
e8a4bf5
e2eb2c5
0141af9
a93bbf1
e2eb2c5
 
 
 
 
0141af9
 
 
 
 
 
 
a93bbf1
e2eb2c5
5ab0443
e2eb2c5
 
0141af9
 
a93bbf1
0141af9
d7ab42c
 
0141af9
 
d7ab42c
 
 
0141af9
a93bbf1
d7ab42c
 
0141af9
e2eb2c5
0141af9
a93bbf1
211710a
a93bbf1
e8a4bf5
211710a
a93bbf1
e8a4bf5
0141af9
 
 
 
 
d7ab42c
0141af9
d7ab42c
 
5ab0443
 
0141af9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# bald_processor_clean.py
import cv2
import torch
import numpy as np
from PIL import Image, UnidentifiedImageError
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

try:
    logger.info("Loading SegFormer face-parsing model...")
    processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
    model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
    model.to(device)
    model.eval()
    logger.info("Model loaded successfully!")
except Exception as e:
    logger.error(f"Failed to load model: {e}", exc_info=True)
    raise RuntimeError("SegFormer model load failed!")

hair_class_id = 13
ear_class_ids = [7, 8]

def make_realistic_bald(input_image: Image.Image) -> Image.Image:
    if input_image is None:
        raise ValueError("No input image provided!")

    try:
        orig_w, orig_h = input_image.size
        original_np = np.array(input_image)
        original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)

        MAX_DIM = 2048
        scale_factor = 1.0
        working_np = original_np.copy()
        working_bgr = original_bgr.copy()
        working_h, working_w = orig_h, orig_w

        if max(orig_w, orig_h) > MAX_DIM:
            scale_factor = MAX_DIM / max(orig_w, orig_h)
            working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
            working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
            working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)

        # Segmentation
        pil_working = Image.fromarray(working_np)
        inputs = processor(images=pil_working, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        upsampled = torch.nn.functional.interpolate(
            logits, size=(working_h, working_w),
            mode="bilinear", align_corners=False
        )
        parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()

        # Hair mask
        hair_mask = (parsing == hair_class_id).astype(np.uint8)
        ears_mask = np.zeros_like(hair_mask)
        for cls in ear_class_ids:
            ears_mask[parsing == cls] = 1
        hair_mask[ears_mask==1] = 0

        # Morphology to clean
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
        hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
        hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
        hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8)

        hair_pixels = np.sum(hair_mask)
        if hair_pixels < 50:
            raise ValueError("NO_HAIR_DETECTED")

        # Inpainting (no extra noise, no blur)
        radius = 15 if hair_pixels > 220000 else 10
        flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
        inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
        inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)

        # Upscale bald area if needed
        if scale_factor < 1.0:
            bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
            mask_up = cv2.resize(hair_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
        else:
            bald_up = inpainted_rgb
            mask_up = hair_mask

        # Composite only bald area, rest untouched
        result = original_np.copy()
        result[mask_up==1] = bald_up[mask_up==1]

        return Image.fromarray(result)

    except UnidentifiedImageError:
        raise ValueError("Invalid image format or corrupt image!")
    except Exception as e:
        logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
        raise RuntimeError(f"Bald processing failed: {str(e)}")