track-anything-annotate / segmenter.py
lniki's picture
add model
0e83290 verified
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import cv2
import numpy as np
from XMem2.inference.interact.interactive_utils import overlay_davis
from config import DEVICE
from tools.mask_display import visualize_unique_mask
import torch
from tools.mask_merge import create_mask, merge_masks
class Segmenter:
def __init__(self, device: str = DEVICE):
self.device = device
sam2_checkpoint = 'checkpoints/sam2.1_hiera_large.pt'
model_cfg = 'configs/sam2.1/sam2.1_hiera_l.yaml'
build = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
self.predictor = SAM2ImagePredictor(build)
self.embedded = False
@torch.no_grad()
def set_image(self, image: np.ndarray):
self.original_image = image
if self.embedded:
print('please reset_image')
return
self.predictor.set_image(image)
self.embedded = True
@torch.no_grad()
def reset_image(self):
self.predictor.reset_predictor()
self.embedded = False
def predict(self, prompt, mode='point', multimask=True):
assert self.embedded, 'dont set image'
assert mode in ['point', 'box', 'both'], 'mode can be point, box or both'
if mode == 'point':
masks, scores, logits = self.predictor.predict(
point_coords=prompt['point_coords'],
point_labels=prompt['point_labels'],
multimask_output=multimask,
)
elif mode == 'box':
masks, scores, logits = self.predictor.predict(
box=prompt['boxes'],
multimask_output=multimask,
)
elif mode == 'both':
masks, scores, logits = self.predictor.predict(
point_coords=prompt['point_coords'],
point_labels=prompt['point_labels'],
box=prompt['boxes'],
multimask_output=multimask,
)
else:
raise ('Error')
return masks, scores, logits
if __name__ == '__main__':
path = 'video-test/truck.jpg'
path = 'video-test/video.mp4'
video = cv2.VideoCapture(path)
ret, frame = video.read()
frame_cop = frame.copy()
video.release()
bboxes = [[476, 166, 102, 154], [8, 252, 91, 149], [106, 335, 211, 90]]
points = [[531, 230], [45, 321], [226, 360], [194, 313]]
prompts = {
'mode': 'point',
'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]],
'point_labels': [1, 1, 1, 1],
}
# prompts = {
# 'mode': 'point',
# 'point_coords': [[[531, 230], [45, 321]], [226, 360], [194, 313]],
# 'point_labels': [[1, 0], 1, 1],
# }
# prompts = {
# 'mode': 'box',
# 'boxes': [
# [476, 166, 578, 320],
# [8, 252, 99, 401],
# [106, 335, 317, 425],
# [155, 283, 225, 339],
# ],
# }
# prompts = {
# 'mode': 'both',
# 'point_coords': [[575, 750]],
# 'point_labels': [0],
# 'boxes': [[425, 600, 700, 875]],
# }
# prompts = {
# 'mode': 'box',
# 'boxes': [
# [75, 275, 1725, 850],
# [425, 600, 700, 875],
# [1375, 550, 1650, 800],
# [1240, 675, 1400, 750],
# ],
# }
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
seg = Segmenter()
seg.set_image(frame)
maskss = []
if prompts['mode'] == 'point':
for point_c, point_l in zip(prompts['point_coords'], prompts['point_labels']):
prompt = {
'point_coords': np.array([point_c]),
'point_labels': np.array([point_l]),
'boxes': None,
}
masks, scores, logits = seg.predict(prompt, prompts['mode'])
maskss.append(masks[np.argmax(scores)])
elif prompts['mode'] == 'box':
for box in prompts['boxes']:
prompt = {
'boxes': np.array([box]),
}
masks, scores, logits = seg.predict(prompt, prompts['mode'], multimask=True)
maskss.append(masks[np.argmax(scores)])
# masks, scores, logits = seg.predict(prompts, prompts['mode'], multimask=False)
else:
masks, scores, logits = seg.predict(prompts, prompts['mode'], multimask=False)
print(len(maskss))
print(len(masks))
# plt.imshow(frame)
if len(maskss) < 1:
maskss = []
for mask in maskss:
# mask = show_mask(mask.squeeze(0), plt.gca(), random_color=True)
mask = create_mask(mask.squeeze(0), random_color=True)
maskss.append(mask)
# plt.axis('off')
# plt.show()
# input_box = np.array([425, 600, 700, 875])
# input_point = np.array([[575, 750]])
# input_label = np.array([0])
# show_masks(
# frame,
# masks,
# scores,
# box_coords=input_box,
# point_coords=input_point,
# input_labels=input_label,
# )
mask, unique_mask = merge_masks(maskss)
f = overlay_davis(frame, unique_mask)
mask = visualize_unique_mask(unique_mask)
f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
cv2.imshow('asd', mask)
cv2.imshow('asd', f)
cv2.waitKey(0)
cv2.destroyAllWindows()