import cv2 from XMem2.inference.interact.interactive_utils import overlay_davis from segmenter import Segmenter from tools.mask_display import visualize_unique_mask from tools.mask_merge import merge_masks import numpy as np class SegmenterController: def __init__(self): """ Инициализация контроллера для работы с Segmenter2. :param device: Устройство для выполнения вычислений ('cuda' или 'cpu'). """ self.segmenter = Segmenter() self.image_set = False def load_image(self, image: np.ndarray): """ Загружает изображение в Segmenter2. :param image: Изображение в формате NumPy массива (H, W, C). """ if self.image_set: print("Изображение уже загружено. Сбросьте его перед загрузкой нового.") return try: self.segmenter.set_image(image) self.image_set = True print("Изображение успешно загружено.") except Exception as e: print(f"Ошибка при загрузке изображения: {e}") def reset_image(self): """ Сбрасывает текущее изображение в Segmenter2. """ if not self.image_set: print("Нет загруженного изображения для сброса.") return try: self.segmenter.reset_image() self.image_set = False print("Изображение успешно сброшено.") except Exception as e: print(f"Ошибка при сбросе изображения: {e}") def _process_point_prompt( self, point_coords: list[list[int] | list[list[int]]], point_labels: list[list[int] | list[list[int]]], ) -> list[dict[str, np.ndarray]]: """ Обрабатывает промпт для точек. :param point_coords: Координаты точек. :param point_labels: Метки точек. :return: Список словарей с подготовленными данными для predict. """ prompts = [] for coords, labels in zip(point_coords, point_labels): # Определяем, является ли текущий элемент списком координат или одной координатой if isinstance(coords[0], list) and isinstance(labels, list): # Если несколько точек и меток, multimask=False prompt = { "point_coords": np.array(coords), "point_labels": np.array(labels), } prompts.append((prompt, False)) else: # Если одна точка, multimask=True prompt = { "point_coords": np.array([coords]), "point_labels": np.array([labels]), } prompts.append((prompt, True)) return prompts def _process_box_prompt( self, boxes: list[list[int]] ) -> list[dict[str, np.ndarray]]: """ Обрабатывает промпт для рамок. :param boxes: Рамки. :return: Список словарей с подготовленными данными для predict. """ prompts = [] for box in boxes: prompt = {"boxes": np.array([box])} prompts.append((prompt, True)) # multimask=True для каждой рамки return prompts def _process_both_prompt( self, point_coords: list[list[int] | None], point_labels: list[int | None], boxes: list[list[int]], ) -> list[dict[str, np.ndarray]]: """ Обрабатывает промпт для комбинированного режима. :param point_coords: Координаты точек. :param point_labels: Метки точек. :param boxes: Рамки. :return: Список словарей с подготовленными данными для predict. """ prompts = [] for box, coords, labels in zip(boxes, point_coords, point_labels): prompt = {"boxes": np.array([box])} if coords is not None and labels is not None: prompt["point_coords"] = np.array([coords]) prompt["point_labels"] = np.array([labels]) prompts.append((prompt, False)) # multimask=False, если есть точки else: prompts.append((prompt, True)) # multimask=True, если точек нет return prompts def predict_from_prompts( self, prompts: dict[str, str | list] ) -> list[list[np.ndarray, np.ndarray, np.ndarray]]: """ Выполняет предсказание на основе заданного промпта. :param prompts: Словарь с данными для предсказания. :return: Список кортежей (маски, оценки, логиты). """ if not self.image_set: raise RuntimeError("Изображение не загружено. Сначала вызовите load_image.") mode = prompts.get("mode") results = [] if mode == "point": point_coords = prompts.get("point_coords", []) point_labels = prompts.get("point_labels", []) processed_prompts = self._process_point_prompt(point_coords, point_labels) elif mode == "box": boxes = prompts.get("boxes", []) processed_prompts = self._process_box_prompt(boxes) elif mode == "both": point_coords = prompts.get( "point_coords", [None] * len(prompts.get("boxes", [])) ) point_labels = prompts.get( "point_labels", [None] * len(prompts.get("boxes", [])) ) boxes = prompts.get("boxes", []) processed_prompts = self._process_both_prompt( point_coords, point_labels, boxes ) else: raise ValueError("Режим должен быть 'point', 'box' или 'both'.") # TODO: добавить вариант без цикла for prompt, multimask in processed_prompts: try: masks, scores, logits = self.segmenter.predict( prompt, mode=mode, multimask=multimask ) results.append([masks, scores, logits]) except Exception as e: print(f"Ошибка при выполнении предсказания: {e}") raise return results if __name__ == '__main__': # Создаем контроллер controller = SegmenterController() path = 'video-test/truck.jpg' path = 'video-test/video.mp4' video = cv2.VideoCapture(path) ret, frame = video.read() frame_cop = frame.copy() video.release() controller.load_image(frame) import timeit # Пример 1: Точки prompts = { 'mode': 'point', 'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]], 'point_labels': [1, 1, 1, 1], } # prompts = { # 'mode': 'point', # 'point_coords': [[[531, 230], [45, 321]], [226, 360], [194, 313]], # 'point_labels': [[1, 0], 1, 1], # } def run_segmentation(): prompts = { 'mode': 'point', 'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]], 'point_labels': [1, 0, 1, 1], } return controller.predict_from_prompts(prompts) results = controller.predict_from_prompts(prompts) execution_time_ms = timeit.timeit(run_segmentation, number=1) * 1000 print(f"Время выполнения: {execution_time_ms:.2f} мс") # Пример 2: Рамки # prompts = { # 'mode': 'box', # 'boxes': [ # [476, 166, 578, 320], # [8, 252, 99, 401], # [106, 335, 317, 425], # [155, 283, 225, 339], # ], # } # results = controller.predict_from_prompts(prompts) # Пример 3: Комбинированный режим # prompts = { # 'mode': 'both', # 'point_coords': [[575, 750]], # 'point_labels': [0], # 'boxes': [[425, 600, 700, 875]], # } # results = controller.predict_from_prompts(prompts) print(len(results)) res = [result[np.argmax(scores)] for result, scores, logits in results] mask, unique_mask = merge_masks(res) f = overlay_davis(frame, unique_mask) mask = visualize_unique_mask(unique_mask) f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) cv2.imshow('asd', mask) cv2.imshow('asd', f) cv2.waitKey(0) cv2.destroyAllWindows() # Сбрасываем изображение controller.reset_image()