Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def show_points(coords, labels, ax, marker_size=375): | |
| pos_points = coords[labels==1] | |
| neg_points = coords[labels==0] | |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| def show_box(box, ax): | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
| def merge_bounding_boxes(bbox1, bbox2): | |
| xmin1, ymin1, xmax1, ymax1 = bbox1 | |
| xmin2, ymin2, xmax2, ymax2 = bbox2 | |
| xmin_merged = min(xmin1, xmin2) | |
| ymin_merged = min(ymin1, ymin2) | |
| xmax_merged = max(xmax1, xmax2) | |
| ymax_merged = max(ymax1, ymax2) | |
| return np.array([xmin_merged, ymin_merged, xmax_merged, ymax_merged]) | |
| def init_sam( | |
| device="cuda", | |
| ckpt_path='/users/kchen157/scratch/weights/SAM/sam_vit_h_4b8939.pth' | |
| ): | |
| sam = sam_model_registry['vit_h'](checkpoint=ckpt_path) | |
| sam.to(device=device) | |
| predictor = SamPredictor(sam) | |
| return predictor | |
| def segment_hand_and_object( | |
| predictor, | |
| image, | |
| hand_kpts, | |
| hand_mask=None, | |
| box_shift_ratio = 0.3, | |
| box_size_factor = 2., | |
| area_threshold = 0.2, | |
| overlap_threshold = 200): | |
| # Find bounding box for HOI | |
| input_box = {} | |
| for hand_type in ['right', 'left']: | |
| if hand_type not in hand_kpts: | |
| continue | |
| input_box[hand_type] = np.stack([hand_kpts[hand_type].min(axis=0), hand_kpts[hand_type].max(axis=0)]) | |
| box_trans = input_box[hand_type][0] * box_shift_ratio + input_box[hand_type][1] * (1 - box_shift_ratio) | |
| input_box[hand_type] = ((input_box[hand_type] - box_trans) * box_size_factor + box_trans).reshape(-1) | |
| if len(input_box) == 2: | |
| input_box = merge_bounding_boxes(input_box['right'], input_box['left']) | |
| input_point = np.array([hand_kpts['right'][0], hand_kpts['left'][0]]) | |
| input_label = np.array([1, 1]) | |
| elif 'right' in input_box: | |
| input_box = input_box['right'] | |
| input_point = np.array([hand_kpts['right'][0]]) | |
| input_label = np.array([1]) | |
| elif 'left' in input_box: | |
| input_box = input_box['left'] | |
| input_point = np.array([hand_kpts['left'][0]]) | |
| input_label = np.array([1]) | |
| box_area = (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) | |
| # segment hand using the wrist point | |
| predictor.set_image(image) | |
| if hand_mask is None: | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=False, | |
| ) | |
| hand_mask = masks[0] | |
| # segment object in hand | |
| input_label = np.zeros_like(input_label) | |
| masks, scores, _ = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| box=input_box[None, :], | |
| multimask_output=False, | |
| ) | |
| object_mask = masks[0] | |
| if (masks[0].astype(int) * hand_mask).sum() > overlap_threshold: | |
| # print('False positive: The mask overlaps the hand.') | |
| object_mask = np.zeros_like(object_mask) | |
| elif object_mask.astype(int).sum() / box_area > area_threshold: | |
| # print('False positive: The area is very big, probably the background') | |
| object_mask = np.zeros_like(object_mask) | |
| return object_mask, hand_mask |