|
|
import os |
|
|
import random |
|
|
import numpy as np |
|
|
from decord import VideoReader |
|
|
import glob |
|
|
from tqdm import tqdm |
|
|
import pickle |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from torch.utils.data.dataset import Dataset |
|
|
from decord import cpu, gpu |
|
|
from torchvision.io import read_video, write_video |
|
|
|
|
|
import json |
|
|
import traceback |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AMDConsecutiveVideo(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir: str = '', |
|
|
sample_size: int = 32, |
|
|
sample_stride: int = 2, |
|
|
sample_n_frames:int = 16, |
|
|
ref_drop_ratio = 0.0, |
|
|
): |
|
|
|
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.ref_drop_ratio = ref_drop_ratio |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(sample_size[0]), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
if 'pkl' in video_dir: |
|
|
with open(video_dir, 'rb') as f: |
|
|
video_files = pickle.load(f) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
elif '.txt' in video_dir: |
|
|
with open(video_dir, 'r') as file: |
|
|
lines = file.readlines() |
|
|
video_dirs = [line.strip() for line in lines] |
|
|
|
|
|
video_files = [] |
|
|
for dir in video_dirs: |
|
|
video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
|
|
|
else: |
|
|
video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
|
|
|
|
|
|
self.metadata_list = [] |
|
|
for file_path in tqdm(video_files): |
|
|
d = {} |
|
|
d['name'] = self.get_file_name(file_path) |
|
|
d['video_path'] = file_path |
|
|
self.metadata_list.append(d) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
file_name = self.metadata_list[idx]['name'] |
|
|
file_name,videos,ref_img = self.get_batch(idx) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
print('error',e) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
sample = dict(name=file_name,videos=videos,ref_img = ref_img) |
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
file_name = meta_data['name'] |
|
|
video_path = meta_data['video_path'] |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path, ctx=cpu(0)) |
|
|
video_length = len(video_reader) |
|
|
|
|
|
sample_frames = self.sample_n_frames + 1 |
|
|
clip_length = min(video_length, (sample_frames - 1) * self.sample_stride + 1) |
|
|
start_idx = random.randint(0, video_length - clip_length) |
|
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int) |
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
|
|
|
|
|
|
|
|
|
videos_cache = self.pixel_transforms(videos) |
|
|
videos = videos_cache[1:,:,:,:] |
|
|
ref_frame = videos_cache[0,:,:,:] |
|
|
|
|
|
|
|
|
ref_frame = ref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) |
|
|
|
|
|
|
|
|
return file_name,videos,ref_frame |
|
|
|
|
|
|
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
|
|
|
name = [item['name'] for item in batch] |
|
|
|
|
|
|
|
|
videos = [item['videos'] for item in batch] |
|
|
videos = torch.stack(videos) |
|
|
|
|
|
|
|
|
ref_img = [item['ref_img'] for item in batch] |
|
|
ref_img = torch.stack(ref_img) |
|
|
|
|
|
randomref_img = None |
|
|
|
|
|
|
|
|
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) |
|
|
|
|
|
class AMDConsecutiveVideoBalance(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir: str = '', |
|
|
sample_size: int = 32, |
|
|
sample_stride: int = 2, |
|
|
sample_n_frames:int = 16, |
|
|
ref_drop_ratio = 0.0, |
|
|
): |
|
|
|
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.ref_drop_ratio = ref_drop_ratio |
|
|
|
|
|
assert '.txt' in video_dir |
|
|
|
|
|
with open(video_dir, 'r') as file: |
|
|
lines = file.readlines() |
|
|
video_paths = [line.strip() for line in lines] |
|
|
|
|
|
assert len(video_paths) == 2 |
|
|
|
|
|
self.dataset1 = AMDConsecutiveVideo(video_dir = video_paths[0], |
|
|
sample_size = sample_size, |
|
|
sample_stride = sample_stride, |
|
|
sample_n_frames = sample_n_frames, |
|
|
ref_drop_ratio = 0.0,) |
|
|
|
|
|
self.dataset2 = AMDConsecutiveVideo(video_dir = video_paths[1], |
|
|
sample_size = sample_size, |
|
|
sample_stride = sample_stride, |
|
|
sample_n_frames = sample_n_frames, |
|
|
ref_drop_ratio = 0.0,) |
|
|
|
|
|
self.len1 = len(self.dataset1) |
|
|
self.len2 = len(self.dataset2) |
|
|
print(f'Total {self.len1 + self.len2} !!!') |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return 2 * max(self.len1,self.len2) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
if torch.rand(1).item() < 0.5: |
|
|
|
|
|
a_idx = torch.randint(0, len(self.dataset1), (1,)).item() |
|
|
return self.dataset1[a_idx] |
|
|
else: |
|
|
|
|
|
b_idx = torch.randint(0, len(self.dataset2), (1,)).item() |
|
|
return self.dataset2[b_idx] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
|
|
|
name = [item['name'] for item in batch] |
|
|
|
|
|
|
|
|
videos = [item['videos'] for item in batch] |
|
|
videos = torch.stack(videos) |
|
|
|
|
|
|
|
|
ref_img = [item['ref_img'] for item in batch] |
|
|
ref_img = torch.stack(ref_img) |
|
|
|
|
|
randomref_img = None |
|
|
|
|
|
|
|
|
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) |
|
|
|
|
|
class AMDConsecutiveVideoDoubleRef(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir: str = '', |
|
|
sample_size: int = 32, |
|
|
sample_stride: int = 2, |
|
|
sample_n_frames:int = 16, |
|
|
ref_drop_ratio = 0.0, |
|
|
): |
|
|
|
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.ref_drop_ratio = ref_drop_ratio |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(sample_size[0]), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
if 'pkl' in video_dir: |
|
|
with open(video_dir, 'rb') as f: |
|
|
video_files = pickle.load(f) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
elif '.txt' in video_dir: |
|
|
with open(video_dir, 'r') as file: |
|
|
lines = file.readlines() |
|
|
video_dirs = [line.strip() for line in lines] |
|
|
|
|
|
video_files = [] |
|
|
for dir in video_dirs: |
|
|
video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
|
|
|
else: |
|
|
video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
|
|
|
|
|
|
self.metadata_list = [] |
|
|
for file_path in tqdm(video_files): |
|
|
d = {} |
|
|
d['name'] = self.get_file_name(file_path) |
|
|
d['video_path'] = file_path |
|
|
self.metadata_list.append(d) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
file_name = self.metadata_list[idx]['name'] |
|
|
file_name,videos,ref_img,randomref_img = self.get_batch(idx) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
print('error',e) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
sample = dict(name=file_name,videos=videos,ref_img = ref_img,randomref_img=randomref_img) |
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
file_name = meta_data['name'] |
|
|
video_path = meta_data['video_path'] |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path, ctx=cpu(0)) |
|
|
video_length = len(video_reader) |
|
|
|
|
|
sample_frames = self.sample_n_frames + 1 |
|
|
clip_length = min(video_length, (sample_frames - 1) * self.sample_stride + 1) |
|
|
start_idx = random.randint(0, video_length - clip_length) |
|
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_frames, dtype=int) |
|
|
|
|
|
|
|
|
idx_all = np.arange(0,video_length) |
|
|
occ_idx = np.arange(start_idx, start_idx + clip_length) |
|
|
randomref_idx = [x for x in idx_all if x not in occ_idx] |
|
|
if len(randomref_idx) == 0: |
|
|
ref_frame_idx = batch_index[0] |
|
|
else: |
|
|
i = torch.randint(low=0, high=len(randomref_idx), size=(1,)).item() |
|
|
ref_frame_idx = randomref_idx[i] |
|
|
|
|
|
batch_index = [ref_frame_idx] + list(batch_index) |
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
videos_cache = self.pixel_transforms(videos) |
|
|
videos = videos_cache[2:,:,:,:] |
|
|
ref_frame = videos_cache[1,:,:,:] |
|
|
randomref_frame = videos_cache[0,:,:,:] |
|
|
|
|
|
|
|
|
ref_frame = ref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) |
|
|
randomref_frame = randomref_frame.unsqueeze(0).repeat(videos.shape[0],1,1,1) |
|
|
|
|
|
|
|
|
return file_name,videos,ref_frame,randomref_frame |
|
|
|
|
|
|
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
|
|
|
name = [item['name'] for item in batch] |
|
|
|
|
|
|
|
|
videos = [item['videos'] for item in batch] |
|
|
videos = torch.stack(videos) |
|
|
|
|
|
|
|
|
ref_img = [item['ref_img'] for item in batch] |
|
|
ref_img = torch.stack(ref_img) |
|
|
|
|
|
randomref_img = [item['randomref_img'] for item in batch] |
|
|
randomref_img = torch.stack(randomref_img) |
|
|
|
|
|
|
|
|
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) |
|
|
|
|
|
class AMDConsecutiveVideoDoubleRefBalance(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir: str = '', |
|
|
sample_size: int = 32, |
|
|
sample_stride: int = 2, |
|
|
sample_n_frames:int = 16, |
|
|
ref_drop_ratio = 0.0, |
|
|
): |
|
|
|
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.ref_drop_ratio = ref_drop_ratio |
|
|
|
|
|
assert '.txt' in video_dir |
|
|
|
|
|
with open(video_dir, 'r') as file: |
|
|
lines = file.readlines() |
|
|
video_paths = [line.strip() for line in lines] |
|
|
|
|
|
self.datasets = [] |
|
|
for vp in video_paths: |
|
|
self.datasets.append(AMDConsecutiveVideoDoubleRef(video_dir = vp, |
|
|
sample_size = sample_size, |
|
|
sample_stride = sample_stride, |
|
|
sample_n_frames = sample_n_frames, |
|
|
ref_drop_ratio = 0.0,)) |
|
|
|
|
|
self.length = len(self.datasets) * max([len(d) for d in self.datasets]) |
|
|
|
|
|
print(self.length) |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
dataset_num = len(self.datasets) |
|
|
cur_idx = idx % dataset_num |
|
|
cur_dataset = self.datasets[cur_idx] |
|
|
idx = torch.randint(0, len(cur_dataset), (1,)).item() |
|
|
return cur_dataset[idx] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
|
|
|
name = [item['name'] for item in batch] |
|
|
|
|
|
|
|
|
videos = [item['videos'] for item in batch] |
|
|
videos = torch.stack(videos) |
|
|
|
|
|
|
|
|
ref_img = [item['ref_img'] for item in batch] |
|
|
ref_img = torch.stack(ref_img) |
|
|
|
|
|
randomref_img = [item['randomref_img'] for item in batch] |
|
|
randomref_img = torch.stack(randomref_img) |
|
|
|
|
|
|
|
|
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=randomref_img) |
|
|
|
|
|
|
|
|
|
|
|
class AMDRandomPair(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir: str = '', |
|
|
sample_size: int = 32, |
|
|
sample_stride: int = 4, |
|
|
sample_n_frames:int = 16, |
|
|
ref_drop_ratio = 0.0, |
|
|
): |
|
|
|
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.ref_drop_ratio = ref_drop_ratio |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
if 'pkl' in video_dir: |
|
|
with open(video_dir, 'rb') as f: |
|
|
video_files = pickle.load(f) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
elif '.txt' in video_dir: |
|
|
with open(video_dir, 'r') as file: |
|
|
lines = file.readlines() |
|
|
video_dirs = [line.strip() for line in lines] |
|
|
|
|
|
video_files = [] |
|
|
for dir in video_dirs: |
|
|
video_files += glob.glob(os.path.join(dir, '**', '*.mp4'), recursive=True) |
|
|
print(f'Total {len(video_files)} !!!') |
|
|
|
|
|
else: |
|
|
video_files = glob.glob(os.path.join(video_dir, '**', '*.mp4'), recursive=True) |
|
|
|
|
|
|
|
|
self.metadata_list = [] |
|
|
for file_path in tqdm(video_files): |
|
|
d = {} |
|
|
d['name'] = self.get_file_name(file_path) |
|
|
d['video_path'] = file_path |
|
|
self.metadata_list.append(d) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
file_name = self.metadata_list[idx]['name'] |
|
|
file_name,videos,ref_img = self.get_batch(idx) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
print('error',e) |
|
|
idx = random.randint(0, self.length-1) |
|
|
|
|
|
sample = dict(name=file_name,videos=videos,ref_img = ref_img) |
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
file_name = meta_data['name'] |
|
|
video_path = meta_data['video_path'] |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
video_length = len(video_reader) |
|
|
|
|
|
|
|
|
ref_idx,video_idx = generate_non_equal_random_lists(frame_num=video_length,sample_num=self.sample_n_frames) |
|
|
|
|
|
ref_videos = torch.from_numpy(video_reader.get_batch(ref_idx).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
ref_videos = ref_videos / 255.0 |
|
|
ref_videos = self.pixel_transforms(ref_videos) |
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(video_idx).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
|
|
|
return file_name,videos,ref_videos |
|
|
|
|
|
|
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
|
|
|
name = [item['name'] for item in batch] |
|
|
|
|
|
|
|
|
videos = [item['videos'] for item in batch] |
|
|
videos = torch.stack(videos) |
|
|
|
|
|
|
|
|
ref_img = [item['ref_img'] for item in batch] |
|
|
ref_img = torch.stack(ref_img) |
|
|
|
|
|
return dict(name=name, videos=videos, ref_img=ref_img,randomref_img=None) |
|
|
|
|
|
class A2MVideoAudio(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0]) |
|
|
|
|
|
|
|
|
sample_frames = self.sample_n_frames + 1 |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
if clip_length > video_length : |
|
|
batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int) |
|
|
batch_index = np.array([d for d in batch_index if d <= video_length-1],dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
|
|
|
pad_length = self.sample_n_frames - gt_video.shape[0] |
|
|
video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype) |
|
|
gt_video = torch.cat([gt_video, video_pad], dim=0) |
|
|
|
|
|
audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype) |
|
|
gt_audio = torch.cat([gt_audio, audio_pad], dim=0) |
|
|
else: |
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
meta_ = dict( |
|
|
video_length = video_length, |
|
|
video_path = video_path, |
|
|
audio_path = whisper_path, |
|
|
) |
|
|
|
|
|
return dict( |
|
|
ref_video=ref_video, |
|
|
gt_video=gt_video, |
|
|
ref_audio=ref_audio, |
|
|
gt_audio=gt_audio, |
|
|
mask = mask, |
|
|
meta=meta_ |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
meta = [item["meta"] for item in batch], |
|
|
ref_video = torch.stack([item['ref_video'] for item in batch]), |
|
|
gt_video = torch.stack([item['gt_video'] for item in batch]), |
|
|
ref_audio = torch.stack([item['ref_audio'] for item in batch]), |
|
|
gt_audio = torch.stack([item['gt_audio'] for item in batch]), |
|
|
mask = torch.stack([item['mask'] for item in batch]) |
|
|
) |
|
|
|
|
|
class A2MVideoAudioPose(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
audio_drop_ratio:float = 0.0, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.audio_drop_ratio = audio_drop_ratio |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
pose_path = meta_data['pose_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
pose_reader = VideoReader(pose_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) |
|
|
|
|
|
|
|
|
sample_frames = self.sample_n_frames + 1 |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
if clip_length > video_length : |
|
|
batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int) |
|
|
batch_index = np.array([d for d in batch_index if d <= video_length-1],dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[0,:] |
|
|
gt_pose = poses[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
|
|
|
pad_length = self.sample_n_frames - gt_video.shape[0] |
|
|
|
|
|
video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype) |
|
|
gt_video = torch.cat([gt_video, video_pad], dim=0) |
|
|
|
|
|
pose_pad = torch.zeros((pad_length, *gt_pose.shape[1:]), dtype=gt_pose.dtype) |
|
|
gt_pose = torch.cat([gt_pose, pose_pad], dim=0) |
|
|
|
|
|
audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype) |
|
|
gt_audio = torch.cat([gt_audio, audio_pad], dim=0) |
|
|
else: |
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[0,:] |
|
|
gt_pose = poses[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < self.audio_drop_ratio: |
|
|
ref_audio = torch.zeros_like(ref_audio) |
|
|
gt_audio = torch.zeros_like(gt_audio) |
|
|
|
|
|
|
|
|
meta_ = dict( |
|
|
video_length = video_length, |
|
|
video_path = video_path, |
|
|
audio_path = whisper_path, |
|
|
) |
|
|
|
|
|
return dict( |
|
|
ref_video=ref_video, |
|
|
gt_video=gt_video, |
|
|
ref_pose = ref_pose, |
|
|
gt_pose = gt_pose, |
|
|
ref_audio=ref_audio, |
|
|
gt_audio=gt_audio, |
|
|
mask = mask, |
|
|
meta=meta_ |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
meta = [item["meta"] for item in batch], |
|
|
ref_video = torch.stack([item['ref_video'] for item in batch]), |
|
|
gt_video = torch.stack([item['gt_video'] for item in batch]), |
|
|
ref_pose = torch.stack([item['ref_pose'] for item in batch]), |
|
|
gt_pose = torch.stack([item['gt_pose'] for item in batch]), |
|
|
ref_audio = torch.stack([item['ref_audio'] for item in batch]), |
|
|
gt_audio = torch.stack([item['gt_audio'] for item in batch]), |
|
|
mask = torch.stack([item['mask'] for item in batch]) |
|
|
) |
|
|
|
|
|
class A2MVideoAudioPoseRandomRef(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
pose_path = meta_data['pose_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
pose_reader = VideoReader(pose_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) |
|
|
|
|
|
|
|
|
sample_frames = self.sample_n_frames |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
if clip_length > video_length : |
|
|
batch_index = np.linspace(0, clip_length - 1, sample_frames, dtype=int) |
|
|
batch_index = list(np.array([d for d in batch_index if d <= video_length-1],dtype=int)) |
|
|
|
|
|
|
|
|
idx_all = np.arange(0,video_length) |
|
|
start_idx = 0 |
|
|
occ_idx = np.arange(start_idx, start_idx + clip_length) |
|
|
ref_idx = [x for x in idx_all if x not in occ_idx] |
|
|
if len(ref_idx) == 0: |
|
|
ref_frame_idx = batch_index[0] |
|
|
else: |
|
|
np.random.shuffle(ref_idx) |
|
|
ref_frame_idx = ref_idx[0] |
|
|
batch_index = [ref_frame_idx] + batch_index |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[0,:] |
|
|
gt_pose = poses[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
|
|
|
pad_length = self.sample_n_frames - gt_video.shape[0] |
|
|
|
|
|
video_pad = torch.zeros((pad_length, *gt_video.shape[1:]), dtype=gt_video.dtype) |
|
|
gt_video = torch.cat([gt_video, video_pad], dim=0) |
|
|
|
|
|
pose_pad = torch.zeros((pad_length, *gt_pose.shape[1:]), dtype=gt_pose.dtype) |
|
|
gt_pose = torch.cat([gt_pose, pose_pad], dim=0) |
|
|
|
|
|
audio_pad = torch.zeros((pad_length, *gt_audio.shape[1:]), dtype=gt_audio.dtype) |
|
|
gt_audio = torch.cat([gt_audio, audio_pad], dim=0) |
|
|
else: |
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)) |
|
|
|
|
|
|
|
|
idx_all = np.arange(0,video_length) |
|
|
occ_idx = np.arange(start_idx, start_idx + clip_length) |
|
|
ref_idx = [x for x in idx_all if x not in occ_idx] |
|
|
if len(ref_idx) == 0: |
|
|
ref_frame_idx = batch_index[0] |
|
|
else: |
|
|
np.random.shuffle(ref_idx) |
|
|
ref_frame_idx = ref_idx[0] |
|
|
|
|
|
batch_index = [ref_frame_idx] + batch_index |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[0,:] |
|
|
gt_pose = poses[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
meta_ = dict( |
|
|
video_length = video_length, |
|
|
video_path = video_path, |
|
|
audio_path = whisper_path, |
|
|
) |
|
|
|
|
|
return dict( |
|
|
ref_video=ref_video, |
|
|
gt_video=gt_video, |
|
|
ref_pose = ref_pose, |
|
|
gt_pose = gt_pose, |
|
|
ref_audio=ref_audio, |
|
|
gt_audio=gt_audio, |
|
|
mask = mask, |
|
|
meta=meta_ |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
meta = [item["meta"] for item in batch], |
|
|
ref_video = torch.stack([item['ref_video'] for item in batch]), |
|
|
gt_video = torch.stack([item['gt_video'] for item in batch]), |
|
|
ref_pose = torch.stack([item['ref_pose'] for item in batch]), |
|
|
gt_pose = torch.stack([item['gt_pose'] for item in batch]), |
|
|
ref_audio = torch.stack([item['ref_audio'] for item in batch]), |
|
|
gt_audio = torch.stack([item['gt_audio'] for item in batch]), |
|
|
mask = torch.stack([item['mask'] for item in batch]) |
|
|
) |
|
|
|
|
|
class A2MVideoAudioPoseMultiSample(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
audio_drop_ratio:float = 0.0, |
|
|
num_sample:int = 4, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.audio_drop_ratio = audio_drop_ratio |
|
|
self.num_sample = num_sample |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
|
|
|
self.length =len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
|
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
pose_path = meta_data['pose_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
pose_reader = VideoReader(pose_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) |
|
|
|
|
|
|
|
|
sample_frames = self.sample_n_frames + 1 |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
ref_video_list = [] |
|
|
gt_video_list = [] |
|
|
ref_pose_list = [] |
|
|
gt_pose_list = [] |
|
|
ref_audio_list = [] |
|
|
gt_audio_list = [] |
|
|
mask_list = [] |
|
|
|
|
|
for i in range(self.num_sample): |
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[0,:] |
|
|
gt_video = videos[1:,:] |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[0,:] |
|
|
gt_pose = poses[1:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[0,:] |
|
|
gt_audio = audios[1:,:] |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < self.audio_drop_ratio: |
|
|
ref_audio = torch.zeros_like(ref_audio) |
|
|
gt_audio = torch.zeros_like(gt_audio) |
|
|
|
|
|
|
|
|
ref_video_list.append(ref_video) |
|
|
gt_video_list.append(gt_video) |
|
|
ref_pose_list.append(ref_pose) |
|
|
gt_pose_list.append(gt_pose) |
|
|
ref_audio_list.append(ref_audio) |
|
|
gt_audio_list.append(gt_audio) |
|
|
mask_list.append(mask) |
|
|
|
|
|
|
|
|
return dict( |
|
|
ref_video=torch.stack(ref_video_list), |
|
|
gt_video=torch.stack(gt_video_list), |
|
|
ref_pose = torch.stack(ref_pose_list), |
|
|
gt_pose = torch.stack(gt_pose_list), |
|
|
ref_audio=torch.stack(ref_audio_list), |
|
|
gt_audio=torch.stack(gt_audio_list), |
|
|
mask = torch.stack(mask_list), |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), |
|
|
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), |
|
|
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), |
|
|
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), |
|
|
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), |
|
|
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), |
|
|
mask = torch.cat([item['mask'] for item in batch],dim=0) |
|
|
) |
|
|
|
|
|
|
|
|
class A2MVideoAudioPoseMultiSampleMultiRefBalance(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
audio_drop_ratio:float = 0.0, |
|
|
num_sample:int = 4, |
|
|
max_ref_frame:int = 8, |
|
|
random_ref_num:bool = False, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
with open(video_dir, 'r') as file: |
|
|
lines = file.readlines() |
|
|
video_dirs = [line.strip() for line in lines] |
|
|
assert len(video_dirs) == 2, 'Only support 2 video dirs' |
|
|
|
|
|
self.dataset1 = A2MVideoAudioPoseMultiSampleMultiRef(video_dir=video_dirs[0], |
|
|
sample_size=sample_size, |
|
|
sample_stride=sample_stride, |
|
|
sample_n_frames=sample_n_frames, |
|
|
audio_drop_ratio=audio_drop_ratio, |
|
|
num_sample=num_sample, |
|
|
max_ref_frame=max_ref_frame, |
|
|
random_ref_num=random_ref_num) |
|
|
self.dataset2 = A2MVideoAudioPoseMultiSampleMultiRef(video_dir=video_dirs[1], |
|
|
sample_size=sample_size, |
|
|
sample_stride=sample_stride, |
|
|
sample_n_frames=sample_n_frames, |
|
|
audio_drop_ratio=audio_drop_ratio, |
|
|
num_sample=num_sample, |
|
|
max_ref_frame=max_ref_frame, |
|
|
random_ref_num=random_ref_num) |
|
|
self.length = 2*max(len(self.dataset1),len(self.dataset2)) |
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
if idx % 2 == 0: |
|
|
a_idx = torch.randint(0, len(self.dataset1), (1,)).item() |
|
|
sample = self.dataset1[a_idx] |
|
|
else: |
|
|
b_idx = torch.randint(0, len(self.dataset2), (1,)).item() |
|
|
sample = self.dataset2[b_idx] |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), |
|
|
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), |
|
|
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), |
|
|
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), |
|
|
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), |
|
|
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), |
|
|
mask = torch.cat([item['mask'] for item in batch],dim=0) |
|
|
) |
|
|
|
|
|
class A2MVideoAudioMultiRefDoubleRef(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
audio_drop_ratio:float = 0.0, |
|
|
num_sample:int = 4, |
|
|
max_ref_frame:int = 8, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.audio_drop_ratio = audio_drop_ratio |
|
|
self.num_sample = num_sample |
|
|
self.max_ref_frame = max_ref_frame |
|
|
self.randomref_num = 8 |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0]) |
|
|
|
|
|
|
|
|
r = torch.rand(1).item() |
|
|
if r < 0.33: |
|
|
ref_num = 0 |
|
|
if r < 0.66: |
|
|
ref_num = 1 |
|
|
else: |
|
|
ref_num = self.max_ref_frame |
|
|
|
|
|
sample_frames = self.sample_n_frames + ref_num |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)) |
|
|
|
|
|
|
|
|
random_index = list(np.linspace(0, video_length - 1, self.randomref_num, dtype=int)) |
|
|
|
|
|
|
|
|
video_batch_index = random_index + batch_index |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(video_batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
|
|
|
randomref_video = videos[:self.randomref_num,:] |
|
|
l_videos = videos[self.randomref_num:,:] |
|
|
ref_video = l_videos[:ref_num,:] if ref_num > 0 else None |
|
|
gt_video = l_videos[ref_num:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[:ref_num,:] if ref_num > 0 else None |
|
|
gt_audio = audios[ref_num:,:] |
|
|
|
|
|
|
|
|
if ref_num == 1: |
|
|
ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:])) |
|
|
ref_video = torch.cat([ref_video_pad,ref_video],dim=0) |
|
|
|
|
|
ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:])) |
|
|
ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0) |
|
|
elif ref_num == 0: |
|
|
ref_video = torch.zeros((self.max_ref_frame,*gt_video.shape[1:])) |
|
|
ref_audio = torch.zeros((self.max_ref_frame,*gt_audio.shape[1:])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
|
|
|
return dict( |
|
|
ref_video=ref_video, |
|
|
gt_video=gt_video, |
|
|
randomref_video = randomref_video, |
|
|
ref_audio= ref_audio, |
|
|
gt_audio=gt_audio, |
|
|
mask = mask, |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
ref_video = torch.stack([item['ref_video'] for item in batch],dim=0), |
|
|
gt_video = torch.stack([item['gt_video'] for item in batch],dim=0), |
|
|
randomref_video = torch.stack([item['randomref_video'] for item in batch],dim=0), |
|
|
ref_audio = torch.stack([item['ref_audio'] for item in batch],dim=0), |
|
|
gt_audio = torch.stack([item['gt_audio'] for item in batch],dim=0), |
|
|
mask = torch.stack([item['mask'] for item in batch],dim=0) |
|
|
) |
|
|
|
|
|
|
|
|
class A2MVideoAudioMultiRefDoubleRefBalance(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
audio_drop_ratio:float = 0.0, |
|
|
num_sample:int = 4, |
|
|
max_ref_frame:int = 8, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.audio_drop_ratio = audio_drop_ratio |
|
|
self.num_sample = num_sample |
|
|
self.max_ref_frame = max_ref_frame |
|
|
self.randomref_num = 8 |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0]) |
|
|
|
|
|
|
|
|
r = torch.rand(1).item() |
|
|
if r < 0.33: |
|
|
ref_num = 0 |
|
|
elif r<0.66: |
|
|
ref_num = 1 |
|
|
else: |
|
|
ref_num = self.max_ref_frame |
|
|
|
|
|
sample_frames = self.sample_n_frames + ref_num |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = list(np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int)) |
|
|
|
|
|
|
|
|
random_index = list(np.linspace(0, video_length - 1, self.randomref_num, dtype=int)) |
|
|
random_index = list(random_index) |
|
|
|
|
|
video_batch_index = random_index + batch_index |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(video_batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
randomref_video = videos[:self.randomref_num,:] |
|
|
l_videos = videos[self.randomref_num:,:] |
|
|
ref_video = l_videos[:ref_num,:] if ref_num > 0 else None |
|
|
gt_video = l_videos[ref_num:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[:ref_num,:] if ref_num > 0 else None |
|
|
gt_audio = audios[ref_num:,:] |
|
|
|
|
|
|
|
|
if ref_num == 1: |
|
|
ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:])) |
|
|
ref_video = torch.cat([ref_video_pad,ref_video],dim=0) |
|
|
|
|
|
ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:])) |
|
|
ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0) |
|
|
elif ref_num == 0: |
|
|
ref_video = torch.zeros((self.max_ref_frame,*gt_video.shape[1:])) |
|
|
ref_audio = torch.zeros((self.max_ref_frame,*gt_audio.shape[1:])) |
|
|
|
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
|
|
|
return dict( |
|
|
ref_video=ref_video, |
|
|
gt_video=gt_video, |
|
|
randomref_video = randomref_video, |
|
|
ref_audio= ref_audio, |
|
|
gt_audio=gt_audio, |
|
|
mask = mask, |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
ref_video = torch.stack([item['ref_video'] for item in batch],dim=0), |
|
|
gt_video = torch.stack([item['gt_video'] for item in batch],dim=0), |
|
|
randomref_video = torch.stack([item['randomref_video'] for item in batch],dim=0), |
|
|
ref_audio = torch.stack([item['ref_audio'] for item in batch],dim=0), |
|
|
gt_audio = torch.stack([item['gt_audio'] for item in batch],dim=0), |
|
|
mask = torch.stack([item['mask'] for item in batch],dim=0) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class A2MVideoAudioPoseRandomRefMultiSample(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
num_sample:int = 4, |
|
|
max_ref_frame:int = 8, |
|
|
random_ref_num:bool = False, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.num_sample = num_sample |
|
|
self.max_ref_frame = max_ref_frame |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
|
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
pose_path = meta_data['pose_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
pose_reader = VideoReader(pose_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) |
|
|
|
|
|
|
|
|
sample_frames = self.num_sample * 2 |
|
|
if video_length < sample_frames: |
|
|
raise ValueError(f"视频长度{video_length}太短了,需要长度{sample_frames}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_idx = torch.randint(low=0, high=video_length-sample_frames, size=(1,)).item() |
|
|
occ_idx = np.arange(start_idx, start_idx+video_length) |
|
|
batch_index = occ_idx[:sample_frames] |
|
|
|
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[:self.num_sample,:].unsqueeze(1) |
|
|
gt_video = videos[self.num_sample:,:].unsqueeze(1) |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[:self.num_sample,:].unsqueeze(1) |
|
|
gt_pose = poses[self.num_sample:,:].unsqueeze(1) |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[:self.num_sample,:].unsqueeze(1) |
|
|
gt_audio = audios[self.num_sample:,:].unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
return dict( |
|
|
ref_video=ref_video, |
|
|
gt_video=gt_video, |
|
|
ref_pose = ref_pose, |
|
|
gt_pose =gt_pose, |
|
|
ref_audio=ref_audio, |
|
|
gt_audio=gt_audio, |
|
|
mask = mask, |
|
|
) |
|
|
|
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), |
|
|
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), |
|
|
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), |
|
|
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), |
|
|
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), |
|
|
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), |
|
|
mask = torch.cat([item['mask'] for item in batch],dim=0) |
|
|
) |
|
|
|
|
|
|
|
|
class A2MVideoAudioPoseMultiSampleMultiRef(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 16, |
|
|
audio_drop_ratio:float = 0.0, |
|
|
num_sample:int = 4, |
|
|
max_ref_frame:int = 8, |
|
|
random_ref_num:bool = False, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
self.audio_drop_ratio = audio_drop_ratio |
|
|
self.num_sample = num_sample |
|
|
self.max_ref_frame = max_ref_frame |
|
|
self.random_ref_num = random_ref_num |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
""" |
|
|
videos : 31,3,256,256 |
|
|
ref_img : 3,256,256 |
|
|
audio_feature : 30,50,384 |
|
|
ref_pose : 3,256,256 |
|
|
meta |
|
|
""" |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
pose_path = meta_data['pose_path'] |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
pose_reader = VideoReader(pose_path) |
|
|
video_length = min(len(video_reader),audio_feature.shape[0],len(pose_reader)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_video_list = [] |
|
|
gt_video_list = [] |
|
|
ref_pose_list = [] |
|
|
gt_pose_list = [] |
|
|
ref_audio_list = [] |
|
|
gt_audio_list = [] |
|
|
mask_list = [] |
|
|
|
|
|
for i in range(self.num_sample): |
|
|
|
|
|
if self.random_ref_num: |
|
|
ref_num = torch.randint(low=1, high=self.max_ref_frame+1, size=(1,)).item() |
|
|
else: |
|
|
ref_num = [1, self.max_ref_frame][torch.randint(2, (1,)).item()] |
|
|
sample_frames = self.sample_n_frames + ref_num |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_video = videos[:ref_num,:] |
|
|
gt_video = videos[ref_num:,:] |
|
|
|
|
|
poses = torch.from_numpy(pose_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
poses = poses / 255.0 |
|
|
poses = self.pixel_transforms(poses) |
|
|
ref_pose = poses[:ref_num,:] |
|
|
gt_pose = poses[ref_num:,:] |
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[:ref_num,:] |
|
|
gt_audio = audios[ref_num:,:] |
|
|
|
|
|
|
|
|
if ref_num < self.max_ref_frame: |
|
|
ref_video_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_video.shape[1:])) |
|
|
ref_video = torch.cat([ref_video_pad,ref_video],dim=0) |
|
|
|
|
|
ref_pose_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_pose.shape[1:])) |
|
|
ref_pose = torch.cat([ref_pose_pad,ref_pose],dim=0) |
|
|
|
|
|
ref_audio_pad = torch.zeros((self.max_ref_frame-ref_num,*ref_audio.shape[1:])) |
|
|
ref_audio = torch.cat([ref_audio_pad,ref_audio],dim=0) |
|
|
|
|
|
|
|
|
cur_available_length = gt_video.shape[0] |
|
|
|
|
|
assert gt_video.shape[0] == self.sample_n_frames ,''+str(gt_video.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_audio.shape[0] == self.sample_n_frames ,''+str(gt_audio.shape[0])+' '+str(self.sample_n_frames) |
|
|
assert gt_pose.shape[0] == self.sample_n_frames ,''+str(gt_pose.shape[0])+' '+str(self.sample_n_frames) |
|
|
|
|
|
|
|
|
mask = torch.zeros(self.sample_n_frames) |
|
|
mask[:cur_available_length] = 1 |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < self.audio_drop_ratio: |
|
|
ref_audio = torch.zeros_like(ref_audio) |
|
|
gt_audio = torch.zeros_like(gt_audio) |
|
|
|
|
|
|
|
|
ref_video_list.append(ref_video) |
|
|
gt_video_list.append(gt_video) |
|
|
ref_pose_list.append(ref_pose) |
|
|
gt_pose_list.append(gt_pose) |
|
|
ref_audio_list.append(ref_audio) |
|
|
gt_audio_list.append(gt_audio) |
|
|
mask_list.append(mask) |
|
|
|
|
|
|
|
|
return dict( |
|
|
ref_video=torch.stack(ref_video_list), |
|
|
gt_video=torch.stack(gt_video_list), |
|
|
ref_pose = torch.stack(ref_pose_list), |
|
|
gt_pose = torch.stack(gt_pose_list), |
|
|
ref_audio=torch.stack(ref_audio_list), |
|
|
gt_audio=torch.stack(gt_audio_list), |
|
|
mask = torch.stack(mask_list), |
|
|
) |
|
|
|
|
|
def get_file_name(self, file_path): |
|
|
return file_path.split('/')[-1].split('.')[0] |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
ref_video = torch.cat([item['ref_video'] for item in batch],dim=0), |
|
|
gt_video = torch.cat([item['gt_video'] for item in batch],dim=0), |
|
|
ref_pose = torch.cat([item['ref_pose'] for item in batch],dim=0), |
|
|
gt_pose = torch.cat([item['gt_pose'] for item in batch],dim=0), |
|
|
ref_audio = torch.cat([item['ref_audio'] for item in batch],dim=0), |
|
|
gt_audio = torch.cat([item['gt_audio'] for item in batch],dim=0), |
|
|
mask = torch.cat([item['mask'] for item in batch],dim=0) |
|
|
) |
|
|
|
|
|
|
|
|
class A2VDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_dir:str, |
|
|
sample_size: int = 256, |
|
|
sample_stride: int = 1, |
|
|
sample_n_frames:int = 120, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_stride = sample_stride |
|
|
self.sample_n_frames = sample_n_frames |
|
|
|
|
|
|
|
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
self.pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
with open(video_dir, 'rb') as f: |
|
|
self.metadata_list = pickle.load(f) |
|
|
self.length = len(self.metadata_list) |
|
|
print(f'Total {self.length} files is available') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
sample = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print('error',e) |
|
|
idx = torch.randint(low=0, high=self.length, size=(1,)).item() |
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def get_batch(self, idx): |
|
|
|
|
|
|
|
|
meta_data = self.metadata_list[idx] |
|
|
video_path = meta_data['video_path'] |
|
|
whisper_path = meta_data['whisper_emb_path'] |
|
|
audio_path = meta_data['audio_path'] |
|
|
|
|
|
name = os.path.basename(video_path).split('.')[0] |
|
|
fps = 25 if 'hdtf' in video_path else 30 |
|
|
|
|
|
|
|
|
audio_feature = torch.load(whisper_path) |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path) |
|
|
|
|
|
video_length = min(len(video_reader),audio_feature.shape[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_num = 1 |
|
|
sample_frames = self.sample_n_frames + ref_num |
|
|
clip_length = (sample_frames - 1) * self.sample_stride + 1 |
|
|
|
|
|
start_idx = np.random.randint(0, video_length - clip_length + 1) |
|
|
end_idx = start_idx + clip_length |
|
|
batch_index = np.linspace(start_idx, end_idx - 1, sample_frames, dtype=int) |
|
|
|
|
|
|
|
|
videos = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = self.pixel_transforms(videos) |
|
|
ref_img = videos[:ref_num,:] |
|
|
inf_video = videos[ref_num:,:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audios = audio_feature[batch_index] |
|
|
ref_audio = audios[:ref_num,:] |
|
|
inf_audio = audios[ref_num:,:] |
|
|
|
|
|
|
|
|
start_time = start_idx / fps |
|
|
_meta = {"name":name, |
|
|
"audio_path":audio_path, |
|
|
"fps":fps, |
|
|
"start_time":start_time} |
|
|
|
|
|
return dict( |
|
|
meta_info = _meta, |
|
|
ref_img=ref_img, |
|
|
gt_video=videos, |
|
|
ref_audio=ref_audio, |
|
|
inf_audio=inf_audio, |
|
|
ref_pose=None, |
|
|
inf_pose=None, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch): |
|
|
return dict( |
|
|
meta_info = [item['meta_info'] for item in batch], |
|
|
ref_img= torch.stack([item["ref_img"] for item in batch]), |
|
|
ref_audio= torch.stack([item["ref_audio"] for item in batch]), |
|
|
inf_audio= torch.stack([item["inf_audio"] for item in batch]), |
|
|
ref_pose= torch.ones((2,2)), |
|
|
inf_pose= torch.ones((2,2)), |
|
|
gt_video= torch.stack([item["gt_video"] for item in batch]), |
|
|
) |
|
|
|
|
|
|
|
|
def generate_non_equal_random_lists(frame_num,sample_num): |
|
|
list1 = [np.random.randint(0, frame_num) for _ in range(sample_num)] |
|
|
|
|
|
list2 = [] |
|
|
for i in range(len(list1)): |
|
|
available_numbers = list(range(0, list1[i])) + list(range(list1[i] + 1, frame_num)) |
|
|
list2.append(random.choice(available_numbers)) |
|
|
|
|
|
return list1, list2 |
|
|
|
|
|
def shuffle_list(l): |
|
|
shuffled_indices = torch.randperm(len(l)) |
|
|
|
|
|
l = [l[i] for i in shuffled_indices] |
|
|
|
|
|
return l |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = AMDVideoAudioFeature( |
|
|
video_dir = '/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/dataset/path/lhz/train.pkl', |
|
|
path_type = 'file' |
|
|
) |
|
|
dataloader = DataLoader( |
|
|
dataset,2,True,num_workers=0, |
|
|
collate_fn=dataset.collate_fn |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for d in dataloader: |
|
|
video = d["videos"] |
|
|
audio_feature = d["audio_feature"] |
|
|
refimg = d["ref_img"] |
|
|
|
|
|
print(video.shape) |
|
|
print(audio_feature.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|