File size: 6,481 Bytes
5109f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
import mediapipe as mp
import cv2
import os
import warnings

# Suppress MediaPipe warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings('ignore')

# Suppress MediaPipe logs
import logging
logging.getLogger('mediapipe').setLevel(logging.ERROR)

# Load model
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")

def get_facemesh_mask(image):
    image_np = np.array(image)
    height, width, _ = image_np.shape
    face_mask = np.zeros((height, width), dtype=np.uint8)
    mp_face_mesh = mp.solutions.face_mesh
    with mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5) as face_mesh:
        results = face_mesh.process(image_np)
        if results.multi_face_landmarks:
            for face_landmarks in results.multi_face_landmarks:
                points = []
                for lm in face_landmarks.landmark:
                    x, y = int(lm.x * width), int(lm.y * height)
                    points.append([x, y])
                points = np.array(points, np.int32)
                # FaceMesh polygon (bao mặt, trán, không lấy cổ)
                if len(points) > 0:
                    hull = cv2.convexHull(points)
                    cv2.fillConvexPoly(face_mask, hull, 1)
    return face_mask

def expand_forehead_mask(face_mask, expand_percent=0.2):
    ys, xs = np.where(face_mask > 0)
    if len(ys) == 0:
        return face_mask  # không tìm thấy mặt
    min_y, max_y = ys.min(), ys.max()
    height = max_y - min_y
    expand = int(height * expand_percent)
    expanded_min_y = max(min_y - expand, 0)
    expanded_mask = np.zeros_like(face_mask)
    
    # Đảm bảo không lỗi khi vùng mở rộng vượt ngoài ảnh
    src_start = min_y
    src_end = max_y
    dst_start = expanded_min_y
    dst_end = expanded_min_y + (src_end - src_start)
    if dst_end > face_mask.shape[0]:
        overlap = dst_end - face_mask.shape[0]
        dst_end = face_mask.shape[0]
        src_end -= overlap

    expanded_mask[dst_start:dst_end, :] = face_mask[src_start:src_end, :]
    return expanded_mask

def extract_hair(image: Image.Image) -> Image.Image:
    """

    Return an RGBA image where hair pixels have alpha=255 and

    all other pixels have alpha=0.

    """
    rgb = image.convert("RGB")
    arr = np.array(rgb)
    h, w = arr.shape[:2]

    # Segment hair
    inputs = processor(images=rgb, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits.cpu()
    up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
    seg = up.argmax(dim=1)[0].numpy()
    hair_mask = (seg == 2).astype(np.uint8)

    # Build RGBA
    alpha = (hair_mask * 255).astype(np.uint8)
    rgba  = np.dstack([arr, alpha])
    return Image.fromarray(rgba)

def get_face(image):
    image = image.convert("RGB")
    # SegFormer hair mask
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits.cpu()
    upsampled_logits = F.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
    hair_mask = (pred_seg == 2).astype(np.uint8)  # tóc

    # Face mesh mask (bao trọn mặt, trán, không cổ)
    face_mesh_mask = get_facemesh_mask(image)
    # Expand lên trên 20% chiều cao mặt (ăn gian trán)
    expanded_face_mask = expand_forehead_mask(face_mesh_mask, expand_percent=0.2)

    # Vùng trán mở rộng chỉ lấy phần không trùng vùng mặt gốc và không trùng tóc
    expanded_only_forehead = cv2.bitwise_and(expanded_face_mask, 1 - face_mesh_mask)
    expanded_only_forehead = cv2.bitwise_and(expanded_only_forehead, 1 - hair_mask)

    # Kết hợp: tóc + mặt mediapipe (gốc) + vùng trán mở rộng (phía trên mặt gốc, không trùng tóc, không trùng mặt gốc)
    combined_mask = ((face_mesh_mask + expanded_only_forehead) > 0).astype(np.uint8)

    # Làm mượt mask
    combined_mask = cv2.GaussianBlur(combined_mask.astype(np.float32), (3, 3), 0)
    combined_mask = (combined_mask > 0.5).astype(np.uint8)

    np_image = np.array(image)
    alpha = (combined_mask * 255).astype(np.uint8)
    rgba_image = np.dstack([np_image, alpha])
    return Image.fromarray(rgba_image)

def extract_hair_face_full_forehead(image):
    image = image.convert("RGB")
    # SegFormer hair mask
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits.cpu()
    upsampled_logits = F.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
    hair_mask = (pred_seg == 2).astype(np.uint8)  # tóc

    # Face mesh mask (bao trọn mặt, trán, không cổ)
    face_mesh_mask = get_facemesh_mask(image)
    # Expand lên trên 20% chiều cao mặt (ăn gian trán)
    expanded_face_mask = expand_forehead_mask(face_mesh_mask, expand_percent=0.2)

    # Vùng trán mở rộng chỉ lấy phần không trùng vùng mặt gốc và không trùng tóc
    expanded_only_forehead = cv2.bitwise_and(expanded_face_mask, 1 - face_mesh_mask)
    expanded_only_forehead = cv2.bitwise_and(expanded_only_forehead, 1 - hair_mask)

    # Kết hợp: tóc + mặt mediapipe (gốc) + vùng trán mở rộng (phía trên mặt gốc, không trùng tóc, không trùng mặt gốc)
    combined_mask = ((hair_mask + face_mesh_mask + expanded_only_forehead) > 0).astype(np.uint8)

    # Làm mượt mask
    combined_mask = cv2.GaussianBlur(combined_mask.astype(np.float32), (3, 3), 0)
    combined_mask = (combined_mask > 0.5).astype(np.uint8)

    np_image = np.array(image)
    alpha = (combined_mask * 255).astype(np.uint8)
    rgba_image = np.dstack([np_image, alpha])
    return Image.fromarray(rgba_image)