Spaces:
Runtime error
Runtime error
File size: 7,735 Bytes
b89c182 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import os.path as osp
import math
import pickle
import warnings
import glob
import torch.utils.data as data
import torch.nn.functional as F
from torchvision.datasets.video_utils import VideoClips
from converter import normalize, normalize_spectrogram, get_mel_spectrogram_from_audio
from torchaudio import transforms as Ta
from torchvision import transforms as Tv
from torchvision.io.video import read_video
import torch
from torchvision.transforms import InterpolationMode
class LatentDataset(data.Dataset):
""" Generic dataset for latents pregenerated from a dataset
Returns a dictionary of latents encoded from the original dataset """
exts = ['pt']
def __init__(self, data_folder, train=True):
"""
Args:
data_folder: path to the folder with videos. The folder
should contain a 'train' and a 'test' directory,
each with corresponding videos stored
"""
super().__init__()
self.train = train
folder = osp.join(data_folder, 'train' if train else 'test')
self.files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
for ext in self.exts], [])
warnings.filterwarnings('ignore')
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
while True:
try:
latents = torch.load(self.files[idx], map_location="cpu")
except Exception as e:
print(f"Dataset Exception: {e}")
idx = (idx + 1) % len(self.files)
continue
break
return latents["video"], latents["audio"], latents["y"]
class AudioVideoDataset(data.Dataset):
""" Generic dataset for videos files stored in folders
Returns BCTHW videos in the range [-0.5, 0.5] """
exts = ['avi', 'mp4', 'webm']
def __init__(self, data_folder, train=True, resolution=64, sample_every_n_frames=1, sequence_length=8, audio_channels=1, sample_rate=16000, min_length=1, ignore_cache=False, labeled=True, target_video_fps=10):
"""
Args:
data_folder: path to the folder with videos. The folder
should contain a 'train' and a 'test' directory,
each with corresponding videos stored
sequence_length: length of extracted video sequences
"""
super().__init__()
self.train = train
self.sequence_length = sequence_length
self.resolution = resolution
self.sample_every_n_frames = sample_every_n_frames
self.audio_channels = audio_channels
self.sample_rate = sample_rate
self.min_length = min_length
self.labeled = labeled
folder = osp.join(data_folder, 'train' if train else 'test')
files = sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
for ext in self.exts], [])
# hacky way to compute # of classes (count # of unique parent directories)
self.classes = list(set([get_parent_dir(f) for f in files]))
self.classes.sort()
self.class_to_label = {c: i for i, c in enumerate(self.classes)}
warnings.filterwarnings('ignore')
cache_file = osp.join(folder, f"metadata_{self.sequence_length}.pkl")
if not osp.exists(cache_file) or ignore_cache or True:
clips = VideoClips(files, self.sequence_length, num_workers=32, frame_rate=target_video_fps)
# pickle.dump(clips.metadata, open(cache_file, 'wb'))
else:
metadata = pickle.load(open(cache_file, 'rb'))
clips = VideoClips(files, self.sequence_length,
_precomputed_metadata=metadata)
# self._clips = clips.subset(np.arange(24))
self._clips = clips
@property
def n_classes(self):
return len(self.classes)
def __len__(self):
return self._clips.num_clips()
def __getitem__(self, idx):
resolution = self.resolution
while True:
try:
video, _, info, _ = self._clips.get_clip(idx)
except Exception:
idx = (idx + 1) % self._clips.num_clips()
continue
break
return preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames), self.get_audio(info, idx), self.get_label(idx)
def get_label(self, idx):
if not self.labeled:
return -1
video_idx, clip_idx = self._clips.get_clip_location(idx)
class_name = get_parent_dir(self._clips.video_paths[video_idx])
label = self.class_to_label[class_name]
return label
def get_audio(self, info, idx):
video_idx, clip_idx = self._clips.get_clip_location(idx)
video_path = self._clips.video_paths[video_idx]
video_fps = self._clips.video_fps[video_idx]
duration_per_frame = self._clips.video_pts[video_idx][1] - self._clips.video_pts[video_idx][0]
clip_pts = self._clips.clips[video_idx][clip_idx]
clip_pid = clip_pts // duration_per_frame
start_t = (clip_pid[0] / video_fps * 1. ).item()
end_t = ((clip_pid[-1] + 1) / video_fps * 1. ).item()
_, raw_audio, _ = read_video(video_path,start_t, end_t, pts_unit='sec')
raw_audio = prepare_audio(raw_audio, info["audio_fps"], self.sample_rate, self.audio_channels, self.sequence_length, self.min_length)
_, spec = get_mel_spectrogram_from_audio(raw_audio[0].numpy())
norm_spec = normalize_spectrogram(spec)
norm_spec = normalize(norm_spec) # normalize to [-1, 1], because pipeline do not normalize for torch.Tensor input
norm_spec.unsqueeze(1) # add channel dimension
return norm_spec
#return raw_audio[0]
def get_parent_dir(path):
return osp.basename(osp.dirname(path))
def preprocess(video, resolution, sample_every_n_frames=1):
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
old_size = video.shape[2:4]
ratio = min(float(resolution)/(old_size[0]), float(resolution)/(old_size[1]) )
new_size = tuple([int(i*ratio) for i in old_size])
pad_w = resolution - new_size[1]
pad_h = resolution- new_size[0]
top,bottom = pad_h//2, pad_h-(pad_h//2)
left,right = pad_w//2, pad_w -(pad_w//2)
transform = Tv.Compose([Tv.Resize(new_size, interpolation=InterpolationMode.BICUBIC), Tv.Pad((left, top, right, bottom))])
video_new = transform(video)
video_new = video_new*2-1
return video_new
def pad_crop_audio(audio, target_length):
target_length = int(target_length)
n, s = audio.shape
start = 0
end = start + target_length
output = audio.new_zeros([n, target_length])
output[:, :min(s, target_length)] = audio[:, start:end]
return output
def prepare_audio(audio, in_sr, target_sr, target_channels, sequence_length, min_length):
if in_sr != target_sr:
resample_tf = Ta.Resample(in_sr, target_sr)
audio = resample_tf(audio)
max_length = target_sr/10*sequence_length
target_length = max_length + (min_length - (max_length % min_length)) % min_length
audio = pad_crop_audio(audio, target_length)
audio = set_audio_channels(audio, target_channels)
return audio
def set_audio_channels(audio, target_channels):
if target_channels == 1:
# Convert to mono
# audio = audio.mean(0, keepdim=True)
audio = audio[:1, :]
elif target_channels == 2:
# Convert to stereo
if audio.shape[0] == 1:
audio = audio.repeat(2, 1)
elif audio.shape[0] > 2:
audio = audio[:2, :]
return audio |