import os import random import numpy as np from decord import VideoReader import glob from tqdm import tqdm import pickle import torch import torchvision.transforms as transforms from torch.utils.data.dataset import Dataset from decord import cpu, gpu from torchvision.io import read_video, write_video import json import traceback # MultiSample : 每个视频采样多个样本 # RandomRef : 参考图像是不是随机的,不是随机的就是上一帧,random ref的情况下,refvideo只有一帧 # MultiRef : 多帧参考帧,默认最多是8帧 class AMDConsecutiveVideo(Dataset): def __init__( self, video_dir: str = '', # video dir or pkl file sample_size: int = 32, sample_stride: int = 2, sample_n_frames:int = 16, ref_drop_ratio = 0.0, ): # Init setting self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.ref_drop_ratio = ref_drop_ratio # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(sample_size[0]), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) if 'pkl' in video_dir: with open(video_dir, 'rb') as f: video_files = pickle.load(f) print(f'Total {len(video_files)} !!!') elif '.txt' in video_dir: with open(video_dir, 'r') as file: lines = file.readlines() video_dirs = [line.strip() for line in lines] video_files = [] for dir in video_dirs: video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True) print(f'Total {len(video_files)} !!!') else: video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True) print(f'Total {len(video_files)} !!!') # Data dict self.metadata_list = [] for file_path in tqdm(video_files): d = {} d['name'] = self.get_file_name(file_path) d['video_path'] = file_path self.metadata_list.append(d) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: file_name = self.metadata_list[idx]['name'] file_name,videos,ref_img = self.get_batch(idx) break except Exception as e: # file_name = self.metadata_list[idx]['name'] # print(file_name) print('error',e) idx = random.randint(0, self.length-1) sample = dict(name=file_name,videos=videos,ref_img = ref_img) return sample def __len__(self): return self.length def get_batch(self, idx): # init meta_data = self.metadata_list[idx] file_name = meta_data['name'] video_path = meta_data['video_path'] # video process video_reader = VideoReader(video_path, ctx=cpu(0)) video_length = len(video_reader) sample_frames = self.sample_n_frames + 1 # refimg + videos clip_length = min(video_length, (sample_frames - 1) * self.sample_stride + 1) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int) videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 # transform videos_cache = self.pixel_transforms(videos) # F+1,C,H,W videos = videos_cache[1:,:,:,:] # F,C,H,W ref_frame = videos_cache[0,:,:,:] # C,H,W # repeat ref_frame = ref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) # F,C,H,W return file_name,videos,ref_frame def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): # name name = [item['name'] for item in batch] # videos videos = [item['videos'] for item in batch] videos = torch.stack(videos) # ref_img ref_img = [item['ref_img'] for item in batch] ref_img = torch.stack(ref_img) randomref_img = None return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) class AMDConsecutiveVideoBalance(Dataset): def __init__( self, video_dir: str = '', # video dir or pkl file sample_size: int = 32, sample_stride: int = 2, sample_n_frames:int = 16, ref_drop_ratio = 0.0, ): # Init setting self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.ref_drop_ratio = ref_drop_ratio assert '.txt' in video_dir with open(video_dir, 'r') as file: lines = file.readlines() video_paths = [line.strip() for line in lines] assert len(video_paths) == 2 self.dataset1 = AMDConsecutiveVideo(video_dir = video_paths[0], # video dir or pkl file sample_size = sample_size, sample_stride = sample_stride, sample_n_frames = sample_n_frames, ref_drop_ratio = 0.0,) self.dataset2 = AMDConsecutiveVideo(video_dir = video_paths[1], # video dir or pkl file sample_size = sample_size, sample_stride = sample_stride, sample_n_frames = sample_n_frames, ref_drop_ratio = 0.0,) self.len1 = len(self.dataset1) self.len2 = len(self.dataset2) print(f'Total {self.len1 + self.len2} !!!') def __len__(self): # 设置为足够大的数值(例如两倍的最大数据集长度) return 2 * max(self.len1,self.len2) def __getitem__(self, idx): # 使用PyTorch的随机数生成器,保证多进程安全 if torch.rand(1).item() < 0.5: # 从A中随机抽取样本 a_idx = torch.randint(0, len(self.dataset1), (1,)).item() return self.dataset1[a_idx] else: # 从B中随机抽取样本 b_idx = torch.randint(0, len(self.dataset2), (1,)).item() return self.dataset2[b_idx] @staticmethod def collate_fn(batch): # name name = [item['name'] for item in batch] # videos videos = [item['videos'] for item in batch] videos = torch.stack(videos) # ref_img ref_img = [item['ref_img'] for item in batch] ref_img = torch.stack(ref_img) randomref_img = None return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) class AMDConsecutiveVideoDoubleRef(Dataset): def __init__( self, video_dir: str = '', # video dir or pkl file sample_size: int = 32, sample_stride: int = 2, sample_n_frames:int = 16, ref_drop_ratio = 0.0, ): # Init setting self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.ref_drop_ratio = ref_drop_ratio # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(sample_size[0]), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) if 'pkl' in video_dir: with open(video_dir, 'rb') as f: video_files = pickle.load(f) print(f'Total {len(video_files)} !!!') elif '.txt' in video_dir: with open(video_dir, 'r') as file: lines = file.readlines() video_dirs = [line.strip() for line in lines] video_files = [] for dir in video_dirs: video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True) print(f'Total {len(video_files)} !!!') else: video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True) print(f'Total {len(video_files)} !!!') # Data dict self.metadata_list = [] for file_path in tqdm(video_files): d = {} d['name'] = self.get_file_name(file_path) d['video_path'] = file_path self.metadata_list.append(d) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: file_name = self.metadata_list[idx]['name'] file_name,videos,ref_img,randomref_img = self.get_batch(idx) break except Exception as e: # file_name = self.metadata_list[idx]['name'] # print(file_name) print('error',e) idx = random.randint(0, self.length-1) sample = dict(name=file_name,videos=videos,ref_img = ref_img,randomref_img=randomref_img) return sample def __len__(self): return self.length def get_batch(self, idx): # init meta_data = self.metadata_list[idx] file_name = meta_data['name'] video_path = meta_data['video_path'] # video process video_reader = VideoReader(video_path, ctx=cpu(0)) video_length = len(video_reader) sample_frames = self.sample_n_frames + 1 # refimg + videos clip_length = min(video_length, (sample_frames - 1) * self.sample_stride + 1) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int) # random ref frame idx_all = np.arange(0,video_length) occ_idx = np.arange(start_idx, start_idx + clip_length) randomref_idx = [x for x in idx_all if x not in occ_idx] if len(randomref_idx) == 0: ref_frame_idx = batch_index[0] else: i = torch.randint(low=0, high=len(randomref_idx), size=(1,)).item() ref_frame_idx = randomref_idx[i] batch_index = [ref_frame_idx] + list(batch_index) videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 # # ref frame # idx_all = np.arange(0,video_length) # occ_idx = np.arange(start_idx, start_idx + clip_length) # ref_idx = [x for x in idx_all if x not in occ_idx] # if len(ref_idx) == 0: # ref_frame_idx = batch_index[0] # else: # np.random.shuffle(ref_idx) # ref_frame_idx = ref_idx[0] # ref_frame = torch.from_numpy(video_reader[ref_frame_idx].asnumpy()).permute(2, 0, 1).contiguous() # ref_frame = ref_frame / 255.0 # transform videos_cache = self.pixel_transforms(videos) # F+1,C,H,W videos = videos_cache[2:,:,:,:] # F,C,H,W ref_frame = videos_cache[1,:,:,:] # C,H,W randomref_frame = videos_cache[0,:,:,:] # C,H,W # repeat ref_frame = ref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) # F,C,H,W randomref_frame = randomref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) # F,C,H,W return file_name,videos,ref_frame,randomref_frame def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): # name name = [item['name'] for item in batch] # videos videos = [item['videos'] for item in batch] videos = torch.stack(videos) # ref_img ref_img = [item['ref_img'] for item in batch] ref_img = torch.stack(ref_img) randomref_img = [item['randomref_img'] for item in batch] randomref_img = torch.stack(randomref_img) return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) class AMDConsecutiveVideoDoubleRefBalance(Dataset): def __init__( self, video_dir: str = '', # video dir or pkl file sample_size: int = 32, sample_stride: int = 2, sample_n_frames:int = 16, ref_drop_ratio = 0.0, ): # Init setting self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.ref_drop_ratio = ref_drop_ratio assert '.txt' in video_dir with open(video_dir, 'r') as file: lines = file.readlines() video_paths = [line.strip() for line in lines] self.datasets = [] for vp in video_paths: self.datasets.append(AMDConsecutiveVideoDoubleRef(video_dir = vp, # video dir or pkl file sample_size = sample_size, sample_stride = sample_stride, sample_n_frames = sample_n_frames, ref_drop_ratio = 0.0,)) self.length = len(self.datasets) * max([len(d) for d in self.datasets]) print(self.length) def __len__(self): # 设置为足够大的数值(例如两倍的最大数据集长度) return self.length def __getitem__(self, idx): dataset_num = len(self.datasets) cur_idx = idx % dataset_num cur_dataset = self.datasets[cur_idx] idx = torch.randint(0, len(cur_dataset), (1,)).item() return cur_dataset[idx] @staticmethod def collate_fn(batch): # name name = [item['name'] for item in batch] # videos videos = [item['videos'] for item in batch] videos = torch.stack(videos) # ref_img ref_img = [item['ref_img'] for item in batch] ref_img = torch.stack(ref_img) randomref_img = [item['randomref_img'] for item in batch] randomref_img = torch.stack(randomref_img) return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) class AMDRandomPair(Dataset): def __init__( self, video_dir: str = '', # video dir or pkl file sample_size: int = 32, sample_stride: int = 4, sample_n_frames:int = 16, ref_drop_ratio = 0.0, ): # Init setting self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.ref_drop_ratio = ref_drop_ratio # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) if 'pkl' in video_dir: with open(video_dir, 'rb') as f: video_files = pickle.load(f) print(f'Total {len(video_files)} !!!') elif '.txt' in video_dir: with open(video_dir, 'r') as file: lines = file.readlines() video_dirs = [line.strip() for line in lines] video_files = [] for dir in video_dirs: video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True) print(f'Total {len(video_files)} !!!') else: video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True) # Data dict self.metadata_list = [] for file_path in tqdm(video_files): d = {} d['name'] = self.get_file_name(file_path) d['video_path'] = file_path self.metadata_list.append(d) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: file_name = self.metadata_list[idx]['name'] file_name,videos,ref_img = self.get_batch(idx) break except Exception as e: # file_name = self.metadata_list[idx]['name'] # print(file_name) print('error',e) idx = random.randint(0, self.length-1) sample = dict(name=file_name,videos=videos,ref_img = ref_img) return sample def __len__(self): return self.length def get_batch(self, idx): # init meta_data = self.metadata_list[idx] file_name = meta_data['name'] video_path = meta_data['video_path'] # video process video_reader = VideoReader(video_path) video_length = len(video_reader) ref_idx,video_idx = generate_non_equal_random_lists(frame_num=video_length,sample_num=self.sample_n_frames) ref_videos = torch.from_numpy(video_reader.get_batch(ref_idx).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) ref_videos = ref_videos / 255.0 ref_videos = self.pixel_transforms(ref_videos) videos = torch.from_numpy(video_reader.get_batch(video_idx).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) return file_name,videos,ref_videos def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): # name name = [item['name'] for item in batch] # videos videos = [item['videos'] for item in batch] videos = torch.stack(videos) # ref_img ref_img = [item['ref_img'] for item in batch] ref_img = torch.stack(ref_img) return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=None) class A2MVideoAudio(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) video_length = min(len(video_reader),audio_feature.shape[0]) # sample_frames sample_frames = self.sample_n_frames + 1 # self.sample_n_frames = 4, sample_frames=5 clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9 if clip_length > video_length : batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int) batch_index = np.array([d for d in batch_index if d <= video_length-1],dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] # pad pad_length = self.sample_n_frames - gt_video.shape[0] video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype) gt_video = torch.cat([gt_video, video_pad], dim=0) # F,C,H,W audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype) gt_audio = torch.cat([gt_audio, audio_pad], dim=0) else: start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 # meta meta_ = dict( video_length = video_length, video_path = video_path, audio_path = whisper_path, ) return dict( ref_video=ref_video, gt_video=gt_video, ref_audio=ref_audio, gt_audio=gt_audio, mask = mask, meta=meta_ ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( meta = [item["meta"] for item in batch], ref_video = torch.stack([item['ref_video'] for item in batch]), gt_video = torch.stack([item['gt_video'] for item in batch]), ref_audio = torch.stack([item['ref_audio'] for item in batch]), gt_audio = torch.stack([item['gt_audio'] for item in batch]), mask = torch.stack([item['mask'] for item in batch]) ) class A2MVideoAudioPose(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, audio_drop_ratio:float = 0.0, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.audio_drop_ratio = audio_drop_ratio # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] pose_path = meta_data['pose_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) pose_reader = VideoReader(pose_path) video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) # sample_frames sample_frames = self.sample_n_frames + 1 # self.sample_n_frames = 4, sample_frames=5 clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9 if clip_length > video_length : batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int) batch_index = np.array([d for d in batch_index if d <= video_length-1],dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[0,:] # C,H,W gt_pose = poses[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] # pad pad_length = self.sample_n_frames - gt_video.shape[0] video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype) gt_video = torch.cat([gt_video, video_pad], dim=0) # F,C,H,W pose_pad = torch.zeros((pad_length, *gt_pose.shape[1:]), dtype=gt_pose.dtype) gt_pose = torch.cat([gt_pose, pose_pad], dim=0) # F,C,H,W audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype) gt_audio = torch.cat([gt_audio, audio_pad], dim=0) else: start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[0,:] # C,H,W gt_pose = poses[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 # drop audio if torch.rand(1).item() < self.audio_drop_ratio: ref_audio = torch.zeros_like(ref_audio) gt_audio = torch.zeros_like(gt_audio) # meta meta_ = dict( video_length = video_length, video_path = video_path, audio_path = whisper_path, ) return dict( ref_video=ref_video, gt_video=gt_video, ref_pose = ref_pose, gt_pose = gt_pose, ref_audio=ref_audio, gt_audio=gt_audio, mask = mask, meta=meta_ ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( meta = [item["meta"] for item in batch], ref_video = torch.stack([item['ref_video'] for item in batch]), gt_video = torch.stack([item['gt_video'] for item in batch]), ref_pose = torch.stack([item['ref_pose'] for item in batch]), gt_pose = torch.stack([item['gt_pose'] for item in batch]), ref_audio = torch.stack([item['ref_audio'] for item in batch]), gt_audio = torch.stack([item['gt_audio'] for item in batch]), mask = torch.stack([item['mask'] for item in batch]) ) class A2MVideoAudioPoseRandomRef(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] pose_path = meta_data['pose_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) pose_reader = VideoReader(pose_path) video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) # sample_frames sample_frames = self.sample_n_frames clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9 if clip_length > video_length : batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int) batch_index = list(np.array([d for d in batch_index if d <= video_length-1],dtype=int)) # ref idx idx_all = np.arange(0,video_length) start_idx = 0 occ_idx = np.arange(start_idx, start_idx + clip_length) ref_idx = [x for x in idx_all if x not in occ_idx] if len(ref_idx) == 0: ref_frame_idx = batch_index[0] else: np.random.shuffle(ref_idx) ref_frame_idx = ref_idx[0] batch_index = [ref_frame_idx] + batch_index # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[0,:] # C,H,W gt_pose = poses[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] # pad pad_length = self.sample_n_frames - gt_video.shape[0] video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype) gt_video = torch.cat([gt_video, video_pad], dim=0) # F,C,H,W pose_pad = torch.zeros((pad_length, *gt_pose.shape[1:]), dtype=gt_pose.dtype) gt_pose = torch.cat([gt_pose, pose_pad], dim=0) # F,C,H,W audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype) gt_audio = torch.cat([gt_audio, audio_pad], dim=0) else: start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)) # ref index idx_all = np.arange(0,video_length) occ_idx = np.arange(start_idx, start_idx + clip_length) ref_idx = [x for x in idx_all if x not in occ_idx] if len(ref_idx) == 0: ref_frame_idx = batch_index[0] else: np.random.shuffle(ref_idx) ref_frame_idx = ref_idx[0] batch_index = [ref_frame_idx] + batch_index # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[0,:] # C,H,W gt_pose = poses[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 # meta meta_ = dict( video_length = video_length, video_path = video_path, audio_path = whisper_path, ) return dict( ref_video=ref_video, gt_video=gt_video, ref_pose = ref_pose, gt_pose = gt_pose, ref_audio=ref_audio, gt_audio=gt_audio, mask = mask, meta=meta_ ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( meta = [item["meta"] for item in batch], ref_video = torch.stack([item['ref_video'] for item in batch]), gt_video = torch.stack([item['gt_video'] for item in batch]), ref_pose = torch.stack([item['ref_pose'] for item in batch]), gt_pose = torch.stack([item['gt_pose'] for item in batch]), ref_audio = torch.stack([item['ref_audio'] for item in batch]), gt_audio = torch.stack([item['gt_audio'] for item in batch]), mask = torch.stack([item['mask'] for item in batch]) ) class A2MVideoAudioPoseMultiSample(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, audio_drop_ratio:float = 0.0, num_sample:int = 4, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.audio_drop_ratio = audio_drop_ratio self.num_sample = num_sample # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length =len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] pose_path = meta_data['pose_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) pose_reader = VideoReader(pose_path) video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) # sample_frames sample_frames = self.sample_n_frames + 1 # self.sample_n_frames = 4, sample_frames=5 clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9 ref_video_list = [] gt_video_list = [] ref_pose_list = [] gt_pose_list = [] ref_audio_list = [] gt_audio_list = [] mask_list = [] for i in range(self.num_sample): start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[0,:] # C,H,W gt_video = videos[1:,:] # F,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[0,:] # C,H,W gt_pose = poses[1:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[0,:] # M,D gt_audio = audios[1:,:] # F,M,D # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 # drop audio if torch.rand(1).item() < self.audio_drop_ratio: ref_audio = torch.zeros_like(ref_audio) gt_audio = torch.zeros_like(gt_audio) # cache ref_video_list.append(ref_video) gt_video_list.append(gt_video) ref_pose_list.append(ref_pose) gt_pose_list.append(gt_pose) ref_audio_list.append(ref_audio) gt_audio_list.append(gt_audio) mask_list.append(mask) return dict( ref_video=torch.stack(ref_video_list), gt_video=torch.stack(gt_video_list), ref_pose = torch.stack(ref_pose_list), gt_pose = torch.stack(gt_pose_list), ref_audio=torch.stack(ref_audio_list), gt_audio=torch.stack(gt_audio_list), mask = torch.stack(mask_list), ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), mask = torch.cat([item['mask'] for item in batch],dim=0) ) class A2MVideoAudioPoseMultiSampleMultiRefBalance(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, audio_drop_ratio:float = 0.0, num_sample:int = 4, max_ref_frame:int = 8, random_ref_num:bool = False, **kwargs ): super().__init__() with open(video_dir, 'r') as file: lines = file.readlines() video_dirs = [line.strip() for line in lines] assert len(video_dirs) == 2, 'Only support 2 video dirs' self.dataset1 = A2MVideoAudioPoseMultiSampleMultiRef(video_dir=video_dirs[0], sample_size=sample_size, sample_stride=sample_stride, sample_n_frames=sample_n_frames, audio_drop_ratio=audio_drop_ratio, num_sample=num_sample, max_ref_frame=max_ref_frame, random_ref_num=random_ref_num) self.dataset2 = A2MVideoAudioPoseMultiSampleMultiRef(video_dir=video_dirs[1], sample_size=sample_size, sample_stride=sample_stride, sample_n_frames=sample_n_frames, audio_drop_ratio=audio_drop_ratio, num_sample=num_sample, max_ref_frame=max_ref_frame, random_ref_num=random_ref_num) self.length = 2*max(len(self.dataset1),len(self.dataset2)) def __getitem__(self, idx): while True: try: if idx % 2 == 0: a_idx = torch.randint(0, len(self.dataset1), (1,)).item() sample = self.dataset1[a_idx] else: b_idx = torch.randint(0, len(self.dataset2), (1,)).item() sample = self.dataset2[b_idx] break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length @staticmethod def collate_fn(batch): return dict( ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), mask = torch.cat([item['mask'] for item in batch],dim=0) ) class A2MVideoAudioMultiRefDoubleRef(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, audio_drop_ratio:float = 0.0, num_sample:int = 4, max_ref_frame:int = 8, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.audio_drop_ratio = audio_drop_ratio self.num_sample = num_sample self.max_ref_frame = max_ref_frame self.randomref_num = 8 # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) video_length = min(len(video_reader),audio_feature.shape[0]) # ref num r = torch.rand(1).item() if r < 0.33: ref_num = 0 if r < 0.66: ref_num = 1 else: ref_num = self.max_ref_frame sample_frames = self.sample_n_frames + ref_num clip_length = (sample_frames - 1) * self.sample_stride + 1 start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)) # randomref random_index = list(np.linspace(0, video_length - 1, self.randomref_num, dtype=int)) # random_index = torch.randint(low=0, high=clip_length-1, size=(1,)).item() video_batch_index = random_index + batch_index # frames videos = torch.from_numpy(video_reader.get_batch(video_batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) randomref_video = videos[:self.randomref_num,:] # T,C,H,W l_videos = videos[self.randomref_num:,:] # T,C,H,W ref_video = l_videos[:ref_num,:] if ref_num > 0 else None gt_video = l_videos[ref_num:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[:ref_num,:] if ref_num > 0 else None gt_audio = audios[ref_num:,:] # F,M,D # padding ref frame if ref_num == 1: ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:])) ref_video = torch.cat([ref_video_pad,ref_video],dim=0) # N,T,C,H,W ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:])) ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0) elif ref_num == 0: ref_video = torch.zeros((self.max_ref_frame,*gt_video.shape[1:])) ref_audio = torch.zeros((self.max_ref_frame,*gt_audio.shape[1:])) # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 return dict( ref_video=ref_video, gt_video=gt_video, randomref_video = randomref_video, ref_audio= ref_audio, gt_audio=gt_audio, mask = mask, ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( ref_video = torch.stack([item['ref_video'] for item in batch],dim=0), gt_video = torch.stack([item['gt_video'] for item in batch],dim=0), randomref_video = torch.stack([item['randomref_video'] for item in batch],dim=0), ref_audio = torch.stack([item['ref_audio'] for item in batch],dim=0), gt_audio = torch.stack([item['gt_audio'] for item in batch],dim=0), mask = torch.stack([item['mask'] for item in batch],dim=0) ) class A2MVideoAudioMultiRefDoubleRefBalance(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, audio_drop_ratio:float = 0.0, num_sample:int = 4, max_ref_frame:int = 8, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.audio_drop_ratio = audio_drop_ratio self.num_sample = num_sample self.max_ref_frame = max_ref_frame self.randomref_num = 8 # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) video_length = min(len(video_reader),audio_feature.shape[0]) # ref num r = torch.rand(1).item() if r < 0.33: ref_num = 0 elif r<0.66: ref_num = 1 else: ref_num = self.max_ref_frame sample_frames = self.sample_n_frames + ref_num clip_length = (sample_frames - 1) * self.sample_stride + 1 start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)) # randomref random_index = list(np.linspace(0, video_length - 1, self.randomref_num, dtype=int)) random_index = list(random_index) video_batch_index = random_index + batch_index # frames videos = torch.from_numpy(video_reader.get_batch(video_batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) randomref_video = videos[:self.randomref_num,:] # T,C,H,W l_videos = videos[self.randomref_num:,:] # T,C,H,W ref_video = l_videos[:ref_num,:] if ref_num > 0 else None gt_video = l_videos[ref_num:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[:ref_num,:] if ref_num > 0 else None gt_audio = audios[ref_num:,:] # F,M,D # padding ref frame if ref_num == 1: ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:])) ref_video = torch.cat([ref_video_pad,ref_video],dim=0) # N,T,C,H,W ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:])) ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0) elif ref_num == 0: ref_video = torch.zeros((self.max_ref_frame,*gt_video.shape[1:])) ref_audio = torch.zeros((self.max_ref_frame,*gt_audio.shape[1:])) # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 return dict( ref_video=ref_video, gt_video=gt_video, randomref_video = randomref_video, ref_audio= ref_audio, gt_audio=gt_audio, mask = mask, ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( ref_video = torch.stack([item['ref_video'] for item in batch],dim=0), gt_video = torch.stack([item['gt_video'] for item in batch],dim=0), randomref_video = torch.stack([item['randomref_video'] for item in batch],dim=0), ref_audio = torch.stack([item['ref_audio'] for item in batch],dim=0), gt_audio = torch.stack([item['gt_audio'] for item in batch],dim=0), mask = torch.stack([item['mask'] for item in batch],dim=0) ) # pose img2img class A2MVideoAudioPoseRandomRefMultiSample(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, num_sample:int = 4, max_ref_frame:int = 8, random_ref_num:bool = False, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.num_sample = num_sample self.max_ref_frame = max_ref_frame # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] pose_path = meta_data['pose_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) pose_reader = VideoReader(pose_path) video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) # batch index sample_frames = self.num_sample * 2 if video_length < sample_frames: raise ValueError(f"视频长度{video_length}太短了,需要长度{sample_frames}") # occ_idx = np.arange(0, video_length) # # # np.random.shuffle(occ_idx) # # # 生成随机排列的索引 # # shuffled_indices = torch.randperm(len(occ_idx)) # # # 使用索引打乱列表 # # occ_idx = [occ_idx[i] for i in shuffled_indices] # occ_idx = shuffle_list(occ_idx) # batch_index = occ_idx[:sample_frames] # batch_index = [torch.randint(low=0, high=video_length, size=(1,)).item() for i in range(sample_frames)] start_idx = torch.randint(low=0, high=video_length-sample_frames, size=(1,)).item() occ_idx = np.arange(start_idx, start_idx+video_length) batch_index = occ_idx[:sample_frames] # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[:self.num_sample,:].unsqueeze(1) # N,1,C,H,W gt_video = videos[self.num_sample:,:].unsqueeze(1) # N,1,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[:self.num_sample,:].unsqueeze(1) # N,1,C,H,W gt_pose = poses[self.num_sample:,:].unsqueeze(1) # N,1,C,H,W audios = audio_feature[batch_index] ref_audio = audios[:self.num_sample,:].unsqueeze(1) # N,1,C,H,W gt_audio = audios[self.num_sample:,:].unsqueeze(1) # N,1,C,H,W # available length cur_available_length = gt_video.shape[0] # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 return dict( ref_video=ref_video, gt_video=gt_video, ref_pose = ref_pose, gt_pose =gt_pose, ref_audio=ref_audio, gt_audio=gt_audio, mask = mask, ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), mask = torch.cat([item['mask'] for item in batch],dim=0) ) class A2MVideoAudioPoseMultiSampleMultiRef(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 16, audio_drop_ratio:float = 0.0, num_sample:int = 4, max_ref_frame:int = 8, random_ref_num:bool = False, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.audio_drop_ratio = audio_drop_ratio self.num_sample = num_sample self.max_ref_frame = max_ref_frame self.random_ref_num = random_ref_num # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): """ videos : 31,3,256,256 ref_img : 3,256,256 audio_feature : 30,50,384 ref_pose : 3,256,256 meta """ # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] pose_path = meta_data['pose_path'] # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) pose_reader = VideoReader(pose_path) video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) # sample_frames ref_video_list = [] gt_video_list = [] ref_pose_list = [] gt_pose_list = [] ref_audio_list = [] gt_audio_list = [] mask_list = [] for i in range(self.num_sample): # random ref num if self.random_ref_num: ref_num = torch.randint(low=1, high=self.max_ref_frame+1, size=(1,)).item() else: ref_num = [1, self.max_ref_frame][torch.randint(2, (1,)).item()] sample_frames = self.sample_n_frames + ref_num clip_length = (sample_frames - 1) * self.sample_stride + 1 start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_video = videos[:ref_num,:] # T,C,H,W gt_video = videos[ref_num:,:] # F,C,H,W poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) poses = poses / 255.0 poses = self.pixel_transforms(poses) ref_pose = poses[:ref_num,:] # T,C,H,W gt_pose = poses[ref_num:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[:ref_num,:] # T,M,D gt_audio = audios[ref_num:,:] # F,M,D # padding ref frame if ref_num < self.max_ref_frame: ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:])) ref_video = torch.cat([ref_video_pad,ref_video],dim=0) # N,T,C,H,W ref_pose_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_pose.shape[1:])) ref_pose = torch.cat([ref_pose_pad,ref_pose],dim=0) ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:])) ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0) # available length cur_available_length = gt_video.shape[0] assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) # mask mask = torch.zeros(self.sample_n_frames) mask[:cur_available_length] = 1 # drop audio if torch.rand(1).item() < self.audio_drop_ratio: ref_audio = torch.zeros_like(ref_audio) gt_audio = torch.zeros_like(gt_audio) # cache ref_video_list.append(ref_video) gt_video_list.append(gt_video) ref_pose_list.append(ref_pose) gt_pose_list.append(gt_pose) ref_audio_list.append(ref_audio) gt_audio_list.append(gt_audio) mask_list.append(mask) return dict( ref_video=torch.stack(ref_video_list), gt_video=torch.stack(gt_video_list), ref_pose = torch.stack(ref_pose_list), gt_pose = torch.stack(gt_pose_list), ref_audio=torch.stack(ref_audio_list), gt_audio=torch.stack(gt_audio_list), mask = torch.stack(mask_list), ) def get_file_name(self, file_path): return file_path.split('/')[-1].split('.')[0] @staticmethod def collate_fn(batch): return dict( ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), mask = torch.cat([item['mask'] for item in batch],dim=0) ) # inference class A2VDataset(Dataset): def __init__( self, video_dir:str, sample_size: int = 256, sample_stride: int = 1, sample_n_frames:int = 120, **kwargs ): super().__init__() self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames # Transform sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256) self.pixel_transforms = transforms.Compose([ transforms.Resize(min(sample_size)), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) with open(video_dir, 'rb') as f: self.metadata_list = pickle.load(f) self.length = len(self.metadata_list) print(f'Total {self.length} files is available') def __getitem__(self, idx): while True: try: sample = self.get_batch(idx) break except Exception as e: print('error',e) idx = torch.randint(low=0, high=self.length, size=(1,)).item() return sample def __len__(self): return self.length def get_batch(self, idx): # init meta_data = self.metadata_list[idx] video_path = meta_data['video_path'] whisper_path = meta_data['whisper_emb_path'] audio_path = meta_data['audio_path'] # pose_path = meta_data['pose_path'] name = os.path.basename(video_path).split('.')[0] fps = 25 if 'hdtf' in video_path else 30 # audio audio_feature = torch.load(whisper_path) # load & check video_reader = VideoReader(video_path) # pose_reader = VideoReader(pose_path) video_length = min(len(video_reader),audio_feature.shape[0]) # sample_frame # batch idx ref_num = 1 sample_frames = self.sample_n_frames + ref_num clip_length = (sample_frames - 1) * self.sample_stride + 1 start_idx = np.random.randint(0, video_length - clip_length + 1) end_idx = start_idx + clip_length batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) # frames videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) videos = videos / 255.0 videos = self.pixel_transforms(videos) ref_img = videos[:ref_num,:] # 1,C,H,W inf_video = videos[ref_num:,:] # F,C,H,W # poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W) # poses = poses / 255.0 # poses = self.pixel_transforms(poses) # ref_pose = poses[:ref_num,:] # T,C,H,W # gt_pose = poses[ref_num:,:] # F,C,H,W audios = audio_feature[batch_index] ref_audio = audios[:ref_num,:] # 1,M,D inf_audio = audios[ref_num:,:] # F,M,D # meta_info start_time = start_idx / fps _meta = {"name":name, "audio_path":audio_path, "fps":fps, "start_time":start_time} return dict( meta_info = _meta, ref_img=ref_img, gt_video=videos, ref_audio=ref_audio, inf_audio=inf_audio, ref_pose=None, inf_pose=None, ) @staticmethod def collate_fn(batch): return dict( meta_info = [item['meta_info'] for item in batch], ref_img= torch.stack([item["ref_img"] for item in batch]), ref_audio= torch.stack([item["ref_audio"] for item in batch]), inf_audio= torch.stack([item["inf_audio"] for item in batch]), ref_pose= torch.ones((2,2)), inf_pose= torch.ones((2,2)), gt_video= torch.stack([item["gt_video"] for item in batch]), ) def generate_non_equal_random_lists(frame_num,sample_num): list1 = [np.random.randint(0, frame_num) for _ in range(sample_num)] list2 = [] for i in range(len(list1)): available_numbers = list(range(0, list1[i])) + list(range(list1[i] + 1, frame_num)) list2.append(random.choice(available_numbers)) return list1, list2 def shuffle_list(l): shuffled_indices = torch.randperm(len(l)) # 使用索引打乱列表 l = [l[i] for i in shuffled_indices] return l if __name__ == "__main__": # dataset = CelebvText() # dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) # for idx, batch in enumerate(dataloader): # print(batch["videos"].shape, len(batch["text"])) # for i in range(batch["videos"].shape[0]): # save_videos_grid(batch["videos"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) from torch.utils.data import DataLoader # dataset = AMDVideoAudioFeature( # path=data_path, # path_type="file", # motion_seq_len=motion_seq_len, # sample_n_frames=num_frames, # audio_processor=audio_processor # ) dataset = AMDVideoAudioFeature( video_dir = '/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/dataset/path/lhz/train.pkl', path_type = 'file' ) dataloader = DataLoader( dataset,2,True,num_workers=0, collate_fn=dataset.collate_fn ) # dataloader = DataLoader(dataset,batch_size=2,collate_fn=dataset.collate_fn,num_workers=2) # d = dataset[10] # video = d["videos"] # audio = d["audio_feature"] # refimg = d["ref_img"] for d in dataloader: video = d["videos"] audio_feature = d["audio_feature"] refimg = d["ref_img"] # break print(video.shape) print(audio_feature.shape) #