Spaces:
Sleeping
Sleeping
zhiweili commited on
Commit ·
40b1711
1
Parent(s): e8e2aa0
auto dilate the hair mask
Browse files- .gitignore +2 -1
- app.py +73 -6
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
.vscode
|
|
|
|
|
|
| 1 |
+
.vscode
|
| 2 |
+
.DS_Store
|
app.py
CHANGED
|
@@ -4,23 +4,30 @@ import numpy as np
|
|
| 4 |
from PIL import Image
|
| 5 |
from mediapipe.tasks import python
|
| 6 |
from mediapipe.tasks.python import vision
|
| 7 |
-
from scipy.ndimage import binary_dilation
|
| 8 |
|
| 9 |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
|
| 10 |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
|
| 11 |
|
| 12 |
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
|
|
|
|
| 13 |
base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
|
| 14 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
| 15 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
| 16 |
-
|
| 17 |
|
| 18 |
def segment(input_image, category):
|
| 19 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
| 20 |
segmentation_result = segmenter.segment(image)
|
| 21 |
category_mask = segmentation_result.category_mask
|
| 22 |
category_mask_np = category_mask.numpy_view()
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Generate solid color images for showing the output segmentation mask.
|
| 26 |
image_data = image.numpy_view()
|
|
@@ -29,18 +36,78 @@ def segment(input_image, category):
|
|
| 29 |
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
| 30 |
bg_image[:] = BG_COLOR
|
| 31 |
|
| 32 |
-
|
| 33 |
-
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
|
| 34 |
|
| 35 |
output_image = np.where(condition, fg_image, bg_image)
|
| 36 |
output_image = Image.fromarray(output_image)
|
| 37 |
return output_image
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
with gr.Blocks() as app:
|
| 40 |
with gr.Row():
|
| 41 |
with gr.Column():
|
| 42 |
input_image = gr.Image(type='pil', label='Upload image')
|
| 43 |
-
category = gr.Dropdown(label='Category', choices=
|
| 44 |
submit_btn = gr.Button(value='Submit', variant='primary')
|
| 45 |
with gr.Column():
|
| 46 |
output_image = gr.Image(type='pil', label='Image Output')
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
from mediapipe.tasks import python
|
| 6 |
from mediapipe.tasks.python import vision
|
| 7 |
+
from scipy.ndimage import binary_dilation, label
|
| 8 |
|
| 9 |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
|
| 10 |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
|
| 11 |
|
| 12 |
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
|
| 13 |
+
category_options = ["hair", "clothes", "background"]
|
| 14 |
base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
|
| 15 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
| 16 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
| 17 |
+
labels = segmenter.labels
|
| 18 |
|
| 19 |
def segment(input_image, category):
|
| 20 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
| 21 |
segmentation_result = segmenter.segment(image)
|
| 22 |
category_mask = segmentation_result.category_mask
|
| 23 |
category_mask_np = category_mask.numpy_view()
|
| 24 |
+
|
| 25 |
+
if category == "hair":
|
| 26 |
+
target_mask = get_hair_mask(category_mask_np, should_dilate=True)
|
| 27 |
+
elif category == "clothes":
|
| 28 |
+
target_mask = get_clothes_mask(category_mask_np)
|
| 29 |
+
else:
|
| 30 |
+
target_mask = category_mask_np == 0
|
| 31 |
|
| 32 |
# Generate solid color images for showing the output segmentation mask.
|
| 33 |
image_data = image.numpy_view()
|
|
|
|
| 36 |
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
| 37 |
bg_image[:] = BG_COLOR
|
| 38 |
|
| 39 |
+
condition = np.stack((target_mask,) * 3, axis=-1) > 0.2
|
|
|
|
| 40 |
|
| 41 |
output_image = np.where(condition, fg_image, bg_image)
|
| 42 |
output_image = Image.fromarray(output_image)
|
| 43 |
return output_image
|
| 44 |
|
| 45 |
+
def get_clothes_mask(category_mask_np):
|
| 46 |
+
body_skin_mask = category_mask_np == 2
|
| 47 |
+
clothes_mask = category_mask_np == 4
|
| 48 |
+
combined_mask = np.logical_or(body_skin_mask, clothes_mask)
|
| 49 |
+
combined_mask = binary_dilation(combined_mask, iterations=4)
|
| 50 |
+
return combined_mask
|
| 51 |
+
|
| 52 |
+
def get_hair_mask(category_mask_np, should_dilate=False):
|
| 53 |
+
hair_mask = category_mask_np == 1
|
| 54 |
+
hair_mask = binary_dilation(hair_mask, iterations=4)
|
| 55 |
+
if not should_dilate:
|
| 56 |
+
return hair_mask
|
| 57 |
+
body_skin_mask = category_mask_np == 2
|
| 58 |
+
face_skin_mask = category_mask_np == 3
|
| 59 |
+
clothes_mask = category_mask_np == 4
|
| 60 |
+
|
| 61 |
+
face_indices = np.where(face_skin_mask)
|
| 62 |
+
min_face_y = np.min(face_indices[0])
|
| 63 |
+
|
| 64 |
+
labeled_hair, hair_features = label(hair_mask)
|
| 65 |
+
top_hair_mask = np.zeros_like(hair_mask)
|
| 66 |
+
for i in range(1, hair_features + 1):
|
| 67 |
+
component_mask = labeled_hair == i
|
| 68 |
+
component_indices = np.where(component_mask)
|
| 69 |
+
min_component_y = np.min(component_indices[0])
|
| 70 |
+
if min_component_y <= min_face_y:
|
| 71 |
+
top_hair_mask[component_mask] = True
|
| 72 |
+
|
| 73 |
+
expanded_face_mask = binary_dilation(face_skin_mask, iterations=40)
|
| 74 |
+
# Combine the reference masks (body, clothes)
|
| 75 |
+
reference_mask = np.logical_or(body_skin_mask, clothes_mask)
|
| 76 |
+
# Exclude the expanded face mask from the reference mask
|
| 77 |
+
reference_mask = np.logical_and(reference_mask, ~expanded_face_mask)
|
| 78 |
+
|
| 79 |
+
# Expand the hair mask downward until it reaches the reference areas
|
| 80 |
+
expanded_hair_mask = top_hair_mask
|
| 81 |
+
while not np.any(np.logical_and(expanded_hair_mask, reference_mask)):
|
| 82 |
+
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
|
| 83 |
+
|
| 84 |
+
# Trim the expanded_hair_mask
|
| 85 |
+
# 1. Remove the area above hair_mask by 20 pixels
|
| 86 |
+
hair_indices = np.where(hair_mask)
|
| 87 |
+
min_hair_y = np.min(hair_indices[0]) - 20
|
| 88 |
+
expanded_hair_mask[:min_hair_y, :] = 0
|
| 89 |
+
|
| 90 |
+
# 2. Remove the areas on both sides that exceed the clothing coordinates
|
| 91 |
+
clothes_indices = np.where(clothes_mask)
|
| 92 |
+
min_clothes_x = np.min(clothes_indices[1])
|
| 93 |
+
max_clothes_x = np.max(clothes_indices[1])
|
| 94 |
+
expanded_hair_mask[:, :min_clothes_x] = 0
|
| 95 |
+
expanded_hair_mask[:, max_clothes_x+1:] = 0
|
| 96 |
+
|
| 97 |
+
# exclude the face-skin, body-skin and clothes areas
|
| 98 |
+
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~face_skin_mask)
|
| 99 |
+
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~body_skin_mask)
|
| 100 |
+
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~clothes_mask)
|
| 101 |
+
# combine the hair mask with the expanded hair mask
|
| 102 |
+
expanded_hair_mask = np.logical_or(hair_mask, expanded_hair_mask)
|
| 103 |
+
|
| 104 |
+
return expanded_hair_mask
|
| 105 |
+
|
| 106 |
with gr.Blocks() as app:
|
| 107 |
with gr.Row():
|
| 108 |
with gr.Column():
|
| 109 |
input_image = gr.Image(type='pil', label='Upload image')
|
| 110 |
+
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
|
| 111 |
submit_btn = gr.Button(value='Submit', variant='primary')
|
| 112 |
with gr.Column():
|
| 113 |
output_image = gr.Image(type='pil', label='Image Output')
|