import gradio as gr import cv2 import numpy as np from PIL import Image, ImageDraw from XMem2.inference.interact.interactive_utils import overlay_davis from sam_controller import SegmenterController from tracker import Tracker from tracker_core_xmem2 import TrackerCore # --- Извлечение всех кадров --- def extract_all_frames(video_input): video_path = video_input frames = [] try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) count_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) while cap.isOpened(): ret, frame = cap.read() if not ret: break if len(frames) == 100: break frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: print("read_frame_source:{} error. {}\n".format(video_path, str(e))) tracker.sam_controller.load_image(frames[0]) video_state = { "fps": fps, "count_frames": count_frames, } video_info = f'FPS: {video_state["fps"]} , Кадров: {video_state["count_frames"]}, Будет обработано: {len(frames)}' return frames[0], frames, video_state, video_info # --- Ручная разметка точками (первый кадр) --- def on_image_click(image, evt: gr.SelectData, annotations_state): x, y = evt.index[0], evt.index[1] annotations_state["point"].append([x, y]) # Отрисовка всех точек img = Image.fromarray(image) draw = ImageDraw.Draw(img) for ann in annotations_state["point"]: x_p, y_p = ann draw.ellipse((x_p - 5, y_p - 5, x_p + 5, y_p + 5), fill="blue") mask_info = f'Выбрано объектов: {len(annotations_state["point"])}, Координаты: {annotations_state["point"]}' return img, annotations_state, mask_info # --- Разметка всех кадров --- def tracking(frames: np.ndarray, video_state: dict) -> list[np.ndarray]: masks = tracker.tracking(frames, video_state["mask"]) video_state["annotation_masks"] = masks video_state["annotation_images"] = [ overlay_davis(frame, mask) for frame, mask in zip(frames, masks) ] tracker.tracker.clear_memory() annotation_info = f'Аннотированных кадров: {len(video_state["annotation_images"])}' return video_state, video_state["annotation_images"], annotation_info # --- Аннотация --- def annotations( frame: np.ndarray, annotations_state: dict, video_state: dict, mask_info ) -> list[np.ndarray]: if len(annotations_state["point"]) == 0: mask_info = 'Поставьте точки на объекты' return frame, video_state, mask_info prompts = { 'mode': 'point', 'point_coords': annotations_state["point"], 'point_labels': [1] * len(annotations_state["point"]), } mask = tracker.select_object(prompts) tracker.sam_controller.reset_image() image = overlay_davis(frame, mask) video_state["mask"] = mask return image, video_state, mask_info segmenter_controller = SegmenterController() tracker_core = TrackerCore() tracker = Tracker(segmenter_controller, tracker_core) # --- Интерфейс Gradio --- with gr.Blocks() as demo: # Состояния frames = gr.State([]) annotations_state = gr.State({"frame_id": 0, "point": []}) video_state = gr.State( { "fps": 30, "count_frames": 0, "mask": None, "annotation_masks": [], "annotation_images": [], } ) gr.Markdown("# Трекинг объектов на видео") with gr.Row(): video_input = gr.Video(label="Загрузите видео") with gr.Column(): first_frame = gr.Image(label="Кадр для выбора объектов", interactive=True) with gr.Row(): annotations_btn = gr.Button("Получить маску") with gr.Row(): video_info = gr.Textbox(label="Информация о видео") mask_info = gr.Textbox(label="Информация разметке") with gr.Row(): with gr.Row(): annotation_info = gr.Textbox(label="Информация о трекинге") with gr.Column(): tracking_btn = gr.Button("Трекинг") with gr.Column(): annotated_gallery = gr.Gallery(label="Все кадры с разметкой", columns=3) video_input.change( extract_all_frames, inputs=video_input, outputs=[first_frame, frames, video_state, video_info], ) # Обработка кликов first_frame.select( on_image_click, inputs=[first_frame, annotations_state], outputs=[first_frame, annotations_state, mask_info], ) annotations_btn.click( annotations, inputs=[first_frame, annotations_state, video_state, mask_info], outputs=[first_frame, video_state, mask_info], ) tracking_btn.click( tracking, inputs=[frames, video_state], outputs=[video_state, annotated_gallery, annotation_info], ) demo.launch()