Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| from copy import deepcopy | |
| import cv2 | |
| class Clicker(object): | |
| def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): | |
| self.click_indx_offset = click_indx_offset | |
| if gt_mask is not None: | |
| self.gt_mask = gt_mask == 1 | |
| self.not_ignore_mask = gt_mask != ignore_label | |
| else: | |
| self.gt_mask = None | |
| self.reset_clicks() | |
| if init_clicks is not None: | |
| for click in init_clicks: | |
| self.add_click(click) | |
| def make_next_click(self, pred_mask): | |
| assert self.gt_mask is not None | |
| click = self._get_next_click(pred_mask) | |
| self.add_click(click) | |
| def get_clicks(self, clicks_limit=None): | |
| return self.clicks_list[:clicks_limit] | |
| def _get_next_click(self, pred_mask, padding=True): | |
| fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) | |
| fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) | |
| if padding: | |
| fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') | |
| fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') | |
| fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) | |
| fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) | |
| if padding: | |
| fn_mask_dt = fn_mask_dt[1:-1, 1:-1] | |
| fp_mask_dt = fp_mask_dt[1:-1, 1:-1] | |
| fn_mask_dt = fn_mask_dt * self.not_clicked_map | |
| fp_mask_dt = fp_mask_dt * self.not_clicked_map | |
| fn_max_dist = np.max(fn_mask_dt) | |
| fp_max_dist = np.max(fp_mask_dt) | |
| is_positive = fn_max_dist > fp_max_dist | |
| if is_positive: | |
| coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] | |
| else: | |
| coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] | |
| return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) | |
| def add_click(self, click): | |
| coords = click.coords | |
| click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks | |
| if click.is_positive: | |
| self.num_pos_clicks += 1 | |
| else: | |
| self.num_neg_clicks += 1 | |
| self.clicks_list.append(click) | |
| if self.gt_mask is not None: | |
| self.not_clicked_map[coords[0], coords[1]] = False | |
| def _remove_last_click(self): | |
| click = self.clicks_list.pop() | |
| coords = click.coords | |
| if click.is_positive: | |
| self.num_pos_clicks -= 1 | |
| else: | |
| self.num_neg_clicks -= 1 | |
| if self.gt_mask is not None: | |
| self.not_clicked_map[coords[0], coords[1]] = True | |
| def reset_clicks(self): | |
| if self.gt_mask is not None: | |
| self.not_clicked_map = np.ones_like(self.gt_mask, dtype=bool) | |
| self.num_pos_clicks = 0 | |
| self.num_neg_clicks = 0 | |
| self.clicks_list = [] | |
| def get_state(self): | |
| return deepcopy(self.clicks_list) | |
| def set_state(self, state): | |
| self.reset_clicks() | |
| for click in state: | |
| self.add_click(click) | |
| def __len__(self): | |
| return len(self.clicks_list) | |
| class Click: | |
| def __init__(self, is_positive, coords, indx=None): | |
| self.is_positive = is_positive | |
| self.coords = coords | |
| self.indx = indx | |
| def coords_and_indx(self): | |
| return (*self.coords, self.indx) | |
| def copy(self, **kwargs): | |
| self_copy = deepcopy(self) | |
| for k, v in kwargs.items(): | |
| setattr(self_copy, k, v) | |
| return self_copy | |