|
|
import json
|
|
|
import os
|
|
|
from os import path
|
|
|
from pathlib import Path
|
|
|
import shutil
|
|
|
import collections
|
|
|
|
|
|
import cv2
|
|
|
from PIL import Image
|
|
|
import torch
|
|
|
|
|
|
from util.image_loader import PaletteConverter
|
|
|
|
|
|
if not hasattr(Image, 'Resampling'):
|
|
|
Image.Resampling = Image
|
|
|
import numpy as np
|
|
|
|
|
|
from util.palette import davis_palette
|
|
|
import progressbar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LRU:
|
|
|
def __init__(self, func, maxsize=128):
|
|
|
self.cache = collections.OrderedDict()
|
|
|
self.func = func
|
|
|
self.maxsize = maxsize
|
|
|
|
|
|
def __call__(self, *args):
|
|
|
cache = self.cache
|
|
|
if args in cache:
|
|
|
cache.move_to_end(args)
|
|
|
return cache[args]
|
|
|
result = self.func(*args)
|
|
|
cache[args] = result
|
|
|
if len(cache) > self.maxsize:
|
|
|
cache.popitem(last=False)
|
|
|
return result
|
|
|
|
|
|
def invalidate(self, key):
|
|
|
self.cache.pop(key, None)
|
|
|
|
|
|
|
|
|
class ResourceManager:
|
|
|
def __init__(self, config):
|
|
|
|
|
|
images = config['images']
|
|
|
video = config['video']
|
|
|
self.workspace = config['workspace']
|
|
|
self.size = config['size']
|
|
|
self.palette = davis_palette
|
|
|
self.palette_converter = PaletteConverter(self.palette)
|
|
|
|
|
|
|
|
|
if self.workspace is None:
|
|
|
if images is not None:
|
|
|
p_images = Path(images)
|
|
|
if p_images.name == 'JPEGImages' or (Path.cwd() / 'workspace') in p_images.parents:
|
|
|
|
|
|
basename = p_images.parent.name
|
|
|
else:
|
|
|
basename = p_images.name
|
|
|
elif video is not None:
|
|
|
basename = path.basename(video)[:-4]
|
|
|
else:
|
|
|
raise NotImplementedError(
|
|
|
'Either images, video, or workspace has to be specified')
|
|
|
|
|
|
self.workspace = path.join('./workspace', basename)
|
|
|
|
|
|
print(f'Workspace is in: {self.workspace}')
|
|
|
self.workspace_info_file = path.join(self.workspace, 'info.json')
|
|
|
self.references = set()
|
|
|
self._num_objects = None
|
|
|
self._try_load_info()
|
|
|
|
|
|
if config['num_objects'] is not None:
|
|
|
self._num_objects = config['num_objects']
|
|
|
elif self._num_objects is None:
|
|
|
self._num_objects = config['num_objects_default_value']
|
|
|
self._save_info()
|
|
|
|
|
|
|
|
|
need_decoding = False
|
|
|
need_resizing = False
|
|
|
if path.exists(path.join(self.workspace, 'images')):
|
|
|
pass
|
|
|
elif images is not None:
|
|
|
need_resizing = True
|
|
|
elif video is not None:
|
|
|
|
|
|
need_decoding = True
|
|
|
|
|
|
|
|
|
self.image_dir = path.join(self.workspace, 'images')
|
|
|
self.mask_dir = path.join(self.workspace, 'masks')
|
|
|
os.makedirs(self.image_dir, exist_ok=True)
|
|
|
os.makedirs(self.mask_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size'])
|
|
|
self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size'])
|
|
|
|
|
|
|
|
|
if need_decoding:
|
|
|
self._extract_frames(video)
|
|
|
|
|
|
|
|
|
if need_resizing:
|
|
|
self._copy_resize_frames(images)
|
|
|
|
|
|
|
|
|
self.names = sorted(os.listdir(self.image_dir))
|
|
|
self.names = [f[:-4] for f in self.names]
|
|
|
self.length = len(self.names)
|
|
|
|
|
|
assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.'
|
|
|
|
|
|
print(f'{self.length} images found.')
|
|
|
|
|
|
self.height, self.width = self.get_image(0).shape[:2]
|
|
|
self.visualization_init = False
|
|
|
|
|
|
self._resize = None
|
|
|
self._masks = None
|
|
|
self._keys = None
|
|
|
self._keys_processed = np.zeros(self.length, dtype=bool)
|
|
|
self.key_h = None
|
|
|
self.key_w = None
|
|
|
|
|
|
def _extract_frames(self, video):
|
|
|
cap = cv2.VideoCapture(video)
|
|
|
frame_index = 0
|
|
|
print(f'Extracting frames from {video} into {self.image_dir}...')
|
|
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength)
|
|
|
while(cap.isOpened()):
|
|
|
_, frame = cap.read()
|
|
|
if frame is None:
|
|
|
break
|
|
|
if self.size > 0:
|
|
|
h, w = frame.shape[:2]
|
|
|
new_w = (w*self.size//min(w, h))
|
|
|
new_h = (h*self.size//min(w, h))
|
|
|
if new_w != w or new_h != h:
|
|
|
frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
|
|
|
cv2.imwrite(path.join(self.image_dir, f'frame_{frame_index:06d}.jpg'), frame)
|
|
|
frame_index += 1
|
|
|
bar.update(frame_index)
|
|
|
bar.finish()
|
|
|
print('Done!')
|
|
|
|
|
|
def _copy_resize_frames(self, images):
|
|
|
image_list = os.listdir(images)
|
|
|
print(f'Copying/resizing frames into {self.image_dir}...')
|
|
|
for image_name in progressbar.progressbar(image_list):
|
|
|
if self.size < 0:
|
|
|
|
|
|
shutil.copy2(path.join(images, image_name), self.image_dir)
|
|
|
else:
|
|
|
frame = cv2.imread(path.join(images, image_name))
|
|
|
h, w = frame.shape[:2]
|
|
|
new_w = (w*self.size//min(w, h))
|
|
|
new_h = (h*self.size//min(w, h))
|
|
|
if new_w != w or new_h != h:
|
|
|
frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
|
|
|
cv2.imwrite(path.join(self.image_dir, image_name), frame)
|
|
|
print('Done!')
|
|
|
|
|
|
def add_key_and_stuff_with_mask(self, ti, key, shrinkage, selection, mask):
|
|
|
if self._keys is None:
|
|
|
c, h, w = key.squeeze().shape
|
|
|
if self.key_h is None:
|
|
|
self.key_h = h
|
|
|
if self.key_w is None:
|
|
|
self.key_w = w
|
|
|
c_mask, h_mask, w_mask = mask.shape
|
|
|
self._keys = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device)
|
|
|
self._shrinkages = torch.empty((self.length, 1, h, w), dtype=key.dtype, device=key.device)
|
|
|
self._selections = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device)
|
|
|
self._masks = torch.empty((self.length, c_mask, h_mask, w_mask), dtype=mask.dtype, device=key.device)
|
|
|
|
|
|
|
|
|
if not self._keys_processed[ti]:
|
|
|
|
|
|
self._keys[ti] = key
|
|
|
self._shrinkages[ti] = shrinkage
|
|
|
self._selections[ti] = selection
|
|
|
self._keys_processed[ti] = True
|
|
|
|
|
|
self._masks[ti] = mask
|
|
|
|
|
|
def all_masks_present(self):
|
|
|
return self._keys_processed.sum() == self.length
|
|
|
|
|
|
def add_reference(self, frame_id: int):
|
|
|
self.references.add(frame_id)
|
|
|
self._save_info()
|
|
|
|
|
|
def remove_reference(self, frame_id: int):
|
|
|
print(self.references)
|
|
|
self.references.remove(frame_id)
|
|
|
self._save_info()
|
|
|
|
|
|
def _save_info(self):
|
|
|
p_workspace_subdir = Path(self.workspace_info_file).parent
|
|
|
p_workspace_subdir.mkdir(parents=True, exist_ok=True)
|
|
|
with open(self.workspace_info_file, 'wt') as f:
|
|
|
data = {'references': sorted(self.references), 'num_objects': self._num_objects}
|
|
|
|
|
|
json.dump(data, f, indent=4)
|
|
|
|
|
|
def _try_load_info(self):
|
|
|
try:
|
|
|
with open(self.workspace_info_file) as f:
|
|
|
data = json.load(f)
|
|
|
self._num_objects = data['num_objects']
|
|
|
|
|
|
|
|
|
self.references = set(data['references'])
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
def save_mask(self, ti, mask):
|
|
|
|
|
|
assert 0 <= ti < self.length
|
|
|
assert isinstance(mask, np.ndarray)
|
|
|
|
|
|
mask = Image.fromarray(mask)
|
|
|
mask.putpalette(self.palette)
|
|
|
mask.save(path.join(self.mask_dir, self.names[ti]+'.png'))
|
|
|
self.invalidate(ti)
|
|
|
|
|
|
def save_visualization(self, ti, image):
|
|
|
|
|
|
assert 0 <= ti < self.length
|
|
|
assert isinstance(image, np.ndarray)
|
|
|
if not self.visualization_init:
|
|
|
self.visualization_dir = path.join(self.workspace, 'visualization')
|
|
|
os.makedirs(self.visualization_dir, exist_ok=True)
|
|
|
self.visualization_init = True
|
|
|
|
|
|
image = Image.fromarray(image)
|
|
|
image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg'))
|
|
|
|
|
|
def _get_image_unbuffered(self, ti):
|
|
|
|
|
|
assert 0 <= ti < self.length
|
|
|
|
|
|
image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg'))
|
|
|
image = np.array(image)
|
|
|
return image
|
|
|
|
|
|
def _get_mask_unbuffered(self, ti):
|
|
|
|
|
|
assert 0 <= ti < self.length
|
|
|
|
|
|
mask_path = path.join(self.mask_dir, self.names[ti]+'.png')
|
|
|
if path.exists(mask_path):
|
|
|
mask = Image.open(mask_path)
|
|
|
mask = np.array(mask)
|
|
|
return mask
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
def read_external_image(self, file_name, size=None, force_mask=False):
|
|
|
image = Image.open(file_name)
|
|
|
is_mask = image.mode in ['L', 'P']
|
|
|
|
|
|
if size is not None:
|
|
|
|
|
|
image = image.resize((size[1], size[0]),
|
|
|
resample=Image.Resampling.NEAREST if is_mask or force_mask else Image.Resampling.BICUBIC)
|
|
|
|
|
|
if force_mask and image.mode != 'P':
|
|
|
image = self.palette_converter.image_to_index_mask(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = np.array(image)
|
|
|
return image
|
|
|
|
|
|
def invalidate(self, ti):
|
|
|
|
|
|
self.get_mask.invalidate((ti,))
|
|
|
|
|
|
def __len__(self):
|
|
|
return self.length
|
|
|
|
|
|
@property
|
|
|
def h(self):
|
|
|
return self.height
|
|
|
|
|
|
@property
|
|
|
def w(self):
|
|
|
return self.width
|
|
|
|
|
|
@property
|
|
|
def small_masks(self):
|
|
|
return self._masks
|
|
|
|
|
|
@property
|
|
|
def keys(self):
|
|
|
return self._keys
|
|
|
|
|
|
|
|
|
@property
|
|
|
def shrinkages(self):
|
|
|
return self._shrinkages
|
|
|
|
|
|
@property
|
|
|
def selections(self):
|
|
|
return self._selections
|
|
|
|
|
|
@property
|
|
|
def num_objects(self):
|
|
|
return self._num_objects
|
|
|
|