Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.append(os.getcwd()) | |
| import os | |
| from tqdm import tqdm | |
| from data_utils.utils import * | |
| import torch.utils.data as data | |
| from data_utils.mesh_dataset import SmplxDataset | |
| from transformers import Wav2Vec2Processor | |
| class MultiVidData(): | |
| def __init__(self, | |
| data_root, | |
| speakers, | |
| split='train', | |
| limbscaling=False, | |
| normalization=False, | |
| norm_method='new', | |
| split_trans_zero=False, | |
| num_frames=25, | |
| num_pre_frames=25, | |
| num_generate_length=None, | |
| aud_feat_win_size=None, | |
| aud_feat_dim=64, | |
| feat_method='mel_spec', | |
| context_info=False, | |
| smplx=False, | |
| audio_sr=16000, | |
| convert_to_6d=False, | |
| expression=False, | |
| config=None | |
| ): | |
| self.data_root = data_root | |
| self.speakers = speakers | |
| self.split = split | |
| if split == 'pre': | |
| self.split = 'train' | |
| self.norm_method=norm_method | |
| self.normalization = normalization | |
| self.limbscaling = limbscaling | |
| self.convert_to_6d = convert_to_6d | |
| self.num_frames=num_frames | |
| self.num_pre_frames=num_pre_frames | |
| if num_generate_length is None: | |
| self.num_generate_length = num_frames | |
| else: | |
| self.num_generate_length = num_generate_length | |
| self.split_trans_zero=split_trans_zero | |
| dataset = SmplxDataset | |
| if self.split_trans_zero: | |
| self.trans_dataset_list = [] | |
| self.zero_dataset_list = [] | |
| else: | |
| self.all_dataset_list = [] | |
| self.dataset={} | |
| self.complete_data=[] | |
| self.config=config | |
| load_mode=self.config.dataset_load_mode | |
| ######################load with pickle file | |
| if load_mode=='pickle': | |
| import pickle | |
| import subprocess | |
| # store_file_path='/tmp/store.pkl' | |
| # cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl | |
| # subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True) | |
| # f = open(self.config.store_file_path, 'rb+') | |
| f = open(self.split+config.Data.pklname, 'rb+') | |
| self.dataset=pickle.load(f) | |
| f.close() | |
| for key in self.dataset: | |
| self.complete_data.append(self.dataset[key].complete_data) | |
| ######################load with pickle file | |
| ######################load with a csv file | |
| elif load_mode=='csv': | |
| # 这里从我的一个code文件夹导入的,后续再完善进来 | |
| try: | |
| sys.path.append(self.config.config_root_path) | |
| from config import config_path | |
| from csv_parser import csv_parse | |
| except ImportError as e: | |
| print(f'err: {e}') | |
| raise ImportError('config root path error...') | |
| for speaker_name in self.speakers: | |
| # df_intervals=pd.read_csv(self.config.voca_csv_file_path) | |
| df_intervals=None | |
| df_intervals=df_intervals[df_intervals['speaker']==speaker_name] | |
| df_intervals = df_intervals[df_intervals['dataset'] == self.split] | |
| print(f'speaker {speaker_name} train interval length: {len(df_intervals)}') | |
| for iter_index, (_, interval) in tqdm( | |
| (enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}' | |
| ): | |
| ( | |
| interval_index, | |
| interval_speaker, | |
| interval_video_fn, | |
| interval_id, | |
| start_time, | |
| end_time, | |
| duration_time, | |
| start_time_10, | |
| over_flow_flag, | |
| short_dur_flag, | |
| big_video_dir, | |
| small_video_dir_name, | |
| speaker_video_path, | |
| voca_basename, | |
| json_basename, | |
| wav_basename, | |
| voca_top_clip_path, | |
| voca_json_clip_path, | |
| voca_wav_clip_path, | |
| audio_output_fn, | |
| image_output_path, | |
| pifpaf_output_path, | |
| mp_output_path, | |
| op_output_path, | |
| deca_output_path, | |
| pixie_output_path, | |
| cam_output_path, | |
| ours_output_path, | |
| merge_output_path, | |
| multi_output_path, | |
| gt_output_path, | |
| ours_images_path, | |
| pkl_fil_path, | |
| )=csv_parse(interval) | |
| if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn): | |
| continue | |
| key=f'{interval_video_fn}/{small_video_dir_name}' | |
| self.dataset[key] = dataset( | |
| data_root=pkl_fil_path, | |
| speaker=speaker_name, | |
| audio_fn=audio_output_fn, | |
| audio_sr=audio_sr, | |
| fps=num_frames, | |
| feat_method=feat_method, | |
| audio_feat_dim=aud_feat_dim, | |
| train=(self.split == 'train'), | |
| load_all=True, | |
| split_trans_zero=self.split_trans_zero, | |
| limbscaling=self.limbscaling, | |
| num_frames=self.num_frames, | |
| num_pre_frames=self.num_pre_frames, | |
| num_generate_length=self.num_generate_length, | |
| audio_feat_win_size=aud_feat_win_size, | |
| context_info=context_info, | |
| convert_to_6d=convert_to_6d, | |
| expression=expression, | |
| config=self.config | |
| ) | |
| self.complete_data.append(self.dataset[key].complete_data) | |
| ######################load with a csv file | |
| ######################origin load method | |
| elif load_mode=='json': | |
| # if self.split == 'train': | |
| # import pickle | |
| # f = open('store.pkl', 'rb+') | |
| # self.dataset=pickle.load(f) | |
| # f.close() | |
| # for key in self.dataset: | |
| # self.complete_data.append(self.dataset[key].complete_data) | |
| # else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav | |
| # if config.Model.model_type == 'face': | |
| am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme") | |
| am_sr = 16000 | |
| # else: | |
| # am, am_sr = None, None | |
| for speaker_name in self.speakers: | |
| speaker_root = os.path.join(self.data_root, speaker_name) | |
| videos=[v for v in os.listdir(speaker_root) ] | |
| print(videos) | |
| haode = huaide = 0 | |
| for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)): | |
| source_vid=vid | |
| # vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split) | |
| vid_pth = os.path.join(speaker_root, source_vid, self.split) | |
| if smplx == 'pose': | |
| seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))] | |
| else: | |
| try: | |
| seqs = [s for s in os.listdir(vid_pth)] | |
| except: | |
| continue | |
| for s in seqs: | |
| seq_root=os.path.join(vid_pth, s) | |
| key = seq_root # correspond to clip****** | |
| audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s)) | |
| motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s)) | |
| if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname): | |
| huaide = huaide + 1 | |
| continue | |
| self.dataset[key]=dataset( | |
| data_root=seq_root, | |
| speaker=speaker_name, | |
| motion_fn=motion_fname, | |
| audio_fn=audio_fname, | |
| audio_sr=audio_sr, | |
| fps=num_frames, | |
| feat_method=feat_method, | |
| audio_feat_dim=aud_feat_dim, | |
| train=(self.split=='train'), | |
| load_all=True, | |
| split_trans_zero=self.split_trans_zero, | |
| limbscaling=self.limbscaling, | |
| num_frames=self.num_frames, | |
| num_pre_frames=self.num_pre_frames, | |
| num_generate_length=self.num_generate_length, | |
| audio_feat_win_size=aud_feat_win_size, | |
| context_info=context_info, | |
| convert_to_6d=convert_to_6d, | |
| expression=expression, | |
| config=self.config, | |
| am=am, | |
| am_sr=am_sr, | |
| whole_video=config.Data.whole_video | |
| ) | |
| self.complete_data.append(self.dataset[key].complete_data) | |
| haode = haode + 1 | |
| print("huaide:{}, haode:{}".format(huaide, haode)) | |
| import pickle | |
| f = open(self.split+config.Data.pklname, 'wb') | |
| pickle.dump(self.dataset, f) | |
| f.close() | |
| ######################origin load method | |
| self.complete_data=np.concatenate(self.complete_data, axis=0) | |
| # assert self.complete_data.shape[-1] == (12+21+21)*2 | |
| self.normalize_stats = {} | |
| self.data_mean = None | |
| self.data_std = None | |
| def get_dataset(self): | |
| self.normalize_stats['mean'] = self.data_mean | |
| self.normalize_stats['std'] = self.data_std | |
| for key in list(self.dataset.keys()): | |
| if self.dataset[key].complete_data.shape[0] < self.num_generate_length: | |
| continue | |
| self.dataset[key].num_generate_length = self.num_generate_length | |
| self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split) | |
| self.all_dataset_list.append(self.dataset[key].all_dataset) | |
| if self.split_trans_zero: | |
| self.trans_dataset = data.ConcatDataset(self.trans_dataset_list) | |
| self.zero_dataset = data.ConcatDataset(self.zero_dataset_list) | |
| else: | |
| self.all_dataset = data.ConcatDataset(self.all_dataset_list) | |