import cv2 import numpy as np import psutil from tqdm import tqdm from tools.mask_merge import merge_masks from tracker_core_xmem2 import TrackerCore from tools.overlay_image import painter_borders from XMem2.inference.interact.interactive_utils import overlay_davis from sam_controller import SegmenterController from interactive_video import InteractVideo class Tracker: def __init__( self, segmenter_controller: SegmenterController, tracker_core: TrackerCore ): self.sam_controller = segmenter_controller self.tracker = tracker_core print(f'used {TrackerCore.name_version}') def select_object(self, prompts: dict) -> np.ndarray: # maskss = [] # for point in points: # prompts = { # 'point_coords': np.array([point]), # 'point_labels': np.array([1]), # } # masks, scores, logits = self.segmenter.predict(prompts, 'point') # maskss.append(masks[np.argmax(scores)]) results = self.sam_controller.predict_from_prompts(prompts) results_masks = [ result[np.argmax(scores)] for result, scores, logits in results ] mask, unique_mask = merge_masks(results_masks) return unique_mask def tracking( self, frames: list[np.ndarray], template_mask: np.ndarray, exhaustive: bool = False, ) -> list: masks = [] for i in tqdm(range(len(frames)), desc='Tracking'): current_memory_usage = psutil.virtual_memory().percent if current_memory_usage > 90: break """ TODO: улучшение точности - надо проверять сколько масок в трекере - смотреть сколько объектов обнаруживается - если они не совпадают добавлять к новым маскам маску из трекера """ if i == 0: mask = self.tracker.track(frames[i], template_mask, exhaustive) masks.append(mask) else: mask = self.tracker.track(frames[i]) masks.append(mask) return masks def tracking_cut( self, frames: list[np.ndarray], templates_masks: dict[str, np.ndarray], exhaustive: bool = False, ): masks = [] for i in tqdm(range(len(frames)), desc='Tracking_cut'): current_memory_usage = psutil.virtual_memory().percent if current_memory_usage > 90: break if str(i) in templates_masks: template_mask = templates_masks[str(i)] if i == 0 and str(i) in templates_masks: mask = self.tracker.track(frames[i], template_mask, exhaustive) masks.append(mask) else: mask = self.tracker.track(frames[i]) masks.append(mask) if len(templates_masks) > 1: exhaustive = True return masks if __name__ == '__main__': path = 'video-test/VID_20241218_134328.mp4' key_interval = 3 controller = InteractVideo(path, key_interval) controller.extract_frames() controller.collect_keypoints() results = controller.get_results() segmenter_controller = SegmenterController() tracker_core = TrackerCore() tracker = Tracker(segmenter_controller, tracker_core) frames = results['frames'] # prompts = { # 'mode': 'point', # 'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]], # 'point_labels': [1, 1, 1, 1], # } frames_idx = list(map(int, results['keypoints'].keys())) result = [] for i in range(len(frames_idx) - 1): current_frame = frames_idx[i] current_coords = results['keypoints'][str(current_frame)] next_frame = frames_idx[i + 1] print(current_frame, next_frame) if current_coords: tracker.sam_controller.load_image(frames[current_frame]) prompts = { 'mode': 'point', 'point_coords': current_coords, 'point_labels': [1] * len(current_coords), } mask = tracker.select_object(prompts) tracker.sam_controller.reset_image() result.append( { "gap": [current_frame, next_frame], "frame": current_frame, "mask": mask, } ) else: result.append( { "gap": [current_frame, next_frame], "frame": current_frame, "mask": None, } ) # masks = tracking.tracking(frames, mask) masks = [] for res in result: current_frame, next_frame = res['gap'] if res['mask'] is not None: print(current_frame, next_frame) mask = tracker.tracking(frames[current_frame:next_frame], res['mask']) tracker.tracker.clear_memory() masks += mask else: print(current_frame, next_frame) m = [] for _ in range(current_frame, next_frame): height, width, _ = frames[current_frame].shape binary_mask = np.zeros((height, width), dtype=np.uint8) binary_mask[:, :] = 1 m.append(binary_mask) masks += m filename = 'output_video_from_file_mem2_ved_pot.mp4' output = cv2.VideoWriter( filename, cv2.VideoWriter_fourcc(*'XVID'), controller.fps, controller.frame_size ) for frame, mask in zip(frames, masks): # f = painter_borders(frame, mask) f = overlay_davis(frame, mask) output.write(f) # Освобождаем ресурсы output.release() cv2.destroyAllWindows() print(f'Видео записано в файл: {filename}')