|
|
"""
|
|
|
Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN
|
|
|
(which is based on https://github.com/seoungwugoh/ivs-demo)
|
|
|
|
|
|
This version is much simplified.
|
|
|
In this repo, we don't have
|
|
|
- local control
|
|
|
- fusion module
|
|
|
- undo
|
|
|
- timers
|
|
|
|
|
|
but with XMem as the backbone and is more memory (for both CPU and GPU) friendly
|
|
|
"""
|
|
|
|
|
|
import functools
|
|
|
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
import re
|
|
|
from time import perf_counter
|
|
|
import cv2
|
|
|
|
|
|
from inference.frame_selection.frame_selection import select_next_candidates
|
|
|
|
|
|
os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH")
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox,
|
|
|
QHBoxLayout, QLabel, QPushButton, QTextEdit, QFileDialog,
|
|
|
QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut,
|
|
|
QRadioButton, QTabWidget, QErrorMessage, QMessageBox, QLineEdit)
|
|
|
|
|
|
from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon, QRegExpValidator
|
|
|
from PyQt5.QtCore import Qt, QTimer, QThreadPool, QRegExp
|
|
|
|
|
|
from model.network import XMem
|
|
|
|
|
|
from inference.inference_core import InferenceCore
|
|
|
from .s2m_controller import S2MController
|
|
|
from .fbrs_controller import FBRSController
|
|
|
|
|
|
from .interactive_utils import *
|
|
|
from .interaction import *
|
|
|
from .resource_manager import ResourceManager
|
|
|
from .gui_utils import *
|
|
|
|
|
|
|
|
|
class App(QWidget):
|
|
|
def __init__(self, net: XMem,
|
|
|
resource_manager: ResourceManager,
|
|
|
s2m_ctrl:S2MController,
|
|
|
fbrs_ctrl:FBRSController, config):
|
|
|
super().__init__()
|
|
|
|
|
|
self.initialized = False
|
|
|
self.num_objects = config['num_objects']
|
|
|
self.s2m_controller = s2m_ctrl
|
|
|
self.fbrs_controller = fbrs_ctrl
|
|
|
self.config = config
|
|
|
self.processor = InferenceCore(net, config)
|
|
|
self.processor.set_all_labels(list(range(1, self.num_objects+1)))
|
|
|
self.res_man = resource_manager
|
|
|
self.threadpool = QThreadPool()
|
|
|
self.last_opened_directory = str(Path.home())
|
|
|
|
|
|
self.num_frames = len(self.res_man)
|
|
|
self.height, self.width = self.res_man.h, self.res_man.w
|
|
|
|
|
|
|
|
|
self.setWindowTitle('XMem++ Demo')
|
|
|
self.setGeometry(100, 100, self.width, self.height+100)
|
|
|
self.setWindowIcon(QIcon('docs/icon.png'))
|
|
|
|
|
|
|
|
|
self.play_button = QPushButton('Play Video')
|
|
|
self.play_button.setToolTip("Play/Pause the video")
|
|
|
self.play_button.clicked.connect(self.on_play_video)
|
|
|
self.commit_button = QPushButton('Commit')
|
|
|
self.commit_button.setToolTip("Finish current interaction with the mask")
|
|
|
self.commit_button.clicked.connect(self.on_commit)
|
|
|
self.save_reference_button = QPushButton('Save reference')
|
|
|
self.save_reference_button.setToolTip("Save current mask in the permanent memory.\nUsed by the model as a reference ground truth.")
|
|
|
self.save_reference_button.clicked.connect(self.on_save_reference)
|
|
|
self.compute_candidates_button = QPushButton('Compute Annotation candidates')
|
|
|
self.compute_candidates_button.setToolTip("Get next <i>k</i> frames that you should annotate.")
|
|
|
self.compute_candidates_button.clicked.connect(self.on_compute_candidates)
|
|
|
|
|
|
self.full_run_button = QPushButton('FULL Propagate')
|
|
|
self.full_run_button.setToolTip("Clear the temporary memory, scroll to beginning and predict new masks for <b>all</b> the frames.")
|
|
|
self.full_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='full'))
|
|
|
|
|
|
self.forward_run_button = QPushButton('Forward Propagate')
|
|
|
self.forward_run_button.setToolTip("Predict new masks for all the frames <b>starting</b> with the current one.")
|
|
|
self.forward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='forward'))
|
|
|
self.forward_run_button.setMinimumWidth(200)
|
|
|
|
|
|
self.backward_run_button = QPushButton('Backward Propagate')
|
|
|
self.backward_run_button.setToolTip("Predict new masks for all the frames <b>before</b> with the current one.")
|
|
|
self.backward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='backward'))
|
|
|
self.backward_run_button.setMinimumWidth(200)
|
|
|
|
|
|
self.reset_button = QPushButton('Delete Mask')
|
|
|
self.reset_button.setToolTip("Delete the mask for the current frames. <b>Cannot be undone</b>!")
|
|
|
self.reset_button.clicked.connect(self.on_reset_mask)
|
|
|
|
|
|
self.spacebar = QShortcut(QKeySequence(Qt.Key_Space), self)
|
|
|
self.spacebar.activated.connect(self.pause_propagation)
|
|
|
|
|
|
|
|
|
self.lcd = QTextEdit()
|
|
|
self.lcd.setReadOnly(True)
|
|
|
self.lcd.setMaximumHeight(28)
|
|
|
self.lcd.setFixedWidth(120)
|
|
|
self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1))
|
|
|
|
|
|
|
|
|
self.tl_slider = QSlider(Qt.Horizontal)
|
|
|
self.tl_slider.valueChanged.connect(self.tl_slide)
|
|
|
self.tl_slider.setMinimum(0)
|
|
|
self.tl_slider.setMaximum(self.num_frames-1)
|
|
|
self.tl_slider.setValue(0)
|
|
|
self.tl_slider.setTickPosition(QSlider.TicksBelow)
|
|
|
self.tl_slider.setTickInterval(1)
|
|
|
|
|
|
|
|
|
self.brush_label = QLabel()
|
|
|
self.brush_label.setAlignment(Qt.AlignCenter)
|
|
|
self.brush_label.setMinimumWidth(100)
|
|
|
|
|
|
self.brush_slider = QSlider(Qt.Horizontal)
|
|
|
self.brush_slider.valueChanged.connect(self.brush_slide)
|
|
|
self.brush_slider.setMinimum(1)
|
|
|
self.brush_slider.setMaximum(100)
|
|
|
self.brush_slider.setValue(3)
|
|
|
self.brush_slider.setTickPosition(QSlider.TicksBelow)
|
|
|
self.brush_slider.setTickInterval(2)
|
|
|
self.brush_slider.setMinimumWidth(300)
|
|
|
|
|
|
|
|
|
self.combo = QComboBox(self)
|
|
|
self.combo.addItem("davis")
|
|
|
self.combo.addItem("fade")
|
|
|
self.combo.addItem("light")
|
|
|
self.combo.addItem("popup")
|
|
|
self.combo.addItem("layered")
|
|
|
self.combo.currentTextChanged.connect(self.set_viz_mode)
|
|
|
|
|
|
self.save_visualization_checkbox = QCheckBox(self)
|
|
|
self.save_visualization_checkbox.setChecked(True)
|
|
|
self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle)
|
|
|
self.save_visualization = True
|
|
|
|
|
|
|
|
|
self.curr_interaction = 'Free'
|
|
|
self.interaction_group = QButtonGroup()
|
|
|
self.radio_fbrs = QRadioButton('Click')
|
|
|
self.radio_fbrs.setToolTip("Clicks in/out of the current mask. <b>Careful - will delete existing mask</b>!")
|
|
|
self.radio_s2m = QRadioButton('Scribble')
|
|
|
self.radio_s2m.setToolTip('Draw a line in/out of the current mask. Edits existing masks directly.')
|
|
|
self.radio_free = QRadioButton('Free')
|
|
|
self.radio_free.setToolTip('Free drawing')
|
|
|
self.interaction_group.addButton(self.radio_fbrs)
|
|
|
self.interaction_group.addButton(self.radio_s2m)
|
|
|
self.interaction_group.addButton(self.radio_free)
|
|
|
self.radio_fbrs.toggled.connect(self.interaction_radio_clicked)
|
|
|
self.radio_s2m.toggled.connect(self.interaction_radio_clicked)
|
|
|
self.radio_free.toggled.connect(self.interaction_radio_clicked)
|
|
|
self.radio_free.toggle()
|
|
|
|
|
|
|
|
|
self.main_canvas = QLabel()
|
|
|
self.main_canvas.setSizePolicy(QSizePolicy.Expanding,
|
|
|
QSizePolicy.Expanding)
|
|
|
self.main_canvas.setAlignment(Qt.AlignCenter)
|
|
|
self.main_canvas.setMinimumSize(100, 100)
|
|
|
|
|
|
self.main_canvas.mousePressEvent = self.on_mouse_press
|
|
|
self.main_canvas.mouseMoveEvent = self.on_mouse_motion
|
|
|
self.main_canvas.setMouseTracking(True)
|
|
|
self.main_canvas.mouseReleaseEvent = self.on_mouse_release
|
|
|
|
|
|
|
|
|
self.minimap = QLabel()
|
|
|
self.minimap.setSizePolicy(QSizePolicy.Expanding,
|
|
|
QSizePolicy.Expanding)
|
|
|
self.minimap.setAlignment(Qt.AlignTop)
|
|
|
self.minimap.setMinimumSize(100, 100)
|
|
|
|
|
|
|
|
|
self.zoom_p_button = QPushButton('Zoom +')
|
|
|
self.zoom_p_button.clicked.connect(self.on_zoom_plus)
|
|
|
self.zoom_m_button = QPushButton('Zoom -')
|
|
|
self.zoom_m_button.clicked.connect(self.on_zoom_minus)
|
|
|
|
|
|
|
|
|
self.clear_mem_button = QPushButton('Clear TEMP and LONG memory')
|
|
|
self.clear_mem_button.setToolTip("Temporary and long-term memory can have features from the previous model run.<br>"
|
|
|
"If you had errors in the predictions, they might influence new masks.<br>"
|
|
|
"So for a new model run either clean the memory or just use <i>FULL propagate</i>.")
|
|
|
self.clear_mem_button.clicked.connect(self.on_clear_memory)
|
|
|
|
|
|
self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size')
|
|
|
self.work_mem_gauge.setToolTip("Temporary and Permanent memory together.")
|
|
|
self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size')
|
|
|
self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)')
|
|
|
self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)')
|
|
|
|
|
|
self.update_memory_size()
|
|
|
self.update_gpu_usage()
|
|
|
|
|
|
self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames',
|
|
|
callback=self.on_work_min_change)
|
|
|
self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames',
|
|
|
callback=self.on_work_max_change)
|
|
|
self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000,
|
|
|
'Max. long-term memory size', step=1000, callback=self.update_config)
|
|
|
self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280,
|
|
|
'Number of prototypes', step=32, callback=self.update_config)
|
|
|
self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)',
|
|
|
callback=self.update_config)
|
|
|
|
|
|
self.work_mem_min.setValue(self.processor.memory.min_mt_frames)
|
|
|
self.work_mem_max.setValue(self.processor.memory.max_mt_frames)
|
|
|
self.long_mem_max.setValue(self.processor.memory.max_long_elements)
|
|
|
self.num_prototypes_box.setValue(self.processor.memory.num_prototypes)
|
|
|
self.mem_every_box.setValue(self.processor.mem_every)
|
|
|
|
|
|
|
|
|
self.import_mask_button = QPushButton('Import mask')
|
|
|
self.import_mask_button.setToolTip("Import an existing .png file with a mask for a current frame.\nReplace existing mask.")
|
|
|
self.import_mask_button.clicked.connect(self.on_import_mask)
|
|
|
|
|
|
self.import_all_masks_button = QPushButton('Import ALL masks')
|
|
|
self.import_all_masks_button.setToolTip("Import a list of mask for some or all frames in the video.\nIf more than 10 are imported, the invididual confirmations will not be shown.")
|
|
|
self.import_all_masks_button.clicked.connect(self.on_import_all_masks)
|
|
|
self.import_layer_button = QPushButton('Import layer')
|
|
|
self.import_layer_button.clicked.connect(self.on_import_layer)
|
|
|
|
|
|
|
|
|
self.console = QPlainTextEdit()
|
|
|
self.console.setReadOnly(True)
|
|
|
self.console.setMinimumHeight(100)
|
|
|
self.console.setMaximumHeight(100)
|
|
|
|
|
|
|
|
|
navi = QHBoxLayout()
|
|
|
navi.addWidget(self.lcd)
|
|
|
navi.addWidget(self.play_button)
|
|
|
|
|
|
interact_subbox = QVBoxLayout()
|
|
|
interact_topbox = QHBoxLayout()
|
|
|
interact_botbox = QHBoxLayout()
|
|
|
interact_topbox.setAlignment(Qt.AlignCenter)
|
|
|
interact_topbox.addWidget(self.radio_s2m)
|
|
|
interact_topbox.addWidget(self.radio_fbrs)
|
|
|
interact_topbox.addWidget(self.radio_free)
|
|
|
interact_topbox.addWidget(self.brush_label)
|
|
|
interact_botbox.addWidget(self.brush_slider)
|
|
|
interact_subbox.addLayout(interact_topbox)
|
|
|
interact_subbox.addLayout(interact_botbox)
|
|
|
navi.addLayout(interact_subbox)
|
|
|
|
|
|
navi.addStretch(1)
|
|
|
navi.addWidget(self.reset_button)
|
|
|
|
|
|
navi.addStretch(1)
|
|
|
navi.addWidget(QLabel('Overlay Mode'))
|
|
|
navi.addWidget(self.combo)
|
|
|
navi.addWidget(QLabel('Save overlay during propagation'))
|
|
|
navi.addWidget(self.save_visualization_checkbox)
|
|
|
navi.addStretch(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
navi.addWidget(self.save_reference_button)
|
|
|
|
|
|
navi.addWidget(self.commit_button)
|
|
|
navi.addWidget(self.full_run_button)
|
|
|
navi.addWidget(self.forward_run_button)
|
|
|
navi.addWidget(self.backward_run_button)
|
|
|
|
|
|
|
|
|
self.color_picker = ColorPicker(self.num_objects, davis_palette)
|
|
|
self.color_picker.clicked.connect(self.hit_number_key)
|
|
|
draw_area = QHBoxLayout()
|
|
|
draw_area.addWidget(self.color_picker)
|
|
|
draw_area.addWidget(self.main_canvas, 4)
|
|
|
|
|
|
self.tabs = QTabWidget()
|
|
|
self.tabs.setMinimumWidth(500)
|
|
|
self.map_tab = QWidget()
|
|
|
self.references_tab = QWidget()
|
|
|
|
|
|
references_scroll = QScrollArea()
|
|
|
references_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
|
|
|
references_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
|
|
references_scroll.setWidgetResizable(True)
|
|
|
references_scroll.setWidget(self.references_tab)
|
|
|
|
|
|
self.references_scroll = references_scroll
|
|
|
|
|
|
self.tabs.addTab(self.map_tab,"Minimap && Stats")
|
|
|
self.tabs.addTab(self.references_scroll, "References && Candidates")
|
|
|
|
|
|
tabs_layout = QVBoxLayout()
|
|
|
|
|
|
|
|
|
minimap_area = QVBoxLayout(self.map_tab)
|
|
|
minimap_area.setAlignment(Qt.AlignTop)
|
|
|
mini_label = QLabel('Minimap')
|
|
|
mini_label.setAlignment(Qt.AlignTop)
|
|
|
minimap_area.addWidget(mini_label)
|
|
|
|
|
|
|
|
|
minimap_ctrl = QHBoxLayout()
|
|
|
minimap_ctrl.setAlignment(Qt.AlignTop)
|
|
|
minimap_ctrl.addWidget(self.zoom_p_button)
|
|
|
minimap_ctrl.addWidget(self.zoom_m_button)
|
|
|
minimap_area.addLayout(minimap_ctrl)
|
|
|
minimap_area.addWidget(self.minimap)
|
|
|
|
|
|
|
|
|
minimap_area.addLayout(self.work_mem_gauge_layout)
|
|
|
minimap_area.addLayout(self.long_mem_gauge_layout)
|
|
|
minimap_area.addLayout(self.gpu_mem_gauge_layout)
|
|
|
minimap_area.addLayout(self.torch_mem_gauge_layout)
|
|
|
minimap_area.addWidget(self.clear_mem_button)
|
|
|
minimap_area.addLayout(self.work_mem_min_layout)
|
|
|
minimap_area.addLayout(self.work_mem_max_layout)
|
|
|
minimap_area.addLayout(self.long_mem_max_layout)
|
|
|
minimap_area.addLayout(self.num_prototypes_box_layout)
|
|
|
minimap_area.addLayout(self.mem_every_box_layout)
|
|
|
|
|
|
|
|
|
import_area = QHBoxLayout()
|
|
|
import_area.setAlignment(Qt.AlignTop)
|
|
|
import_area.addWidget(self.import_mask_button)
|
|
|
import_area.addWidget(self.import_all_masks_button)
|
|
|
import_area.addWidget(self.import_layer_button)
|
|
|
minimap_area.addLayout(import_area)
|
|
|
|
|
|
chosen_figures_area = QVBoxLayout(self.references_tab)
|
|
|
chosen_figures_area.addWidget(QLabel("SAVED REFERENCES IN PERMANENT MEMORY"))
|
|
|
self.references_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail, delete_image=self.on_remove_reference, name='Reference frames')
|
|
|
chosen_figures_area.addWidget(self.references_collection)
|
|
|
|
|
|
self.candidates_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail, name='Candidate frames')
|
|
|
chosen_figures_area.addWidget(QLabel("ANNOTATION CANDIDATES"))
|
|
|
chosen_figures_area.addWidget(self.candidates_collection)
|
|
|
|
|
|
tabs_layout.addWidget(self.tabs)
|
|
|
tabs_layout.addWidget(self.console)
|
|
|
draw_area.addLayout(tabs_layout, 1)
|
|
|
|
|
|
candidates_area = QVBoxLayout()
|
|
|
self.candidates_min_mask_size_edit = QLineEdit()
|
|
|
self.candidates_min_mask_size_edit.setToolTip("Minimal size a mask should have to be considered, % of the total image size."
|
|
|
"\nIf it's smaller than the value specified, the frame will be ignored."
|
|
|
"\nUsed to filter out \"junk\" frames or frames with very heavy occlusions.")
|
|
|
float_validator = QRegExpValidator(QRegExp(r"^(100(\.0+)?|[1-9]?\d(\.\d+)?|0(\.\d+)?)$"))
|
|
|
self.candidates_min_mask_size_edit.setValidator(float_validator)
|
|
|
self.candidates_min_mask_size_edit.setText("0.25")
|
|
|
self.candidates_k_slider = NamedSlider("k", 1, 20, 1, default=5)
|
|
|
self.candidates_k_slider.setToolTip("How many annotation candidates to select.")
|
|
|
self.candidates_alpha_slider = NamedSlider("α", 0, 100, 1, default=50, multiplier=0.01, min_text='Frames', max_text='Masks')
|
|
|
self.candidates_alpha_slider.setToolTip("Target importance."
|
|
|
"<br>If <b>0</b> the candidates will be the same regardless of which object is being segmented."
|
|
|
"<br>If <b>1</b> the only part of the image considered will be the one occupied by the mask.")
|
|
|
candidates_area.addWidget(QLabel("Min mask size, % of the total image size, 0-100"))
|
|
|
candidates_area.addWidget(self.candidates_min_mask_size_edit)
|
|
|
candidates_area.addWidget(QLabel("Candidates calculation hyperparameters"))
|
|
|
candidates_area.addWidget(self.candidates_k_slider)
|
|
|
candidates_area.addWidget(self.candidates_alpha_slider)
|
|
|
candidates_area.addWidget(self.compute_candidates_button)
|
|
|
tabs_layout.addLayout(candidates_area)
|
|
|
|
|
|
layout = QVBoxLayout()
|
|
|
layout.addLayout(draw_area)
|
|
|
layout.addWidget(self.tl_slider)
|
|
|
layout.addLayout(navi)
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
|
|
|
self.timer = QTimer()
|
|
|
self.timer.setSingleShot(False)
|
|
|
|
|
|
|
|
|
self.gpu_timer = QTimer()
|
|
|
self.gpu_timer.setSingleShot(False)
|
|
|
self.gpu_timer.timeout.connect(self.on_gpu_timer)
|
|
|
self.gpu_timer.setInterval(2000)
|
|
|
self.gpu_timer.start()
|
|
|
|
|
|
|
|
|
self.curr_frame_dirty = False
|
|
|
self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
self.current_image_torch = None
|
|
|
self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8)
|
|
|
self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda()
|
|
|
|
|
|
|
|
|
self.viz_mode = 'davis'
|
|
|
self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
|
|
|
self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
|
|
|
self.cursur = 0
|
|
|
self.on_showing = None
|
|
|
self.reference_ids = set()
|
|
|
self.candidates_ids = []
|
|
|
|
|
|
|
|
|
self.zoom_pixels = 150
|
|
|
|
|
|
|
|
|
self.interaction = None
|
|
|
self.pressed = False
|
|
|
self.right_click = False
|
|
|
self.current_object = 1
|
|
|
self.last_ex = self.last_ey = 0
|
|
|
|
|
|
self.propagating = False
|
|
|
|
|
|
|
|
|
for i in range(1, self.num_objects+1):
|
|
|
QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i))
|
|
|
|
|
|
|
|
|
QShortcut(QKeySequence(Qt.Key_Left), self).activated.connect(self.on_prev_frame)
|
|
|
QShortcut(QKeySequence(Qt.Key_Right), self).activated.connect(self.on_next_frame)
|
|
|
|
|
|
self.interacted_prob = None
|
|
|
self.overlay_layer = None
|
|
|
self.overlay_layer_torch = None
|
|
|
|
|
|
|
|
|
self.vis_target_objects = [1]
|
|
|
|
|
|
self._try_load_layer('./docs/ECCV-logo.png')
|
|
|
|
|
|
self.load_current_image_mask()
|
|
|
self.show_current_frame()
|
|
|
self.show()
|
|
|
self.style_new_reference()
|
|
|
self.load_existing_references()
|
|
|
|
|
|
self.console_push_text('Initialized.')
|
|
|
self.initialized = True
|
|
|
|
|
|
def resizeEvent(self, event):
|
|
|
self.show_current_frame()
|
|
|
|
|
|
def console_push_text(self, text):
|
|
|
self.console.moveCursor(QTextCursor.End)
|
|
|
self.console.insertPlainText(text+'\n')
|
|
|
|
|
|
def interaction_radio_clicked(self, event):
|
|
|
self.last_interaction = self.curr_interaction
|
|
|
if self.radio_s2m.isChecked():
|
|
|
self.curr_interaction = 'Scribble'
|
|
|
self.brush_size = 3
|
|
|
self.brush_slider.setDisabled(True)
|
|
|
elif self.radio_fbrs.isChecked():
|
|
|
self.curr_interaction = 'Click'
|
|
|
self.brush_size = 3
|
|
|
self.brush_slider.setDisabled(True)
|
|
|
elif self.radio_free.isChecked():
|
|
|
self.brush_slider.setDisabled(False)
|
|
|
self.brush_slide()
|
|
|
self.curr_interaction = 'Free'
|
|
|
if self.curr_interaction == 'Scribble':
|
|
|
self.commit_button.setEnabled(True)
|
|
|
else:
|
|
|
self.commit_button.setEnabled(False)
|
|
|
|
|
|
def load_current_image_mask(self, no_mask=False):
|
|
|
self.current_image = self.res_man.get_image(self.cursur)
|
|
|
self.current_image_torch = None
|
|
|
|
|
|
if not no_mask:
|
|
|
loaded_mask = self.res_man.get_mask(self.cursur)
|
|
|
if loaded_mask is None:
|
|
|
self.current_mask.fill(0)
|
|
|
else:
|
|
|
self.current_mask = loaded_mask.copy()
|
|
|
self.current_prob = None
|
|
|
|
|
|
def load_current_torch_image_mask(self, no_mask=False):
|
|
|
if self.current_image_torch is None:
|
|
|
self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image)
|
|
|
|
|
|
if self.current_prob is None and not no_mask:
|
|
|
self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
|
|
|
|
|
|
def compose_current_im(self):
|
|
|
self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask,
|
|
|
self.overlay_layer, self.vis_target_objects)
|
|
|
|
|
|
def update_interact_vis(self):
|
|
|
|
|
|
height, width, channel = self.viz.shape
|
|
|
bytesPerLine = 3 * width
|
|
|
|
|
|
vis_map = self.vis_map
|
|
|
vis_alpha = self.vis_alpha
|
|
|
brush_vis_map = self.brush_vis_map
|
|
|
brush_vis_alpha = self.brush_vis_alpha
|
|
|
|
|
|
self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha
|
|
|
self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha
|
|
|
self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8)
|
|
|
|
|
|
qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format_RGB888)
|
|
|
self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
|
|
|
Qt.KeepAspectRatio, Qt.FastTransformation)))
|
|
|
|
|
|
self.main_canvas_size = self.main_canvas.size()
|
|
|
self.image_size = qImg.size()
|
|
|
|
|
|
def update_minimap(self):
|
|
|
ex, ey = self.last_ex, self.last_ey
|
|
|
r = self.zoom_pixels//2
|
|
|
ex = int(round(max(r, min(self.width-r, ex))))
|
|
|
ey = int(round(max(r, min(self.height-r, ey))))
|
|
|
|
|
|
patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8)
|
|
|
|
|
|
height, width, channel = patch.shape
|
|
|
bytesPerLine = 3 * width
|
|
|
qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format_RGB888)
|
|
|
self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(),
|
|
|
Qt.KeepAspectRatio, Qt.FastTransformation)))
|
|
|
|
|
|
def update_current_image_fast(self):
|
|
|
|
|
|
self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm,
|
|
|
self.current_prob, self.overlay_layer_torch, self.vis_target_objects)
|
|
|
if self.save_visualization:
|
|
|
self.res_man.save_visualization(self.cursur, self.viz)
|
|
|
|
|
|
height, width, channel = self.viz.shape
|
|
|
bytesPerLine = 3 * width
|
|
|
|
|
|
qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888)
|
|
|
self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
|
|
|
Qt.KeepAspectRatio, Qt.FastTransformation)))
|
|
|
|
|
|
def load_current_image_thumbnail(self, *args, size=128):
|
|
|
|
|
|
viz = get_visualization(self.viz_mode, self.current_image, self.current_mask,
|
|
|
self.overlay_layer, self.vis_target_objects)
|
|
|
|
|
|
height, width, channel = viz.shape
|
|
|
bytesPerLine = 3 * width
|
|
|
qImg = QImage(viz.data, width, height, bytesPerLine, QImage.Format_RGB888)
|
|
|
curr_pixmap = QPixmap(qImg.scaled(self.main_canvas.size(),
|
|
|
Qt.KeepAspectRatio, Qt.FastTransformation))
|
|
|
|
|
|
curr_size = curr_pixmap.size()
|
|
|
h = curr_size.height()
|
|
|
w = curr_size.width()
|
|
|
|
|
|
if h < w:
|
|
|
thumbnail = curr_pixmap.scaledToHeight(size)
|
|
|
else:
|
|
|
thumbnail = curr_pixmap.scaledToWidth(size)
|
|
|
|
|
|
return thumbnail
|
|
|
|
|
|
def show_current_frame(self, fast=False):
|
|
|
|
|
|
if fast:
|
|
|
self.update_current_image_fast()
|
|
|
else:
|
|
|
self.compose_current_im()
|
|
|
self.update_interact_vis()
|
|
|
self.update_minimap()
|
|
|
|
|
|
self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1))
|
|
|
self.tl_slider.setValue(self.cursur)
|
|
|
|
|
|
if self.cursur in self.reference_ids:
|
|
|
self.style_editing_reference()
|
|
|
else:
|
|
|
self.style_new_reference()
|
|
|
|
|
|
def style_editing_reference(self):
|
|
|
self.save_reference_button.setText("Update reference")
|
|
|
self.save_reference_button.setStyleSheet('QPushButton {background-color: #E4A11B; font-weight: bold; }')
|
|
|
|
|
|
def style_new_reference(self):
|
|
|
self.save_reference_button.setText("Save reference")
|
|
|
self.save_reference_button.setStyleSheet('QPushButton {background-color: #14A44D; font-weight: bold;}')
|
|
|
|
|
|
def load_existing_references(self):
|
|
|
for i in self.res_man.references:
|
|
|
self.scroll_to(i)
|
|
|
self.on_save_reference()
|
|
|
self.scroll_to(0)
|
|
|
|
|
|
def pixel_pos_to_image_pos(self, x, y):
|
|
|
|
|
|
oh, ow = self.image_size.height(), self.image_size.width()
|
|
|
nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width()
|
|
|
|
|
|
h_ratio = nh/oh
|
|
|
w_ratio = nw/ow
|
|
|
dominate_ratio = min(h_ratio, w_ratio)
|
|
|
|
|
|
|
|
|
x /= dominate_ratio
|
|
|
y /= dominate_ratio
|
|
|
|
|
|
|
|
|
fh, fw = nh/dominate_ratio, nw/dominate_ratio
|
|
|
x -= (fw-ow)/2
|
|
|
y -= (fh-oh)/2
|
|
|
|
|
|
return x, y
|
|
|
|
|
|
def is_pos_out_of_bound(self, x, y):
|
|
|
x, y = self.pixel_pos_to_image_pos(x, y)
|
|
|
|
|
|
out_of_bound = (
|
|
|
(x < 0) or
|
|
|
(y < 0) or
|
|
|
(x > self.width-1) or
|
|
|
(y > self.height-1)
|
|
|
)
|
|
|
|
|
|
return out_of_bound
|
|
|
|
|
|
def get_scaled_pos(self, x, y):
|
|
|
x, y = self.pixel_pos_to_image_pos(x, y)
|
|
|
|
|
|
x = max(0, min(self.width-1, x))
|
|
|
y = max(0, min(self.height-1, y))
|
|
|
|
|
|
return x, y
|
|
|
|
|
|
def clear_visualization(self):
|
|
|
self.vis_map.fill(0)
|
|
|
self.vis_alpha.fill(0)
|
|
|
|
|
|
def reset_this_interaction(self):
|
|
|
self.complete_interaction()
|
|
|
self.clear_visualization()
|
|
|
self.interaction = None
|
|
|
if self.fbrs_controller is not None:
|
|
|
self.fbrs_controller.unanchor()
|
|
|
|
|
|
def set_viz_mode(self):
|
|
|
self.viz_mode = self.combo.currentText()
|
|
|
self.show_current_frame()
|
|
|
|
|
|
def save_current_mask(self):
|
|
|
|
|
|
self.res_man.save_mask(self.cursur, self.current_mask)
|
|
|
|
|
|
def tl_slide(self):
|
|
|
|
|
|
|
|
|
if not self.propagating:
|
|
|
if self.curr_frame_dirty:
|
|
|
self.save_current_mask()
|
|
|
self.curr_frame_dirty = False
|
|
|
|
|
|
self.reset_this_interaction()
|
|
|
self.cursur = self.tl_slider.value()
|
|
|
self.load_current_image_mask()
|
|
|
self.show_current_frame()
|
|
|
|
|
|
def scroll_to(self, idx):
|
|
|
assert self.tl_slider.minimum() <= idx <= self.tl_slider.maximum()
|
|
|
self.tl_slider.setValue(idx)
|
|
|
self.tl_slide()
|
|
|
|
|
|
def brush_slide(self):
|
|
|
self.brush_size = self.brush_slider.value()
|
|
|
self.brush_label.setText('Brush size: %d' % self.brush_size)
|
|
|
try:
|
|
|
if type(self.interaction) == FreeInteraction:
|
|
|
self.interaction.set_size(self.brush_size)
|
|
|
except AttributeError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
def confirm_ready_for_propagation(self):
|
|
|
if len(self.reference_ids) > 0:
|
|
|
return True
|
|
|
|
|
|
qm = QErrorMessage(self)
|
|
|
qm.setWindowModality(Qt.WindowModality.WindowModal)
|
|
|
qm.showMessage("Save at least 1 reference!")
|
|
|
|
|
|
return False
|
|
|
|
|
|
def general_propagation_callback(self, propagation_type: str):
|
|
|
if not self.confirm_ready_for_propagation():
|
|
|
return
|
|
|
|
|
|
self.tabs.setCurrentIndex(0)
|
|
|
if propagation_type == 'full':
|
|
|
self.on_full_propagation()
|
|
|
elif propagation_type == 'forward':
|
|
|
self.on_forward_propagation()
|
|
|
elif propagation_type == 'backward':
|
|
|
self.on_backward_propagation()
|
|
|
|
|
|
def on_full_propagation(self):
|
|
|
self.on_clear_memory()
|
|
|
self.scroll_to(0)
|
|
|
self.on_forward_propagation()
|
|
|
|
|
|
def on_forward_propagation(self):
|
|
|
if self.propagating:
|
|
|
|
|
|
self.propagating = False
|
|
|
else:
|
|
|
self.propagate_fn = self.on_next_frame
|
|
|
self.full_run_button.setEnabled(False)
|
|
|
self.backward_run_button.setEnabled(False)
|
|
|
self.forward_run_button.setText('Pause Propagation')
|
|
|
self.on_propagation()
|
|
|
|
|
|
def on_backward_propagation(self):
|
|
|
if self.propagating:
|
|
|
|
|
|
self.propagating = False
|
|
|
else:
|
|
|
self.propagate_fn = self.on_prev_frame
|
|
|
self.full_run_button.setEnabled(False)
|
|
|
self.forward_run_button.setEnabled(False)
|
|
|
self.backward_run_button.setText('Pause Propagation')
|
|
|
self.on_propagation()
|
|
|
|
|
|
def on_pause(self):
|
|
|
self.propagating = False
|
|
|
self.full_run_button.setEnabled(True)
|
|
|
self.forward_run_button.setEnabled(True)
|
|
|
self.backward_run_button.setEnabled(True)
|
|
|
self.clear_mem_button.setEnabled(True)
|
|
|
self.forward_run_button.setText('Forward Propagate')
|
|
|
self.backward_run_button.setText('Backward Propagate')
|
|
|
self.console_push_text('Propagation stopped.')
|
|
|
|
|
|
def on_propagation(self):
|
|
|
|
|
|
self.load_current_torch_image_mask()
|
|
|
self.show_current_frame(fast=True)
|
|
|
|
|
|
self.console_push_text('Propagation started.')
|
|
|
is_mask = self.cursur in self.reference_ids
|
|
|
msk = self.current_prob[1:] if self.cursur in self.reference_ids else None
|
|
|
current_prob, key, shrinkage, selection = self.processor.step(self.current_image_torch, msk, return_key_and_stuff=True)
|
|
|
if not is_mask:
|
|
|
self.current_prob = current_prob
|
|
|
self.res_man.add_key_and_stuff_with_mask(self.cursur, key, shrinkage, selection, self.current_prob[1:])
|
|
|
|
|
|
self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
|
|
|
|
|
|
self.interacted_prob = None
|
|
|
self.reset_this_interaction()
|
|
|
|
|
|
self.propagating = True
|
|
|
self.clear_mem_button.setEnabled(False)
|
|
|
|
|
|
while self.propagating:
|
|
|
self.propagate_fn()
|
|
|
|
|
|
self.load_current_image_mask(no_mask=True)
|
|
|
self.load_current_torch_image_mask(no_mask=True)
|
|
|
is_mask = self.cursur in self.reference_ids
|
|
|
msk = self.current_prob[1:] if self.cursur in self.reference_ids else None
|
|
|
current_prob, key, shrinkage, selection = self.processor.step(self.current_image_torch, msk, return_key_and_stuff=True)
|
|
|
self.res_man.add_key_and_stuff_with_mask(self.cursur, key, shrinkage, selection, self.current_prob[1:])
|
|
|
|
|
|
if not is_mask:
|
|
|
self.current_prob = current_prob
|
|
|
self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
|
|
|
self.save_current_mask()
|
|
|
|
|
|
self.show_current_frame(fast=True)
|
|
|
|
|
|
self.update_memory_size()
|
|
|
QApplication.processEvents()
|
|
|
|
|
|
if self.cursur == 0 or self.cursur == self.num_frames-1:
|
|
|
break
|
|
|
|
|
|
self.propagating = False
|
|
|
self.curr_frame_dirty = False
|
|
|
self.on_pause()
|
|
|
self.tl_slide()
|
|
|
QApplication.processEvents()
|
|
|
|
|
|
def pause_propagation(self):
|
|
|
self.propagating = False
|
|
|
|
|
|
def on_commit(self):
|
|
|
self.complete_interaction()
|
|
|
self.update_interacted_mask()
|
|
|
|
|
|
def confirm_ready_for_candidates_selection(self):
|
|
|
if self.res_man.all_masks_present():
|
|
|
return True
|
|
|
|
|
|
qm = QErrorMessage(self)
|
|
|
qm.setWindowModality(Qt.WindowModality.WindowModal)
|
|
|
qm.showMessage("Run propagation on all frames first!")
|
|
|
|
|
|
return False
|
|
|
|
|
|
def on_compute_candidates(self):
|
|
|
def _update_candidates(candidates_ids):
|
|
|
print(candidates_ids)
|
|
|
for i in self.candidates_ids:
|
|
|
|
|
|
self.candidates_collection.remove_image(i)
|
|
|
self.candidates_ids = candidates_ids
|
|
|
|
|
|
prev_pos = self.cursur
|
|
|
for i in self.candidates_ids:
|
|
|
self.scroll_to(i)
|
|
|
self.candidates_collection.add_image(i)
|
|
|
self.scroll_to(prev_pos)
|
|
|
self.tabs.setCurrentIndex(1)
|
|
|
|
|
|
def _update_progress(i):
|
|
|
candidate_progress.setValue(i)
|
|
|
|
|
|
if not self.confirm_ready_for_candidates_selection():
|
|
|
return
|
|
|
|
|
|
k = self.candidates_k_slider.value()
|
|
|
alpha = self.candidates_alpha_slider.value()
|
|
|
candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint))
|
|
|
worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.shrinkages, self.res_man.selections, self.res_man.small_masks, k, self.reference_ids,
|
|
|
print_progress=False, min_mask_presence_percent=float(self.candidates_min_mask_size_edit.text()), alpha=alpha)
|
|
|
worker.signals.result.connect(_update_candidates)
|
|
|
worker.signals.progress.connect(_update_progress)
|
|
|
|
|
|
self.threadpool.start(worker)
|
|
|
|
|
|
candidate_progress.open()
|
|
|
|
|
|
def on_save_reference(self):
|
|
|
if self.interaction is not None:
|
|
|
self.on_commit()
|
|
|
current_image_torch, _ = image_to_torch(self.current_image)
|
|
|
current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
|
|
|
|
|
|
msk = current_prob[1:]
|
|
|
a = perf_counter()
|
|
|
is_update = self.processor.put_to_permanent_memory(current_image_torch, msk, self.cursur)
|
|
|
b = perf_counter()
|
|
|
|
|
|
self.console_push_text(f"Saving took {(b-a)*1000:.2f} ms.")
|
|
|
|
|
|
if is_update:
|
|
|
self.reference_ids.remove(self.cursur)
|
|
|
self.references_collection.remove_image(self.cursur)
|
|
|
|
|
|
self.reference_ids.add(self.cursur)
|
|
|
self.references_collection.add_image(self.cursur)
|
|
|
self.res_man.add_reference(self.cursur)
|
|
|
|
|
|
if self.cursur in self.candidates_ids:
|
|
|
self.candidates_ids.remove(self.cursur)
|
|
|
self.candidates_collection.remove_image(self.cursur)
|
|
|
|
|
|
self.show_current_frame()
|
|
|
self.tabs.setCurrentIndex(1)
|
|
|
|
|
|
def on_remove_reference(self, img_idx):
|
|
|
self.processor.remove_from_permanent_memory(img_idx)
|
|
|
self.reference_ids.remove(img_idx)
|
|
|
self.res_man.remove_reference(img_idx)
|
|
|
|
|
|
self.show_current_frame()
|
|
|
|
|
|
def on_prev_frame(self):
|
|
|
|
|
|
self.cursur = max(0, self.cursur-1)
|
|
|
self.tl_slider.setValue(self.cursur)
|
|
|
|
|
|
def on_next_frame(self):
|
|
|
|
|
|
self.cursur = min(self.cursur+1, self.num_frames-1)
|
|
|
self.tl_slider.setValue(self.cursur)
|
|
|
|
|
|
def on_play_video_timer(self):
|
|
|
self.cursur += 1
|
|
|
if self.cursur > self.num_frames-1:
|
|
|
self.cursur = 0
|
|
|
self.tl_slider.setValue(self.cursur)
|
|
|
|
|
|
def on_play_video(self):
|
|
|
if self.timer.isActive():
|
|
|
self.timer.stop()
|
|
|
self.play_button.setText('Play Video')
|
|
|
else:
|
|
|
self.timer.start(1000 / 30)
|
|
|
self.play_button.setText('Stop Video')
|
|
|
|
|
|
def on_reset_mask(self):
|
|
|
self.current_mask.fill(0)
|
|
|
if self.current_prob is not None:
|
|
|
self.current_prob.fill_(0)
|
|
|
self.curr_frame_dirty = True
|
|
|
self.save_current_mask()
|
|
|
self.reset_this_interaction()
|
|
|
self.show_current_frame()
|
|
|
|
|
|
def on_zoom_plus(self):
|
|
|
self.zoom_pixels -= 25
|
|
|
self.zoom_pixels = max(50, self.zoom_pixels)
|
|
|
self.update_minimap()
|
|
|
|
|
|
def on_zoom_minus(self):
|
|
|
self.zoom_pixels += 25
|
|
|
self.zoom_pixels = min(self.zoom_pixels, 300)
|
|
|
self.update_minimap()
|
|
|
|
|
|
def set_navi_enable(self, boolean):
|
|
|
self.zoom_p_button.setEnabled(boolean)
|
|
|
self.zoom_m_button.setEnabled(boolean)
|
|
|
self.run_button.setEnabled(boolean)
|
|
|
self.tl_slider.setEnabled(boolean)
|
|
|
self.play_button.setEnabled(boolean)
|
|
|
self.lcd.setEnabled(boolean)
|
|
|
|
|
|
def hit_number_key(self, number):
|
|
|
if number == self.current_object:
|
|
|
return
|
|
|
self.current_object = number
|
|
|
if self.fbrs_controller is not None:
|
|
|
self.fbrs_controller.unanchor()
|
|
|
self.console_push_text(f'Current object changed to {number}.')
|
|
|
self.clear_brush()
|
|
|
self.vis_brush(self.last_ex, self.last_ey)
|
|
|
self.update_interact_vis()
|
|
|
self.show_current_frame()
|
|
|
self.color_picker.select(self.current_object)
|
|
|
|
|
|
def clear_brush(self):
|
|
|
self.brush_vis_map.fill(0)
|
|
|
self.brush_vis_alpha.fill(0)
|
|
|
|
|
|
def vis_brush(self, ex, ey):
|
|
|
self.brush_vis_map = cv2.circle(self.brush_vis_map,
|
|
|
(int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1)
|
|
|
self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha,
|
|
|
(int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1)
|
|
|
|
|
|
def on_mouse_press(self, event):
|
|
|
if self.is_pos_out_of_bound(event.x(), event.y()):
|
|
|
return
|
|
|
|
|
|
|
|
|
if (event.button() == Qt.MidButton):
|
|
|
ex, ey = self.get_scaled_pos(event.x(), event.y())
|
|
|
target_object = self.current_mask[int(ey),int(ex)]
|
|
|
if target_object in self.vis_target_objects:
|
|
|
self.vis_target_objects.remove(target_object)
|
|
|
else:
|
|
|
self.vis_target_objects.append(target_object)
|
|
|
self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}')
|
|
|
self.show_current_frame()
|
|
|
return
|
|
|
|
|
|
self.right_click = (event.button() == Qt.RightButton)
|
|
|
self.pressed = True
|
|
|
|
|
|
h, w = self.height, self.width
|
|
|
|
|
|
self.load_current_torch_image_mask()
|
|
|
image = self.current_image_torch
|
|
|
|
|
|
last_interaction = self.interaction
|
|
|
new_interaction = None
|
|
|
if self.curr_interaction == 'Scribble':
|
|
|
if last_interaction is None or type(last_interaction) != ScribbleInteraction:
|
|
|
self.complete_interaction()
|
|
|
new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(),
|
|
|
(h, w), self.s2m_controller, self.num_objects)
|
|
|
elif self.curr_interaction == 'Free':
|
|
|
if last_interaction is None or type(last_interaction) != FreeInteraction:
|
|
|
self.complete_interaction()
|
|
|
new_interaction = FreeInteraction(image, self.current_mask, (h, w),
|
|
|
self.num_objects)
|
|
|
new_interaction.set_size(self.brush_size)
|
|
|
elif self.curr_interaction == 'Click':
|
|
|
if (last_interaction is None or type(last_interaction) != ClickInteraction
|
|
|
or last_interaction.tar_obj != self.current_object):
|
|
|
self.complete_interaction()
|
|
|
self.fbrs_controller.unanchor()
|
|
|
new_interaction = ClickInteraction(image, self.current_prob, (h, w),
|
|
|
self.fbrs_controller, self.current_object)
|
|
|
|
|
|
if new_interaction is not None:
|
|
|
self.interaction = new_interaction
|
|
|
|
|
|
|
|
|
self.on_mouse_motion(event)
|
|
|
|
|
|
def on_mouse_motion(self, event):
|
|
|
ex, ey = self.get_scaled_pos(event.x(), event.y())
|
|
|
self.last_ex, self.last_ey = ex, ey
|
|
|
self.clear_brush()
|
|
|
|
|
|
self.vis_brush(ex, ey)
|
|
|
if self.pressed:
|
|
|
if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
|
|
|
obj = 0 if self.right_click else self.current_object
|
|
|
self.vis_map, self.vis_alpha = self.interaction.push_point(
|
|
|
ex, ey, obj, (self.vis_map, self.vis_alpha)
|
|
|
)
|
|
|
self.update_interact_vis()
|
|
|
self.update_minimap()
|
|
|
|
|
|
def update_interacted_mask(self):
|
|
|
self.current_prob = self.interacted_prob
|
|
|
self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob)
|
|
|
self.show_current_frame()
|
|
|
self.save_current_mask()
|
|
|
self.curr_frame_dirty = False
|
|
|
|
|
|
def complete_interaction(self):
|
|
|
if self.interaction is not None:
|
|
|
self.clear_visualization()
|
|
|
self.interaction = None
|
|
|
|
|
|
def on_mouse_release(self, event):
|
|
|
if not self.pressed:
|
|
|
|
|
|
return
|
|
|
|
|
|
ex, ey = self.get_scaled_pos(event.x(), event.y())
|
|
|
|
|
|
self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur))
|
|
|
interaction = self.interaction
|
|
|
|
|
|
if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
|
|
|
self.on_mouse_motion(event)
|
|
|
interaction.end_path()
|
|
|
if self.curr_interaction == 'Free':
|
|
|
self.clear_visualization()
|
|
|
elif self.curr_interaction == 'Click':
|
|
|
ex, ey = self.get_scaled_pos(event.x(), event.y())
|
|
|
self.vis_map, self.vis_alpha = interaction.push_point(ex, ey,
|
|
|
self.right_click, (self.vis_map, self.vis_alpha))
|
|
|
|
|
|
self.interacted_prob = interaction.predict()
|
|
|
self.update_interacted_mask()
|
|
|
self.update_gpu_usage()
|
|
|
|
|
|
self.pressed = self.right_click = False
|
|
|
|
|
|
def wheelEvent(self, event):
|
|
|
ex, ey = self.get_scaled_pos(event.x(), event.y())
|
|
|
if self.curr_interaction == 'Free':
|
|
|
self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30)
|
|
|
self.clear_brush()
|
|
|
self.vis_brush(ex, ey)
|
|
|
self.update_interact_vis()
|
|
|
self.update_minimap()
|
|
|
|
|
|
def update_gpu_usage(self):
|
|
|
info = torch.cuda.mem_get_info()
|
|
|
global_free, global_total = info
|
|
|
global_free /= (2**30)
|
|
|
global_total /= (2**30)
|
|
|
global_used = global_total - global_free
|
|
|
|
|
|
self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB')
|
|
|
self.gpu_mem_gauge.setValue(round(global_used/global_total*100))
|
|
|
|
|
|
used_by_torch = torch.cuda.max_memory_allocated() / (2**20)
|
|
|
self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB')
|
|
|
self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024))
|
|
|
|
|
|
def on_gpu_timer(self):
|
|
|
self.update_gpu_usage()
|
|
|
|
|
|
def update_memory_size(self):
|
|
|
try:
|
|
|
max_work_elements = self.processor.memory.max_work_elements + self.processor.memory.permanent_work_mem.size
|
|
|
max_long_elements = self.processor.memory.max_long_elements
|
|
|
|
|
|
curr_work_elements = self.processor.memory.temporary_work_mem.size + self.processor.memory.permanent_work_mem.size
|
|
|
curr_long_elements = self.processor.memory.long_mem.size
|
|
|
|
|
|
self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}')
|
|
|
self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100))
|
|
|
|
|
|
self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}')
|
|
|
self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100))
|
|
|
|
|
|
except AttributeError:
|
|
|
self.work_mem_gauge.setFormat('Unknown')
|
|
|
self.long_mem_gauge.setFormat('Unknown')
|
|
|
self.work_mem_gauge.setValue(0)
|
|
|
self.long_mem_gauge.setValue(0)
|
|
|
|
|
|
def on_work_min_change(self):
|
|
|
if self.initialized:
|
|
|
self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1))
|
|
|
self.update_config()
|
|
|
|
|
|
def on_work_max_change(self):
|
|
|
if self.initialized:
|
|
|
self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1))
|
|
|
self.update_config()
|
|
|
|
|
|
def update_config(self):
|
|
|
if self.initialized:
|
|
|
self.config['min_mid_term_frames'] = self.work_mem_min.value()
|
|
|
self.config['max_mid_term_frames'] = self.work_mem_max.value()
|
|
|
self.config['max_long_term_elements'] = self.long_mem_max.value()
|
|
|
self.config['num_prototypes'] = self.num_prototypes_box.value()
|
|
|
self.config['mem_every'] = self.mem_every_box.value()
|
|
|
|
|
|
self.processor.update_config(self.config)
|
|
|
|
|
|
def on_clear_memory(self):
|
|
|
self.processor.clear_memory(keep_permanent=True)
|
|
|
torch.cuda.empty_cache()
|
|
|
self.update_gpu_usage()
|
|
|
self.update_memory_size()
|
|
|
|
|
|
def _open_file(self, prompt):
|
|
|
options = QFileDialog.Options()
|
|
|
file_name, _ = QFileDialog.getOpenFileName(self, prompt, self.last_opened_directory, "Image files (*)", options=options)
|
|
|
if file_name:
|
|
|
self.last_opened_directory = str(Path(file_name).parent)
|
|
|
return file_name
|
|
|
|
|
|
def on_import_all_masks(self):
|
|
|
dir_path = QFileDialog.getExistingDirectory()
|
|
|
if dir_path:
|
|
|
self.last_opened_directory = dir_path
|
|
|
|
|
|
all_correct = True
|
|
|
frame_ids = []
|
|
|
incorrect_files = []
|
|
|
pattern = re.compile(r'([0-9]+)')
|
|
|
files_paths = sorted(Path(dir_path).iterdir())
|
|
|
for p_f in files_paths:
|
|
|
match = pattern.search(p_f.name)
|
|
|
if match:
|
|
|
frame_id = int(match.string[match.start():match.end()])
|
|
|
frame_ids.append(frame_id)
|
|
|
else:
|
|
|
all_correct = False
|
|
|
incorrect_files.apend(p_f.name)
|
|
|
|
|
|
|
|
|
if not all_correct or frame_ids != sorted(frame_ids):
|
|
|
qm = QErrorMessage(self)
|
|
|
qm.setWindowModality(Qt.WindowModality.WindowModal)
|
|
|
broken_file_names = '\n'.join(incorrect_files)
|
|
|
qm.showMessage(f"Files with incorrect names: {broken_file_names}")
|
|
|
|
|
|
else:
|
|
|
if len(frame_ids) > 10:
|
|
|
qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "")
|
|
|
question = "There are more than 10 masks to import, so confirmations for each individual one would not be asked. Are you willing to continue?"
|
|
|
ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No)
|
|
|
if ret == qm.Yes:
|
|
|
progress_dialog = QProgressDialog("Importing masks", None, 0, len(frame_ids), self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint))
|
|
|
progress_dialog.open()
|
|
|
a = perf_counter()
|
|
|
for i, p_f in zip(frame_ids, files_paths):
|
|
|
|
|
|
self.cursur = i
|
|
|
self.on_import_mask(str(p_f), ask_confirmation=False)
|
|
|
progress_dialog.setValue(i + 1)
|
|
|
QApplication.processEvents()
|
|
|
b = perf_counter()
|
|
|
self.console_push_text(f"Importing {len(frame_ids)} masks took {b-a:.2f} seconds ({len(frame_ids)/(b-a):.2f} FPS)")
|
|
|
self.cursur = 0
|
|
|
|
|
|
else:
|
|
|
for i, p_f in zip(frame_ids, files_paths):
|
|
|
self.scroll_to(i)
|
|
|
self.on_import_mask(str(p_f), ask_confirmation=True)
|
|
|
|
|
|
def on_import_mask(self, mask_file_path=None, ask_confirmation=True):
|
|
|
if mask_file_path:
|
|
|
file_name = mask_file_path
|
|
|
else:
|
|
|
file_name = self._open_file('Mask')
|
|
|
if len(file_name) == 0:
|
|
|
return
|
|
|
|
|
|
mask = self.res_man.read_external_image(file_name, size=(self.height, self.width), force_mask=True)
|
|
|
|
|
|
shape_condition = (
|
|
|
(len(mask.shape) == 2) and
|
|
|
(mask.shape[-1] == self.width) and
|
|
|
(mask.shape[-2] == self.height)
|
|
|
)
|
|
|
|
|
|
object_condition = (
|
|
|
mask.max() <= self.num_objects
|
|
|
)
|
|
|
|
|
|
if not shape_condition:
|
|
|
self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.')
|
|
|
elif not object_condition:
|
|
|
self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.')
|
|
|
else:
|
|
|
if ask_confirmation:
|
|
|
qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "")
|
|
|
question = f"Replace mask for current frame {self.cursur} with {Path(file_name).name}?"
|
|
|
ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No)
|
|
|
|
|
|
if not ask_confirmation or ret == qm.Yes:
|
|
|
self.console_push_text(f'Mask file {file_name} loaded.')
|
|
|
self.current_image_torch = self.current_prob = None
|
|
|
self.current_mask = mask
|
|
|
|
|
|
if ask_confirmation:
|
|
|
|
|
|
self.curr_frame_dirty = False
|
|
|
self.reset_this_interaction()
|
|
|
self.show_current_frame()
|
|
|
|
|
|
self.save_current_mask()
|
|
|
|
|
|
if ask_confirmation:
|
|
|
|
|
|
|
|
|
self.on_save_reference()
|
|
|
|
|
|
|
|
|
def on_import_layer(self):
|
|
|
file_name = self._open_file('Layer')
|
|
|
if len(file_name) == 0:
|
|
|
return
|
|
|
|
|
|
self._try_load_layer(file_name)
|
|
|
|
|
|
def _try_load_layer(self, file_name):
|
|
|
try:
|
|
|
layer = self.res_man.read_external_image(file_name, size=(self.height, self.width))
|
|
|
|
|
|
if layer.shape[-1] == 3:
|
|
|
layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1)
|
|
|
|
|
|
condition = (
|
|
|
(len(layer.shape) == 3) and
|
|
|
(layer.shape[-1] == 4) and
|
|
|
(layer.shape[-2] == self.width) and
|
|
|
(layer.shape[-3] == self.height)
|
|
|
)
|
|
|
|
|
|
if not condition:
|
|
|
self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.')
|
|
|
else:
|
|
|
self.console_push_text(f'Layer file {file_name} loaded.')
|
|
|
self.overlay_layer = layer
|
|
|
self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255
|
|
|
self.show_current_frame()
|
|
|
except FileNotFoundError:
|
|
|
self.console_push_text(f'{file_name} not found.')
|
|
|
|
|
|
def on_save_visualization_toggle(self):
|
|
|
self.save_visualization = self.save_visualization_checkbox.isChecked()
|
|
|
|