import torch import numpy as np from tkinter import messagebox from isegm.inference import clicker from isegm.inference.predictors import get_predictor from isegm.utils.vis import draw_with_blend_and_clicks class InteractiveController: def __init__(self, net, device, predictor_params, update_image_callback, prob_thresh=0.5): self.net = net self.prob_thresh = prob_thresh self.clicker = clicker.Clicker() self.states = [] self.probs_history = [] self.object_count = 0 self._result_mask = None self._init_mask = None self.image = None self.predictor = None self.device = device self.update_image_callback = update_image_callback self.predictor_params = predictor_params self.reset_predictor() def set_image(self, image): self.image = image self._result_mask = np.zeros(image.shape[:2], dtype=np.uint16) self.object_count = 0 self.reset_last_object(update_image=False) self.update_image_callback(reset_canvas=True) def set_mask(self, mask): if self.image.shape[:2] != mask.shape[:2]: messagebox.showwarning("Warning", "A segmentation mask must have the same sizes as the current image!") return if len(self.probs_history) > 0: self.reset_last_object() self._init_mask = mask.astype(np.float32) self.probs_history.append((np.zeros_like(self._init_mask), self._init_mask)) self._init_mask = torch.tensor(self._init_mask, device=self.device).unsqueeze(0).unsqueeze(0) self.clicker.click_indx_offset = 1 def add_click(self, x, y, is_positive): self.states.append({ 'clicker': self.clicker.get_state(), 'predictor': self.predictor.get_states() }) click = clicker.Click(is_positive=is_positive, coords=(y, x)) self.clicker.add_click(click) pred = self.predictor.get_prediction(self.clicker, prev_mask=self._init_mask) if self._init_mask is not None and len(self.clicker) == 1: pred = self.predictor.get_prediction(self.clicker, prev_mask=self._init_mask) torch.cuda.empty_cache() if self.probs_history: self.probs_history.append((self.probs_history[-1][0], pred)) else: self.probs_history.append((np.zeros_like(pred), pred)) self.update_image_callback() def undo_click(self): if not self.states: return prev_state = self.states.pop() self.clicker.set_state(prev_state['clicker']) self.predictor.set_states(prev_state['predictor']) self.probs_history.pop() if not self.probs_history: self.reset_init_mask() self.update_image_callback() def partially_finish_object(self): object_prob = self.current_object_prob if object_prob is None: return self.probs_history.append((object_prob, np.zeros_like(object_prob))) self.states.append(self.states[-1]) self.clicker.reset_clicks() self.reset_predictor() self.reset_init_mask() self.update_image_callback() def finish_object(self): if self.current_object_prob is None: return self._result_mask = self.result_mask self.object_count += 1 self.reset_last_object() def reset_last_object(self, update_image=True): self.states = [] self.probs_history = [] self.clicker.reset_clicks() self.reset_predictor() self.reset_init_mask() if update_image: self.update_image_callback() def reset_predictor(self, predictor_params=None): if predictor_params is not None: self.predictor_params = predictor_params self.predictor = get_predictor(self.net, device=self.device, **self.predictor_params) if self.image is not None: self.predictor.set_input_image(self.image) def reset_init_mask(self): self._init_mask = None self.clicker.click_indx_offset = 0 @property def current_object_prob(self): if self.probs_history: current_prob_total, current_prob_additive = self.probs_history[-1] return np.maximum(current_prob_total, current_prob_additive) else: return None @property def is_incomplete_mask(self): return len(self.probs_history) > 0 @property def result_mask(self): result_mask = self._result_mask.copy() if self.probs_history: result_mask[self.current_object_prob > self.prob_thresh] = self.object_count + 1 return result_mask def get_visualization(self, alpha_blend, click_radius): if self.image is None: return None results_mask_for_vis = self.result_mask vis = draw_with_blend_and_clicks(self.image, mask=results_mask_for_vis, alpha=alpha_blend, clicks_list=self.clicker.clicks_list, radius=click_radius) if self.probs_history: total_mask = self.probs_history[-1][0] > self.prob_thresh results_mask_for_vis[np.logical_not(total_mask)] = 0 vis = draw_with_blend_and_clicks(vis, mask=results_mask_for_vis, alpha=alpha_blend) return vis