| import cv2 | |
| import os | |
| import sys | |
| import clip | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| def convert_box_xywh_to_xyxy(box): | |
| if len(box) == 4: | |
| return [box[0],box[1],box[0]+box[2],box[1]+box[3]] | |
| else: | |
| result = [] | |
| for b in box: | |
| b = convert_box_xywh_to_xyxy(b) | |
| result.append(b) | |
| return result | |
| def segment_image(image,bbox): | |
| image_array = np.array(image) | |
| segmented_image_array = np.zeros_like(image_array) | |
| x1,y1,x2,y2 = bbox | |
| segmented_image_array[y1:y2,x1:x2] = image_array[y1:y2,x1:x2] | |
| segmented_image = Image.fromarray(segmented_image_array) | |
| black_image = Image.new("RGB",image.size,(255,255,255)) | |
| transparency_mask = np.zeros((image_array.shape[0],image_array.shape[1]),dtype=np.uint8) | |
| transparency_mask[y1:y2,x1:x2] = 255 | |
| transparency_mask_image = Image.fromarray(transparency_mask,mode="L") | |
| black_image.paste(segmented_image,mask=transparency_mask_image) | |
| return black_image | |
| def format_results(result,filter=0): | |
| annotations = [] | |
| n = len(result.masks.data) | |
| for i in range(n): | |
| annotation = [] | |
| mask = result.masks.data[i] == 1.0 | |
| if torch.sum(mask) < filter: | |
| continue | |
| annotation['id'] = i | |
| annotation['segmentation'] = mask.cpu().numpy() | |
| annotation['bbox'] = result.boxes.data[i] | |
| annotation['score'] = result.boxes.conf[i] | |
| annotation['area'] = annotation['segmentation'].sum() | |
| annotations.append(annotation) | |
| return annotations | |
| def filter_masks(annotations): | |
| annotations.sort(key=lambda x: x['area'],reverse=True) | |
| to_remove = set() | |
| for i in range(0,len(annotations)): | |
| a = annotations[i] | |
| for j in range(i+1,len(annotations)): | |
| b = annotations[j] | |
| if i!=j and (j not in to_remove): | |
| if b['area'] < a['area']: | |
| if (a['segmentation'] & b['segmentation']).sum()/b['segmentation'].sum()>0.8: | |
| to.remove.add(j) | |
| return [a for i,a in enumerate(annotations) if i not in to_remove], to_remove | |
| def get_bbox_from_mask(mask): | |
| mask = mask.astype(np.uint8) | |
| contours,hierarchy = cv2.findContours(mask,cv2.RETR) |