semo / dataset /dataset.py
HappyP4nda's picture
Upload folder using huggingface_hub
bd546bf verified
import os
import random
import numpy as np
from decord import VideoReader
import glob
from tqdm import tqdm
import pickle
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from decord import cpu, gpu
from torchvision.io import read_video, write_video
import json
import traceback
# MultiSample : 每个视频采样多个样本
# RandomRef : 参考图像是不是随机的,不是随机的就是上一帧,random ref的情况下,refvideo只有一帧
# MultiRef : 多帧参考帧,默认最多是8帧
class AMDConsecutiveVideo(Dataset):
def __init__(
self,
video_dir: str = '', # video dir or pkl file
sample_size: int = 32,
sample_stride: int = 2,
sample_n_frames:int = 16,
ref_drop_ratio = 0.0,
):
# Init setting
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.ref_drop_ratio = ref_drop_ratio
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
if 'pkl' in video_dir:
with open(video_dir, 'rb') as f:
video_files = pickle.load(f)
print(f'Total {len(video_files)} !!!')
elif '.txt' in video_dir:
with open(video_dir, 'r') as file:
lines = file.readlines()
video_dirs = [line.strip() for line in lines]
video_files = []
for dir in video_dirs:
video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True)
print(f'Total {len(video_files)} !!!')
else:
video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True)
print(f'Total {len(video_files)} !!!')
# Data dict
self.metadata_list = []
for file_path in tqdm(video_files):
d = {}
d['name'] = self.get_file_name(file_path)
d['video_path'] = file_path
self.metadata_list.append(d)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
file_name = self.metadata_list[idx]['name']
file_name,videos,ref_img = self.get_batch(idx)
break
except Exception as e:
# file_name = self.metadata_list[idx]['name']
# print(file_name)
print('error',e)
idx = random.randint(0, self.length-1)
sample = dict(name=file_name,videos=videos,ref_img = ref_img)
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
# init
meta_data = self.metadata_list[idx]
file_name = meta_data['name']
video_path = meta_data['video_path']
# video process
video_reader = VideoReader(video_path, ctx=cpu(0))
video_length = len(video_reader)
sample_frames = self.sample_n_frames + 1 # refimg + videos
clip_length = min(video_length, (sample_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int)
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
# transform
videos_cache = self.pixel_transforms(videos) # F+1,C,H,W
videos = videos_cache[1:,:,:,:] # F,C,H,W
ref_frame = videos_cache[0,:,:,:] # C,H,W
# repeat
ref_frame = ref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) # F,C,H,W
return file_name,videos,ref_frame
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
# name
name = [item['name'] for item in batch]
# videos
videos = [item['videos'] for item in batch]
videos = torch.stack(videos)
# ref_img
ref_img = [item['ref_img'] for item in batch]
ref_img = torch.stack(ref_img)
randomref_img = None
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img)
class AMDConsecutiveVideoBalance(Dataset):
def __init__(
self,
video_dir: str = '', # video dir or pkl file
sample_size: int = 32,
sample_stride: int = 2,
sample_n_frames:int = 16,
ref_drop_ratio = 0.0,
):
# Init setting
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.ref_drop_ratio = ref_drop_ratio
assert '.txt' in video_dir
with open(video_dir, 'r') as file:
lines = file.readlines()
video_paths = [line.strip() for line in lines]
assert len(video_paths) == 2
self.dataset1 = AMDConsecutiveVideo(video_dir = video_paths[0], # video dir or pkl file
sample_size = sample_size,
sample_stride = sample_stride,
sample_n_frames = sample_n_frames,
ref_drop_ratio = 0.0,)
self.dataset2 = AMDConsecutiveVideo(video_dir = video_paths[1], # video dir or pkl file
sample_size = sample_size,
sample_stride = sample_stride,
sample_n_frames = sample_n_frames,
ref_drop_ratio = 0.0,)
self.len1 = len(self.dataset1)
self.len2 = len(self.dataset2)
print(f'Total {self.len1 + self.len2} !!!')
def __len__(self):
# 设置为足够大的数值(例如两倍的最大数据集长度)
return 2 * max(self.len1,self.len2)
def __getitem__(self, idx):
# 使用PyTorch的随机数生成器,保证多进程安全
if torch.rand(1).item() < 0.5:
# 从A中随机抽取样本
a_idx = torch.randint(0, len(self.dataset1), (1,)).item()
return self.dataset1[a_idx]
else:
# 从B中随机抽取样本
b_idx = torch.randint(0, len(self.dataset2), (1,)).item()
return self.dataset2[b_idx]
@staticmethod
def collate_fn(batch):
# name
name = [item['name'] for item in batch]
# videos
videos = [item['videos'] for item in batch]
videos = torch.stack(videos)
# ref_img
ref_img = [item['ref_img'] for item in batch]
ref_img = torch.stack(ref_img)
randomref_img = None
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img)
class AMDConsecutiveVideoDoubleRef(Dataset):
def __init__(
self,
video_dir: str = '', # video dir or pkl file
sample_size: int = 32,
sample_stride: int = 2,
sample_n_frames:int = 16,
ref_drop_ratio = 0.0,
):
# Init setting
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.ref_drop_ratio = ref_drop_ratio
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
if 'pkl' in video_dir:
with open(video_dir, 'rb') as f:
video_files = pickle.load(f)
print(f'Total {len(video_files)} !!!')
elif '.txt' in video_dir:
with open(video_dir, 'r') as file:
lines = file.readlines()
video_dirs = [line.strip() for line in lines]
video_files = []
for dir in video_dirs:
video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True)
print(f'Total {len(video_files)} !!!')
else:
video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True)
print(f'Total {len(video_files)} !!!')
# Data dict
self.metadata_list = []
for file_path in tqdm(video_files):
d = {}
d['name'] = self.get_file_name(file_path)
d['video_path'] = file_path
self.metadata_list.append(d)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
file_name = self.metadata_list[idx]['name']
file_name,videos,ref_img,randomref_img = self.get_batch(idx)
break
except Exception as e:
# file_name = self.metadata_list[idx]['name']
# print(file_name)
print('error',e)
idx = random.randint(0, self.length-1)
sample = dict(name=file_name,videos=videos,ref_img = ref_img,randomref_img=randomref_img)
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
# init
meta_data = self.metadata_list[idx]
file_name = meta_data['name']
video_path = meta_data['video_path']
# video process
video_reader = VideoReader(video_path, ctx=cpu(0))
video_length = len(video_reader)
sample_frames = self.sample_n_frames + 1 # refimg + videos
clip_length = min(video_length, (sample_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int)
# random ref frame
idx_all = np.arange(0,video_length)
occ_idx = np.arange(start_idx, start_idx + clip_length)
randomref_idx = [x for x in idx_all if x not in occ_idx]
if len(randomref_idx) == 0:
ref_frame_idx = batch_index[0]
else:
i = torch.randint(low=0, high=len(randomref_idx), size=(1,)).item()
ref_frame_idx = randomref_idx[i]
batch_index = [ref_frame_idx] + list(batch_index)
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
# # ref frame
# idx_all = np.arange(0,video_length)
# occ_idx = np.arange(start_idx, start_idx + clip_length)
# ref_idx = [x for x in idx_all if x not in occ_idx]
# if len(ref_idx) == 0:
# ref_frame_idx = batch_index[0]
# else:
# np.random.shuffle(ref_idx)
# ref_frame_idx = ref_idx[0]
# ref_frame = torch.from_numpy(video_reader[ref_frame_idx].asnumpy()).permute(2, 0, 1).contiguous()
# ref_frame = ref_frame / 255.0
# transform
videos_cache = self.pixel_transforms(videos) # F+1,C,H,W
videos = videos_cache[2:,:,:,:] # F,C,H,W
ref_frame = videos_cache[1,:,:,:] # C,H,W
randomref_frame = videos_cache[0,:,:,:] # C,H,W
# repeat
ref_frame = ref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) # F,C,H,W
randomref_frame = randomref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) # F,C,H,W
return file_name,videos,ref_frame,randomref_frame
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
# name
name = [item['name'] for item in batch]
# videos
videos = [item['videos'] for item in batch]
videos = torch.stack(videos)
# ref_img
ref_img = [item['ref_img'] for item in batch]
ref_img = torch.stack(ref_img)
randomref_img = [item['randomref_img'] for item in batch]
randomref_img = torch.stack(randomref_img)
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img)
class AMDConsecutiveVideoDoubleRefBalance(Dataset):
def __init__(
self,
video_dir: str = '', # video dir or pkl file
sample_size: int = 32,
sample_stride: int = 2,
sample_n_frames:int = 16,
ref_drop_ratio = 0.0,
):
# Init setting
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.ref_drop_ratio = ref_drop_ratio
assert '.txt' in video_dir
with open(video_dir, 'r') as file:
lines = file.readlines()
video_paths = [line.strip() for line in lines]
self.datasets = []
for vp in video_paths:
self.datasets.append(AMDConsecutiveVideoDoubleRef(video_dir = vp, # video dir or pkl file
sample_size = sample_size,
sample_stride = sample_stride,
sample_n_frames = sample_n_frames,
ref_drop_ratio = 0.0,))
self.length = len(self.datasets) * max([len(d) for d in self.datasets])
print(self.length)
def __len__(self):
# 设置为足够大的数值(例如两倍的最大数据集长度)
return self.length
def __getitem__(self, idx):
dataset_num = len(self.datasets)
cur_idx = idx % dataset_num
cur_dataset = self.datasets[cur_idx]
idx = torch.randint(0, len(cur_dataset), (1,)).item()
return cur_dataset[idx]
@staticmethod
def collate_fn(batch):
# name
name = [item['name'] for item in batch]
# videos
videos = [item['videos'] for item in batch]
videos = torch.stack(videos)
# ref_img
ref_img = [item['ref_img'] for item in batch]
ref_img = torch.stack(ref_img)
randomref_img = [item['randomref_img'] for item in batch]
randomref_img = torch.stack(randomref_img)
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img)
class AMDRandomPair(Dataset):
def __init__(
self,
video_dir: str = '', # video dir or pkl file
sample_size: int = 32,
sample_stride: int = 4,
sample_n_frames:int = 16,
ref_drop_ratio = 0.0,
):
# Init setting
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.ref_drop_ratio = ref_drop_ratio
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
if 'pkl' in video_dir:
with open(video_dir, 'rb') as f:
video_files = pickle.load(f)
print(f'Total {len(video_files)} !!!')
elif '.txt' in video_dir:
with open(video_dir, 'r') as file:
lines = file.readlines()
video_dirs = [line.strip() for line in lines]
video_files = []
for dir in video_dirs:
video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True)
print(f'Total {len(video_files)} !!!')
else:
video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True)
# Data dict
self.metadata_list = []
for file_path in tqdm(video_files):
d = {}
d['name'] = self.get_file_name(file_path)
d['video_path'] = file_path
self.metadata_list.append(d)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
file_name = self.metadata_list[idx]['name']
file_name,videos,ref_img = self.get_batch(idx)
break
except Exception as e:
# file_name = self.metadata_list[idx]['name']
# print(file_name)
print('error',e)
idx = random.randint(0, self.length-1)
sample = dict(name=file_name,videos=videos,ref_img = ref_img)
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
# init
meta_data = self.metadata_list[idx]
file_name = meta_data['name']
video_path = meta_data['video_path']
# video process
video_reader = VideoReader(video_path)
video_length = len(video_reader)
ref_idx,video_idx = generate_non_equal_random_lists(frame_num=video_length,sample_num=self.sample_n_frames)
ref_videos = torch.from_numpy(video_reader.get_batch(ref_idx).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
ref_videos = ref_videos / 255.0
ref_videos = self.pixel_transforms(ref_videos)
videos = torch.from_numpy(video_reader.get_batch(video_idx).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
return file_name,videos,ref_videos
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
# name
name = [item['name'] for item in batch]
# videos
videos = [item['videos'] for item in batch]
videos = torch.stack(videos)
# ref_img
ref_img = [item['ref_img'] for item in batch]
ref_img = torch.stack(ref_img)
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=None)
class A2MVideoAudio(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
video_length = min(len(video_reader),audio_feature.shape[0])
# sample_frames
sample_frames = self.sample_n_frames + 1 # self.sample_n_frames = 4, sample_frames=5
clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9
if clip_length > video_length :
batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int)
batch_index = np.array([d for d in batch_index if d <= video_length-1],dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
# pad
pad_length = self.sample_n_frames - gt_video.shape[0]
video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype)
gt_video = torch.cat([gt_video, video_pad], dim=0) # F,C,H,W
audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype)
gt_audio = torch.cat([gt_audio, audio_pad], dim=0)
else:
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
# meta
meta_ = dict(
video_length = video_length,
video_path = video_path,
audio_path = whisper_path,
)
return dict(
ref_video=ref_video,
gt_video=gt_video,
ref_audio=ref_audio,
gt_audio=gt_audio,
mask = mask,
meta=meta_
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
meta = [item["meta"] for item in batch],
ref_video = torch.stack([item['ref_video'] for item in batch]),
gt_video = torch.stack([item['gt_video'] for item in batch]),
ref_audio = torch.stack([item['ref_audio'] for item in batch]),
gt_audio = torch.stack([item['gt_audio'] for item in batch]),
mask = torch.stack([item['mask'] for item in batch])
)
class A2MVideoAudioPose(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
audio_drop_ratio:float = 0.0,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.audio_drop_ratio = audio_drop_ratio
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
pose_path = meta_data['pose_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
pose_reader = VideoReader(pose_path)
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader))
# sample_frames
sample_frames = self.sample_n_frames + 1 # self.sample_n_frames = 4, sample_frames=5
clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9
if clip_length > video_length :
batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int)
batch_index = np.array([d for d in batch_index if d <= video_length-1],dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[0,:] # C,H,W
gt_pose = poses[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
# pad
pad_length = self.sample_n_frames - gt_video.shape[0]
video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype)
gt_video = torch.cat([gt_video, video_pad], dim=0) # F,C,H,W
pose_pad = torch.zeros((pad_length, *gt_pose.shape[1:]), dtype=gt_pose.dtype)
gt_pose = torch.cat([gt_pose, pose_pad], dim=0) # F,C,H,W
audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype)
gt_audio = torch.cat([gt_audio, audio_pad], dim=0)
else:
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[0,:] # C,H,W
gt_pose = poses[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
# drop audio
if torch.rand(1).item() < self.audio_drop_ratio:
ref_audio = torch.zeros_like(ref_audio)
gt_audio = torch.zeros_like(gt_audio)
# meta
meta_ = dict(
video_length = video_length,
video_path = video_path,
audio_path = whisper_path,
)
return dict(
ref_video=ref_video,
gt_video=gt_video,
ref_pose = ref_pose,
gt_pose = gt_pose,
ref_audio=ref_audio,
gt_audio=gt_audio,
mask = mask,
meta=meta_
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
meta = [item["meta"] for item in batch],
ref_video = torch.stack([item['ref_video'] for item in batch]),
gt_video = torch.stack([item['gt_video'] for item in batch]),
ref_pose = torch.stack([item['ref_pose'] for item in batch]),
gt_pose = torch.stack([item['gt_pose'] for item in batch]),
ref_audio = torch.stack([item['ref_audio'] for item in batch]),
gt_audio = torch.stack([item['gt_audio'] for item in batch]),
mask = torch.stack([item['mask'] for item in batch])
)
class A2MVideoAudioPoseRandomRef(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
pose_path = meta_data['pose_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
pose_reader = VideoReader(pose_path)
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader))
# sample_frames
sample_frames = self.sample_n_frames
clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9
if clip_length > video_length :
batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int)
batch_index = list(np.array([d for d in batch_index if d <= video_length-1],dtype=int))
# ref idx
idx_all = np.arange(0,video_length)
start_idx = 0
occ_idx = np.arange(start_idx, start_idx + clip_length)
ref_idx = [x for x in idx_all if x not in occ_idx]
if len(ref_idx) == 0:
ref_frame_idx = batch_index[0]
else:
np.random.shuffle(ref_idx)
ref_frame_idx = ref_idx[0]
batch_index = [ref_frame_idx] + batch_index
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[0,:] # C,H,W
gt_pose = poses[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
# pad
pad_length = self.sample_n_frames - gt_video.shape[0]
video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype)
gt_video = torch.cat([gt_video, video_pad], dim=0) # F,C,H,W
pose_pad = torch.zeros((pad_length, *gt_pose.shape[1:]), dtype=gt_pose.dtype)
gt_pose = torch.cat([gt_pose, pose_pad], dim=0) # F,C,H,W
audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype)
gt_audio = torch.cat([gt_audio, audio_pad], dim=0)
else:
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int))
# ref index
idx_all = np.arange(0,video_length)
occ_idx = np.arange(start_idx, start_idx + clip_length)
ref_idx = [x for x in idx_all if x not in occ_idx]
if len(ref_idx) == 0:
ref_frame_idx = batch_index[0]
else:
np.random.shuffle(ref_idx)
ref_frame_idx = ref_idx[0]
batch_index = [ref_frame_idx] + batch_index
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[0,:] # C,H,W
gt_pose = poses[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
# meta
meta_ = dict(
video_length = video_length,
video_path = video_path,
audio_path = whisper_path,
)
return dict(
ref_video=ref_video,
gt_video=gt_video,
ref_pose = ref_pose,
gt_pose = gt_pose,
ref_audio=ref_audio,
gt_audio=gt_audio,
mask = mask,
meta=meta_
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
meta = [item["meta"] for item in batch],
ref_video = torch.stack([item['ref_video'] for item in batch]),
gt_video = torch.stack([item['gt_video'] for item in batch]),
ref_pose = torch.stack([item['ref_pose'] for item in batch]),
gt_pose = torch.stack([item['gt_pose'] for item in batch]),
ref_audio = torch.stack([item['ref_audio'] for item in batch]),
gt_audio = torch.stack([item['gt_audio'] for item in batch]),
mask = torch.stack([item['mask'] for item in batch])
)
class A2MVideoAudioPoseMultiSample(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
audio_drop_ratio:float = 0.0,
num_sample:int = 4,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.audio_drop_ratio = audio_drop_ratio
self.num_sample = num_sample
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length =len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
pose_path = meta_data['pose_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
pose_reader = VideoReader(pose_path)
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader))
# sample_frames
sample_frames = self.sample_n_frames + 1 # self.sample_n_frames = 4, sample_frames=5
clip_length = (sample_frames - 1) * self.sample_stride + 1 # clip_length = 9
ref_video_list = []
gt_video_list = []
ref_pose_list = []
gt_pose_list = []
ref_audio_list = []
gt_audio_list = []
mask_list = []
for i in range(self.num_sample):
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[0,:] # C,H,W
gt_video = videos[1:,:] # F,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[0,:] # C,H,W
gt_pose = poses[1:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[0,:] # M,D
gt_audio = audios[1:,:] # F,M,D
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
# drop audio
if torch.rand(1).item() < self.audio_drop_ratio:
ref_audio = torch.zeros_like(ref_audio)
gt_audio = torch.zeros_like(gt_audio)
# cache
ref_video_list.append(ref_video)
gt_video_list.append(gt_video)
ref_pose_list.append(ref_pose)
gt_pose_list.append(gt_pose)
ref_audio_list.append(ref_audio)
gt_audio_list.append(gt_audio)
mask_list.append(mask)
return dict(
ref_video=torch.stack(ref_video_list),
gt_video=torch.stack(gt_video_list),
ref_pose = torch.stack(ref_pose_list),
gt_pose = torch.stack(gt_pose_list),
ref_audio=torch.stack(ref_audio_list),
gt_audio=torch.stack(gt_audio_list),
mask = torch.stack(mask_list),
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0),
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0),
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0),
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0),
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0),
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0),
mask = torch.cat([item['mask'] for item in batch],dim=0)
)
class A2MVideoAudioPoseMultiSampleMultiRefBalance(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
audio_drop_ratio:float = 0.0,
num_sample:int = 4,
max_ref_frame:int = 8,
random_ref_num:bool = False,
**kwargs
):
super().__init__()
with open(video_dir, 'r') as file:
lines = file.readlines()
video_dirs = [line.strip() for line in lines]
assert len(video_dirs) == 2, 'Only support 2 video dirs'
self.dataset1 = A2MVideoAudioPoseMultiSampleMultiRef(video_dir=video_dirs[0],
sample_size=sample_size,
sample_stride=sample_stride,
sample_n_frames=sample_n_frames,
audio_drop_ratio=audio_drop_ratio,
num_sample=num_sample,
max_ref_frame=max_ref_frame,
random_ref_num=random_ref_num)
self.dataset2 = A2MVideoAudioPoseMultiSampleMultiRef(video_dir=video_dirs[1],
sample_size=sample_size,
sample_stride=sample_stride,
sample_n_frames=sample_n_frames,
audio_drop_ratio=audio_drop_ratio,
num_sample=num_sample,
max_ref_frame=max_ref_frame,
random_ref_num=random_ref_num)
self.length = 2*max(len(self.dataset1),len(self.dataset2))
def __getitem__(self, idx):
while True:
try:
if idx % 2 == 0:
a_idx = torch.randint(0, len(self.dataset1), (1,)).item()
sample = self.dataset1[a_idx]
else:
b_idx = torch.randint(0, len(self.dataset2), (1,)).item()
sample = self.dataset2[b_idx]
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
@staticmethod
def collate_fn(batch):
return dict(
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0),
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0),
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0),
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0),
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0),
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0),
mask = torch.cat([item['mask'] for item in batch],dim=0)
)
class A2MVideoAudioMultiRefDoubleRef(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
audio_drop_ratio:float = 0.0,
num_sample:int = 4,
max_ref_frame:int = 8,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.audio_drop_ratio = audio_drop_ratio
self.num_sample = num_sample
self.max_ref_frame = max_ref_frame
self.randomref_num = 8
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
video_length = min(len(video_reader),audio_feature.shape[0])
# ref num
r = torch.rand(1).item()
if r < 0.33:
ref_num = 0
if r < 0.66:
ref_num = 1
else:
ref_num = self.max_ref_frame
sample_frames = self.sample_n_frames + ref_num
clip_length = (sample_frames - 1) * self.sample_stride + 1
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int))
# randomref
random_index = list(np.linspace(0, video_length - 1, self.randomref_num, dtype=int))
# random_index = torch.randint(low=0, high=clip_length-1, size=(1,)).item()
video_batch_index = random_index + batch_index
# frames
videos = torch.from_numpy(video_reader.get_batch(video_batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
randomref_video = videos[:self.randomref_num,:] # T,C,H,W
l_videos = videos[self.randomref_num:,:] # T,C,H,W
ref_video = l_videos[:ref_num,:] if ref_num > 0 else None
gt_video = l_videos[ref_num:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[:ref_num,:] if ref_num > 0 else None
gt_audio = audios[ref_num:,:] # F,M,D
# padding ref frame
if ref_num == 1:
ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:]))
ref_video = torch.cat([ref_video_pad,ref_video],dim=0) # N,T,C,H,W
ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:]))
ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0)
elif ref_num == 0:
ref_video = torch.zeros((self.max_ref_frame,*gt_video.shape[1:]))
ref_audio = torch.zeros((self.max_ref_frame,*gt_audio.shape[1:]))
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
return dict(
ref_video=ref_video,
gt_video=gt_video,
randomref_video = randomref_video,
ref_audio= ref_audio,
gt_audio=gt_audio,
mask = mask,
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
ref_video = torch.stack([item['ref_video'] for item in batch],dim=0),
gt_video = torch.stack([item['gt_video'] for item in batch],dim=0),
randomref_video = torch.stack([item['randomref_video'] for item in batch],dim=0),
ref_audio = torch.stack([item['ref_audio'] for item in batch],dim=0),
gt_audio = torch.stack([item['gt_audio'] for item in batch],dim=0),
mask = torch.stack([item['mask'] for item in batch],dim=0)
)
class A2MVideoAudioMultiRefDoubleRefBalance(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
audio_drop_ratio:float = 0.0,
num_sample:int = 4,
max_ref_frame:int = 8,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.audio_drop_ratio = audio_drop_ratio
self.num_sample = num_sample
self.max_ref_frame = max_ref_frame
self.randomref_num = 8
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
video_length = min(len(video_reader),audio_feature.shape[0])
# ref num
r = torch.rand(1).item()
if r < 0.33:
ref_num = 0
elif r<0.66:
ref_num = 1
else:
ref_num = self.max_ref_frame
sample_frames = self.sample_n_frames + ref_num
clip_length = (sample_frames - 1) * self.sample_stride + 1
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int))
# randomref
random_index = list(np.linspace(0, video_length - 1, self.randomref_num, dtype=int))
random_index = list(random_index)
video_batch_index = random_index + batch_index
# frames
videos = torch.from_numpy(video_reader.get_batch(video_batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
randomref_video = videos[:self.randomref_num,:] # T,C,H,W
l_videos = videos[self.randomref_num:,:] # T,C,H,W
ref_video = l_videos[:ref_num,:] if ref_num > 0 else None
gt_video = l_videos[ref_num:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[:ref_num,:] if ref_num > 0 else None
gt_audio = audios[ref_num:,:] # F,M,D
# padding ref frame
if ref_num == 1:
ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:]))
ref_video = torch.cat([ref_video_pad,ref_video],dim=0) # N,T,C,H,W
ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:]))
ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0)
elif ref_num == 0:
ref_video = torch.zeros((self.max_ref_frame,*gt_video.shape[1:]))
ref_audio = torch.zeros((self.max_ref_frame,*gt_audio.shape[1:]))
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
return dict(
ref_video=ref_video,
gt_video=gt_video,
randomref_video = randomref_video,
ref_audio= ref_audio,
gt_audio=gt_audio,
mask = mask,
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
ref_video = torch.stack([item['ref_video'] for item in batch],dim=0),
gt_video = torch.stack([item['gt_video'] for item in batch],dim=0),
randomref_video = torch.stack([item['randomref_video'] for item in batch],dim=0),
ref_audio = torch.stack([item['ref_audio'] for item in batch],dim=0),
gt_audio = torch.stack([item['gt_audio'] for item in batch],dim=0),
mask = torch.stack([item['mask'] for item in batch],dim=0)
)
# pose img2img
class A2MVideoAudioPoseRandomRefMultiSample(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
num_sample:int = 4,
max_ref_frame:int = 8,
random_ref_num:bool = False,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.num_sample = num_sample
self.max_ref_frame = max_ref_frame
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
pose_path = meta_data['pose_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
pose_reader = VideoReader(pose_path)
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader))
# batch index
sample_frames = self.num_sample * 2
if video_length < sample_frames:
raise ValueError(f"视频长度{video_length}太短了,需要长度{sample_frames}")
# occ_idx = np.arange(0, video_length)
# # # np.random.shuffle(occ_idx)
# # # 生成随机排列的索引
# # shuffled_indices = torch.randperm(len(occ_idx))
# # # 使用索引打乱列表
# # occ_idx = [occ_idx[i] for i in shuffled_indices]
# occ_idx = shuffle_list(occ_idx)
# batch_index = occ_idx[:sample_frames]
# batch_index = [torch.randint(low=0, high=video_length, size=(1,)).item() for i in range(sample_frames)]
start_idx = torch.randint(low=0, high=video_length-sample_frames, size=(1,)).item()
occ_idx = np.arange(start_idx, start_idx+video_length)
batch_index = occ_idx[:sample_frames]
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[:self.num_sample,:].unsqueeze(1) # N,1,C,H,W
gt_video = videos[self.num_sample:,:].unsqueeze(1) # N,1,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[:self.num_sample,:].unsqueeze(1) # N,1,C,H,W
gt_pose = poses[self.num_sample:,:].unsqueeze(1) # N,1,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[:self.num_sample,:].unsqueeze(1) # N,1,C,H,W
gt_audio = audios[self.num_sample:,:].unsqueeze(1) # N,1,C,H,W
# available length
cur_available_length = gt_video.shape[0]
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
return dict(
ref_video=ref_video,
gt_video=gt_video,
ref_pose = ref_pose,
gt_pose =gt_pose,
ref_audio=ref_audio,
gt_audio=gt_audio,
mask = mask,
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0),
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0),
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0),
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0),
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0),
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0),
mask = torch.cat([item['mask'] for item in batch],dim=0)
)
class A2MVideoAudioPoseMultiSampleMultiRef(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 16,
audio_drop_ratio:float = 0.0,
num_sample:int = 4,
max_ref_frame:int = 8,
random_ref_num:bool = False,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.audio_drop_ratio = audio_drop_ratio
self.num_sample = num_sample
self.max_ref_frame = max_ref_frame
self.random_ref_num = random_ref_num
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
"""
videos : 31,3,256,256
ref_img : 3,256,256
audio_feature : 30,50,384
ref_pose : 3,256,256
meta
"""
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
pose_path = meta_data['pose_path']
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
pose_reader = VideoReader(pose_path)
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader))
# sample_frames
ref_video_list = []
gt_video_list = []
ref_pose_list = []
gt_pose_list = []
ref_audio_list = []
gt_audio_list = []
mask_list = []
for i in range(self.num_sample):
# random ref num
if self.random_ref_num:
ref_num = torch.randint(low=1, high=self.max_ref_frame+1, size=(1,)).item()
else:
ref_num = [1, self.max_ref_frame][torch.randint(2, (1,)).item()]
sample_frames = self.sample_n_frames + ref_num
clip_length = (sample_frames - 1) * self.sample_stride + 1
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_video = videos[:ref_num,:] # T,C,H,W
gt_video = videos[ref_num:,:] # F,C,H,W
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
poses = poses / 255.0
poses = self.pixel_transforms(poses)
ref_pose = poses[:ref_num,:] # T,C,H,W
gt_pose = poses[ref_num:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[:ref_num,:] # T,M,D
gt_audio = audios[ref_num:,:] # F,M,D
# padding ref frame
if ref_num < self.max_ref_frame:
ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:]))
ref_video = torch.cat([ref_video_pad,ref_video],dim=0) # N,T,C,H,W
ref_pose_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_pose.shape[1:]))
ref_pose = torch.cat([ref_pose_pad,ref_pose],dim=0)
ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:]))
ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0)
# available length
cur_available_length = gt_video.shape[0]
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames)
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames)
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames)
# mask
mask = torch.zeros(self.sample_n_frames)
mask[:cur_available_length] = 1
# drop audio
if torch.rand(1).item() < self.audio_drop_ratio:
ref_audio = torch.zeros_like(ref_audio)
gt_audio = torch.zeros_like(gt_audio)
# cache
ref_video_list.append(ref_video)
gt_video_list.append(gt_video)
ref_pose_list.append(ref_pose)
gt_pose_list.append(gt_pose)
ref_audio_list.append(ref_audio)
gt_audio_list.append(gt_audio)
mask_list.append(mask)
return dict(
ref_video=torch.stack(ref_video_list),
gt_video=torch.stack(gt_video_list),
ref_pose = torch.stack(ref_pose_list),
gt_pose = torch.stack(gt_pose_list),
ref_audio=torch.stack(ref_audio_list),
gt_audio=torch.stack(gt_audio_list),
mask = torch.stack(mask_list),
)
def get_file_name(self, file_path):
return file_path.split('/')[-1].split('.')[0]
@staticmethod
def collate_fn(batch):
return dict(
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0),
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0),
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0),
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0),
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0),
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0),
mask = torch.cat([item['mask'] for item in batch],dim=0)
)
# inference
class A2VDataset(Dataset):
def __init__(
self,
video_dir:str,
sample_size: int = 256,
sample_stride: int = 1,
sample_n_frames:int = 120,
**kwargs
):
super().__init__()
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
# Transform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) # (256,256)
self.pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
with open(video_dir, 'rb') as f:
self.metadata_list = pickle.load(f)
self.length = len(self.metadata_list)
print(f'Total {self.length} files is available')
def __getitem__(self, idx):
while True:
try:
sample = self.get_batch(idx)
break
except Exception as e:
print('error',e)
idx = torch.randint(low=0, high=self.length, size=(1,)).item()
return sample
def __len__(self):
return self.length
def get_batch(self, idx):
# init
meta_data = self.metadata_list[idx]
video_path = meta_data['video_path']
whisper_path = meta_data['whisper_emb_path']
audio_path = meta_data['audio_path']
# pose_path = meta_data['pose_path']
name = os.path.basename(video_path).split('.')[0]
fps = 25 if 'hdtf' in video_path else 30
# audio
audio_feature = torch.load(whisper_path)
# load & check
video_reader = VideoReader(video_path)
# pose_reader = VideoReader(pose_path)
video_length = min(len(video_reader),audio_feature.shape[0])
# sample_frame
# batch idx
ref_num = 1
sample_frames = self.sample_n_frames + ref_num
clip_length = (sample_frames - 1) * self.sample_stride + 1
start_idx = np.random.randint(0, video_length - clip_length + 1)
end_idx = start_idx + clip_length
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)
# frames
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
videos = videos / 255.0
videos = self.pixel_transforms(videos)
ref_img = videos[:ref_num,:] # 1,C,H,W
inf_video = videos[ref_num:,:] # F,C,H,W
# poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() #(N,H,W,C)->(N,C,H,W)
# poses = poses / 255.0
# poses = self.pixel_transforms(poses)
# ref_pose = poses[:ref_num,:] # T,C,H,W
# gt_pose = poses[ref_num:,:] # F,C,H,W
audios = audio_feature[batch_index]
ref_audio = audios[:ref_num,:] # 1,M,D
inf_audio = audios[ref_num:,:] # F,M,D
# meta_info
start_time = start_idx / fps
_meta = {"name":name,
"audio_path":audio_path,
"fps":fps,
"start_time":start_time}
return dict(
meta_info = _meta,
ref_img=ref_img,
gt_video=videos,
ref_audio=ref_audio,
inf_audio=inf_audio,
ref_pose=None,
inf_pose=None,
)
@staticmethod
def collate_fn(batch):
return dict(
meta_info = [item['meta_info'] for item in batch],
ref_img= torch.stack([item["ref_img"] for item in batch]),
ref_audio= torch.stack([item["ref_audio"] for item in batch]),
inf_audio= torch.stack([item["inf_audio"] for item in batch]),
ref_pose= torch.ones((2,2)),
inf_pose= torch.ones((2,2)),
gt_video= torch.stack([item["gt_video"] for item in batch]),
)
def generate_non_equal_random_lists(frame_num,sample_num):
list1 = [np.random.randint(0, frame_num) for _ in range(sample_num)]
list2 = []
for i in range(len(list1)):
available_numbers = list(range(0, list1[i])) + list(range(list1[i] + 1, frame_num))
list2.append(random.choice(available_numbers))
return list1, list2
def shuffle_list(l):
shuffled_indices = torch.randperm(len(l))
# 使用索引打乱列表
l = [l[i] for i in shuffled_indices]
return l
if __name__ == "__main__":
# dataset = CelebvText()
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
# for idx, batch in enumerate(dataloader):
# print(batch["videos"].shape, len(batch["text"]))
# for i in range(batch["videos"].shape[0]):
# save_videos_grid(batch["videos"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
from torch.utils.data import DataLoader
# dataset = AMDVideoAudioFeature(
# path=data_path,
# path_type="file",
# motion_seq_len=motion_seq_len,
# sample_n_frames=num_frames,
# audio_processor=audio_processor
# )
dataset = AMDVideoAudioFeature(
video_dir = '/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/dataset/path/lhz/train.pkl',
path_type = 'file'
)
dataloader = DataLoader(
dataset,2,True,num_workers=0,
collate_fn=dataset.collate_fn
)
# dataloader = DataLoader(dataset,batch_size=2,collate_fn=dataset.collate_fn,num_workers=2)
# d = dataset[10]
# video = d["videos"]
# audio = d["audio_feature"]
# refimg = d["ref_img"]
for d in dataloader:
video = d["videos"]
audio_feature = d["audio_feature"]
refimg = d["ref_img"]
# break
print(video.shape)
print(audio_feature.shape)
#