Spaces:
Sleeping
Sleeping
| import pickle | |
| import sys | |
| import os | |
| sys.path.append(os.getcwd()) | |
| import json | |
| from glob import glob | |
| from data_utils.utils import * | |
| import torch.utils.data as data | |
| from data_utils.consts import speaker_id | |
| from data_utils.lower_body import count_part | |
| import random | |
| from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d | |
| with open('data_utils/hand_component.json') as file_obj: | |
| comp = json.load(file_obj) | |
| left_hand_c = np.asarray(comp['left']) | |
| right_hand_c = np.asarray(comp['right']) | |
| def to3d(data): | |
| left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :]) | |
| right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :]) | |
| data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1) | |
| return data | |
| class SmplxDataset(): | |
| ''' | |
| creat a dataset for every segment and concat. | |
| ''' | |
| def __init__(self, | |
| data_root, | |
| speaker, | |
| motion_fn, | |
| audio_fn, | |
| audio_sr, | |
| fps, | |
| feat_method='mel_spec', | |
| audio_feat_dim=64, | |
| audio_feat_win_size=None, | |
| train=True, | |
| load_all=False, | |
| split_trans_zero=False, | |
| limbscaling=False, | |
| num_frames=25, | |
| num_pre_frames=25, | |
| num_generate_length=25, | |
| context_info=False, | |
| convert_to_6d=False, | |
| expression=False, | |
| config=None, | |
| am=None, | |
| am_sr=None, | |
| whole_video=False | |
| ): | |
| self.data_root = data_root | |
| self.speaker = speaker | |
| self.feat_method = feat_method | |
| self.audio_fn = audio_fn | |
| self.audio_sr = audio_sr | |
| self.fps = fps | |
| self.audio_feat_dim = audio_feat_dim | |
| self.audio_feat_win_size = audio_feat_win_size | |
| self.context_info = context_info # for aud feat | |
| self.convert_to_6d = convert_to_6d | |
| self.expression = expression | |
| self.train = train | |
| self.load_all = load_all | |
| self.split_trans_zero = split_trans_zero | |
| self.limbscaling = limbscaling | |
| self.num_frames = num_frames | |
| self.num_pre_frames = num_pre_frames | |
| self.num_generate_length = num_generate_length | |
| # print('num_generate_length ', self.num_generate_length) | |
| self.config = config | |
| self.am_sr = am_sr | |
| self.whole_video = whole_video | |
| load_mode = self.config.dataset_load_mode | |
| if load_mode == 'pickle': | |
| raise NotImplementedError | |
| elif load_mode == 'csv': | |
| import pickle | |
| with open(data_root, 'rb') as f: | |
| u = pickle._Unpickler(f) | |
| data = u.load() | |
| self.data = data[0] | |
| if self.load_all: | |
| self._load_npz_all() | |
| elif load_mode == 'json': | |
| self.annotations = glob(data_root + '/*pkl') | |
| if len(self.annotations) == 0: | |
| raise FileNotFoundError(data_root + ' are empty') | |
| self.annotations = sorted(self.annotations) | |
| self.img_name_list = self.annotations | |
| if self.load_all: | |
| self._load_them_all(am, am_sr, motion_fn) | |
| def _load_npz_all(self): | |
| self.loaded_data = {} | |
| self.complete_data = [] | |
| data = self.data | |
| shape = data['body_pose_axis'].shape[0] | |
| self.betas = data['betas'] | |
| self.img_name_list = [] | |
| for index in range(shape): | |
| img_name = f'{index:6d}' | |
| self.img_name_list.append(img_name) | |
| jaw_pose = data['jaw_pose'][index] | |
| leye_pose = data['leye_pose'][index] | |
| reye_pose = data['reye_pose'][index] | |
| global_orient = data['global_orient'][index] | |
| body_pose = data['body_pose_axis'][index] | |
| left_hand_pose = data['left_hand_pose'][index] | |
| right_hand_pose = data['right_hand_pose'][index] | |
| full_body = np.concatenate( | |
| (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose)) | |
| assert full_body.shape[0] == 99 | |
| if self.convert_to_6d: | |
| full_body = to3d(full_body) | |
| full_body = torch.from_numpy(full_body) | |
| full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body)) | |
| full_body = np.asarray(full_body) | |
| if self.expression: | |
| expression = data['expression'][index] | |
| full_body = np.concatenate((full_body, expression)) | |
| # full_body = np.concatenate((full_body, non_zero)) | |
| else: | |
| full_body = to3d(full_body) | |
| if self.expression: | |
| expression = data['expression'][index] | |
| full_body = np.concatenate((full_body, expression)) | |
| self.loaded_data[img_name] = full_body.reshape(-1) | |
| self.complete_data.append(full_body.reshape(-1)) | |
| self.complete_data = np.array(self.complete_data) | |
| if self.audio_feat_win_size is not None: | |
| self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) | |
| # print(self.audio_feat.shape) | |
| else: | |
| if self.feat_method == 'mel_spec': | |
| self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim) | |
| elif self.feat_method == 'mfcc': | |
| self.audio_feat = get_mfcc(self.audio_fn, | |
| smlpx=True, | |
| sr=self.audio_sr, | |
| n_mfcc=self.audio_feat_dim, | |
| win_size=self.audio_feat_win_size | |
| ) | |
| def _load_them_all(self, am, am_sr, motion_fn): | |
| self.loaded_data = {} | |
| self.complete_data = [] | |
| f = open(motion_fn, 'rb+') | |
| data = pickle.load(f) | |
| self.betas = np.array(data['betas']) | |
| jaw_pose = np.array(data['jaw_pose']) | |
| leye_pose = np.array(data['leye_pose']) | |
| reye_pose = np.array(data['reye_pose']) | |
| global_orient = np.array(data['global_orient']).squeeze() | |
| body_pose = np.array(data['body_pose_axis']) | |
| left_hand_pose = np.array(data['left_hand_pose']) | |
| right_hand_pose = np.array(data['right_hand_pose']) | |
| full_body = np.concatenate( | |
| (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1) | |
| assert full_body.shape[1] == 99 | |
| if self.convert_to_6d: | |
| full_body = to3d(full_body) | |
| full_body = torch.from_numpy(full_body) | |
| full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330) | |
| full_body = np.asarray(full_body) | |
| if self.expression: | |
| expression = np.array(data['expression']) | |
| full_body = np.concatenate((full_body, expression), axis=1) | |
| else: | |
| full_body = to3d(full_body) | |
| expression = np.array(data['expression']) | |
| full_body = np.concatenate((full_body, expression), axis=1) | |
| self.complete_data = full_body | |
| self.complete_data = np.array(self.complete_data) | |
| if self.audio_feat_win_size is not None: | |
| self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0) | |
| else: | |
| # if self.feat_method == 'mel_spec': | |
| # self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim) | |
| # elif self.feat_method == 'mfcc': | |
| self.audio_feat = get_mfcc_ta(self.audio_fn, | |
| smlpx=True, | |
| fps=30, | |
| sr=self.audio_sr, | |
| n_mfcc=self.audio_feat_dim, | |
| win_size=self.audio_feat_win_size, | |
| type=self.feat_method, | |
| am=am, | |
| am_sr=am_sr, | |
| encoder_choice=self.config.Model.encoder_choice, | |
| ) | |
| # with open(audio_file, 'w', encoding='utf-8') as file: | |
| # file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False)) | |
| def get_dataset(self, normalization=False, normalize_stats=None, split='train'): | |
| class __Worker__(data.Dataset): | |
| def __init__(child, index_list, normalization, normalize_stats, split='train') -> None: | |
| super().__init__() | |
| child.index_list = index_list | |
| child.normalization = normalization | |
| child.normalize_stats = normalize_stats | |
| child.split = split | |
| def __getitem__(child, index): | |
| num_generate_length = self.num_generate_length | |
| num_pre_frames = self.num_pre_frames | |
| seq_len = num_generate_length + num_pre_frames | |
| # print(num_generate_length) | |
| index = child.index_list[index] | |
| index_new = index + random.randrange(0, 5, 3) | |
| if index_new + seq_len > self.complete_data.shape[0]: | |
| index_new = index | |
| index = index_new | |
| if child.split in ['val', 'pre', 'test'] or self.whole_video: | |
| index = 0 | |
| seq_len = self.complete_data.shape[0] | |
| seq_data = [] | |
| assert index + seq_len <= self.complete_data.shape[0] | |
| # print(seq_len) | |
| seq_data = self.complete_data[index:(index + seq_len), :] | |
| seq_data = np.array(seq_data) | |
| ''' | |
| audio feature, | |
| ''' | |
| if not self.context_info: | |
| if not self.whole_video: | |
| audio_feat = self.audio_feat[index:index + seq_len, ...] | |
| if audio_feat.shape[0] < seq_len: | |
| audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]], | |
| mode='reflect') | |
| assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim | |
| else: | |
| audio_feat = self.audio_feat | |
| else: # including feature and history | |
| if self.audio_feat_win_size is None: | |
| audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...] | |
| if audio_feat.shape[0] < seq_len + num_pre_frames: | |
| audio_feat = np.pad(audio_feat, | |
| [[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]], | |
| mode='constant') | |
| assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[ | |
| 1] == self.audio_feat_dim | |
| if child.normalization: | |
| data_mean = child.normalize_stats['mean'].reshape(1, -1) | |
| data_std = child.normalize_stats['std'].reshape(1, -1) | |
| seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std | |
| if child.split in['train', 'test']: | |
| if self.convert_to_6d: | |
| if self.expression: | |
| data_sample = { | |
| 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), | |
| 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), | |
| # 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0), | |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), | |
| 'speaker': speaker_id[self.speaker], | |
| 'betas': self.betas, | |
| 'aud_file': self.audio_fn, | |
| } | |
| else: | |
| data_sample = { | |
| 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), | |
| 'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0), | |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), | |
| 'speaker': speaker_id[self.speaker], | |
| 'betas': self.betas | |
| } | |
| else: | |
| if self.expression: | |
| data_sample = { | |
| 'poses': seq_data[:, :165].astype(np.float).transpose(1, 0), | |
| 'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0), | |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), | |
| # 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0), | |
| 'speaker': speaker_id[self.speaker], | |
| 'aud_file': self.audio_fn, | |
| 'betas': self.betas | |
| } | |
| else: | |
| data_sample = { | |
| 'poses': seq_data.astype(np.float).transpose(1, 0), | |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), | |
| 'speaker': speaker_id[self.speaker], | |
| 'betas': self.betas | |
| } | |
| return data_sample | |
| else: | |
| data_sample = { | |
| 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0), | |
| 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0), | |
| # 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0), | |
| 'aud_feat': audio_feat.astype(np.float).transpose(1, 0), | |
| 'aud_file': self.audio_fn, | |
| 'speaker': speaker_id[self.speaker], | |
| 'betas': self.betas | |
| } | |
| return data_sample | |
| def __len__(child): | |
| return len(child.index_list) | |
| if split == 'train': | |
| index_list = list( | |
| range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames, | |
| 6)) | |
| elif split in ['val', 'test']: | |
| index_list = list([0]) | |
| if self.whole_video: | |
| index_list = list([0]) | |
| self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split) | |
| def __len__(self): | |
| return len(self.img_name_list) | |