spandey8 commited on
Commit
f597984
·
verified ·
1 Parent(s): 96f2bcd

Upload 3 files

Browse files
ISPFD_preprocessing/sam_clip_ispfdv1blackback.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
+ import argparse
4
+ import json
5
+ from PIL import Image
6
+ import os
7
+ import numpy as np
8
+ from typing import Any, Dict, List
9
+ from tqdm import tqdm
10
+ from transformers import AutoProcessor, CLIPModel
11
+ import torch
12
+
13
+ parser = argparse.ArgumentParser(description=())
14
+ parser.add_argument("--parentdir", type=str)
15
+ parser.add_argument("--dstndir", type=str)
16
+ parser.add_argument("--device", type=str, default="cuda")
17
+ parser.add_argument("--convert-to-rle",action="store_true")
18
+ amg_settings = parser.add_argument_group("AMG Settings")
19
+ amg_settings.add_argument(
20
+ "--points-per-side",
21
+ type=int,
22
+ default=None,
23
+ )
24
+ amg_settings.add_argument(
25
+ "--points-per-batch",
26
+ type=int,
27
+ default=None,
28
+ help="How many input points to process simultaneously in one batch.",
29
+ )
30
+ amg_settings.add_argument(
31
+ "--pred-iou-thresh",
32
+ type=float,
33
+ default=None,
34
+ help="Exclude masks with a predicted score from the model that is lower than this threshold.",
35
+ )
36
+ amg_settings.add_argument(
37
+ "--stability-score-thresh",
38
+ type=float,
39
+ default=None,
40
+ help="Exclude masks with a stability score lower than this threshold.",
41
+ )
42
+ amg_settings.add_argument(
43
+ "--stability-score-offset",
44
+ type=float,
45
+ default=None,
46
+ help="Larger values perturb the mask more when measuring stability score.",
47
+ )
48
+ amg_settings.add_argument(
49
+ "--box-nms-thresh",
50
+ type=float,
51
+ default=None,
52
+ help="The overlap threshold for excluding a duplicate mask.",
53
+ )
54
+ amg_settings.add_argument(
55
+ "--crop-n-layers",
56
+ type=int,
57
+ default=None,
58
+ help=(
59
+ "If >0, mask generation is run on smaller crops of the image to generate more masks. "
60
+ "The value sets how many different scales to crop at."
61
+ ),
62
+ )
63
+ amg_settings.add_argument(
64
+ "--crop-nms-thresh",
65
+ type=float,
66
+ default=None,
67
+ help="The overlap threshold for excluding duplicate masks across different crops.",
68
+ )
69
+ amg_settings.add_argument(
70
+ "--crop-overlap-ratio",
71
+ type=int,
72
+ default=None,
73
+ help="Larger numbers mean image crops will overlap more.",
74
+ )
75
+ amg_settings.add_argument(
76
+ "--crop-n-points-downscale-factor",
77
+ type=int,
78
+ default=None,
79
+ help="The number of points-per-side in each layer of crop is reduced by this factor.",
80
+ )
81
+ amg_settings.add_argument(
82
+ "--min-mask-region-area",
83
+ type=int,
84
+ default=None,
85
+ help=(
86
+ "Disconnected mask regions or holes with area smaller than this value "
87
+ "in pixels are removed by postprocessing."
88
+ ),
89
+ )
90
+
91
+ def get_amg_kwargs(args):
92
+ amg_kwargs = {
93
+ "points_per_side": args.points_per_side,
94
+ "points_per_batch": args.points_per_batch,
95
+ "pred_iou_thresh": args.pred_iou_thresh,
96
+ "stability_score_thresh": args.stability_score_thresh,
97
+ "stability_score_offset": args.stability_score_offset,
98
+ "box_nms_thresh": args.box_nms_thresh,
99
+ "crop_n_layers": args.crop_n_layers,
100
+ "crop_nms_thresh": args.crop_nms_thresh,
101
+ "crop_overlap_ratio": args.crop_overlap_ratio,
102
+ "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
103
+ "min_mask_region_area": args.min_mask_region_area,
104
+ }
105
+ amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
106
+ return amg_kwargs
107
+
108
+ def write_masks_to_folder(masks):
109
+ masks_lst = list()
110
+ box_lst = list()
111
+ for _, mask_data in enumerate(masks):
112
+ mask = mask_data["segmentation"]
113
+ masks_lst.append(mask * 255)
114
+ box_lst.append(mask_data['bbox'])
115
+ return masks_lst, box_lst
116
+
117
+ def pad_and_crop_mask(mask, image, padding):
118
+ non_zero_indices = np.where(mask == 255)
119
+ y_min, y_max = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + 1
120
+ x_min, x_max = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + 1
121
+ pad_width = ((padding, padding), (padding, padding))
122
+ y_min = max(y_min - pad_width[0][0], 0)
123
+ y_max = min(y_max + pad_width[0][1], image.shape[0])
124
+ x_min = max(x_min - pad_width[1][0], 0)
125
+ x_max = min(x_max + pad_width[1][1], image.shape[1])
126
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
127
+ cropped_image = image_rgb[y_min:y_max, x_min:x_max]
128
+ h,w,_ = cropped_image.shape
129
+ if h > w:
130
+ cropped_image = cropped_image[:int((3*h)/4), :]
131
+ else:
132
+ cropped_image = cropped_image[:, :int(w/2)]
133
+ return cropped_image
134
+
135
+ def get_object_from_mask(image, mask):
136
+ if not isinstance(image, np.ndarray) or not isinstance(mask, np.ndarray):
137
+ raise TypeError("Image and mask must be NumPy arrays.")
138
+ if image.shape[:2] != mask.shape:
139
+ raise ValueError("Image and mask must have the same spatial dimensions.")
140
+ object_image = np.zeros_like(image)
141
+ object_image[mask == 255] = image[mask == 255]
142
+ object_image = cv2.cvtColor(object_image, cv2.COLOR_BGR2RGB)
143
+ return object_image
144
+
145
+ def orient_and_adjust(image,bbox):
146
+ if image.shape[1]>image.shape[0]: # -- image is horizontal
147
+ # image = cv2.rotate(image, cv2.ROTATE_180)
148
+ # new_x = image.shape[1] - bbox[0] - bbox[2]
149
+ # new_y = image.shape[0] - bbox[1] - bbox[3]
150
+ # bbox = (new_x, new_y, bbox[2], bbox[3])
151
+ # img = cv2.rectangle(image,(bbox[0],bbox[1]),(bbox[0]+bbox[2],bbox[1]+bbox[3]),(0,255,0),2)
152
+ # cv2.imwrite('test.jpg',image)
153
+ box_mid = bbox[0] + (bbox[2]//2)
154
+ if image.shape[1]//2 < box_mid: #-----coming from left
155
+ image = cv2.flip(image,0)
156
+ else: #-----coming from right
157
+ image = cv2.rotate(image, cv2.ROTATE_180)
158
+ image = cv2.flip(image,0)
159
+ return "H",image
160
+ else:
161
+ # cv2.imwrite('test.jpg',image)
162
+ box_mid = bbox[1] + (bbox[3]//2)
163
+ if image.shape[0]//2 > box_mid: # ----- coming from down
164
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
165
+ image = cv2.flip(image,1)
166
+ else: # coming from up
167
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
168
+ image = cv2.flip(image,1)
169
+ return "V",image
170
+
171
+ def tight_crop_with_padding(image, padding=5):
172
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
173
+ _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
174
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
175
+ largest_contour = max(contours, key=cv2.contourArea)
176
+ x, y, w, h = cv2.boundingRect(largest_contour)
177
+ x, y, w, h = x - padding, y - padding, w + padding * 2, h + padding * 2
178
+ cropped_image = image[y:y+h, x:x+w]
179
+ return cropped_image
180
+
181
+ def split_image_vertically(image):
182
+ height, width, channels = image.shape
183
+ half_width = int(0.55*width)
184
+ left_half = image[:, :half_width, :]
185
+ return left_half
186
+
187
+ def main(args: argparse.Namespace):
188
+ print("Loading model...")
189
+ sam = sam_model_registry['vit_h'](checkpoint="< Path to sam_vit_h_4b8939.pth cloned from SAM v1 repo >").to(device=args.device)
190
+ output_mode = "binary_mask"
191
+ amg_kwargs = get_amg_kwargs(args)
192
+ generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
193
+
194
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device=args.device)
195
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
196
+ text_prompt = "Human Finger"
197
+
198
+ parent_folder = args.parentdir
199
+ dstn_folder = args.dstndir
200
+
201
+ targets = list()
202
+ for file in os.listdir(parent_folder):
203
+ targets.append(os.path.join(parent_folder,file))
204
+
205
+ exce = list()
206
+ for t in tqdm(targets):
207
+ image = cv2.imread(t)
208
+ if image is None:
209
+ print(f"Could not load '{t}' as an image, skipping...")
210
+ continue
211
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
212
+ masks = generator.generate(image)
213
+ dstn_file = t.split("/")[-1]
214
+ count=1
215
+ img_lst = list()
216
+ sim_lst = list()
217
+ if output_mode == "binary_mask":
218
+ lst,box_lst = write_masks_to_folder(masks)
219
+ for i in lst:
220
+ i = get_object_from_mask(image, i)
221
+ img = Image.fromarray(i)
222
+ inputs = processor(text=[text_prompt], images=img, return_tensors="pt", padding=True).to(device=args.device)
223
+ with torch.no_grad():
224
+ outputs = model(**inputs)
225
+ logits_per_image = outputs.logits_per_image
226
+ sim_lst.append(logits_per_image.cpu().numpy()[0])
227
+ img_lst.append(i)
228
+ count += 1
229
+ best_image = img_lst[sim_lst.index(max(sim_lst))]
230
+ bbox = box_lst[sim_lst.index(max(sim_lst))]
231
+ # postprocessing
232
+ orienta,best_image = orient_and_adjust(best_image,bbox)
233
+ best_image = tight_crop_with_padding(best_image,5)
234
+ best_image = split_image_vertically(best_image)
235
+ try:
236
+ cv2.imwrite(os.path.join(dstn_folder,t.split("/")[-1]),best_image)
237
+ except:
238
+ exce.append(t.split("/")[-1])
239
+ print(f"number of files skipped: {len(exce)}")
240
+ with open(dstn_folder.split("/")[-2]+"_"+dstn_folder.split("/")[-1]+"_exceptions.json",'w') as js:
241
+ json.dump(exce)
242
+
243
+ if __name__ == "__main__":
244
+ args = parser.parse_args()
245
+ main(args)
ISPFD_preprocessing/sam_clip_ispfdv1colorback.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
+ import argparse
4
+ import json
5
+ from PIL import Image
6
+ import os
7
+ import numpy as np
8
+ from typing import Any, Dict, List
9
+ from tqdm import tqdm
10
+ from transformers import AutoProcessor, CLIPModel
11
+ import torch
12
+
13
+ parser = argparse.ArgumentParser(description=())
14
+ parser.add_argument("--parentdir", type=str)
15
+ parser.add_argument("--dstndir", type=str)
16
+ parser.add_argument("--device", type=str, default="cuda")
17
+ parser.add_argument("--convert-to-rle",action="store_true")
18
+ amg_settings = parser.add_argument_group("AMG Settings")
19
+ amg_settings.add_argument(
20
+ "--points-per-side",
21
+ type=int,
22
+ default=None,
23
+ )
24
+ amg_settings.add_argument(
25
+ "--points-per-batch",
26
+ type=int,
27
+ default=None,
28
+ help="How many input points to process simultaneously in one batch.",
29
+ )
30
+ amg_settings.add_argument(
31
+ "--pred-iou-thresh",
32
+ type=float,
33
+ default=None,
34
+ help="Exclude masks with a predicted score from the model that is lower than this threshold.",
35
+ )
36
+ amg_settings.add_argument(
37
+ "--stability-score-thresh",
38
+ type=float,
39
+ default=None,
40
+ help="Exclude masks with a stability score lower than this threshold.",
41
+ )
42
+ amg_settings.add_argument(
43
+ "--stability-score-offset",
44
+ type=float,
45
+ default=None,
46
+ help="Larger values perturb the mask more when measuring stability score.",
47
+ )
48
+ amg_settings.add_argument(
49
+ "--box-nms-thresh",
50
+ type=float,
51
+ default=None,
52
+ help="The overlap threshold for excluding a duplicate mask.",
53
+ )
54
+ amg_settings.add_argument(
55
+ "--crop-n-layers",
56
+ type=int,
57
+ default=None,
58
+ help=(
59
+ "If >0, mask generation is run on smaller crops of the image to generate more masks. "
60
+ "The value sets how many different scales to crop at."
61
+ ),
62
+ )
63
+ amg_settings.add_argument(
64
+ "--crop-nms-thresh",
65
+ type=float,
66
+ default=None,
67
+ help="The overlap threshold for excluding duplicate masks across different crops.",
68
+ )
69
+ amg_settings.add_argument(
70
+ "--crop-overlap-ratio",
71
+ type=int,
72
+ default=None,
73
+ help="Larger numbers mean image crops will overlap more.",
74
+ )
75
+ amg_settings.add_argument(
76
+ "--crop-n-points-downscale-factor",
77
+ type=int,
78
+ default=None,
79
+ help="The number of points-per-side in each layer of crop is reduced by this factor.",
80
+ )
81
+ amg_settings.add_argument(
82
+ "--min-mask-region-area",
83
+ type=int,
84
+ default=None,
85
+ help=(
86
+ "Disconnected mask regions or holes with area smaller than this value "
87
+ "in pixels are removed by postprocessing."
88
+ ),
89
+ )
90
+
91
+ def get_amg_kwargs(args):
92
+ amg_kwargs = {
93
+ "points_per_side": args.points_per_side,
94
+ "points_per_batch": args.points_per_batch,
95
+ "pred_iou_thresh": args.pred_iou_thresh,
96
+ "stability_score_thresh": args.stability_score_thresh,
97
+ "stability_score_offset": args.stability_score_offset,
98
+ "box_nms_thresh": args.box_nms_thresh,
99
+ "crop_n_layers": args.crop_n_layers,
100
+ "crop_nms_thresh": args.crop_nms_thresh,
101
+ "crop_overlap_ratio": args.crop_overlap_ratio,
102
+ "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
103
+ "min_mask_region_area": args.min_mask_region_area,
104
+ }
105
+ amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
106
+ return amg_kwargs
107
+
108
+ def write_masks_to_folder(masks):
109
+ masks_lst = list()
110
+ box_lst = list()
111
+ for _, mask_data in enumerate(masks):
112
+ mask = mask_data["segmentation"]
113
+ masks_lst.append(mask * 255)
114
+ box_lst.append(mask_data['bbox'])
115
+ return masks_lst, box_lst
116
+
117
+ def calculate_total_zeros_in_stride_right(array, start_column, stride_length):
118
+ end_column = start_column + stride_length
119
+ columns_to_check = array[:, start_column:end_column]
120
+ total_zeros = np.sum(columns_to_check == 0)
121
+ return total_zeros
122
+
123
+ def calculate_total_zeros_in_left_stride(array, start_column, stride_length):
124
+ end_column = max(0, start_column - stride_length)
125
+ columns_to_check = array[:, end_column:start_column]
126
+ total_zeros = np.sum(columns_to_check == 0)
127
+ return total_zeros
128
+
129
+ def calculate_total_zeros_in_downward_stride(matrix, start_row, stride_length):
130
+ end_row = min(start_row + stride_length, matrix.shape[0])
131
+ rows_to_check = matrix[start_row:end_row, :]
132
+ total_zeros = np.sum(rows_to_check == 0)
133
+ return total_zeros
134
+
135
+ def calculate_total_zeros_in_upward_stride(matrix, start_row, stride_length):
136
+ end_row = max(0, start_row - stride_length)
137
+ rows_to_check = matrix[end_row:start_row, :]
138
+ total_zeros = np.sum(rows_to_check == 0)
139
+ return total_zeros
140
+
141
+ def pad_and_crop_mask(mask, image, padding):
142
+ non_zero_indices = np.where(mask == 255)
143
+ y_min, y_max = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + 1
144
+ x_min, x_max = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + 1
145
+ pad_width = ((padding, padding), (padding, padding))
146
+ y_min = max(y_min - pad_width[0][0], 0)
147
+ y_max = min(y_max + pad_width[0][1], image.shape[0])
148
+ x_min = max(x_min - pad_width[1][0], 0)
149
+ x_max = min(x_max + pad_width[1][1], image.shape[1])
150
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
151
+ cropped_image = image_rgb[y_min:y_max, x_min:x_max]
152
+ h,w,_ = cropped_image.shape
153
+ if h > w:
154
+ cropped_image = cropped_image[:int((3*h)/4), :]
155
+ else:
156
+ cropped_image = cropped_image[:, :int(w/2)]
157
+ return cropped_image
158
+
159
+ def get_object_from_mask(image, mask):
160
+ if not isinstance(image, np.ndarray) or not isinstance(mask, np.ndarray):
161
+ raise TypeError("Image and mask must be NumPy arrays.")
162
+ if image.shape[:2] != mask.shape:
163
+ raise ValueError("Image and mask must have the same spatial dimensions.")
164
+ object_image = np.zeros_like(image)
165
+ object_image[mask == 255] = image[mask == 255]
166
+ object_image = cv2.cvtColor(object_image, cv2.COLOR_BGR2RGB)
167
+ return object_image
168
+
169
+ def orient_and_adjust(image,bbox):
170
+ if image.shape[1]>image.shape[0]: # -- image is horizontal
171
+ img = image[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]]
172
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
173
+ for i in range(img.shape[1]):
174
+ if np.count_nonzero(img[:, i]) >= 20:
175
+ left_index = i
176
+ break
177
+ for i in range(img.shape[1] - 1, -1, -1):
178
+ if np.count_nonzero(img[:, i]) >= 20:
179
+ right_index = i
180
+ break
181
+ total_zeros_towards_right = calculate_total_zeros_in_stride_right(img, left_index, 15)
182
+ total_zeros_towards_left = calculate_total_zeros_in_left_stride(img, right_index, 15)
183
+ if total_zeros_towards_right > total_zeros_towards_left: #---coming from left
184
+ image = cv2.flip(image,0)
185
+ orien = 'No'
186
+ else: #-----coming from right
187
+ image = cv2.rotate(image, cv2.ROTATE_180)
188
+ image = cv2.flip(image,0)
189
+ orien = '180'
190
+ return "H", image, orien
191
+ else:
192
+ img = image[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]]
193
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
194
+ for i in range(img.shape[0]):
195
+ if np.count_nonzero(img[i, :]) >= 20:
196
+ top_index = i
197
+ break
198
+ for i in range(img.shape[0] - 1, -1, -1):
199
+ if np.count_nonzero(img[i, :]) >= 20:
200
+ bottom_index = i
201
+ break
202
+ total_zeros_towards_down = calculate_total_zeros_in_downward_stride(img, top_index, 15)
203
+ total_zeros_towards_up = calculate_total_zeros_in_upward_stride(img, bottom_index, 15)
204
+ if total_zeros_towards_down > total_zeros_towards_up: # ----- coming from down
205
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
206
+ image = cv2.flip(image,0)
207
+ orien = 'Rotate90anti'
208
+ else: # coming from up
209
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
210
+ image = cv2.flip(image,0)
211
+ orien = 'Rotate90'
212
+ return "V", image, orien
213
+
214
+ def tight_crop_with_padding(image, original_image, padding=5):
215
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
216
+ _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
217
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
218
+ largest_contour = max(contours, key=cv2.contourArea)
219
+ x, y, w, h = cv2.boundingRect(largest_contour)
220
+ x, y, w, h = x - padding, y - padding, w + padding * 2, h + padding * 2
221
+ cropped_image = image[y:y+h, x:x+w]
222
+ crop_original = original_image[y:y+h, x:x+w]
223
+ return cropped_image, crop_original
224
+
225
+ def split_image_vertically(image):
226
+ height, width, channels = image.shape
227
+ half_width = int(0.55*width)
228
+ left_half = image[:, :half_width, :]
229
+ return left_half
230
+
231
+ def tight_bounding_box(image):
232
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
233
+ _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
234
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
235
+ cnt = max(contours, key=cv2.contourArea)
236
+ ellipse = cv2.fitEllipse(cnt)
237
+ # print(ellipse[1][0]/ellipse[1][1])
238
+ # cv2.imwrite("test.jpg",cv2.ellipse(image, ellipse, (0, 255, 0), 2))
239
+ # exit(0)
240
+ return ellipse[1][0]/ellipse[1][1]
241
+
242
+ def resize_image_with_aspect_ratio(image, max_width=None, max_height=None):
243
+ height, width, _ = image.shape
244
+ aspect_ratio = width / height
245
+ if max_width and width > max_width:
246
+ new_width = max_width
247
+ new_height = int(new_width / aspect_ratio)
248
+ elif max_height and height > max_height:
249
+ new_height = max_height
250
+ new_width = int(new_height * aspect_ratio)
251
+ else:
252
+ return image
253
+ resized_image = cv2.resize(image, (new_width, new_height))
254
+ return resized_image
255
+
256
+ def main(args: argparse.Namespace):
257
+ print("Loading model...")
258
+ sam = sam_model_registry['vit_h'](checkpoint="< Path to sam_vit_h_4b8939.pth cloned from SAM v1 repo >").to(device=args.device)
259
+ output_mode = "binary_mask"
260
+ amg_kwargs = get_amg_kwargs(args)
261
+ generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
262
+
263
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device=args.device)
264
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
265
+ text_prompt = "Human Finger"
266
+
267
+ parent_folder = args.parentdir
268
+ dstn_folder = args.dstndir
269
+
270
+ targets = list()
271
+ for file in os.listdir(parent_folder):
272
+ targets.append(os.path.join(parent_folder,file))
273
+
274
+ exce = list()
275
+ skipper = 0
276
+ for t in tqdm(targets):
277
+ image = cv2.imread(t)
278
+ if image is None:
279
+ print(f"Could not load '{t}' as an image, skipping...")
280
+ continue
281
+ # image = resize_image_with_aspect_ratio(image, 3264, 2448)
282
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
283
+ masks = generator.generate(image)
284
+ dstn_file = t.split("/")[-1]
285
+ count=1
286
+ img_lst = list()
287
+ sim_lst = list()
288
+ if output_mode == "binary_mask":
289
+ lst,box_lst = write_masks_to_folder(masks)
290
+ for i in lst:
291
+ i = get_object_from_mask(image, i)
292
+ img = Image.fromarray(i)
293
+ inputs = processor(text=[text_prompt], images=img, return_tensors="pt", padding=True).to(device=args.device)
294
+ with torch.no_grad():
295
+ outputs = model(**inputs)
296
+ logits_per_image = outputs.logits_per_image
297
+ sim_lst.append(logits_per_image.cpu().numpy()[0])
298
+ img_lst.append(i)
299
+ count += 1
300
+ best_image = img_lst[sim_lst.index(max(sim_lst))]
301
+ bbox = box_lst[sim_lst.index(max(sim_lst))]
302
+ # postprocessing
303
+ try:
304
+ orienta, best_image, orien = orient_and_adjust(best_image,bbox)
305
+ if orienta == 'V' and orien == 'Rotate90':
306
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
307
+ elif orienta == 'V' and orien == 'Rotate90anti':
308
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
309
+ elif orienta == 'H' and orien == 'No':
310
+ pass
311
+ elif orienta == 'H' and orien == '180':
312
+ image = cv2.rotate(image, cv2.ROTATE_180)
313
+ image = cv2.flip(image,0)
314
+ best_image, image = tight_crop_with_padding(best_image,image,5)
315
+ ratio = tight_bounding_box(best_image)
316
+ if 0.46<=ratio<=55:
317
+ pass
318
+ else:
319
+ image = split_image_vertically(image)
320
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
321
+ cv2.imwrite(os.path.join(dstn_folder,t.split("/")[-1]),image)
322
+ except:
323
+ skipper+=1
324
+ print(t.split("/")[-1])
325
+ exce.append(t.split("/")[-1])
326
+ print(f"number of files skipped: {len(exce)}")
327
+ with open(dstn_folder.split("/")[-2]+"_"+dstn_folder.split("/")[-1]+"_exceptions_v2.json",'w') as js:
328
+ json.dump(exce,js,indent=4)
329
+
330
+ if __name__ == "__main__":
331
+ args = parser.parse_args()
332
+ main(args)
ISPFD_preprocessing/sam_clip_ispfdv2colorback.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
+ import argparse
4
+ import json
5
+ from PIL import Image
6
+ import os
7
+ import numpy as np
8
+ from typing import Any, Dict, List
9
+ from tqdm import tqdm
10
+ from transformers import AutoProcessor, CLIPModel
11
+ import torch
12
+
13
+ parser = argparse.ArgumentParser(description=())
14
+ parser.add_argument("--parentdir", type=str)
15
+ parser.add_argument("--dstndir", type=str)
16
+ parser.add_argument("--device", type=str, default="cuda")
17
+ parser.add_argument("--convert-to-rle",action="store_true")
18
+ amg_settings = parser.add_argument_group("AMG Settings")
19
+ amg_settings.add_argument(
20
+ "--points-per-side",
21
+ type=int,
22
+ default=None,
23
+ )
24
+ amg_settings.add_argument(
25
+ "--points-per-batch",
26
+ type=int,
27
+ default=None,
28
+ help="How many input points to process simultaneously in one batch.",
29
+ )
30
+ amg_settings.add_argument(
31
+ "--pred-iou-thresh",
32
+ type=float,
33
+ default=None,
34
+ help="Exclude masks with a predicted score from the model that is lower than this threshold.",
35
+ )
36
+ amg_settings.add_argument(
37
+ "--stability-score-thresh",
38
+ type=float,
39
+ default=None,
40
+ help="Exclude masks with a stability score lower than this threshold.",
41
+ )
42
+ amg_settings.add_argument(
43
+ "--stability-score-offset",
44
+ type=float,
45
+ default=None,
46
+ help="Larger values perturb the mask more when measuring stability score.",
47
+ )
48
+ amg_settings.add_argument(
49
+ "--box-nms-thresh",
50
+ type=float,
51
+ default=None,
52
+ help="The overlap threshold for excluding a duplicate mask.",
53
+ )
54
+ amg_settings.add_argument(
55
+ "--crop-n-layers",
56
+ type=int,
57
+ default=None,
58
+ help=(
59
+ "If >0, mask generation is run on smaller crops of the image to generate more masks. "
60
+ "The value sets how many different scales to crop at."
61
+ ),
62
+ )
63
+ amg_settings.add_argument(
64
+ "--crop-nms-thresh",
65
+ type=float,
66
+ default=None,
67
+ help="The overlap threshold for excluding duplicate masks across different crops.",
68
+ )
69
+ amg_settings.add_argument(
70
+ "--crop-overlap-ratio",
71
+ type=int,
72
+ default=None,
73
+ help="Larger numbers mean image crops will overlap more.",
74
+ )
75
+ amg_settings.add_argument(
76
+ "--crop-n-points-downscale-factor",
77
+ type=int,
78
+ default=None,
79
+ help="The number of points-per-side in each layer of crop is reduced by this factor.",
80
+ )
81
+ amg_settings.add_argument(
82
+ "--min-mask-region-area",
83
+ type=int,
84
+ default=None,
85
+ help=(
86
+ "Disconnected mask regions or holes with area smaller than this value "
87
+ "in pixels are removed by postprocessing."
88
+ ),
89
+ )
90
+
91
+ def get_amg_kwargs(args):
92
+ amg_kwargs = {
93
+ "points_per_side": args.points_per_side,
94
+ "points_per_batch": args.points_per_batch,
95
+ "pred_iou_thresh": args.pred_iou_thresh,
96
+ "stability_score_thresh": args.stability_score_thresh,
97
+ "stability_score_offset": args.stability_score_offset,
98
+ "box_nms_thresh": args.box_nms_thresh,
99
+ "crop_n_layers": args.crop_n_layers,
100
+ "crop_nms_thresh": args.crop_nms_thresh,
101
+ "crop_overlap_ratio": args.crop_overlap_ratio,
102
+ "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
103
+ "min_mask_region_area": args.min_mask_region_area,
104
+ }
105
+ amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
106
+ return amg_kwargs
107
+
108
+ def write_masks_to_folder(masks):
109
+ masks_lst = list()
110
+ box_lst = list()
111
+ for _, mask_data in enumerate(masks):
112
+ mask = mask_data["segmentation"]
113
+ masks_lst.append(mask * 255)
114
+ box_lst.append(mask_data['bbox'])
115
+ return masks_lst, box_lst
116
+
117
+ def calculate_total_zeros_in_stride_right(array, start_column, stride_length):
118
+ end_column = start_column + stride_length
119
+ columns_to_check = array[:, start_column:end_column]
120
+ total_zeros = np.sum(columns_to_check == 0)
121
+ return total_zeros
122
+
123
+ def calculate_total_zeros_in_left_stride(array, start_column, stride_length):
124
+ end_column = max(0, start_column - stride_length)
125
+ columns_to_check = array[:, end_column:start_column]
126
+ total_zeros = np.sum(columns_to_check == 0)
127
+ return total_zeros
128
+
129
+ def calculate_total_zeros_in_downward_stride(matrix, start_row, stride_length):
130
+ end_row = min(start_row + stride_length, matrix.shape[0])
131
+ rows_to_check = matrix[start_row:end_row, :]
132
+ total_zeros = np.sum(rows_to_check == 0)
133
+ return total_zeros
134
+
135
+ def calculate_total_zeros_in_upward_stride(matrix, start_row, stride_length):
136
+ end_row = max(0, start_row - stride_length)
137
+ rows_to_check = matrix[end_row:start_row, :]
138
+ total_zeros = np.sum(rows_to_check == 0)
139
+ return total_zeros
140
+
141
+ def pad_and_crop_mask(mask, image, padding):
142
+ non_zero_indices = np.where(mask == 255)
143
+ y_min, y_max = np.min(non_zero_indices[0]), np.max(non_zero_indices[0]) + 1
144
+ x_min, x_max = np.min(non_zero_indices[1]), np.max(non_zero_indices[1]) + 1
145
+ pad_width = ((padding, padding), (padding, padding))
146
+ y_min = max(y_min - pad_width[0][0], 0)
147
+ y_max = min(y_max + pad_width[0][1], image.shape[0])
148
+ x_min = max(x_min - pad_width[1][0], 0)
149
+ x_max = min(x_max + pad_width[1][1], image.shape[1])
150
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
151
+ cropped_image = image_rgb[y_min:y_max, x_min:x_max]
152
+ h,w,_ = cropped_image.shape
153
+ if h > w:
154
+ cropped_image = cropped_image[:int((3*h)/4), :]
155
+ else:
156
+ cropped_image = cropped_image[:, :int(w/2)]
157
+ return cropped_image
158
+
159
+ def get_object_from_mask(image, mask):
160
+ if not isinstance(image, np.ndarray) or not isinstance(mask, np.ndarray):
161
+ raise TypeError("Image and mask must be NumPy arrays.")
162
+ if image.shape[:2] != mask.shape:
163
+ raise ValueError("Image and mask must have the same spatial dimensions.")
164
+ object_image = np.zeros_like(image)
165
+ object_image[mask == 255] = image[mask == 255]
166
+ object_image = cv2.cvtColor(object_image, cv2.COLOR_BGR2RGB)
167
+ return object_image
168
+
169
+ def orient_and_adjust(image,bbox):
170
+ if image.shape[1]>image.shape[0]: # -- image is horizontal
171
+ img = image[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]]
172
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
173
+ for i in range(img.shape[1]):
174
+ if np.count_nonzero(img[:, i]) >= 20:
175
+ left_index = i
176
+ break
177
+ for i in range(img.shape[1] - 1, -1, -1):
178
+ if np.count_nonzero(img[:, i]) >= 20:
179
+ right_index = i
180
+ break
181
+ total_zeros_towards_right = calculate_total_zeros_in_stride_right(img, left_index, 15)
182
+ total_zeros_towards_left = calculate_total_zeros_in_left_stride(img, right_index, 15)
183
+ if total_zeros_towards_right > total_zeros_towards_left: #---coming from left
184
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
185
+ image = cv2.flip(image,1)
186
+ orien = 'Rotate90'
187
+ else: #-----coming from right
188
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
189
+ image = cv2.flip(image,1)
190
+ orien = 'Rotate90anti'
191
+ return "H", image, orien
192
+ else:
193
+ img = image[bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]]
194
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
195
+ for i in range(img.shape[0]):
196
+ if np.count_nonzero(img[i, :]) >= 20:
197
+ top_index = i
198
+ break
199
+ for i in range(img.shape[0] - 1, -1, -1):
200
+ if np.count_nonzero(img[i, :]) >= 20:
201
+ bottom_index = i
202
+ break
203
+ total_zeros_towards_down = calculate_total_zeros_in_downward_stride(img, top_index, 15)
204
+ total_zeros_towards_up = calculate_total_zeros_in_upward_stride(img, bottom_index, 15)
205
+ if total_zeros_towards_down > total_zeros_towards_up: # ----- coming from down
206
+ # image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
207
+ image = cv2.flip(image,1)
208
+ orien = 'No'
209
+ else: # coming from up
210
+ image = cv2.rotate(image, cv2.ROTATE_180)
211
+ image = cv2.flip(image,1)
212
+ orien = '180'
213
+ return "V", image, orien
214
+
215
+ def tight_crop_with_padding(image, original_image, padding=5):
216
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
217
+ _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
218
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
219
+ largest_contour = max(contours, key=cv2.contourArea)
220
+ x, y, w, h = cv2.boundingRect(largest_contour)
221
+ x, y, w, h = x - padding, y - padding, w + padding * 2, h + padding * 2
222
+ cropped_image = image[y:y+h, x:x+w]
223
+ crop_original = original_image[y:y+h, x:x+w]
224
+ return cropped_image, crop_original
225
+
226
+ def split_image_horizontally(image):
227
+ height, width, channels = image.shape
228
+ half_height = int(0.8*height)
229
+ upper_half = image[:half_height, :, :]
230
+ return upper_half
231
+
232
+ def tight_bounding_box(image):
233
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
234
+ _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
235
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
236
+ cnt = max(contours, key=cv2.contourArea)
237
+ ellipse = cv2.fitEllipse(cnt)
238
+ return ellipse[1][0]/ellipse[1][1]
239
+
240
+ def resize_image_with_aspect_ratio(image, max_width=None, max_height=None):
241
+ height, width, _ = image.shape
242
+ aspect_ratio = width / height
243
+ if max_width and width > max_width:
244
+ new_width = max_width
245
+ new_height = int(new_width / aspect_ratio)
246
+ elif max_height and height > max_height:
247
+ new_height = max_height
248
+ new_width = int(new_height * aspect_ratio)
249
+ else:
250
+ return image
251
+ resized_image = cv2.resize(image, (new_width, new_height))
252
+ return resized_image
253
+
254
+ def main(args: argparse.Namespace):
255
+ print("Loading model...")
256
+ sam = sam_model_registry['vit_h'](checkpoint="< Path to sam_vit_h_4b8939.pth cloned from SAM v1 repo >").to(device=args.device)
257
+ output_mode = "binary_mask"
258
+ amg_kwargs = get_amg_kwargs(args)
259
+ generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
260
+
261
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device=args.device)
262
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
263
+ text_prompt = "Human Finger"
264
+
265
+ parent_folder = args.parentdir
266
+ dstn_folder = args.dstndir
267
+
268
+ targets = list()
269
+ for folder in os.listdir(parent_folder):
270
+ if not os.path.exists(os.path.join(dstn_folder,folder)):
271
+ os.mkdir(os.path.join(dstn_folder,folder))
272
+ for sub in os.listdir(os.path.join(parent_folder,folder)):
273
+ if not os.path.exists(os.path.join(dstn_folder,folder,sub)):
274
+ os.mkdir(os.path.join(dstn_folder,folder,sub))
275
+ for file in os.listdir(os.path.join(parent_folder,folder,sub)):
276
+ targets.append(os.path.join(parent_folder,folder,sub,file))
277
+
278
+ exce = list()
279
+ skipper = 0
280
+ for t in tqdm(targets):
281
+ image = cv2.imread(t)
282
+ if image is None:
283
+ print(f"Could not load '{t}' as an image, skipping...")
284
+ continue
285
+ if image.shape[0] > image.shape[1]:
286
+ image = resize_image_with_aspect_ratio(image, 2448, 3264)
287
+ else:
288
+ image = resize_image_with_aspect_ratio(image, 3264, 2448)
289
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
290
+ masks = generator.generate(image)
291
+ dstn_file = t.split("/")[-1]
292
+ count=1
293
+ img_lst = list()
294
+ sim_lst = list()
295
+ if output_mode == "binary_mask":
296
+ lst,box_lst = write_masks_to_folder(masks)
297
+ for i in lst:
298
+ i = get_object_from_mask(image, i)
299
+ img = Image.fromarray(i)
300
+ inputs = processor(text=[text_prompt], images=img, return_tensors="pt", padding=True).to(device=args.device)
301
+ with torch.no_grad():
302
+ outputs = model(**inputs)
303
+ logits_per_image = outputs.logits_per_image
304
+ sim_lst.append(logits_per_image.cpu().numpy()[0])
305
+ img_lst.append(i)
306
+ count += 1
307
+ best_image = img_lst[sim_lst.index(max(sim_lst))]
308
+ bbox = box_lst[sim_lst.index(max(sim_lst))]
309
+ # postprocessing
310
+ try:
311
+ orienta, best_image, orien = orient_and_adjust(best_image,bbox)
312
+ if orienta == 'H' and orien == 'Rotate90':
313
+ image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
314
+ elif orienta == 'H' and orien == 'Rotate90anti':
315
+ image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
316
+ elif orienta == 'V' and orien == 'No':
317
+ pass
318
+ elif orienta == 'V' and orien == '180':
319
+ image = cv2.rotate(image, cv2.ROTATE_180)
320
+ image = cv2.flip(image,1)
321
+ best_image, image = tight_crop_with_padding(best_image,image,5)
322
+ ratio = tight_bounding_box(best_image)
323
+ if 0.46<=ratio<=55:
324
+ pass
325
+ else:
326
+ image = split_image_horizontally(image)
327
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
328
+ cv2.imwrite(os.path.join(dstn_folder,t.split("/")[-1]),image)
329
+ except:
330
+ skipper+=1
331
+ print(t.split("/")[-1])
332
+ exce.append(t.split("/")[-1])
333
+ print(f"number of files skipped: {len(exce)}")
334
+ with open(dstn_folder.split("/")[-2]+"_"+dstn_folder.split("/")[-1]+"_exceptions_v2.json",'w') as js:
335
+ json.dump(exce,js,indent=4)
336
+
337
+ if __name__ == "__main__":
338
+ args = parser.parse_args()
339
+ main(args)