| from typing import List, Optional, Iterable |
| import logging |
| from omegaconf import DictConfig |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from matanyone.inference.memory_manager import MemoryManager |
| from matanyone.inference.object_manager import ObjectManager |
| from matanyone.inference.image_feature_store import ImageFeatureStore |
| from matanyone.model.matanyone import MatAnyone |
| from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate |
|
|
| log = logging.getLogger() |
|
|
|
|
| class InferenceCore: |
|
|
| def __init__(self, |
| network: MatAnyone, |
| cfg: DictConfig, |
| *, |
| image_feature_store: ImageFeatureStore = None): |
| self.network = network |
| self.cfg = cfg |
| self.mem_every = cfg.mem_every |
| stagger_updates = cfg.stagger_updates |
| self.chunk_size = cfg.chunk_size |
| self.save_aux = cfg.save_aux |
| self.max_internal_size = cfg.max_internal_size |
| self.flip_aug = cfg.flip_aug |
|
|
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| |
| if stagger_updates >= self.mem_every: |
| self.stagger_ti = set(range(1, self.mem_every + 1)) |
| else: |
| self.stagger_ti = set( |
| np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int)) |
| self.object_manager = ObjectManager() |
| self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager) |
|
|
| if image_feature_store is None: |
| self.image_feature_store = ImageFeatureStore(self.network) |
| else: |
| self.image_feature_store = image_feature_store |
|
|
| self.last_mask = None |
| self.last_pix_feat = None |
| self.last_msk_value = None |
|
|
| def clear_memory(self): |
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager) |
|
|
| def clear_non_permanent_memory(self): |
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| self.memory.clear_non_permanent_memory() |
|
|
| def clear_sensory_memory(self): |
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| self.memory.clear_sensory_memory() |
|
|
| def update_config(self, cfg): |
| self.mem_every = cfg['mem_every'] |
| self.memory.update_config(cfg) |
| |
| def clear_temp_mem(self): |
| self.memory.clear_work_mem() |
| |
| self.memory.clear_obj_mem() |
| |
|
|
| def _add_memory(self, |
| image: torch.Tensor, |
| pix_feat: torch.Tensor, |
| prob: torch.Tensor, |
| key: torch.Tensor, |
| shrinkage: torch.Tensor, |
| selection: torch.Tensor, |
| *, |
| is_deep_update: bool = True, |
| force_permanent: bool = False) -> None: |
| """ |
| Memorize the given segmentation in all memory stores. |
| |
| The batch dimension is 1 if flip augmentation is not used. |
| image: RGB image, (1/2)*3*H*W |
| pix_feat: from the key encoder, (1/2)*_*H*W |
| prob: (1/2)*num_objects*H*W, in [0, 1] |
| key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W |
| selection can be None if not using long-term memory |
| is_deep_update: whether to use deep update (e.g. with the mask encoder) |
| force_permanent: whether to force the memory to be permanent |
| """ |
| if prob.shape[1] == 0: |
| |
| log.warn('Trying to add an empty object mask to memory!') |
| return |
|
|
| if force_permanent: |
| as_permanent = 'all' |
| else: |
| as_permanent = 'first' |
|
|
| self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids) |
| msk_value, sensory, obj_value, _ = self.network.encode_mask( |
| image, |
| pix_feat, |
| self.memory.get_sensory(self.object_manager.all_obj_ids), |
| prob, |
| deep_update=is_deep_update, |
| chunk_size=self.chunk_size, |
| need_weights=self.save_aux) |
| self.memory.add_memory(key, |
| shrinkage, |
| msk_value, |
| obj_value, |
| self.object_manager.all_obj_ids, |
| selection=selection, |
| as_permanent=as_permanent) |
| self.last_mem_ti = self.curr_ti |
| if is_deep_update: |
| self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) |
| self.last_msk_value = msk_value |
|
|
| def _segment(self, |
| key: torch.Tensor, |
| selection: torch.Tensor, |
| pix_feat: torch.Tensor, |
| ms_features: Iterable[torch.Tensor], |
| update_sensory: bool = True) -> torch.Tensor: |
| """ |
| Produce a segmentation using the given features and the memory |
| |
| The batch dimension is 1 if flip augmentation is not used. |
| key/selection: for anisotropic l2: (1/2) * _ * H * W |
| pix_feat: from the key encoder, (1/2) * _ * H * W |
| ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W |
| with strides 16, 8, and 4 respectively |
| update_sensory: whether to update the sensory memory |
| |
| Returns: (num_objects+1)*H*W normalized probability; the first channel is the background |
| """ |
| bs = key.shape[0] |
| if self.flip_aug: |
| assert bs == 2 |
| else: |
| assert bs == 1 |
|
|
| if not self.memory.engaged: |
| log.warn('Trying to segment without any memory!') |
| return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), |
| device=key.device, |
| dtype=key.dtype) |
| |
| uncert_output = None |
|
|
| if self.curr_ti == 0: |
| memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output) |
| else: |
| memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti, |
| last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask) |
| memory_readout = self.object_manager.realize_dict(memory_readout) |
|
|
| sensory, _, pred_prob_with_bg = self.network.segment(ms_features, |
| memory_readout, |
| self.memory.get_sensory( |
| self.object_manager.all_obj_ids), |
| chunk_size=self.chunk_size, |
| update_sensory=update_sensory) |
| |
| if self.flip_aug: |
| |
| pred_prob_with_bg = (pred_prob_with_bg[0] + |
| torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2 |
| else: |
| pred_prob_with_bg = pred_prob_with_bg[0] |
| if update_sensory: |
| self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) |
| return pred_prob_with_bg |
| |
| def pred_all_flow(self, images): |
| self.total_len = images.shape[0] |
| images, self.pad = pad_divide_by(images, 16) |
| images = images.unsqueeze(0) |
| |
| self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images) |
|
|
| def encode_all_images(self, images): |
| images, self.pad = pad_divide_by(images, 16) |
| self.image_feature_store.get_all_features(images) |
| return images |
|
|
| def step(self, |
| image: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| objects: Optional[List[int]] = None, |
| *, |
| idx_mask: bool = False, |
| end: bool = False, |
| delete_buffer: bool = True, |
| force_permanent: bool = False, |
| matting: bool = True, |
| first_frame_pred: bool = False) -> torch.Tensor: |
| """ |
| Take a step with a new incoming image. |
| If there is an incoming mask with new objects, we will memorize them. |
| If there is no incoming mask, we will segment the image using the memory. |
| In both cases, we will update the memory and return a segmentation. |
| |
| image: 3*H*W |
| mask: H*W (if idx mask) or len(objects)*H*W or None |
| objects: list of object ids that are valid in the mask Tensor. |
| The ids themselves do not need to be consecutive/in order, but they need to be |
| in the same position in the list as the corresponding mask |
| in the tensor in non-idx-mask mode. |
| objects is ignored if the mask is None. |
| If idx_mask is False and objects is None, we sequentially infer the object ids. |
| idx_mask: if True, mask is expected to contain an object id at every pixel. |
| If False, mask should have multiple channels with each channel representing one object. |
| end: if we are at the end of the sequence, we do not need to update memory |
| if unsure just set it to False |
| delete_buffer: whether to delete the image feature buffer after this step |
| force_permanent: the memory recorded this frame will be added to the permanent memory |
| """ |
| if objects is None and mask is not None: |
| assert not idx_mask |
| objects = list(range(1, mask.shape[0] + 1)) |
|
|
| |
| resize_needed = False |
| if self.max_internal_size > 0: |
| h, w = image.shape[-2:] |
| min_side = min(h, w) |
| if min_side > self.max_internal_size: |
| resize_needed = True |
| new_h = int(h / min_side * self.max_internal_size) |
| new_w = int(w / min_side * self.max_internal_size) |
| image = F.interpolate(image.unsqueeze(0), |
| size=(new_h, new_w), |
| mode='bilinear', |
| align_corners=False)[0] |
| if mask is not None: |
| if idx_mask: |
| mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), |
| size=(new_h, new_w), |
| mode='nearest-exact', |
| align_corners=False)[0, 0].round().long() |
| else: |
| mask = F.interpolate(mask.unsqueeze(0), |
| size=(new_h, new_w), |
| mode='bilinear', |
| align_corners=False)[0] |
|
|
| self.curr_ti += 1 |
|
|
| image, self.pad = pad_divide_by(image, 16) |
| image = image.unsqueeze(0) |
| if self.flip_aug: |
| image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0) |
|
|
| |
| is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or |
| (mask is not None)) and (not end) |
| |
| need_segment = (mask is None) or (self.object_manager.num_obj > 0 |
| and not self.object_manager.has_all(objects)) |
| update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end) |
|
|
| |
| if first_frame_pred: |
| self.curr_ti = 0 |
| self.last_mem_ti = 0 |
| is_mem_frame = True |
| need_segment = True |
| update_sensory = True |
|
|
| |
| ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) |
| key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image) |
|
|
| |
| if need_segment: |
| pred_prob_with_bg = self._segment(key, |
| selection, |
| pix_feat, |
| ms_feat, |
| update_sensory=update_sensory) |
|
|
| |
| if mask is not None: |
| |
| |
| |
| corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects) |
|
|
| mask, _ = pad_divide_by(mask, 16) |
| if need_segment: |
| |
| pred_prob_no_bg = pred_prob_with_bg[1:] |
| |
| if idx_mask: |
| pred_prob_no_bg[:, mask > 0] = 0 |
| else: |
| pred_prob_no_bg[:, mask.max(0) > 0.5] = 0 |
|
|
| new_masks = [] |
| for mask_id, tmp_id in enumerate(corresponding_tmp_ids): |
| if idx_mask: |
| this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg) |
| else: |
| this_mask = mask[tmp_id] |
| if tmp_id > pred_prob_no_bg.shape[0]: |
| new_masks.append(this_mask.unsqueeze(0)) |
| else: |
| |
| pred_prob_no_bg[tmp_id - 1] = this_mask |
| |
| mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0) |
| elif idx_mask: |
| |
| if len(objects) == 0: |
| if delete_buffer: |
| self.image_feature_store.delete(self.curr_ti) |
| log.warn('Trying to insert an empty mask as memory!') |
| return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), |
| device=key.device, |
| dtype=key.dtype) |
| mask = torch.stack( |
| [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)], |
| dim=0) |
| if matting: |
| mask = mask.unsqueeze(0).float() / 255. |
| pred_prob_with_bg = torch.cat([1-mask, mask], 0) |
| else: |
| pred_prob_with_bg = aggregate(mask, dim=0) |
| pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0) |
|
|
| self.last_mask = pred_prob_with_bg[1:].unsqueeze(0) |
| if self.flip_aug: |
| self.last_mask = torch.cat( |
| [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0) |
| self.last_pix_feat = pix_feat |
|
|
| |
| if is_mem_frame or force_permanent: |
| |
| if first_frame_pred: |
| self.clear_temp_mem() |
| self._add_memory(image, |
| pix_feat, |
| self.last_mask, |
| key, |
| shrinkage, |
| selection, |
| force_permanent=force_permanent, |
| is_deep_update=True) |
| else: |
| msk_value, _, _, _ = self.network.encode_mask( |
| image, |
| pix_feat, |
| self.memory.get_sensory(self.object_manager.all_obj_ids), |
| self.last_mask, |
| deep_update=False, |
| chunk_size=self.chunk_size, |
| need_weights=self.save_aux) |
| self.last_msk_value = msk_value |
|
|
| if delete_buffer: |
| self.image_feature_store.delete(self.curr_ti) |
|
|
| output_prob = unpad(pred_prob_with_bg, self.pad) |
| if resize_needed: |
| |
| output_prob = F.interpolate(output_prob.unsqueeze(0), |
| size=(h, w), |
| mode='bilinear', |
| align_corners=False)[0] |
|
|
| return output_prob |
|
|
| def delete_objects(self, objects: List[int]) -> None: |
| """ |
| Delete the given objects from the memory. |
| """ |
| self.object_manager.delete_objects(objects) |
| self.memory.purge_except(self.object_manager.all_obj_ids) |
|
|
| def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor: |
| if matting: |
| new_mask = output_prob[1:].squeeze(0) |
| else: |
| mask = torch.argmax(output_prob, dim=0) |
|
|
| |
| new_mask = torch.zeros_like(mask) |
| for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): |
| new_mask[mask == tmp_id] = obj.id |
|
|
| return new_mask |
|
|