xfu314's picture
Add phantom project with submodules and dependencies
96da58e
import copy
import os
import numpy as np
import torch
from typing import Any, Dict, List
from yacs.config import CfgNode
import braceexpand
import cv2
from .dataset import Dataset
from .utils import get_example, expand_to_aspect_ratio
def expand(s):
return os.path.expanduser(os.path.expandvars(s))
def expand_urls(urls: str|List[str]):
if isinstance(urls, str):
urls = [urls]
urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))]
return urls
FLIP_KEYPOINT_PERMUTATION = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
DEFAULT_IMG_SIZE = 256
class ImageDataset(Dataset):
def __init__(self,
cfg: CfgNode,
dataset_file: str,
img_dir: str,
train: bool = True,
rescale_factor = 2,
prune: Dict[str, Any] = {},
**kwargs):
"""
Dataset class used for loading images and corresponding annotations.
Args:
cfg (CfgNode): Model config file.
dataset_file (str): Path to npz file containing dataset info.
img_dir (str): Path to image folder.
train (bool): Whether it is for training or not (enables data augmentation).
"""
super(ImageDataset, self).__init__()
self.train = train
self.cfg = cfg
self.img_size = cfg.MODEL.IMAGE_SIZE
self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
self.rescale_factor = rescale_factor
self.img_dir = img_dir
self.data = np.load(dataset_file, allow_pickle=True)
self.imgname = self.data['imgname']
self.personid = np.zeros(len(self.imgname), dtype=np.int32)
self.extra_info = self.data.get('extra_info', [{} for _ in range(len(self.imgname))])
self.flip_keypoint_permutation = copy.copy(FLIP_KEYPOINT_PERMUTATION)
num_pose = 3 * (self.cfg.MANO.NUM_HAND_JOINTS + 1)
# Bounding boxes are assumed to be in the center and scale format
self.center = self.data['center']
self.scale = self.data['scale'].reshape(len(self.center), -1) / 200.0
if self.scale.shape[1] == 1:
self.scale = np.tile(self.scale, (1, 2))
assert self.scale.shape == (len(self.center), 2)
try:
self.right = self.data['right']
except KeyError:
self.right = np.ones(len(self.imgname), dtype=np.float32)
# Get gt MANO parameters, if available
try:
self.hand_pose = self.data['hand_pose'].astype(np.float32)
self.has_hand_pose = self.data['has_hand_pose'].astype(np.float32)
except KeyError:
self.hand_pose = np.zeros((len(self.imgname), num_pose), dtype=np.float32)
self.has_hand_pose = np.zeros(len(self.imgname), dtype=np.float32)
try:
self.betas = self.data['betas'].astype(np.float32)
self.has_betas = self.data['has_betas'].astype(np.float32)
except KeyError:
self.betas = np.zeros((len(self.imgname), 10), dtype=np.float32)
self.has_betas = np.zeros(len(self.imgname), dtype=np.float32)
# Try to get 2d keypoints, if available
try:
hand_keypoints_2d = self.data['hand_keypoints_2d']
except KeyError:
hand_keypoints_2d = np.zeros((len(self.center), 21, 3))
self.keypoints_2d = hand_keypoints_2d
# Try to get 3d keypoints, if available
try:
hand_keypoints_3d = self.data['hand_keypoints_3d'].astype(np.float32)
except KeyError:
hand_keypoints_3d = np.zeros((len(self.center), 21, 4), dtype=np.float32)
self.keypoints_3d = hand_keypoints_3d
def __len__(self) -> int:
return len(self.scale)
def __getitem__(self, idx: int) -> Dict:
"""
Returns an example from the dataset.
"""
try:
image_file_rel = self.imgname[idx].decode('utf-8')
except AttributeError:
image_file_rel = self.imgname[idx]
image_file = os.path.join(self.img_dir, image_file_rel)
keypoints_2d = self.keypoints_2d[idx].copy()
keypoints_3d = self.keypoints_3d[idx].copy()
center = self.center[idx].copy()
center_x = center[0]
center_y = center[1]
scale = self.scale[idx]
right = self.right[idx].copy()
if self.rescale_factor == -1:
BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
bbox_expand_factor = bbox_size / ((scale*200).max())
else:
bbox_expand_factor = self.rescale_factor
bbox_size = bbox_expand_factor*scale.max()*200
hand_pose = self.hand_pose[idx].copy().astype(np.float32)
betas = self.betas[idx].copy().astype(np.float32)
has_hand_pose = self.has_hand_pose[idx].copy()
has_betas = self.has_betas[idx].copy()
mano_params = {'global_orient': hand_pose[:3],
'hand_pose': hand_pose[3:],
'betas': betas
}
has_mano_params = {'global_orient': has_hand_pose,
'hand_pose': has_hand_pose,
'betas': has_betas
}
mano_params_is_axis_angle = {'global_orient': True,
'hand_pose': True,
'betas': False
}
augm_config = self.cfg.DATASETS.CONFIG
# Crop image and (possibly) perform data augmentation
img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size = get_example(image_file,
center_x, center_y,
bbox_size, bbox_size,
keypoints_2d, keypoints_3d,
mano_params, has_mano_params,
self.flip_keypoint_permutation,
self.img_size, self.img_size,
self.mean, self.std, self.train, right, augm_config)
item = {}
# These are the keypoints in the original image coordinates (before cropping)
orig_keypoints_2d = self.keypoints_2d[idx].copy()
item['img'] = img_patch
item['keypoints_2d'] = keypoints_2d.astype(np.float32)
item['keypoints_3d'] = keypoints_3d.astype(np.float32)
item['orig_keypoints_2d'] = orig_keypoints_2d
item['box_center'] = self.center[idx].copy()
item['box_size'] = bbox_size
item['bbox_expand_factor'] = bbox_expand_factor
item['img_size'] = 1.0 * img_size[::-1].copy()
item['mano_params'] = mano_params
item['has_mano_params'] = has_mano_params
item['mano_params_is_axis_angle'] = mano_params_is_axis_angle
item['imgname'] = image_file
item['imgname_rel'] = image_file_rel
item['personid'] = int(self.personid[idx])
item['extra_info'] = copy.deepcopy(self.extra_info[idx])
item['idx'] = idx
item['_scale'] = scale
item['right'] = self.right[idx].copy()
return item
@staticmethod
def load_tars_as_webdataset(cfg: CfgNode, urls: str|List[str], train: bool,
resampled=False,
epoch_size=None,
cache_dir=None,
**kwargs) -> Dataset:
"""
Loads the dataset from a webdataset tar file.
"""
IMG_SIZE = cfg.MODEL.IMAGE_SIZE
BBOX_SHAPE = cfg.MODEL.get('BBOX_SHAPE', None)
MEAN = 255. * np.array(cfg.MODEL.IMAGE_MEAN)
STD = 255. * np.array(cfg.MODEL.IMAGE_STD)
def split_data(source):
for item in source:
datas = item['data.pyd']
for data in datas:
if 'detection.npz' in item:
det_idx = data['extra_info']['detection_npz_idx']
mask = item['detection.npz']['masks'][det_idx]
else:
mask = np.ones_like(item['jpg'][:,:,0], dtype=bool)
yield {
'__key__': item['__key__'],
'jpg': item['jpg'],
'data.pyd': data,
'mask': mask,
}
def suppress_bad_kps(item, thresh=0.0):
if thresh > 0:
kp2d = item['data.pyd']['keypoints_2d']
kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2])
item['data.pyd']['keypoints_2d'] = np.concatenate([kp2d[:,:2], kp2d_conf[:,None]], axis=1)
return item
def filter_numkp(item, numkp=4, thresh=0.0):
kp_conf = item['data.pyd']['keypoints_2d'][:, 2]
return (kp_conf > thresh).sum() > numkp
def filter_reproj_error(item, thresh=10**4.5):
losses = item['data.pyd'].get('extra_info', {}).get('fitting_loss', np.array({})).item()
reproj_loss = losses.get('reprojection_loss', None)
return reproj_loss is None or reproj_loss < thresh
def filter_bbox_size(item, thresh=1):
bbox_size_min = item['data.pyd']['scale'].min().item() * 200.
return bbox_size_min > thresh
def filter_no_poses(item):
return (item['data.pyd']['has_hand_pose'] > 0)
def supress_bad_betas(item, thresh=3):
has_betas = item['data.pyd']['has_betas']
if thresh > 0 and has_betas:
betas_abs = np.abs(item['data.pyd']['betas'])
if (betas_abs > thresh).any():
item['data.pyd']['has_betas'] = False
return item
def supress_bad_poses(item):
has_hand_pose = item['data.pyd']['has_hand_pose']
if has_hand_pose:
hand_pose = item['data.pyd']['hand_pose']
pose_is_probable = poses_check_probable(torch.from_numpy(hand_pose)[None, 3:], amass_poses_hist100_smooth).item()
if not pose_is_probable:
item['data.pyd']['has_hand_pose'] = False
return item
def poses_betas_simultaneous(item):
# We either have both hand_pose and betas, or neither
has_betas = item['data.pyd']['has_betas']
has_hand_pose = item['data.pyd']['has_hand_pose']
item['data.pyd']['has_betas'] = item['data.pyd']['has_hand_pose'] = np.array(float((has_hand_pose>0) and (has_betas>0)))
return item
def set_betas_for_reg(item):
# Always have betas set to true
has_betas = item['data.pyd']['has_betas']
betas = item['data.pyd']['betas']
if not (has_betas>0):
item['data.pyd']['has_betas'] = np.array(float((True)))
item['data.pyd']['betas'] = betas * 0
return item
# Load the dataset
if epoch_size is not None:
resampled = True
#corrupt_filter = lambda sample: (sample['__key__'] not in CORRUPT_KEYS)
import webdataset as wds
dataset = wds.WebDataset(expand_urls(urls),
nodesplitter=wds.split_by_node,
shardshuffle=True,
resampled=resampled,
cache_dir=cache_dir,
) #.select(corrupt_filter)
if train:
dataset = dataset.shuffle(100)
dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png')
# Process the dataset
dataset = dataset.compose(split_data)
# Filter/clean the dataset
SUPPRESS_KP_CONF_THRESH = cfg.DATASETS.get('SUPPRESS_KP_CONF_THRESH', 0.0)
SUPPRESS_BETAS_THRESH = cfg.DATASETS.get('SUPPRESS_BETAS_THRESH', 0.0)
SUPPRESS_BAD_POSES = cfg.DATASETS.get('SUPPRESS_BAD_POSES', False)
POSES_BETAS_SIMULTANEOUS = cfg.DATASETS.get('POSES_BETAS_SIMULTANEOUS', False)
BETAS_REG = cfg.DATASETS.get('BETAS_REG', False)
FILTER_NO_POSES = cfg.DATASETS.get('FILTER_NO_POSES', False)
FILTER_NUM_KP = cfg.DATASETS.get('FILTER_NUM_KP', 4)
FILTER_NUM_KP_THRESH = cfg.DATASETS.get('FILTER_NUM_KP_THRESH', 0.0)
FILTER_REPROJ_THRESH = cfg.DATASETS.get('FILTER_REPROJ_THRESH', 0.0)
FILTER_MIN_BBOX_SIZE = cfg.DATASETS.get('FILTER_MIN_BBOX_SIZE', 0.0)
if SUPPRESS_KP_CONF_THRESH > 0:
dataset = dataset.map(lambda x: suppress_bad_kps(x, thresh=SUPPRESS_KP_CONF_THRESH))
if SUPPRESS_BETAS_THRESH > 0:
dataset = dataset.map(lambda x: supress_bad_betas(x, thresh=SUPPRESS_BETAS_THRESH))
if SUPPRESS_BAD_POSES:
dataset = dataset.map(lambda x: supress_bad_poses(x))
if POSES_BETAS_SIMULTANEOUS:
dataset = dataset.map(lambda x: poses_betas_simultaneous(x))
if FILTER_NO_POSES:
dataset = dataset.select(lambda x: filter_no_poses(x))
if FILTER_NUM_KP > 0:
dataset = dataset.select(lambda x: filter_numkp(x, numkp=FILTER_NUM_KP, thresh=FILTER_NUM_KP_THRESH))
if FILTER_REPROJ_THRESH > 0:
dataset = dataset.select(lambda x: filter_reproj_error(x, thresh=FILTER_REPROJ_THRESH))
if FILTER_MIN_BBOX_SIZE > 0:
dataset = dataset.select(lambda x: filter_bbox_size(x, thresh=FILTER_MIN_BBOX_SIZE))
if BETAS_REG:
dataset = dataset.map(lambda x: set_betas_for_reg(x)) # NOTE: Must be at the end
use_skimage_antialias = cfg.DATASETS.get('USE_SKIMAGE_ANTIALIAS', False)
border_mode = {
'constant': cv2.BORDER_CONSTANT,
'replicate': cv2.BORDER_REPLICATE,
}[cfg.DATASETS.get('BORDER_MODE', 'constant')]
# Process the dataset further
dataset = dataset.map(lambda x: ImageDataset.process_webdataset_tar_item(x, train,
augm_config=cfg.DATASETS.CONFIG,
MEAN=MEAN, STD=STD, IMG_SIZE=IMG_SIZE,
BBOX_SHAPE=BBOX_SHAPE,
use_skimage_antialias=use_skimage_antialias,
border_mode=border_mode,
))
if epoch_size is not None:
dataset = dataset.with_epoch(epoch_size)
return dataset
@staticmethod
def process_webdataset_tar_item(item, train,
augm_config=None,
MEAN=DEFAULT_MEAN,
STD=DEFAULT_STD,
IMG_SIZE=DEFAULT_IMG_SIZE,
BBOX_SHAPE=None,
use_skimage_antialias=False,
border_mode=cv2.BORDER_CONSTANT,
):
# Read data from item
key = item['__key__']
image = item['jpg']
data = item['data.pyd']
mask = item['mask']
keypoints_2d = data['keypoints_2d']
keypoints_3d = data['keypoints_3d']
center = data['center']
scale = data['scale']
hand_pose = data['hand_pose']
betas = data['betas']
right = data['right']
has_hand_pose = data['has_hand_pose']
has_betas = data['has_betas']
# image_file = data['image_file']
# Process data
orig_keypoints_2d = keypoints_2d.copy()
center_x = center[0]
center_y = center[1]
bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
if bbox_size < 1:
breakpoint()
mano_params = {'global_orient': hand_pose[:3],
'hand_pose': hand_pose[3:],
'betas': betas
}
has_mano_params = {'global_orient': has_hand_pose,
'hand_pose': has_hand_pose,
'betas': has_betas
}
mano_params_is_axis_angle = {'global_orient': True,
'hand_pose': True,
'betas': False
}
augm_config = copy.deepcopy(augm_config)
# Crop image and (possibly) perform data augmentation
img_rgba = np.concatenate([image, mask.astype(np.uint8)[:,:,None]*255], axis=2)
img_patch_rgba, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans = get_example(img_rgba,
center_x, center_y,
bbox_size, bbox_size,
keypoints_2d, keypoints_3d,
mano_params, has_mano_params,
FLIP_KEYPOINT_PERMUTATION,
IMG_SIZE, IMG_SIZE,
MEAN, STD, train, right, augm_config,
is_bgr=False, return_trans=True,
use_skimage_antialias=use_skimage_antialias,
border_mode=border_mode,
)
img_patch = img_patch_rgba[:3,:,:]
mask_patch = (img_patch_rgba[3,:,:] / 255.0).clip(0,1)
if (mask_patch < 0.5).all():
mask_patch = np.ones_like(mask_patch)
item = {}
item['img'] = img_patch
item['mask'] = mask_patch
# item['img_og'] = image
# item['mask_og'] = mask
item['keypoints_2d'] = keypoints_2d.astype(np.float32)
item['keypoints_3d'] = keypoints_3d.astype(np.float32)
item['orig_keypoints_2d'] = orig_keypoints_2d
item['box_center'] = center.copy()
item['box_size'] = bbox_size
item['img_size'] = 1.0 * img_size[::-1].copy()
item['mano_params'] = mano_params
item['has_mano_params'] = has_mano_params
item['mano_params_is_axis_angle'] = mano_params_is_axis_angle
item['_scale'] = scale
item['_trans'] = trans
item['imgname'] = key
# item['idx'] = idx
return item