Spaces:
Build error
Build error
| import json | |
| import os | |
| from os import path | |
| from pathlib import Path | |
| import shutil | |
| import collections | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| from util.image_loader import PaletteConverter | |
| if not hasattr(Image, 'Resampling'): # Pillow<9.0 | |
| Image.Resampling = Image | |
| import numpy as np | |
| from util.palette import davis_palette | |
| import progressbar | |
| # https://bugs.python.org/issue28178 | |
| # ah python ah why | |
| class LRU: | |
| def __init__(self, func, maxsize=128): | |
| self.cache = collections.OrderedDict() | |
| self.func = func | |
| self.maxsize = maxsize | |
| def __call__(self, *args): | |
| cache = self.cache | |
| if args in cache: | |
| cache.move_to_end(args) | |
| return cache[args] | |
| result = self.func(*args) | |
| cache[args] = result | |
| if len(cache) > self.maxsize: | |
| cache.popitem(last=False) | |
| return result | |
| def invalidate(self, key): | |
| self.cache.pop(key, None) | |
| class ResourceManager: | |
| def __init__(self, config): | |
| # determine inputs | |
| images = config['images'] | |
| video = config['video'] | |
| self.workspace = config['workspace'] | |
| self.size = config['size'] | |
| self.palette = davis_palette | |
| self.palette_converter = PaletteConverter(self.palette) | |
| # create temporary workspace if not specified | |
| if self.workspace is None: | |
| if images is not None: | |
| p_images = Path(images) | |
| if p_images.name == 'JPEGImages' or (Path.cwd() / 'workspace') in p_images.parents: | |
| # take the name instead of actual images dir (second case checks for videos already in ./workspace ) | |
| basename = p_images.parent.name | |
| else: | |
| basename = p_images.name | |
| elif video is not None: | |
| basename = path.basename(video)[:-4] | |
| else: | |
| raise NotImplementedError( | |
| 'Either images, video, or workspace has to be specified') | |
| self.workspace = path.join('./workspace', basename) | |
| print(f'Workspace is in: {self.workspace}') | |
| self.workspace_info_file = path.join(self.workspace, 'info.json') | |
| self.references = set() | |
| self._num_objects = None | |
| self._try_load_info() | |
| if config['num_objects'] is not None: # forced overwrite from user | |
| self._num_objects = config['num_objects'] | |
| elif self._num_objects is None: # both are None, single object first run use case | |
| self._num_objects = config['num_objects_default_value'] | |
| self._save_info() | |
| # determine the location of input images | |
| need_decoding = False | |
| need_resizing = False | |
| if path.exists(path.join(self.workspace, 'images')): | |
| pass | |
| elif images is not None: | |
| need_resizing = True | |
| elif video is not None: | |
| # will decode video into frames later | |
| need_decoding = True | |
| # create workspace subdirectories | |
| self.image_dir = path.join(self.workspace, 'images') | |
| self.mask_dir = path.join(self.workspace, 'masks') | |
| os.makedirs(self.image_dir, exist_ok=True) | |
| os.makedirs(self.mask_dir, exist_ok=True) | |
| # convert read functions to be buffered | |
| self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size']) | |
| self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size']) | |
| # extract frames from video | |
| if need_decoding: | |
| self._extract_frames(video) | |
| # copy/resize existing images to the workspace | |
| if need_resizing: | |
| self._copy_resize_frames(images) | |
| # read all frame names | |
| self.names = sorted(os.listdir(self.image_dir)) | |
| self.names = [f[:-4] for f in self.names] # remove extensions | |
| self.length = len(self.names) | |
| assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.' | |
| print(f'{self.length} images found.') | |
| self.height, self.width = self.get_image(0).shape[:2] | |
| self.visualization_init = False | |
| self._resize = None | |
| self._masks = None | |
| self._keys = None | |
| self._keys_processed = np.zeros(self.length, dtype=bool) | |
| self.key_h = None | |
| self.key_w = None | |
| def _extract_frames(self, video): | |
| cap = cv2.VideoCapture(video) | |
| frame_index = 0 | |
| print(f'Extracting frames from {video} into {self.image_dir}...') | |
| bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength) | |
| while(cap.isOpened()): | |
| _, frame = cap.read() | |
| if frame is None: | |
| break | |
| if self.size > 0: | |
| h, w = frame.shape[:2] | |
| new_w = (w*self.size//min(w, h)) | |
| new_h = (h*self.size//min(w, h)) | |
| if new_w != w or new_h != h: | |
| frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) | |
| cv2.imwrite(path.join(self.image_dir, f'frame_{frame_index:06d}.jpg'), frame) | |
| frame_index += 1 | |
| bar.update(frame_index) | |
| bar.finish() | |
| print('Done!') | |
| def _copy_resize_frames(self, images): | |
| image_list = os.listdir(images) | |
| print(f'Copying/resizing frames into {self.image_dir}...') | |
| for image_name in progressbar.progressbar(image_list): | |
| if self.size < 0: | |
| # just copy | |
| shutil.copy2(path.join(images, image_name), self.image_dir) | |
| else: | |
| frame = cv2.imread(path.join(images, image_name)) | |
| h, w = frame.shape[:2] | |
| new_w = (w*self.size//min(w, h)) | |
| new_h = (h*self.size//min(w, h)) | |
| if new_w != w or new_h != h: | |
| frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) | |
| cv2.imwrite(path.join(self.image_dir, image_name), frame) | |
| print('Done!') | |
| def add_key_and_stuff_with_mask(self, ti, key, shrinkage, selection, mask): | |
| if self._keys is None: | |
| c, h, w = key.squeeze().shape | |
| if self.key_h is None: | |
| self.key_h = h | |
| if self.key_w is None: | |
| self.key_w = w | |
| c_mask, h_mask, w_mask = mask.shape | |
| self._keys = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) | |
| self._shrinkages = torch.empty((self.length, 1, h, w), dtype=key.dtype, device=key.device) | |
| self._selections = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) | |
| self._masks = torch.empty((self.length, c_mask, h_mask, w_mask), dtype=mask.dtype, device=key.device) | |
| # self._resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) | |
| if not self._keys_processed[ti]: | |
| # keys don't change for the video, so we only save them once | |
| self._keys[ti] = key | |
| self._shrinkages[ti] = shrinkage | |
| self._selections[ti] = selection | |
| self._keys_processed[ti] = True | |
| self._masks[ti] = mask# self._resize(mask) | |
| def all_masks_present(self): | |
| return self._keys_processed.sum() == self.length | |
| def add_reference(self, frame_id: int): | |
| self.references.add(frame_id) | |
| self._save_info() | |
| def remove_reference(self, frame_id: int): | |
| print(self.references) | |
| self.references.remove(frame_id) | |
| self._save_info() | |
| def _save_info(self): | |
| p_workspace_subdir = Path(self.workspace_info_file).parent | |
| p_workspace_subdir.mkdir(parents=True, exist_ok=True) | |
| with open(self.workspace_info_file, 'wt') as f: | |
| data = {'references': sorted(self.references), 'num_objects': self._num_objects} | |
| json.dump(data, f, indent=4) | |
| def _try_load_info(self): | |
| try: | |
| with open(self.workspace_info_file) as f: | |
| data = json.load(f) | |
| self._num_objects = data['num_objects'] | |
| # We might have num_objects, but not references if imported the project | |
| self.references = set(data['references']) | |
| except Exception: | |
| pass | |
| def save_mask(self, ti, mask): | |
| # mask should be uint8 H*W without channels | |
| assert 0 <= ti < self.length | |
| assert isinstance(mask, np.ndarray) | |
| mask = Image.fromarray(mask) | |
| mask.putpalette(self.palette) | |
| mask.save(path.join(self.mask_dir, self.names[ti]+'.png')) | |
| self.invalidate(ti) | |
| def save_visualization(self, ti, image): | |
| # image should be uint8 3*H*W | |
| assert 0 <= ti < self.length | |
| assert isinstance(image, np.ndarray) | |
| if not self.visualization_init: | |
| self.visualization_dir = path.join(self.workspace, 'visualization') | |
| os.makedirs(self.visualization_dir, exist_ok=True) | |
| self.visualization_init = True | |
| image = Image.fromarray(image) | |
| image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg')) | |
| def _get_image_unbuffered(self, ti): | |
| # returns H*W*3 uint8 array | |
| assert 0 <= ti < self.length | |
| image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg')) | |
| image = np.array(image) | |
| return image | |
| def _get_mask_unbuffered(self, ti): | |
| # returns H*W uint8 array | |
| assert 0 <= ti < self.length | |
| mask_path = path.join(self.mask_dir, self.names[ti]+'.png') | |
| if path.exists(mask_path): | |
| mask = Image.open(mask_path) | |
| mask = np.array(mask) | |
| return mask | |
| else: | |
| return None | |
| def read_external_image(self, file_name, size=None, force_mask=False): | |
| image = Image.open(file_name) | |
| is_mask = image.mode in ['L', 'P'] | |
| if size is not None: | |
| # PIL uses (width, height) | |
| image = image.resize((size[1], size[0]), | |
| resample=Image.Resampling.NEAREST if is_mask or force_mask else Image.Resampling.BICUBIC) | |
| if force_mask and image.mode != 'P': | |
| image = self.palette_converter.image_to_index_mask(image) | |
| # if image.mode in ['RGB', 'L'] and len(image.getcolors()) <= 2: | |
| # image = np.array(image.convert('L')) | |
| # # hardcoded for b&w images | |
| # image = np.where(image, 1, 0) # 255 (or whatever) -> binarize | |
| # return image.astype('uint8') | |
| # elif image.mode == 'RGB': | |
| # image = image.convert('P', palette=self.palette) | |
| # tmp_image = np.array(image) | |
| # out_image = np.zeros_like(tmp_image) | |
| # for i, c in enumerate(np.unique(tmp_image)): | |
| # if i == 0: | |
| # continue | |
| # out_image[tmp_image == c] = i # palette indices into 0, 1, 2, ... | |
| # self.palette = image.getpalette() | |
| # return out_image | |
| # image = image.convert('P', palette=self.palette) # saved without DAVIS palette, just number objects 0, 1, ... | |
| image = np.array(image) | |
| return image | |
| def invalidate(self, ti): | |
| # the image buffer is never invalidated | |
| self.get_mask.invalidate((ti,)) | |
| def __len__(self): | |
| return self.length | |
| def h(self): | |
| return self.height | |
| def w(self): | |
| return self.width | |
| def small_masks(self): | |
| return self._masks | |
| def keys(self): | |
| return self._keys | |
| def shrinkages(self): | |
| return self._shrinkages | |
| def selections(self): | |
| return self._selections | |
| def num_objects(self): | |
| return self._num_objects | |