| import numpy as np |
| import torch |
| from torch.utils.data.dataset import Dataset |
| import os |
| from configs.paths import dataset_root |
| import copy |
| from tqdm import tqdm |
| from .base import BASE |
|
|
| class BEDLAM(BASE): |
| def __init__(self, split='train_6fps',**kwargs): |
| super(BEDLAM, self).__init__(**kwargs) |
| assert split in ['train_1fps','train_3fps','train_6fps','validation_6fps'] |
| assert not self.kid_offset |
|
|
| self.ds_name = 'bedlam' |
| self.dataset_path = os.path.join(dataset_root,'bedlam') |
| annots_path = os.path.join(self.dataset_path,f'bedlam_smpl_{split}.npz') |
| self.annots = np.load(annots_path, allow_pickle=True)['annots'][()] |
| self.img_names = list(self.annots.keys()) |
| self.split = 'train' if 'train' in split else 'validation' |
| |
| def __len__(self): |
| return len(self.img_names) |
|
|
| def cnt_instances(self): |
| ins_cnt = 0 |
| for idx in tqdm(range(len(self))): |
| img_id = idx |
| img_name = self.img_names[img_id] |
| |
| ins_cnt += len(self.annots[img_name]['shape']) |
| |
|
|
| print(f'TOTAL: {ins_cnt}') |
| |
| def get_raw_data(self, idx): |
|
|
| img_id = idx%len(self.img_names) |
| img_name = self.img_names[img_id] |
| |
| annots = copy.deepcopy(self.annots[img_name]) |
| img_path = os.path.join(self.dataset_path,self.split,img_name) |
|
|
| cam_intrinsics = torch.from_numpy(annots['cam_int']).unsqueeze(0) |
| cam_rot = torch.from_numpy(np.stack(annots['cam_rot'])) |
| cam_trans = torch.from_numpy(np.stack(annots['cam_trans'])) |
| |
| betas = torch.from_numpy(np.stack(annots['shape'])) |
| poses = torch.from_numpy(np.stack(annots['pose_world'])) |
| transl = torch.from_numpy(np.stack(annots['trans_world'])) |
|
|
| raw_data={'img_path': img_path, |
| 'ds': 'bedlam', |
| 'pnum': len(betas), |
| 'betas': betas.float(), |
| 'poses': poses.float(), |
| 'transl': transl.float(), |
| 'cam_rot': cam_rot.float(), |
| 'cam_trans': cam_trans.float(), |
| 'cam_intrinsics':cam_intrinsics.float(), |
| '3d_valid': True, |
| 'age_valid': False, |
| 'detect_all_people':True |
| } |
|
|
| if self.mode == 'eval': |
| raw_data['occ_level'] = torch.zeros(len(betas),dtype=int) |
| |
| return raw_data |
|
|
|
|
|
|