AItool commited on
Commit
c6b2642
·
verified ·
1 Parent(s): a9e81fe

Delete segment_utils.py

Browse files
Files changed (1) hide show
  1. segment_utils.py +0 -105
segment_utils.py DELETED
@@ -1,105 +0,0 @@
1
- import numpy as np
2
- import mediapipe as mp
3
- import uuid
4
- import os
5
-
6
- from PIL import Image
7
- from mediapipe.tasks import python
8
- from mediapipe.tasks.python import vision
9
- from scipy.ndimage import binary_dilation
10
- from croper import Croper
11
-
12
- segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
13
- base_options = python.BaseOptions(model_asset_path=segment_model)
14
- options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
15
- segmenter = vision.ImageSegmenter.create_from_options(options)
16
-
17
- def restore_result(croper, category, generated_image):
18
- square_length = croper.square_length
19
- generated_image = generated_image.resize((square_length, square_length))
20
-
21
- cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
22
- cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
23
-
24
- restored_image = croper.input_image.copy()
25
- restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
26
-
27
- extension = 'png'
28
- # if restored_image.mode == 'RGBA':
29
- # extension = 'png'
30
- # else:
31
- # extension = 'jpg'
32
-
33
- tmpPrefix = "/tmp/gradio/"
34
-
35
- targetDir = f"{tmpPrefix}output/"
36
- if not os.path.exists(targetDir):
37
- os.makedirs(targetDir)
38
-
39
- path = f"{targetDir}{uuid.uuid4()}.{extension}"
40
- restored_image.save(path, quality=100)
41
-
42
- return restored_image, path
43
-
44
- def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
45
- mask_size = int(input_size)
46
- mask_expansion = int(mask_expansion)
47
-
48
- image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
49
- segmentation_result = segmenter.segment(image)
50
- category_mask = segmentation_result.category_mask
51
- category_mask_np = category_mask.numpy_view()
52
-
53
- if category == "hair":
54
- target_mask = get_hair_mask(category_mask_np, mask_dilation)
55
- elif category == "clothes":
56
- target_mask = get_clothes_mask(category_mask_np, mask_dilation)
57
- elif category == "face":
58
- target_mask = get_face_mask(category_mask_np, mask_dilation)
59
- else:
60
- target_mask = get_face_mask(category_mask_np, mask_dilation)
61
-
62
- croper = Croper(input_image, target_mask, mask_size, mask_expansion)
63
- croper.corp_mask_image()
64
- origin_area_image = croper.resized_square_image
65
-
66
- return origin_area_image, croper
67
-
68
- def get_face_mask(category_mask_np, dilation=1):
69
- face_skin_mask = category_mask_np == 3
70
- if dilation > 0:
71
- face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
72
-
73
- return face_skin_mask
74
-
75
- def get_clothes_mask(category_mask_np, dilation=1):
76
- body_skin_mask = category_mask_np == 2
77
- clothes_mask = category_mask_np == 4
78
- combined_mask = np.logical_or(body_skin_mask, clothes_mask)
79
- combined_mask = binary_dilation(combined_mask, iterations=4)
80
- if dilation > 0:
81
- combined_mask = binary_dilation(combined_mask, iterations=dilation)
82
- return combined_mask
83
-
84
- def get_hair_mask(category_mask_np, dilation=1):
85
- hair_mask = category_mask_np == 1
86
- if dilation > 0:
87
- hair_mask = binary_dilation(hair_mask, iterations=dilation)
88
- return hair_mask
89
-
90
- def get_restore_mask_image(croper, category, generated_image):
91
- image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
92
- segmentation_result = segmenter.segment(image)
93
- category_mask = segmentation_result.category_mask
94
- category_mask_np = category_mask.numpy_view()
95
-
96
- if category == "hair":
97
- target_mask = get_hair_mask(category_mask_np, 0)
98
- elif category == "clothes":
99
- target_mask = get_clothes_mask(category_mask_np, 0)
100
- elif category == "face":
101
- target_mask = get_face_mask(category_mask_np, 0)
102
-
103
- combined_mask = np.logical_or(target_mask, croper.corp_mask)
104
- mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
105
- return mask_image