| import logging |
| from omegaconf import DictConfig |
| from typing import List, Optional, Iterable, Union,Tuple |
|
|
| import os |
| import cv2 |
| import torch |
| import imageio |
| import tempfile |
| import numpy as np |
| from tqdm import tqdm |
| from PIL import Image |
| import torch.nn.functional as F |
|
|
| from matanyone2.inference.memory_manager import MemoryManager |
| from matanyone2.inference.object_manager import ObjectManager |
| from matanyone2.inference.image_feature_store import ImageFeatureStore |
| from matanyone2.model.matanyone2 import MatAnyone2 |
| from matanyone2.utils.tensor_utils import pad_divide_by, unpad, aggregate |
| from matanyone2.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos |
| from matanyone2.utils.device import get_default_device, safe_autocast |
|
|
|
|
| log = logging.getLogger() |
|
|
|
|
| class InferenceCore: |
|
|
| def __init__(self, |
| network: Union[MatAnyone2,str], |
| cfg: DictConfig = None, |
| *, |
| image_feature_store: ImageFeatureStore = None, |
| device: Optional[Union[str, torch.device]] = None |
| ): |
| if device is None: |
| device = get_default_device() |
| self.device = device |
| if isinstance(network, str): |
| network = MatAnyone2.from_pretrained(network) |
| network.to(device) |
| network.eval() |
| self.network = network |
| cfg = cfg if cfg is not None else network.cfg |
| self.cfg = cfg |
| self.mem_every = cfg.mem_every |
| stagger_updates = cfg.stagger_updates |
| self.chunk_size = cfg.chunk_size |
| self.save_aux = cfg.save_aux |
| self.max_internal_size = cfg.max_internal_size |
| self.flip_aug = cfg.flip_aug |
|
|
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| |
| if stagger_updates >= self.mem_every: |
| self.stagger_ti = set(range(1, self.mem_every + 1)) |
| else: |
| self.stagger_ti = set( |
| np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int)) |
| self.object_manager = ObjectManager() |
| self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager) |
|
|
| if image_feature_store is None: |
| self.image_feature_store = ImageFeatureStore(self.network) |
| else: |
| self.image_feature_store = image_feature_store |
|
|
| self.last_mask = None |
| self.last_pix_feat = None |
| self.last_msk_value = None |
|
|
| def clear_memory(self): |
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager) |
|
|
| def clear_non_permanent_memory(self): |
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| self.memory.clear_non_permanent_memory() |
|
|
| def clear_sensory_memory(self): |
| self.curr_ti = -1 |
| self.last_mem_ti = 0 |
| self.memory.clear_sensory_memory() |
|
|
| def update_config(self, cfg): |
| self.mem_every = cfg['mem_every'] |
| self.memory.update_config(cfg) |
| |
| def clear_temp_mem(self): |
| self.memory.clear_work_mem() |
| |
| self.memory.clear_obj_mem() |
| |
|
|
| def _add_memory(self, |
| image: torch.Tensor, |
| pix_feat: torch.Tensor, |
| prob: torch.Tensor, |
| key: torch.Tensor, |
| shrinkage: torch.Tensor, |
| selection: torch.Tensor, |
| *, |
| is_deep_update: bool = True, |
| force_permanent: bool = False) -> None: |
| """ |
| Memorize the given segmentation in all memory stores. |
| |
| The batch dimension is 1 if flip augmentation is not used. |
| image: RGB image, (1/2)*3*H*W |
| pix_feat: from the key encoder, (1/2)*_*H*W |
| prob: (1/2)*num_objects*H*W, in [0, 1] |
| key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W |
| selection can be None if not using long-term memory |
| is_deep_update: whether to use deep update (e.g. with the mask encoder) |
| force_permanent: whether to force the memory to be permanent |
| """ |
| if prob.shape[1] == 0: |
| |
| log.warn('Trying to add an empty object mask to memory!') |
| return |
|
|
| if force_permanent: |
| as_permanent = 'all' |
| else: |
| as_permanent = 'first' |
|
|
| self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids) |
| msk_value, sensory, obj_value, _ = self.network.encode_mask( |
| image, |
| pix_feat, |
| self.memory.get_sensory(self.object_manager.all_obj_ids), |
| prob, |
| deep_update=is_deep_update, |
| chunk_size=self.chunk_size, |
| need_weights=self.save_aux) |
| self.memory.add_memory(key, |
| shrinkage, |
| msk_value, |
| obj_value, |
| self.object_manager.all_obj_ids, |
| selection=selection, |
| as_permanent=as_permanent) |
| self.last_mem_ti = self.curr_ti |
| if is_deep_update: |
| self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) |
| self.last_msk_value = msk_value |
|
|
| def _segment(self, |
| key: torch.Tensor, |
| selection: torch.Tensor, |
| pix_feat: torch.Tensor, |
| ms_features: Iterable[torch.Tensor], |
| update_sensory: bool = True) -> torch.Tensor: |
| """ |
| Produce a segmentation using the given features and the memory |
| |
| The batch dimension is 1 if flip augmentation is not used. |
| key/selection: for anisotropic l2: (1/2) * _ * H * W |
| pix_feat: from the key encoder, (1/2) * _ * H * W |
| ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W |
| with strides 16, 8, and 4 respectively |
| update_sensory: whether to update the sensory memory |
| |
| Returns: (num_objects+1)*H*W normalized probability; the first channel is the background |
| """ |
| bs = key.shape[0] |
| if self.flip_aug: |
| assert bs == 2 |
| else: |
| assert bs == 1 |
|
|
| if not self.memory.engaged: |
| log.warn('Trying to segment without any memory!') |
| return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), |
| device=key.device, |
| dtype=key.dtype) |
| |
| uncert_output = None |
|
|
| if self.curr_ti == 0: |
| memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output) |
| else: |
| memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti, |
| last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask) |
| memory_readout = self.object_manager.realize_dict(memory_readout) |
|
|
| sensory, _, pred_prob_with_bg = self.network.segment(ms_features, |
| memory_readout, |
| self.memory.get_sensory( |
| self.object_manager.all_obj_ids), |
| chunk_size=self.chunk_size, |
| update_sensory=update_sensory) |
| |
| if self.flip_aug: |
| |
| pred_prob_with_bg = (pred_prob_with_bg[0] + |
| torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2 |
| else: |
| pred_prob_with_bg = pred_prob_with_bg[0] |
| if update_sensory: |
| self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) |
| return pred_prob_with_bg |
| |
| def pred_all_flow(self, images): |
| self.total_len = images.shape[0] |
| images, self.pad = pad_divide_by(images, 16) |
| images = images.unsqueeze(0) |
| |
| self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images) |
|
|
| def encode_all_images(self, images): |
| images, self.pad = pad_divide_by(images, 16) |
| self.image_feature_store.get_all_features(images) |
| return images |
|
|
| def step(self, |
| image: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| objects: Optional[List[int]] = None, |
| *, |
| idx_mask: bool = False, |
| end: bool = False, |
| delete_buffer: bool = True, |
| force_permanent: bool = False, |
| matting: bool = True, |
| first_frame_pred: bool = False) -> torch.Tensor: |
| """ |
| Take a step with a new incoming image. |
| If there is an incoming mask with new objects, we will memorize them. |
| If there is no incoming mask, we will segment the image using the memory. |
| In both cases, we will update the memory and return a segmentation. |
| |
| image: 3*H*W |
| mask: H*W (if idx mask) or len(objects)*H*W or None |
| objects: list of object ids that are valid in the mask Tensor. |
| The ids themselves do not need to be consecutive/in order, but they need to be |
| in the same position in the list as the corresponding mask |
| in the tensor in non-idx-mask mode. |
| objects is ignored if the mask is None. |
| If idx_mask is False and objects is None, we sequentially infer the object ids. |
| idx_mask: if True, mask is expected to contain an object id at every pixel. |
| If False, mask should have multiple channels with each channel representing one object. |
| end: if we are at the end of the sequence, we do not need to update memory |
| if unsure just set it to False |
| delete_buffer: whether to delete the image feature buffer after this step |
| force_permanent: the memory recorded this frame will be added to the permanent memory |
| """ |
| if objects is None and mask is not None: |
| assert not idx_mask |
| objects = list(range(1, mask.shape[0] + 1)) |
|
|
| |
| resize_needed = False |
| if self.max_internal_size > 0: |
| h, w = image.shape[-2:] |
| min_side = min(h, w) |
| if min_side > self.max_internal_size: |
| resize_needed = True |
| new_h = int(h / min_side * self.max_internal_size) |
| new_w = int(w / min_side * self.max_internal_size) |
| image = F.interpolate(image.unsqueeze(0), |
| size=(new_h, new_w), |
| mode='bilinear', |
| align_corners=False)[0] |
| if mask is not None: |
| if idx_mask: |
| mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), |
| size=(new_h, new_w), |
| mode='nearest-exact', |
| align_corners=False)[0, 0].round().long() |
| else: |
| mask = F.interpolate(mask.unsqueeze(0), |
| size=(new_h, new_w), |
| mode='bilinear', |
| align_corners=False)[0] |
|
|
| self.curr_ti += 1 |
|
|
| image, self.pad = pad_divide_by(image, 16) |
| image = image.unsqueeze(0) |
| if self.flip_aug: |
| image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0) |
|
|
| |
| is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or |
| (mask is not None)) and (not end) |
| |
| need_segment = (mask is None) or (self.object_manager.num_obj > 0 |
| and not self.object_manager.has_all(objects)) |
| update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end) |
|
|
| |
| if first_frame_pred: |
| self.curr_ti = 0 |
| self.last_mem_ti = 0 |
| is_mem_frame = True |
| need_segment = True |
| update_sensory = True |
|
|
| |
| ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) |
| key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image) |
|
|
| |
| if need_segment: |
| pred_prob_with_bg = self._segment(key, |
| selection, |
| pix_feat, |
| ms_feat, |
| update_sensory=update_sensory) |
|
|
| |
| if mask is not None: |
| |
| |
| |
| corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects) |
|
|
| mask, _ = pad_divide_by(mask, 16) |
| if need_segment: |
| |
| pred_prob_no_bg = pred_prob_with_bg[1:] |
| |
| if idx_mask: |
| pred_prob_no_bg[:, mask > 0] = 0 |
| else: |
| pred_prob_no_bg[:, mask.max(0) > 0.5] = 0 |
|
|
| new_masks = [] |
| for mask_id, tmp_id in enumerate(corresponding_tmp_ids): |
| if idx_mask: |
| this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg) |
| else: |
| this_mask = mask[tmp_id] |
| if tmp_id > pred_prob_no_bg.shape[0]: |
| new_masks.append(this_mask.unsqueeze(0)) |
| else: |
| |
| pred_prob_no_bg[tmp_id - 1] = this_mask |
| |
| mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0) |
| elif idx_mask: |
| |
| if len(objects) == 0: |
| if delete_buffer: |
| self.image_feature_store.delete(self.curr_ti) |
| log.warn('Trying to insert an empty mask as memory!') |
| return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), |
| device=key.device, |
| dtype=key.dtype) |
| mask = torch.stack( |
| [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)], |
| dim=0) |
| if matting: |
| mask = mask.unsqueeze(0).float() / 255. |
| pred_prob_with_bg = torch.cat([1-mask, mask], 0) |
| else: |
| pred_prob_with_bg = aggregate(mask, dim=0) |
| pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0) |
|
|
| self.last_mask = pred_prob_with_bg[1:].unsqueeze(0) |
| if self.flip_aug: |
| self.last_mask = torch.cat( |
| [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0) |
| self.last_pix_feat = pix_feat |
|
|
| |
| if is_mem_frame or force_permanent: |
| |
| if first_frame_pred: |
| self.clear_temp_mem() |
| self._add_memory(image, |
| pix_feat, |
| self.last_mask, |
| key, |
| shrinkage, |
| selection, |
| force_permanent=force_permanent, |
| is_deep_update=True) |
| else: |
| msk_value, _, _, _ = self.network.encode_mask( |
| image, |
| pix_feat, |
| self.memory.get_sensory(self.object_manager.all_obj_ids), |
| self.last_mask, |
| deep_update=False, |
| chunk_size=self.chunk_size, |
| need_weights=self.save_aux) |
| self.last_msk_value = msk_value |
|
|
| if delete_buffer: |
| self.image_feature_store.delete(self.curr_ti) |
|
|
| output_prob = unpad(pred_prob_with_bg, self.pad) |
| if resize_needed: |
| |
| output_prob = F.interpolate(output_prob.unsqueeze(0), |
| size=(h, w), |
| mode='bilinear', |
| align_corners=False)[0] |
|
|
| return output_prob |
|
|
| def delete_objects(self, objects: List[int]) -> None: |
| """ |
| Delete the given objects from the memory. |
| """ |
| self.object_manager.delete_objects(objects) |
| self.memory.purge_except(self.object_manager.all_obj_ids) |
|
|
| def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor: |
| if matting: |
| new_mask = output_prob[1:].squeeze(0) |
| else: |
| mask = torch.argmax(output_prob, dim=0) |
|
|
| |
| new_mask = torch.zeros_like(mask) |
| for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): |
| new_mask[mask == tmp_id] = obj.id |
|
|
| return new_mask |
|
|
| @torch.inference_mode() |
| @safe_autocast() |
| def process_video( |
| self, |
| input_path: str, |
| mask_path: str, |
| output_path: str = None, |
| n_warmup: int = 10, |
| r_erode: int = 10, |
| r_dilate: int = 10, |
| suffix: str = "", |
| save_image: bool = False, |
| max_size: int = -1, |
| ) -> Tuple: |
| """ |
| Process a video for object segmentation and matting. |
| This method processes a video file by performing object segmentation and matting on each frame. |
| It supports warmup frames, mask erosion/dilation, and various output options. |
| Args: |
| input_path (str): Path to the input video file |
| mask_path (str): Path to the mask image file used for initial segmentation |
| output_path (str, optional): Directory path where output files will be saved. Defaults to a temporary directory |
| n_warmup (int, optional): Number of warmup frames to use. Defaults to 10 |
| r_erode (int, optional): Erosion radius for mask processing. Defaults to 10 |
| r_dilate (int, optional): Dilation radius for mask processing. Defaults to 10 |
| suffix (str, optional): Suffix to append to output filename. Defaults to "" |
| save_image (bool, optional): Whether to save individual frames. Defaults to False |
| max_size (int, optional): Maximum size for frame dimension. Use -1 for no limit. Defaults to -1 |
| Returns: |
| Tuple[str, str]: A tuple containing: |
| - Path to the output foreground video file (str) |
| - Path to the output alpha matte video file (str) |
| Output: |
| - Saves processed video files with foreground (_fgr) and alpha matte (_pha) |
| - If save_image=True, saves individual frames in separate directories |
| """ |
| output_path = output_path if output_path is not None else tempfile.TemporaryDirectory().name |
| r_erode = int(r_erode) |
| r_dilate = int(r_dilate) |
| n_warmup = int(n_warmup) |
| max_size = int(max_size) |
|
|
| vframes, fps, length, video_name = read_frame_from_videos(input_path) |
| repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1) |
| vframes = torch.cat([repeated_frames, vframes], dim=0).float() |
| length += n_warmup |
|
|
| new_h, new_w = vframes.shape[-2:] |
| if max_size > 0: |
| h, w = new_h, new_w |
| min_side = min(h, w) |
| if min_side > max_size: |
| new_h = int(h / min_side * max_size) |
| new_w = int(w / min_side * max_size) |
| vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area") |
|
|
| os.makedirs(output_path, exist_ok=True) |
| if suffix: |
| video_name = f"{video_name}_{suffix}" |
| if save_image: |
| os.makedirs(f"{output_path}/{video_name}", exist_ok=True) |
| os.makedirs(f"{output_path}/{video_name}/pha", exist_ok=True) |
| os.makedirs(f"{output_path}/{video_name}/fgr", exist_ok=True) |
|
|
| mask = np.array(Image.open(mask_path).convert("L")) |
| if r_dilate > 0: |
| mask = gen_dilate(mask, r_dilate, r_dilate) |
| if r_erode > 0: |
| mask = gen_erosion(mask, r_erode, r_erode) |
| |
| mask = torch.from_numpy(mask).float().to(self.device) |
| if max_size > 0: |
| mask = F.interpolate( |
| mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest" |
| )[0, 0] |
|
|
| bgr = (np.array([120, 255, 155], dtype=np.float32) / 255).reshape((1, 1, 3)) |
| objects = [1] |
|
|
| phas = [] |
| fgrs = [] |
| for ti in tqdm(range(length)): |
| image = vframes[ti] |
| image_np = np.array(image.permute(1, 2, 0)) |
| image = (image / 255.0).float().to(self.device) |
|
|
| if ti == 0: |
| output_prob = self.step(image, mask, objects=objects) |
| output_prob = self.step(image, first_frame_pred=True) |
| else: |
| if ti <= n_warmup: |
| output_prob = self.step(image, first_frame_pred=True) |
| else: |
| output_prob = self.step(image) |
|
|
| mask = self.output_prob_to_mask(output_prob) |
| pha = mask.unsqueeze(2).cpu().numpy() |
| com_np = image_np / 255.0 * pha + bgr * (1 - pha) |
|
|
| if ti > (n_warmup - 1): |
| com_np = (com_np * 255).astype(np.uint8) |
| pha = (pha * 255).astype(np.uint8) |
| fgrs.append(com_np) |
| phas.append(pha) |
| if save_image: |
| cv2.imwrite( |
| f"{output_path}/{video_name}/pha/{str(ti - n_warmup).zfill(5)}.png", |
| pha, |
| ) |
| cv2.imwrite( |
| f"{output_path}/{video_name}/fgr/{str(ti - n_warmup).zfill(5)}.png", |
| com_np[..., [2, 1, 0]], |
| ) |
|
|
| fgrs = np.array(fgrs) |
| phas = np.array(phas) |
| |
| fgr_filename = f"{output_path}/{video_name}_fgr.mp4" |
| alpha_filename = f"{output_path}/{video_name}_pha.mp4" |
| |
| imageio.mimwrite(fgr_filename, fgrs, fps=fps, quality=7) |
| imageio.mimwrite(alpha_filename, phas, fps=fps, quality=7) |
| |
| return (fgr_filename,alpha_filename) |
|
|