Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import | |
| from __future__ import print_function | |
| from __future__ import division | |
| import os.path as osp | |
| from glob import glob | |
| from collections import defaultdict | |
| import cv2 | |
| import torch | |
| import pickle | |
| import joblib | |
| import argparse | |
| import numpy as np | |
| from loguru import logger | |
| from progress.bar import Bar | |
| from configs import constants as _C | |
| from lib.models.smpl import SMPL | |
| from lib.models.preproc.extractor import FeatureExtractor | |
| from lib.models.preproc.backbone.utils import process_image | |
| dataset = defaultdict(list) | |
| detection_results_dir = 'dataset/detection_results/3DPW' | |
| tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_train_db.pt' | |
| def preprocess(batch_size): | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_train_vit.pth') # Use ViT feature extractor | |
| extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size) | |
| tcmr_data = joblib.load(tcmr_annot_pth) | |
| annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True) | |
| idxs = idxs.tolist() | |
| annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)] | |
| annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', 'train', annot_file[:-2] + '.pkl') for annot_file in annot_file_list] | |
| annot_file_list = list(dict.fromkeys(annot_file_list)) | |
| vid_idx = 0 | |
| for annot_file in annot_file_list: | |
| seq = annot_file.split('/')[-1].split('.')[0] | |
| data = pickle.load(open(annot_file, 'rb'), encoding='latin1') | |
| num_people = len(data['poses']) | |
| num_frames = len(data['img_frame_ids']) | |
| assert (data['poses2d'][0].shape[0] == num_frames) | |
| K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float() | |
| for p_id in range(num_people): | |
| logger.info(f'==> {seq} {p_id}') | |
| gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]] | |
| smpl_gender = SMPL(model_path=_C.BMODEL.FLDR, gender=gender) | |
| # ======== Add TCMR data ======== # | |
| vid_name = f'{seq}_{p_id}' | |
| tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v] | |
| frame_ids = tcmr_data['frame_id'][tcmr_ids] | |
| pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids] | |
| shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1) | |
| trans = torch.from_numpy(data['trans'][p_id]).float()[frame_ids] | |
| cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float() | |
| # ======== Align the mesh params ======== # | |
| Rc = cam_poses[:, :3, :3] | |
| Tc = cam_poses[:, :3, 3] | |
| org_output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3], transl=trans) | |
| org_v0 = (org_output.vertices + org_output.offset.unsqueeze(1)).mean(1) | |
| pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float() | |
| output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3]) | |
| v0 = (output.vertices + output.offset.unsqueeze(1)).mean(1) | |
| trans = (Rc @ org_v0.reshape(-1, 3, 1)).reshape(-1, 3) + Tc - v0 | |
| j3d = output.joints + (output.offset + trans).unsqueeze(1) | |
| j2d = torch.div(j3d, j3d[..., 2:]) | |
| kp2d = torch.matmul(K, j2d.transpose(-1, -2)).transpose(-1, -2)[..., :2] | |
| # ======== Align the mesh params ======== # | |
| # ======== Get detection results ======== # | |
| fname = f'{seq}_{p_id}.npy' | |
| pred_kp2d = torch.from_numpy( | |
| np.load(osp.join(detection_results_dir, fname)) | |
| ).float()[frame_ids] | |
| # ======== Get detection results ======== # | |
| img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg'))) | |
| img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids] | |
| img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2] | |
| vid_idxs = torch.from_numpy(np.array([vid_idx] * len(img_paths)).astype(int)) | |
| vid_idx += 1 | |
| # ======== Append data ======== # | |
| dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float()) | |
| dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float()) | |
| dataset['vid'].append(vid_idxs) | |
| dataset['pose'].append(pose) | |
| dataset['betas'].append(shape) | |
| dataset['transl'].append(trans) | |
| dataset['kp2d'].append(pred_kp2d) | |
| dataset['joints3D'].append(j3d) | |
| dataset['joints2D'].append(kp2d) | |
| dataset['frame_id'].append(torch.from_numpy(frame_ids)) | |
| dataset['cam_poses'].append(cam_poses) | |
| dataset['gender'].append(torch.tensor([['male','female'].index(gender)]).repeat(len(frame_ids))) | |
| # ======== Append data ======== # | |
| # ======== Extract features ======== # | |
| patch_list = [] | |
| bboxes = dataset['bbox'][-1].clone().numpy() | |
| bar = Bar(f'Load images', fill='#', max=len(img_paths)) | |
| for img_path, bbox in zip(img_paths, bboxes): | |
| img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) | |
| norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256) | |
| patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float()) | |
| bar.next() | |
| patch_list = torch.split(torch.cat(patch_list), batch_size) | |
| features = [] | |
| for i, patch in enumerate(patch_list): | |
| pred = extractor.model(patch.cuda(), encode=True) | |
| features.append(pred.cpu()) | |
| features = torch.cat(features) | |
| dataset['features'].append(features) | |
| # ======== Extract features ======== # | |
| for key in dataset.keys(): | |
| dataset[key] = torch.cat(dataset[key]) | |
| joblib.dump(dataset, save_pth) | |
| logger.info(f'\n ==> Done !') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split') | |
| args = parser.parse_args() | |
| preprocess(args.batch_size) |