Spaces:
Runtime error
Runtime error
| from copy import deepcopy | |
| import numpy as np | |
| from albumentations import ReplayCompose | |
| from isegm.data.transforms import remove_image_only_transforms | |
| from isegm.utils.misc import get_labels_with_sizes | |
| class DSample: | |
| def __init__( | |
| self, | |
| image, | |
| encoded_masks, | |
| objects=None, | |
| objects_ids=None, | |
| ignore_ids=None, | |
| sample_id=None, | |
| ): | |
| self.image = image | |
| self.sample_id = sample_id | |
| if len(encoded_masks.shape) == 2: | |
| encoded_masks = encoded_masks[:, :, np.newaxis] | |
| self._encoded_masks = encoded_masks | |
| self._ignored_regions = [] | |
| if objects_ids is not None: | |
| if not objects_ids or not isinstance(objects_ids[0], tuple): | |
| assert encoded_masks.shape[2] == 1 | |
| objects_ids = [(0, obj_id) for obj_id in objects_ids] | |
| self._objects = dict() | |
| for indx, obj_mapping in enumerate(objects_ids): | |
| self._objects[indx] = { | |
| "parent": None, | |
| "mapping": obj_mapping, | |
| "children": [], | |
| } | |
| if ignore_ids: | |
| if isinstance(ignore_ids[0], tuple): | |
| self._ignored_regions = ignore_ids | |
| else: | |
| self._ignored_regions = [(0, region_id) for region_id in ignore_ids] | |
| else: | |
| self._objects = deepcopy(objects) | |
| self._augmented = False | |
| self._soft_mask_aug = None | |
| self._original_data = self.image, self._encoded_masks, deepcopy(self._objects) | |
| def augment(self, augmentator): | |
| self.reset_augmentation() | |
| aug_output = augmentator(image=self.image, mask=self._encoded_masks) | |
| self.image = aug_output["image"] | |
| self._encoded_masks = aug_output["mask"] | |
| aug_replay = aug_output.get("replay", None) | |
| if aug_replay: | |
| assert len(self._ignored_regions) == 0 | |
| mask_replay = remove_image_only_transforms(aug_replay) | |
| self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay) | |
| self._compute_objects_areas() | |
| self.remove_small_objects(min_area=1) | |
| self._augmented = True | |
| def reset_augmentation(self): | |
| if not self._augmented: | |
| return | |
| orig_image, orig_masks, orig_objects = self._original_data | |
| self.image = orig_image | |
| self._encoded_masks = orig_masks | |
| self._objects = deepcopy(orig_objects) | |
| self._augmented = False | |
| self._soft_mask_aug = None | |
| def remove_small_objects(self, min_area): | |
| if self._objects and not "area" in list(self._objects.values())[0]: | |
| self._compute_objects_areas() | |
| for obj_id, obj_info in list(self._objects.items()): | |
| if obj_info["area"] < min_area: | |
| self._remove_object(obj_id) | |
| def get_object_mask(self, obj_id): | |
| layer_indx, mask_id = self._objects[obj_id]["mapping"] | |
| obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32) | |
| if self._ignored_regions: | |
| for layer_indx, mask_id in self._ignored_regions: | |
| ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id | |
| obj_mask[ignore_mask] = -1 | |
| return obj_mask | |
| def get_soft_object_mask(self, obj_id): | |
| assert self._soft_mask_aug is not None | |
| original_encoded_masks = self._original_data[1] | |
| layer_indx, mask_id = self._objects[obj_id]["mapping"] | |
| obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype( | |
| np.float32 | |
| ) | |
| obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[ | |
| "image" | |
| ] | |
| return np.clip(obj_mask, 0, 1) | |
| def get_background_mask(self): | |
| return np.max(self._encoded_masks, axis=2) == 0 | |
| def objects_ids(self): | |
| return list(self._objects.keys()) | |
| def gt_mask(self): | |
| assert len(self._objects) == 1 | |
| return self.get_object_mask(self.objects_ids[0]) | |
| def root_objects(self): | |
| return [ | |
| obj_id | |
| for obj_id, obj_info in self._objects.items() | |
| if obj_info["parent"] is None | |
| ] | |
| def _compute_objects_areas(self): | |
| inverse_index = { | |
| node["mapping"]: node_id for node_id, node in self._objects.items() | |
| } | |
| ignored_regions_keys = set(self._ignored_regions) | |
| for layer_indx in range(self._encoded_masks.shape[2]): | |
| objects_ids, objects_areas = get_labels_with_sizes( | |
| self._encoded_masks[:, :, layer_indx] | |
| ) | |
| for obj_id, obj_area in zip(objects_ids, objects_areas): | |
| inv_key = (layer_indx, obj_id) | |
| if inv_key in ignored_regions_keys: | |
| continue | |
| try: | |
| self._objects[inverse_index[inv_key]]["area"] = obj_area | |
| del inverse_index[inv_key] | |
| except KeyError: | |
| layer = self._encoded_masks[:, :, layer_indx] | |
| layer[layer == obj_id] = 0 | |
| self._encoded_masks[:, :, layer_indx] = layer | |
| for obj_id in inverse_index.values(): | |
| self._objects[obj_id]["area"] = 0 | |
| def _remove_object(self, obj_id): | |
| obj_info = self._objects[obj_id] | |
| obj_parent = obj_info["parent"] | |
| for child_id in obj_info["children"]: | |
| self._objects[child_id]["parent"] = obj_parent | |
| if obj_parent is not None: | |
| parent_children = self._objects[obj_parent]["children"] | |
| parent_children = [x for x in parent_children if x != obj_id] | |
| self._objects[obj_parent]["children"] = ( | |
| parent_children + obj_info["children"] | |
| ) | |
| del self._objects[obj_id] | |
| def __len__(self): | |
| return len(self._objects) | |