|
|
import cv2
|
|
|
from XMem2.inference.interact.interactive_utils import overlay_davis
|
|
|
from segmenter import Segmenter
|
|
|
from tools.mask_display import visualize_unique_mask
|
|
|
from tools.mask_merge import merge_masks
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class SegmenterController:
|
|
|
def __init__(self):
|
|
|
"""
|
|
|
Инициализация контроллера для работы с Segmenter2.
|
|
|
:param device: Устройство для выполнения вычислений ('cuda' или 'cpu').
|
|
|
"""
|
|
|
self.segmenter = Segmenter()
|
|
|
self.image_set = False
|
|
|
|
|
|
def load_image(self, image: np.ndarray):
|
|
|
"""
|
|
|
Загружает изображение в Segmenter2.
|
|
|
:param image: Изображение в формате NumPy массива (H, W, C).
|
|
|
"""
|
|
|
if self.image_set:
|
|
|
print("Изображение уже загружено. Сбросьте его перед загрузкой нового.")
|
|
|
return
|
|
|
try:
|
|
|
self.segmenter.set_image(image)
|
|
|
self.image_set = True
|
|
|
print("Изображение успешно загружено.")
|
|
|
except Exception as e:
|
|
|
print(f"Ошибка при загрузке изображения: {e}")
|
|
|
|
|
|
def reset_image(self):
|
|
|
"""
|
|
|
Сбрасывает текущее изображение в Segmenter2.
|
|
|
"""
|
|
|
if not self.image_set:
|
|
|
print("Нет загруженного изображения для сброса.")
|
|
|
return
|
|
|
try:
|
|
|
self.segmenter.reset_image()
|
|
|
self.image_set = False
|
|
|
print("Изображение успешно сброшено.")
|
|
|
except Exception as e:
|
|
|
print(f"Ошибка при сбросе изображения: {e}")
|
|
|
|
|
|
def _process_point_prompt(
|
|
|
self,
|
|
|
point_coords: list[list[int] | list[list[int]]],
|
|
|
point_labels: list[list[int] | list[list[int]]],
|
|
|
) -> list[dict[str, np.ndarray]]:
|
|
|
"""
|
|
|
Обрабатывает промпт для точек.
|
|
|
:param point_coords: Координаты точек.
|
|
|
:param point_labels: Метки точек.
|
|
|
:return: Список словарей с подготовленными данными для predict.
|
|
|
"""
|
|
|
prompts = []
|
|
|
for coords, labels in zip(point_coords, point_labels):
|
|
|
|
|
|
if isinstance(coords[0], list) and isinstance(labels, list):
|
|
|
|
|
|
prompt = {
|
|
|
"point_coords": np.array(coords),
|
|
|
"point_labels": np.array(labels),
|
|
|
}
|
|
|
prompts.append((prompt, False))
|
|
|
else:
|
|
|
|
|
|
prompt = {
|
|
|
"point_coords": np.array([coords]),
|
|
|
"point_labels": np.array([labels]),
|
|
|
}
|
|
|
prompts.append((prompt, True))
|
|
|
return prompts
|
|
|
|
|
|
def _process_box_prompt(
|
|
|
self, boxes: list[list[int]]
|
|
|
) -> list[dict[str, np.ndarray]]:
|
|
|
"""
|
|
|
Обрабатывает промпт для рамок.
|
|
|
:param boxes: Рамки.
|
|
|
:return: Список словарей с подготовленными данными для predict.
|
|
|
"""
|
|
|
prompts = []
|
|
|
for box in boxes:
|
|
|
prompt = {"boxes": np.array([box])}
|
|
|
prompts.append((prompt, True))
|
|
|
return prompts
|
|
|
|
|
|
def _process_both_prompt(
|
|
|
self,
|
|
|
point_coords: list[list[int] | None],
|
|
|
point_labels: list[int | None],
|
|
|
boxes: list[list[int]],
|
|
|
) -> list[dict[str, np.ndarray]]:
|
|
|
"""
|
|
|
Обрабатывает промпт для комбинированного режима.
|
|
|
:param point_coords: Координаты точек.
|
|
|
:param point_labels: Метки точек.
|
|
|
:param boxes: Рамки.
|
|
|
:return: Список словарей с подготовленными данными для predict.
|
|
|
"""
|
|
|
prompts = []
|
|
|
for box, coords, labels in zip(boxes, point_coords, point_labels):
|
|
|
prompt = {"boxes": np.array([box])}
|
|
|
if coords is not None and labels is not None:
|
|
|
prompt["point_coords"] = np.array([coords])
|
|
|
prompt["point_labels"] = np.array([labels])
|
|
|
prompts.append((prompt, False))
|
|
|
else:
|
|
|
prompts.append((prompt, True))
|
|
|
return prompts
|
|
|
|
|
|
def predict_from_prompts(
|
|
|
self, prompts: dict[str, str | list]
|
|
|
) -> list[list[np.ndarray, np.ndarray, np.ndarray]]:
|
|
|
"""
|
|
|
Выполняет предсказание на основе заданного промпта.
|
|
|
:param prompts: Словарь с данными для предсказания.
|
|
|
:return: Список кортежей (маски, оценки, логиты).
|
|
|
"""
|
|
|
if not self.image_set:
|
|
|
raise RuntimeError("Изображение не загружено. Сначала вызовите load_image.")
|
|
|
|
|
|
mode = prompts.get("mode")
|
|
|
results = []
|
|
|
|
|
|
if mode == "point":
|
|
|
point_coords = prompts.get("point_coords", [])
|
|
|
point_labels = prompts.get("point_labels", [])
|
|
|
processed_prompts = self._process_point_prompt(point_coords, point_labels)
|
|
|
elif mode == "box":
|
|
|
boxes = prompts.get("boxes", [])
|
|
|
processed_prompts = self._process_box_prompt(boxes)
|
|
|
elif mode == "both":
|
|
|
point_coords = prompts.get(
|
|
|
"point_coords", [None] * len(prompts.get("boxes", []))
|
|
|
)
|
|
|
point_labels = prompts.get(
|
|
|
"point_labels", [None] * len(prompts.get("boxes", []))
|
|
|
)
|
|
|
boxes = prompts.get("boxes", [])
|
|
|
processed_prompts = self._process_both_prompt(
|
|
|
point_coords, point_labels, boxes
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError("Режим должен быть 'point', 'box' или 'both'.")
|
|
|
|
|
|
|
|
|
|
|
|
for prompt, multimask in processed_prompts:
|
|
|
try:
|
|
|
masks, scores, logits = self.segmenter.predict(
|
|
|
prompt, mode=mode, multimask=multimask
|
|
|
)
|
|
|
results.append([masks, scores, logits])
|
|
|
except Exception as e:
|
|
|
print(f"Ошибка при выполнении предсказания: {e}")
|
|
|
raise
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
controller = SegmenterController()
|
|
|
|
|
|
path = 'video-test/truck.jpg'
|
|
|
path = 'video-test/video.mp4'
|
|
|
video = cv2.VideoCapture(path)
|
|
|
ret, frame = video.read()
|
|
|
frame_cop = frame.copy()
|
|
|
video.release()
|
|
|
controller.load_image(frame)
|
|
|
import timeit
|
|
|
|
|
|
|
|
|
prompts = {
|
|
|
'mode': 'point',
|
|
|
'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]],
|
|
|
'point_labels': [1, 1, 1, 1],
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_segmentation():
|
|
|
prompts = {
|
|
|
'mode': 'point',
|
|
|
'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]],
|
|
|
'point_labels': [1, 0, 1, 1],
|
|
|
}
|
|
|
return controller.predict_from_prompts(prompts)
|
|
|
|
|
|
results = controller.predict_from_prompts(prompts)
|
|
|
|
|
|
execution_time_ms = timeit.timeit(run_segmentation, number=1) * 1000
|
|
|
print(f"Время выполнения: {execution_time_ms:.2f} мс")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(len(results))
|
|
|
res = [result[np.argmax(scores)] for result, scores, logits in results]
|
|
|
mask, unique_mask = merge_masks(res)
|
|
|
f = overlay_davis(frame, unique_mask)
|
|
|
mask = visualize_unique_mask(unique_mask)
|
|
|
f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
|
|
|
cv2.imshow('asd', mask)
|
|
|
cv2.imshow('asd', f)
|
|
|
cv2.waitKey(0)
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
|
|
|
controller.reset_image()
|
|
|
|