Spaces:
Sleeping
Sleeping
| import copy | |
| import os | |
| import numpy as np | |
| import torch | |
| from typing import Any, Dict, List | |
| from yacs.config import CfgNode | |
| import braceexpand | |
| import cv2 | |
| from .dataset import Dataset | |
| from .utils import get_example, expand_to_aspect_ratio | |
| def expand(s): | |
| return os.path.expanduser(os.path.expandvars(s)) | |
| def expand_urls(urls: str|List[str]): | |
| if isinstance(urls, str): | |
| urls = [urls] | |
| urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))] | |
| return urls | |
| FLIP_KEYPOINT_PERMUTATION = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] | |
| DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406]) | |
| DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225]) | |
| DEFAULT_IMG_SIZE = 256 | |
| class ImageDataset(Dataset): | |
| def __init__(self, | |
| cfg: CfgNode, | |
| dataset_file: str, | |
| img_dir: str, | |
| train: bool = True, | |
| rescale_factor = 2, | |
| prune: Dict[str, Any] = {}, | |
| **kwargs): | |
| """ | |
| Dataset class used for loading images and corresponding annotations. | |
| Args: | |
| cfg (CfgNode): Model config file. | |
| dataset_file (str): Path to npz file containing dataset info. | |
| img_dir (str): Path to image folder. | |
| train (bool): Whether it is for training or not (enables data augmentation). | |
| """ | |
| super(ImageDataset, self).__init__() | |
| self.train = train | |
| self.cfg = cfg | |
| self.img_size = cfg.MODEL.IMAGE_SIZE | |
| self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN) | |
| self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD) | |
| self.rescale_factor = rescale_factor | |
| self.img_dir = img_dir | |
| self.data = np.load(dataset_file, allow_pickle=True) | |
| self.imgname = self.data['imgname'] | |
| self.personid = np.zeros(len(self.imgname), dtype=np.int32) | |
| self.extra_info = self.data.get('extra_info', [{} for _ in range(len(self.imgname))]) | |
| self.flip_keypoint_permutation = copy.copy(FLIP_KEYPOINT_PERMUTATION) | |
| num_pose = 3 * (self.cfg.MANO.NUM_HAND_JOINTS + 1) | |
| # Bounding boxes are assumed to be in the center and scale format | |
| self.center = self.data['center'] | |
| self.scale = self.data['scale'].reshape(len(self.center), -1) / 200.0 | |
| if self.scale.shape[1] == 1: | |
| self.scale = np.tile(self.scale, (1, 2)) | |
| assert self.scale.shape == (len(self.center), 2) | |
| try: | |
| self.right = self.data['right'] | |
| except KeyError: | |
| self.right = np.ones(len(self.imgname), dtype=np.float32) | |
| # Get gt MANO parameters, if available | |
| try: | |
| self.hand_pose = self.data['hand_pose'].astype(np.float32) | |
| self.has_hand_pose = self.data['has_hand_pose'].astype(np.float32) | |
| except KeyError: | |
| self.hand_pose = np.zeros((len(self.imgname), num_pose), dtype=np.float32) | |
| self.has_hand_pose = np.zeros(len(self.imgname), dtype=np.float32) | |
| try: | |
| self.betas = self.data['betas'].astype(np.float32) | |
| self.has_betas = self.data['has_betas'].astype(np.float32) | |
| except KeyError: | |
| self.betas = np.zeros((len(self.imgname), 10), dtype=np.float32) | |
| self.has_betas = np.zeros(len(self.imgname), dtype=np.float32) | |
| # Try to get 2d keypoints, if available | |
| try: | |
| hand_keypoints_2d = self.data['hand_keypoints_2d'] | |
| except KeyError: | |
| hand_keypoints_2d = np.zeros((len(self.center), 21, 3)) | |
| self.keypoints_2d = hand_keypoints_2d | |
| # Try to get 3d keypoints, if available | |
| try: | |
| hand_keypoints_3d = self.data['hand_keypoints_3d'].astype(np.float32) | |
| except KeyError: | |
| hand_keypoints_3d = np.zeros((len(self.center), 21, 4), dtype=np.float32) | |
| self.keypoints_3d = hand_keypoints_3d | |
| def __len__(self) -> int: | |
| return len(self.scale) | |
| def __getitem__(self, idx: int) -> Dict: | |
| """ | |
| Returns an example from the dataset. | |
| """ | |
| try: | |
| image_file_rel = self.imgname[idx].decode('utf-8') | |
| except AttributeError: | |
| image_file_rel = self.imgname[idx] | |
| image_file = os.path.join(self.img_dir, image_file_rel) | |
| keypoints_2d = self.keypoints_2d[idx].copy() | |
| keypoints_3d = self.keypoints_3d[idx].copy() | |
| center = self.center[idx].copy() | |
| center_x = center[0] | |
| center_y = center[1] | |
| scale = self.scale[idx] | |
| right = self.right[idx].copy() | |
| if self.rescale_factor == -1: | |
| BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None) | |
| bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() | |
| bbox_expand_factor = bbox_size / ((scale*200).max()) | |
| else: | |
| bbox_expand_factor = self.rescale_factor | |
| bbox_size = bbox_expand_factor*scale.max()*200 | |
| hand_pose = self.hand_pose[idx].copy().astype(np.float32) | |
| betas = self.betas[idx].copy().astype(np.float32) | |
| has_hand_pose = self.has_hand_pose[idx].copy() | |
| has_betas = self.has_betas[idx].copy() | |
| mano_params = {'global_orient': hand_pose[:3], | |
| 'hand_pose': hand_pose[3:], | |
| 'betas': betas | |
| } | |
| has_mano_params = {'global_orient': has_hand_pose, | |
| 'hand_pose': has_hand_pose, | |
| 'betas': has_betas | |
| } | |
| mano_params_is_axis_angle = {'global_orient': True, | |
| 'hand_pose': True, | |
| 'betas': False | |
| } | |
| augm_config = self.cfg.DATASETS.CONFIG | |
| # Crop image and (possibly) perform data augmentation | |
| img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size = get_example(image_file, | |
| center_x, center_y, | |
| bbox_size, bbox_size, | |
| keypoints_2d, keypoints_3d, | |
| mano_params, has_mano_params, | |
| self.flip_keypoint_permutation, | |
| self.img_size, self.img_size, | |
| self.mean, self.std, self.train, right, augm_config) | |
| item = {} | |
| # These are the keypoints in the original image coordinates (before cropping) | |
| orig_keypoints_2d = self.keypoints_2d[idx].copy() | |
| item['img'] = img_patch | |
| item['keypoints_2d'] = keypoints_2d.astype(np.float32) | |
| item['keypoints_3d'] = keypoints_3d.astype(np.float32) | |
| item['orig_keypoints_2d'] = orig_keypoints_2d | |
| item['box_center'] = self.center[idx].copy() | |
| item['box_size'] = bbox_size | |
| item['bbox_expand_factor'] = bbox_expand_factor | |
| item['img_size'] = 1.0 * img_size[::-1].copy() | |
| item['mano_params'] = mano_params | |
| item['has_mano_params'] = has_mano_params | |
| item['mano_params_is_axis_angle'] = mano_params_is_axis_angle | |
| item['imgname'] = image_file | |
| item['imgname_rel'] = image_file_rel | |
| item['personid'] = int(self.personid[idx]) | |
| item['extra_info'] = copy.deepcopy(self.extra_info[idx]) | |
| item['idx'] = idx | |
| item['_scale'] = scale | |
| item['right'] = self.right[idx].copy() | |
| return item | |
| def load_tars_as_webdataset(cfg: CfgNode, urls: str|List[str], train: bool, | |
| resampled=False, | |
| epoch_size=None, | |
| cache_dir=None, | |
| **kwargs) -> Dataset: | |
| """ | |
| Loads the dataset from a webdataset tar file. | |
| """ | |
| IMG_SIZE = cfg.MODEL.IMAGE_SIZE | |
| BBOX_SHAPE = cfg.MODEL.get('BBOX_SHAPE', None) | |
| MEAN = 255. * np.array(cfg.MODEL.IMAGE_MEAN) | |
| STD = 255. * np.array(cfg.MODEL.IMAGE_STD) | |
| def split_data(source): | |
| for item in source: | |
| datas = item['data.pyd'] | |
| for data in datas: | |
| if 'detection.npz' in item: | |
| det_idx = data['extra_info']['detection_npz_idx'] | |
| mask = item['detection.npz']['masks'][det_idx] | |
| else: | |
| mask = np.ones_like(item['jpg'][:,:,0], dtype=bool) | |
| yield { | |
| '__key__': item['__key__'], | |
| 'jpg': item['jpg'], | |
| 'data.pyd': data, | |
| 'mask': mask, | |
| } | |
| def suppress_bad_kps(item, thresh=0.0): | |
| if thresh > 0: | |
| kp2d = item['data.pyd']['keypoints_2d'] | |
| kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2]) | |
| item['data.pyd']['keypoints_2d'] = np.concatenate([kp2d[:,:2], kp2d_conf[:,None]], axis=1) | |
| return item | |
| def filter_numkp(item, numkp=4, thresh=0.0): | |
| kp_conf = item['data.pyd']['keypoints_2d'][:, 2] | |
| return (kp_conf > thresh).sum() > numkp | |
| def filter_reproj_error(item, thresh=10**4.5): | |
| losses = item['data.pyd'].get('extra_info', {}).get('fitting_loss', np.array({})).item() | |
| reproj_loss = losses.get('reprojection_loss', None) | |
| return reproj_loss is None or reproj_loss < thresh | |
| def filter_bbox_size(item, thresh=1): | |
| bbox_size_min = item['data.pyd']['scale'].min().item() * 200. | |
| return bbox_size_min > thresh | |
| def filter_no_poses(item): | |
| return (item['data.pyd']['has_hand_pose'] > 0) | |
| def supress_bad_betas(item, thresh=3): | |
| has_betas = item['data.pyd']['has_betas'] | |
| if thresh > 0 and has_betas: | |
| betas_abs = np.abs(item['data.pyd']['betas']) | |
| if (betas_abs > thresh).any(): | |
| item['data.pyd']['has_betas'] = False | |
| return item | |
| def supress_bad_poses(item): | |
| has_hand_pose = item['data.pyd']['has_hand_pose'] | |
| if has_hand_pose: | |
| hand_pose = item['data.pyd']['hand_pose'] | |
| pose_is_probable = poses_check_probable(torch.from_numpy(hand_pose)[None, 3:], amass_poses_hist100_smooth).item() | |
| if not pose_is_probable: | |
| item['data.pyd']['has_hand_pose'] = False | |
| return item | |
| def poses_betas_simultaneous(item): | |
| # We either have both hand_pose and betas, or neither | |
| has_betas = item['data.pyd']['has_betas'] | |
| has_hand_pose = item['data.pyd']['has_hand_pose'] | |
| item['data.pyd']['has_betas'] = item['data.pyd']['has_hand_pose'] = np.array(float((has_hand_pose>0) and (has_betas>0))) | |
| return item | |
| def set_betas_for_reg(item): | |
| # Always have betas set to true | |
| has_betas = item['data.pyd']['has_betas'] | |
| betas = item['data.pyd']['betas'] | |
| if not (has_betas>0): | |
| item['data.pyd']['has_betas'] = np.array(float((True))) | |
| item['data.pyd']['betas'] = betas * 0 | |
| return item | |
| # Load the dataset | |
| if epoch_size is not None: | |
| resampled = True | |
| #corrupt_filter = lambda sample: (sample['__key__'] not in CORRUPT_KEYS) | |
| import webdataset as wds | |
| dataset = wds.WebDataset(expand_urls(urls), | |
| nodesplitter=wds.split_by_node, | |
| shardshuffle=True, | |
| resampled=resampled, | |
| cache_dir=cache_dir, | |
| ) #.select(corrupt_filter) | |
| if train: | |
| dataset = dataset.shuffle(100) | |
| dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png') | |
| # Process the dataset | |
| dataset = dataset.compose(split_data) | |
| # Filter/clean the dataset | |
| SUPPRESS_KP_CONF_THRESH = cfg.DATASETS.get('SUPPRESS_KP_CONF_THRESH', 0.0) | |
| SUPPRESS_BETAS_THRESH = cfg.DATASETS.get('SUPPRESS_BETAS_THRESH', 0.0) | |
| SUPPRESS_BAD_POSES = cfg.DATASETS.get('SUPPRESS_BAD_POSES', False) | |
| POSES_BETAS_SIMULTANEOUS = cfg.DATASETS.get('POSES_BETAS_SIMULTANEOUS', False) | |
| BETAS_REG = cfg.DATASETS.get('BETAS_REG', False) | |
| FILTER_NO_POSES = cfg.DATASETS.get('FILTER_NO_POSES', False) | |
| FILTER_NUM_KP = cfg.DATASETS.get('FILTER_NUM_KP', 4) | |
| FILTER_NUM_KP_THRESH = cfg.DATASETS.get('FILTER_NUM_KP_THRESH', 0.0) | |
| FILTER_REPROJ_THRESH = cfg.DATASETS.get('FILTER_REPROJ_THRESH', 0.0) | |
| FILTER_MIN_BBOX_SIZE = cfg.DATASETS.get('FILTER_MIN_BBOX_SIZE', 0.0) | |
| if SUPPRESS_KP_CONF_THRESH > 0: | |
| dataset = dataset.map(lambda x: suppress_bad_kps(x, thresh=SUPPRESS_KP_CONF_THRESH)) | |
| if SUPPRESS_BETAS_THRESH > 0: | |
| dataset = dataset.map(lambda x: supress_bad_betas(x, thresh=SUPPRESS_BETAS_THRESH)) | |
| if SUPPRESS_BAD_POSES: | |
| dataset = dataset.map(lambda x: supress_bad_poses(x)) | |
| if POSES_BETAS_SIMULTANEOUS: | |
| dataset = dataset.map(lambda x: poses_betas_simultaneous(x)) | |
| if FILTER_NO_POSES: | |
| dataset = dataset.select(lambda x: filter_no_poses(x)) | |
| if FILTER_NUM_KP > 0: | |
| dataset = dataset.select(lambda x: filter_numkp(x, numkp=FILTER_NUM_KP, thresh=FILTER_NUM_KP_THRESH)) | |
| if FILTER_REPROJ_THRESH > 0: | |
| dataset = dataset.select(lambda x: filter_reproj_error(x, thresh=FILTER_REPROJ_THRESH)) | |
| if FILTER_MIN_BBOX_SIZE > 0: | |
| dataset = dataset.select(lambda x: filter_bbox_size(x, thresh=FILTER_MIN_BBOX_SIZE)) | |
| if BETAS_REG: | |
| dataset = dataset.map(lambda x: set_betas_for_reg(x)) # NOTE: Must be at the end | |
| use_skimage_antialias = cfg.DATASETS.get('USE_SKIMAGE_ANTIALIAS', False) | |
| border_mode = { | |
| 'constant': cv2.BORDER_CONSTANT, | |
| 'replicate': cv2.BORDER_REPLICATE, | |
| }[cfg.DATASETS.get('BORDER_MODE', 'constant')] | |
| # Process the dataset further | |
| dataset = dataset.map(lambda x: ImageDataset.process_webdataset_tar_item(x, train, | |
| augm_config=cfg.DATASETS.CONFIG, | |
| MEAN=MEAN, STD=STD, IMG_SIZE=IMG_SIZE, | |
| BBOX_SHAPE=BBOX_SHAPE, | |
| use_skimage_antialias=use_skimage_antialias, | |
| border_mode=border_mode, | |
| )) | |
| if epoch_size is not None: | |
| dataset = dataset.with_epoch(epoch_size) | |
| return dataset | |
| def process_webdataset_tar_item(item, train, | |
| augm_config=None, | |
| MEAN=DEFAULT_MEAN, | |
| STD=DEFAULT_STD, | |
| IMG_SIZE=DEFAULT_IMG_SIZE, | |
| BBOX_SHAPE=None, | |
| use_skimage_antialias=False, | |
| border_mode=cv2.BORDER_CONSTANT, | |
| ): | |
| # Read data from item | |
| key = item['__key__'] | |
| image = item['jpg'] | |
| data = item['data.pyd'] | |
| mask = item['mask'] | |
| keypoints_2d = data['keypoints_2d'] | |
| keypoints_3d = data['keypoints_3d'] | |
| center = data['center'] | |
| scale = data['scale'] | |
| hand_pose = data['hand_pose'] | |
| betas = data['betas'] | |
| right = data['right'] | |
| has_hand_pose = data['has_hand_pose'] | |
| has_betas = data['has_betas'] | |
| # image_file = data['image_file'] | |
| # Process data | |
| orig_keypoints_2d = keypoints_2d.copy() | |
| center_x = center[0] | |
| center_y = center[1] | |
| bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() | |
| if bbox_size < 1: | |
| breakpoint() | |
| mano_params = {'global_orient': hand_pose[:3], | |
| 'hand_pose': hand_pose[3:], | |
| 'betas': betas | |
| } | |
| has_mano_params = {'global_orient': has_hand_pose, | |
| 'hand_pose': has_hand_pose, | |
| 'betas': has_betas | |
| } | |
| mano_params_is_axis_angle = {'global_orient': True, | |
| 'hand_pose': True, | |
| 'betas': False | |
| } | |
| augm_config = copy.deepcopy(augm_config) | |
| # Crop image and (possibly) perform data augmentation | |
| img_rgba = np.concatenate([image, mask.astype(np.uint8)[:,:,None]*255], axis=2) | |
| img_patch_rgba, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans = get_example(img_rgba, | |
| center_x, center_y, | |
| bbox_size, bbox_size, | |
| keypoints_2d, keypoints_3d, | |
| mano_params, has_mano_params, | |
| FLIP_KEYPOINT_PERMUTATION, | |
| IMG_SIZE, IMG_SIZE, | |
| MEAN, STD, train, right, augm_config, | |
| is_bgr=False, return_trans=True, | |
| use_skimage_antialias=use_skimage_antialias, | |
| border_mode=border_mode, | |
| ) | |
| img_patch = img_patch_rgba[:3,:,:] | |
| mask_patch = (img_patch_rgba[3,:,:] / 255.0).clip(0,1) | |
| if (mask_patch < 0.5).all(): | |
| mask_patch = np.ones_like(mask_patch) | |
| item = {} | |
| item['img'] = img_patch | |
| item['mask'] = mask_patch | |
| # item['img_og'] = image | |
| # item['mask_og'] = mask | |
| item['keypoints_2d'] = keypoints_2d.astype(np.float32) | |
| item['keypoints_3d'] = keypoints_3d.astype(np.float32) | |
| item['orig_keypoints_2d'] = orig_keypoints_2d | |
| item['box_center'] = center.copy() | |
| item['box_size'] = bbox_size | |
| item['img_size'] = 1.0 * img_size[::-1].copy() | |
| item['mano_params'] = mano_params | |
| item['has_mano_params'] = has_mano_params | |
| item['mano_params_is_axis_angle'] = mano_params_is_axis_angle | |
| item['_scale'] = scale | |
| item['_trans'] = trans | |
| item['imgname'] = key | |
| # item['idx'] = idx | |
| return item | |