Spaces:
Sleeping
Sleeping
zhiweili commited on
Commit ·
0ec7070
1
Parent(s): 0f3fb3e
add crop
Browse files
app.py
CHANGED
|
@@ -15,8 +15,10 @@ base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
|
|
| 15 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
| 16 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
| 17 |
labels = segmenter.labels
|
|
|
|
| 18 |
|
| 19 |
def segment(input_image, category):
|
|
|
|
| 20 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
| 21 |
segmentation_result = segmenter.segment(image)
|
| 22 |
category_mask = segmentation_result.category_mask
|
|
@@ -29,6 +31,41 @@ def segment(input_image, category):
|
|
| 29 |
else:
|
| 30 |
target_mask = category_mask_np == 0
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# Generate solid color images for showing the output segmentation mask.
|
| 33 |
image_data = image.numpy_view()
|
| 34 |
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
|
@@ -40,7 +77,7 @@ def segment(input_image, category):
|
|
| 40 |
|
| 41 |
output_image = np.where(condition, fg_image, bg_image)
|
| 42 |
output_image = Image.fromarray(output_image)
|
| 43 |
-
return
|
| 44 |
|
| 45 |
def get_clothes_mask(category_mask_np):
|
| 46 |
body_skin_mask = category_mask_np == 2
|
|
@@ -82,10 +119,10 @@ def get_hair_mask(category_mask_np, should_dilate=False):
|
|
| 82 |
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
|
| 83 |
|
| 84 |
# Trim the expanded_hair_mask
|
| 85 |
-
# 1. Remove the area above hair_mask by
|
| 86 |
hair_indices = np.where(hair_mask)
|
| 87 |
min_hair_y = np.min(hair_indices[0])
|
| 88 |
-
expanded_hair_mask[:min_hair_y -
|
| 89 |
|
| 90 |
# 2. Remove the areas on both sides that exceed the clothing coordinates
|
| 91 |
clothes_indices = np.where(clothes_mask)
|
|
@@ -110,7 +147,8 @@ with gr.Blocks() as app:
|
|
| 110 |
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
|
| 111 |
submit_btn = gr.Button(value='Submit', variant='primary')
|
| 112 |
with gr.Column():
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
submit_btn.click(
|
| 116 |
fn=segment,
|
|
@@ -118,7 +156,7 @@ with gr.Blocks() as app:
|
|
| 118 |
input_image,
|
| 119 |
category,
|
| 120 |
],
|
| 121 |
-
outputs=[output_image]
|
| 122 |
)
|
| 123 |
|
| 124 |
app.launch(debug=False, show_error=True)
|
|
|
|
| 15 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
| 16 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
| 17 |
labels = segmenter.labels
|
| 18 |
+
expand_size = 40
|
| 19 |
|
| 20 |
def segment(input_image, category):
|
| 21 |
+
original_height, original_width = input_image.size
|
| 22 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
| 23 |
segmentation_result = segmenter.segment(image)
|
| 24 |
category_mask = segmentation_result.category_mask
|
|
|
|
| 31 |
else:
|
| 32 |
target_mask = category_mask_np == 0
|
| 33 |
|
| 34 |
+
target_indices = np.where(target_mask)
|
| 35 |
+
start_y = np.min(target_indices[0]) - expand_size
|
| 36 |
+
if start_y < 0:
|
| 37 |
+
start_y = 0
|
| 38 |
+
end_y = np.max(target_indices[0]) + expand_size
|
| 39 |
+
if end_y > original_height:
|
| 40 |
+
end_y = original_height
|
| 41 |
+
start_x = np.min(target_indices[1]) - expand_size
|
| 42 |
+
if start_x < 0:
|
| 43 |
+
start_x = 0
|
| 44 |
+
end_x = np.max(target_indices[1]) + expand_size
|
| 45 |
+
if end_x > original_width:
|
| 46 |
+
end_x = original_width
|
| 47 |
+
target_height = end_y - start_y
|
| 48 |
+
target_width = end_x - start_x
|
| 49 |
+
|
| 50 |
+
# choose the max side length
|
| 51 |
+
max_side_length = max(target_height, target_width)
|
| 52 |
+
# calculate the crop area
|
| 53 |
+
crop_mask = target_mask[start_y:end_y, start_x:end_x]
|
| 54 |
+
crop_mask_height, crop_mask_width = crop_mask.shape
|
| 55 |
+
crop_mask_start_y = (max_side_length - crop_mask_height) // 2
|
| 56 |
+
crop_mask_end_y = crop_mask_start_y + crop_mask_height
|
| 57 |
+
crop_mask_start_x = (max_side_length - crop_mask_width) // 2
|
| 58 |
+
crop_mask_end_x = crop_mask_start_x + crop_mask_width
|
| 59 |
+
# create a square mask
|
| 60 |
+
crop_mask_square = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
|
| 61 |
+
crop_mask_square[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
|
| 62 |
+
# create a square image
|
| 63 |
+
crop_mask_image = Image.fromarray((crop_mask_square * 255).astype(np.uint8))
|
| 64 |
+
|
| 65 |
+
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
|
| 66 |
+
crop_image_square = Image.new("RGB", (max_side_length, max_side_length))
|
| 67 |
+
crop_image_square.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
|
| 68 |
+
|
| 69 |
# Generate solid color images for showing the output segmentation mask.
|
| 70 |
image_data = image.numpy_view()
|
| 71 |
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
|
|
|
| 77 |
|
| 78 |
output_image = np.where(condition, fg_image, bg_image)
|
| 79 |
output_image = Image.fromarray(output_image)
|
| 80 |
+
return crop_mask_image, crop_image_square
|
| 81 |
|
| 82 |
def get_clothes_mask(category_mask_np):
|
| 83 |
body_skin_mask = category_mask_np == 2
|
|
|
|
| 119 |
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
|
| 120 |
|
| 121 |
# Trim the expanded_hair_mask
|
| 122 |
+
# 1. Remove the area above hair_mask by 10 pixels
|
| 123 |
hair_indices = np.where(hair_mask)
|
| 124 |
min_hair_y = np.min(hair_indices[0])
|
| 125 |
+
expanded_hair_mask[:min_hair_y - 10, :] = 0
|
| 126 |
|
| 127 |
# 2. Remove the areas on both sides that exceed the clothing coordinates
|
| 128 |
clothes_indices = np.where(clothes_mask)
|
|
|
|
| 147 |
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
|
| 148 |
submit_btn = gr.Button(value='Submit', variant='primary')
|
| 149 |
with gr.Column():
|
| 150 |
+
mask_image = gr.Image(type='pil', label='Segmentation mask')
|
| 151 |
+
output_image = gr.Image(type='pil', label='Segmented image')
|
| 152 |
|
| 153 |
submit_btn.click(
|
| 154 |
fn=segment,
|
|
|
|
| 156 |
input_image,
|
| 157 |
category,
|
| 158 |
],
|
| 159 |
+
outputs=[mask_image, output_image]
|
| 160 |
)
|
| 161 |
|
| 162 |
app.launch(debug=False, show_error=True)
|