Spaces:
Sleeping
Sleeping
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| import cv2 | |
| import numpy as np | |
| from XMem2.inference.interact.interactive_utils import overlay_davis | |
| from config import DEVICE | |
| from tools.mask_display import visualize_unique_mask | |
| import torch | |
| from tools.mask_merge import create_mask, merge_masks | |
| class Segmenter: | |
| def __init__(self, device: str = DEVICE): | |
| self.device = device | |
| sam2_checkpoint = 'checkpoints/sam2.1_hiera_base_plus.pt' | |
| model_cfg = 'configs/sam2.1/sam2.1_hiera_b+.yaml' | |
| build = build_sam2(model_cfg, sam2_checkpoint, device=self.device) | |
| self.predictor = SAM2ImagePredictor(build) | |
| self.embedded = False | |
| def set_image(self, image: np.ndarray): | |
| self.original_image = image | |
| if self.embedded: | |
| print('please reset_image') | |
| return | |
| self.predictor.set_image(image) | |
| self.embedded = True | |
| def reset_image(self): | |
| self.predictor.reset_predictor() | |
| self.embedded = False | |
| def predict(self, prompt, mode='point', multimask=True): | |
| assert self.embedded, 'dont set image' | |
| assert mode in ['point', 'box', 'both'], 'mode can be point, box or both' | |
| if mode == 'point': | |
| masks, scores, logits = self.predictor.predict( | |
| point_coords=prompt['point_coords'], | |
| point_labels=prompt['point_labels'], | |
| multimask_output=multimask, | |
| ) | |
| elif mode == 'box': | |
| masks, scores, logits = self.predictor.predict( | |
| box=prompt['boxes'], | |
| multimask_output=multimask, | |
| ) | |
| elif mode == 'both': | |
| masks, scores, logits = self.predictor.predict( | |
| point_coords=prompt['point_coords'], | |
| point_labels=prompt['point_labels'], | |
| box=prompt['boxes'], | |
| multimask_output=multimask, | |
| ) | |
| else: | |
| raise ('Error') | |
| return masks, scores, logits | |
| if __name__ == '__main__': | |
| path = 'video-test/truck.jpg' | |
| path = 'video-test/video.mp4' | |
| video = cv2.VideoCapture(path) | |
| ret, frame = video.read() | |
| frame_cop = frame.copy() | |
| video.release() | |
| bboxes = [[476, 166, 102, 154], [8, 252, 91, 149], [106, 335, 211, 90]] | |
| points = [[531, 230], [45, 321], [226, 360], [194, 313]] | |
| prompts = { | |
| 'mode': 'point', | |
| 'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]], | |
| 'point_labels': [1, 1, 1, 1], | |
| } | |
| # prompts = { | |
| # 'mode': 'point', | |
| # 'point_coords': [[[531, 230], [45, 321]], [226, 360], [194, 313]], | |
| # 'point_labels': [[1, 0], 1, 1], | |
| # } | |
| # prompts = { | |
| # 'mode': 'box', | |
| # 'boxes': [ | |
| # [476, 166, 578, 320], | |
| # [8, 252, 99, 401], | |
| # [106, 335, 317, 425], | |
| # [155, 283, 225, 339], | |
| # ], | |
| # } | |
| # prompts = { | |
| # 'mode': 'both', | |
| # 'point_coords': [[575, 750]], | |
| # 'point_labels': [0], | |
| # 'boxes': [[425, 600, 700, 875]], | |
| # } | |
| # prompts = { | |
| # 'mode': 'box', | |
| # 'boxes': [ | |
| # [75, 275, 1725, 850], | |
| # [425, 600, 700, 875], | |
| # [1375, 550, 1650, 800], | |
| # [1240, 675, 1400, 750], | |
| # ], | |
| # } | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| seg = Segmenter() | |
| seg.set_image(frame) | |
| maskss = [] | |
| if prompts['mode'] == 'point': | |
| for point_c, point_l in zip(prompts['point_coords'], prompts['point_labels']): | |
| prompt = { | |
| 'point_coords': np.array([point_c]), | |
| 'point_labels': np.array([point_l]), | |
| 'boxes': None, | |
| } | |
| masks, scores, logits = seg.predict(prompt, prompts['mode']) | |
| maskss.append(masks[np.argmax(scores)]) | |
| elif prompts['mode'] == 'box': | |
| for box in prompts['boxes']: | |
| prompt = { | |
| 'boxes': np.array([box]), | |
| } | |
| masks, scores, logits = seg.predict(prompt, prompts['mode'], multimask=True) | |
| maskss.append(masks[np.argmax(scores)]) | |
| # masks, scores, logits = seg.predict(prompts, prompts['mode'], multimask=False) | |
| else: | |
| masks, scores, logits = seg.predict(prompts, prompts['mode'], multimask=False) | |
| print(len(maskss)) | |
| print(len(masks)) | |
| # plt.imshow(frame) | |
| if len(maskss) < 1: | |
| maskss = [] | |
| for mask in maskss: | |
| # mask = show_mask(mask.squeeze(0), plt.gca(), random_color=True) | |
| mask = create_mask(mask.squeeze(0), random_color=True) | |
| maskss.append(mask) | |
| # plt.axis('off') | |
| # plt.show() | |
| # input_box = np.array([425, 600, 700, 875]) | |
| # input_point = np.array([[575, 750]]) | |
| # input_label = np.array([0]) | |
| # show_masks( | |
| # frame, | |
| # masks, | |
| # scores, | |
| # box_coords=input_box, | |
| # point_coords=input_point, | |
| # input_labels=input_label, | |
| # ) | |
| mask, unique_mask = merge_masks(maskss) | |
| f = overlay_davis(frame, unique_mask) | |
| mask = visualize_unique_mask(unique_mask) | |
| f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) | |
| cv2.imshow('asd', mask) | |
| cv2.imshow('asd', f) | |
| cv2.waitKey(0) | |
| cv2.destroyAllWindows() | |