Spaces:
No application file
No application file
| import torch | |
| import numpy as np | |
| from tkinter import messagebox | |
| from isegm.inference import clicker | |
| from isegm.inference.predictors import get_predictor | |
| from isegm.utils.vis import draw_with_blend_and_clicks | |
| class InteractiveController: | |
| def __init__(self, net, device, predictor_params, update_image_callback, prob_thresh=0.5): | |
| self.net = net | |
| self.prob_thresh = prob_thresh | |
| self.clicker = clicker.Clicker() | |
| self.states = [] | |
| self.probs_history = [] | |
| self.object_count = 0 | |
| self._result_mask = None | |
| self._init_mask = None | |
| self.image = None | |
| self.predictor = None | |
| self.device = device | |
| self.update_image_callback = update_image_callback | |
| self.predictor_params = predictor_params | |
| self.reset_predictor() | |
| def set_image(self, image): | |
| self.image = image | |
| self._result_mask = np.zeros(image.shape[:2], dtype=np.uint16) | |
| self.object_count = 0 | |
| self.reset_last_object(update_image=False) | |
| self.update_image_callback(reset_canvas=True) | |
| def set_mask(self, mask): | |
| if self.image.shape[:2] != mask.shape[:2]: | |
| messagebox.showwarning("Warning", "A segmentation mask must have the same sizes as the current image!") | |
| return | |
| if len(self.probs_history) > 0: | |
| self.reset_last_object() | |
| self._init_mask = mask.astype(np.float32) | |
| self.probs_history.append((np.zeros_like(self._init_mask), self._init_mask)) | |
| self._init_mask = torch.tensor(self._init_mask, device=self.device).unsqueeze(0).unsqueeze(0) | |
| self.clicker.click_indx_offset = 1 | |
| def add_click(self, x, y, is_positive): | |
| self.states.append({ | |
| 'clicker': self.clicker.get_state(), | |
| 'predictor': self.predictor.get_states() | |
| }) | |
| click = clicker.Click(is_positive=is_positive, coords=(y, x)) | |
| self.clicker.add_click(click) | |
| pred = self.predictor.get_prediction(self.clicker, prev_mask=self._init_mask) | |
| if self._init_mask is not None and len(self.clicker) == 1: | |
| pred = self.predictor.get_prediction(self.clicker, prev_mask=self._init_mask) | |
| torch.cuda.empty_cache() | |
| if self.probs_history: | |
| self.probs_history.append((self.probs_history[-1][0], pred)) | |
| else: | |
| self.probs_history.append((np.zeros_like(pred), pred)) | |
| self.update_image_callback() | |
| def undo_click(self): | |
| if not self.states: | |
| return | |
| prev_state = self.states.pop() | |
| self.clicker.set_state(prev_state['clicker']) | |
| self.predictor.set_states(prev_state['predictor']) | |
| self.probs_history.pop() | |
| if not self.probs_history: | |
| self.reset_init_mask() | |
| self.update_image_callback() | |
| def partially_finish_object(self): | |
| object_prob = self.current_object_prob | |
| if object_prob is None: | |
| return | |
| self.probs_history.append((object_prob, np.zeros_like(object_prob))) | |
| self.states.append(self.states[-1]) | |
| self.clicker.reset_clicks() | |
| self.reset_predictor() | |
| self.reset_init_mask() | |
| self.update_image_callback() | |
| def finish_object(self): | |
| if self.current_object_prob is None: | |
| return | |
| self._result_mask = self.result_mask | |
| self.object_count += 1 | |
| self.reset_last_object() | |
| def reset_last_object(self, update_image=True): | |
| self.states = [] | |
| self.probs_history = [] | |
| self.clicker.reset_clicks() | |
| self.reset_predictor() | |
| self.reset_init_mask() | |
| if update_image: | |
| self.update_image_callback() | |
| def reset_predictor(self, predictor_params=None): | |
| if predictor_params is not None: | |
| self.predictor_params = predictor_params | |
| self.predictor = get_predictor(self.net, device=self.device, | |
| **self.predictor_params) | |
| if self.image is not None: | |
| self.predictor.set_input_image(self.image) | |
| def reset_init_mask(self): | |
| self._init_mask = None | |
| self.clicker.click_indx_offset = 0 | |
| def current_object_prob(self): | |
| if self.probs_history: | |
| current_prob_total, current_prob_additive = self.probs_history[-1] | |
| return np.maximum(current_prob_total, current_prob_additive) | |
| else: | |
| return None | |
| def is_incomplete_mask(self): | |
| return len(self.probs_history) > 0 | |
| def result_mask(self): | |
| result_mask = self._result_mask.copy() | |
| if self.probs_history: | |
| result_mask[self.current_object_prob > self.prob_thresh] = self.object_count + 1 | |
| return result_mask | |
| def get_visualization(self, alpha_blend, click_radius): | |
| if self.image is None: | |
| return None | |
| results_mask_for_vis = self.result_mask | |
| vis = draw_with_blend_and_clicks(self.image, mask=results_mask_for_vis, alpha=alpha_blend, | |
| clicks_list=self.clicker.clicks_list, radius=click_radius) | |
| if self.probs_history: | |
| total_mask = self.probs_history[-1][0] > self.prob_thresh | |
| results_mask_for_vis[np.logical_not(total_mask)] = 0 | |
| vis = draw_with_blend_and_clicks(vis, mask=results_mask_for_vis, alpha=alpha_blend) | |
| return vis | |