be_rejection / segmentation.py
VanNguyen1214's picture
Upload 3 files
ab6feec verified
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
import mediapipe as mp
import cv2
import os
import warnings
# Suppress MediaPipe warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings('ignore')
# Suppress MediaPipe logs
import logging
logging.getLogger('mediapipe').setLevel(logging.ERROR)
# Load model
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
def get_facemesh_mask(image):
image_np = np.array(image)
height, width, _ = image_np.shape
face_mask = np.zeros((height, width), dtype=np.uint8)
mp_face_mesh = mp.solutions.face_mesh
with mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5) as face_mesh:
results = face_mesh.process(image_np)
if results.multi_face_landmarks:
for face_landmarks in results.multi_face_landmarks:
points = []
for lm in face_landmarks.landmark:
x, y = int(lm.x * width), int(lm.y * height)
points.append([x, y])
points = np.array(points, np.int32)
# FaceMesh polygon (bao mặt, trán, không lấy cổ)
if len(points) > 0:
hull = cv2.convexHull(points)
cv2.fillConvexPoly(face_mask, hull, 1)
return face_mask
def expand_forehead_mask(face_mask, expand_percent=0.2):
ys, xs = np.where(face_mask > 0)
if len(ys) == 0:
return face_mask # không tìm thấy mặt
min_y, max_y = ys.min(), ys.max()
height = max_y - min_y
expand = int(height * expand_percent)
expanded_min_y = max(min_y - expand, 0)
expanded_mask = np.zeros_like(face_mask)
# Đảm bảo không lỗi khi vùng mở rộng vượt ngoài ảnh
src_start = min_y
src_end = max_y
dst_start = expanded_min_y
dst_end = expanded_min_y + (src_end - src_start)
if dst_end > face_mask.shape[0]:
overlap = dst_end - face_mask.shape[0]
dst_end = face_mask.shape[0]
src_end -= overlap
expanded_mask[dst_start:dst_end, :] = face_mask[src_start:src_end, :]
return expanded_mask
def extract_hair(image: Image.Image) -> Image.Image:
"""
Return an RGBA image where hair pixels have alpha=255 and
all other pixels have alpha=0.
"""
rgb = image.convert("RGB")
arr = np.array(rgb)
h, w = arr.shape[:2]
# Segment hair
inputs = processor(images=rgb, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits.cpu()
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
seg = up.argmax(dim=1)[0].numpy()
hair_mask = (seg == 2).astype(np.uint8)
# Build RGBA
alpha = (hair_mask * 255).astype(np.uint8)
rgba = np.dstack([arr, alpha])
return Image.fromarray(rgba)
def get_face(image):
image = image.convert("RGB")
# SegFormer hair mask
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
hair_mask = (pred_seg == 2).astype(np.uint8) # tóc
# Face mesh mask (bao trọn mặt, trán, không cổ)
face_mesh_mask = get_facemesh_mask(image)
# Expand lên trên 20% chiều cao mặt (ăn gian trán)
expanded_face_mask = expand_forehead_mask(face_mesh_mask, expand_percent=0.2)
# Vùng trán mở rộng chỉ lấy phần không trùng vùng mặt gốc và không trùng tóc
expanded_only_forehead = cv2.bitwise_and(expanded_face_mask, 1 - face_mesh_mask)
expanded_only_forehead = cv2.bitwise_and(expanded_only_forehead, 1 - hair_mask)
# Kết hợp: tóc + mặt mediapipe (gốc) + vùng trán mở rộng (phía trên mặt gốc, không trùng tóc, không trùng mặt gốc)
combined_mask = ((face_mesh_mask + expanded_only_forehead) > 0).astype(np.uint8)
# Làm mượt mask
combined_mask = cv2.GaussianBlur(combined_mask.astype(np.float32), (3, 3), 0)
combined_mask = (combined_mask > 0.5).astype(np.uint8)
np_image = np.array(image)
alpha = (combined_mask * 255).astype(np.uint8)
rgba_image = np.dstack([np_image, alpha])
return Image.fromarray(rgba_image)
def extract_hair_face_full_forehead(image):
image = image.convert("RGB")
# SegFormer hair mask
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = F.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
hair_mask = (pred_seg == 2).astype(np.uint8) # tóc
# Face mesh mask (bao trọn mặt, trán, không cổ)
face_mesh_mask = get_facemesh_mask(image)
# Expand lên trên 20% chiều cao mặt (ăn gian trán)
expanded_face_mask = expand_forehead_mask(face_mesh_mask, expand_percent=0.2)
# Vùng trán mở rộng chỉ lấy phần không trùng vùng mặt gốc và không trùng tóc
expanded_only_forehead = cv2.bitwise_and(expanded_face_mask, 1 - face_mesh_mask)
expanded_only_forehead = cv2.bitwise_and(expanded_only_forehead, 1 - hair_mask)
# Kết hợp: tóc + mặt mediapipe (gốc) + vùng trán mở rộng (phía trên mặt gốc, không trùng tóc, không trùng mặt gốc)
combined_mask = ((hair_mask + face_mesh_mask + expanded_only_forehead) > 0).astype(np.uint8)
# Làm mượt mask
combined_mask = cv2.GaussianBlur(combined_mask.astype(np.float32), (3, 3), 0)
combined_mask = (combined_mask > 0.5).astype(np.uint8)
np_image = np.array(image)
alpha = (combined_mask * 255).astype(np.uint8)
rgba_image = np.dstack([np_image, alpha])
return Image.fromarray(rgba_image)