|
|
import math
|
|
|
import random
|
|
|
import torch
|
|
|
|
|
|
from torch import nn
|
|
|
from typing import Tuple
|
|
|
import os
|
|
|
import subprocess as sp
|
|
|
from PIL import Image
|
|
|
from torchvision import transforms
|
|
|
from decord import VideoReader, cpu
|
|
|
|
|
|
class PadCrop(nn.Module):
|
|
|
def __init__(self, n_samples, randomize=True):
|
|
|
super().__init__()
|
|
|
self.n_samples = n_samples
|
|
|
self.randomize = randomize
|
|
|
|
|
|
def __call__(self, signal):
|
|
|
n, s = signal.shape
|
|
|
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
|
|
|
end = start + self.n_samples
|
|
|
output = signal.new_zeros([n, self.n_samples])
|
|
|
output[:, :min(s, self.n_samples)] = signal[:, start:end]
|
|
|
return output
|
|
|
|
|
|
|
|
|
class PadCrop_Normalized_T(nn.Module):
|
|
|
|
|
|
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
|
|
super().__init__()
|
|
|
self.n_samples = n_samples
|
|
|
self.sample_rate = sample_rate
|
|
|
self.randomize = randomize
|
|
|
|
|
|
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int, torch.Tensor]:
|
|
|
n_channels, n_samples = source.shape
|
|
|
|
|
|
|
|
|
total_duration = n_samples // self.sample_rate
|
|
|
|
|
|
|
|
|
upper_bound = max(0, n_samples - self.n_samples)
|
|
|
|
|
|
|
|
|
offset = 0
|
|
|
|
|
|
if self.randomize and n_samples > self.n_samples:
|
|
|
valid_offsets = [
|
|
|
i * self.sample_rate for i in range(0, total_duration, 10)
|
|
|
if i * self.sample_rate + self.n_samples <= n_samples and
|
|
|
(total_duration <= 20 or total_duration - i >= 15)
|
|
|
]
|
|
|
if valid_offsets:
|
|
|
offset = random.choice(valid_offsets)
|
|
|
|
|
|
|
|
|
t_start = offset / (upper_bound + self.n_samples)
|
|
|
t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
|
|
|
|
|
|
|
|
|
chunk = source.new_zeros([n_channels, self.n_samples])
|
|
|
|
|
|
|
|
|
chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
|
|
|
|
|
|
|
|
|
seconds_start = math.floor(offset / self.sample_rate)
|
|
|
seconds_total = math.ceil(n_samples / self.sample_rate)
|
|
|
|
|
|
|
|
|
padding_mask = torch.zeros([self.n_samples])
|
|
|
padding_mask[:min(n_samples, self.n_samples)] = 1
|
|
|
|
|
|
return (
|
|
|
chunk,
|
|
|
t_start,
|
|
|
t_end,
|
|
|
seconds_start,
|
|
|
seconds_total,
|
|
|
padding_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
class PhaseFlipper(nn.Module):
|
|
|
"Randomly invert the phase of a signal"
|
|
|
def __init__(self, p=0.5):
|
|
|
super().__init__()
|
|
|
self.p = p
|
|
|
def __call__(self, signal):
|
|
|
return -signal if (random.random() < self.p) else signal
|
|
|
|
|
|
class Mono(nn.Module):
|
|
|
def __call__(self, signal):
|
|
|
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
|
|
|
|
|
|
class Stereo(nn.Module):
|
|
|
def __call__(self, signal):
|
|
|
signal_shape = signal.shape
|
|
|
|
|
|
if len(signal_shape) == 1:
|
|
|
signal = signal.unsqueeze(0).repeat(2, 1)
|
|
|
elif len(signal_shape) == 2:
|
|
|
if signal_shape[0] == 1:
|
|
|
signal = signal.repeat(2, 1)
|
|
|
elif signal_shape[0] > 2:
|
|
|
signal = signal[:2, :]
|
|
|
|
|
|
return signal
|
|
|
|
|
|
|
|
|
def adjust_video_duration(video_tensor, duration, target_fps):
|
|
|
current_duration = video_tensor.shape[0]
|
|
|
target_duration = duration * target_fps
|
|
|
if current_duration > target_duration:
|
|
|
video_tensor = video_tensor[:target_duration]
|
|
|
elif current_duration < target_duration:
|
|
|
last_frame = video_tensor[-1:]
|
|
|
repeat_times = target_duration - current_duration
|
|
|
video_tensor = torch.cat((video_tensor, last_frame.repeat(repeat_times, 1, 1, 1)), dim=0)
|
|
|
return video_tensor
|
|
|
|
|
|
def read_video(filepath, seek_time=0., duration=-1, target_fps=2):
|
|
|
if filepath is None:
|
|
|
return torch.zeros((int(duration * target_fps), 3, 224, 224))
|
|
|
|
|
|
ext = os.path.splitext(filepath)[1].lower()
|
|
|
if ext in ['.jpg', '.jpeg', '.png']:
|
|
|
resize_transform = transforms.Resize((224, 224))
|
|
|
image = Image.open(filepath).convert("RGB")
|
|
|
frame = transforms.ToTensor()(image).unsqueeze(0)
|
|
|
frame = resize_transform(frame)
|
|
|
target_frames = int(duration * target_fps)
|
|
|
frame = frame.repeat(int(math.ceil(target_frames / frame.shape[0])), 1, 1, 1)[:target_frames]
|
|
|
assert frame.shape[0] == target_frames, f"The shape of frame is {frame.shape}"
|
|
|
return frame
|
|
|
|
|
|
vr = VideoReader(filepath, ctx=cpu(0))
|
|
|
fps = vr.get_avg_fps()
|
|
|
total_frames = len(vr)
|
|
|
|
|
|
seek_frame = int(seek_time * fps)
|
|
|
if duration > 0:
|
|
|
total_frames_to_read = int(target_fps * duration)
|
|
|
frame_interval = int(math.ceil(fps / target_fps))
|
|
|
end_frame = min(seek_frame + total_frames_to_read * frame_interval, total_frames)
|
|
|
frame_ids = list(range(seek_frame, end_frame, frame_interval))
|
|
|
else:
|
|
|
frame_interval = int(math.ceil(fps / target_fps))
|
|
|
frame_ids = list(range(0, total_frames, frame_interval))
|
|
|
|
|
|
frames = vr.get_batch(frame_ids).asnumpy()
|
|
|
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
|
|
|
|
|
|
if frames.shape[2] != 224 or frames.shape[3] != 224:
|
|
|
resize_transform = transforms.Resize((224, 224))
|
|
|
frames = resize_transform(frames)
|
|
|
|
|
|
video_tensor = adjust_video_duration(frames, duration, target_fps)
|
|
|
assert video_tensor.shape[0] == duration * target_fps, f"The shape of video_tensor is {video_tensor.shape}"
|
|
|
return video_tensor
|
|
|
|
|
|
def merge_video_audio(video_path, audio_path, output_path, start_time, duration):
|
|
|
command = [
|
|
|
'ffmpeg',
|
|
|
'-y',
|
|
|
'-ss', str(start_time),
|
|
|
'-t', str(duration),
|
|
|
'-i', video_path,
|
|
|
'-i', audio_path,
|
|
|
'-c:v', 'copy',
|
|
|
'-c:a', 'aac',
|
|
|
'-map', '0:v:0',
|
|
|
'-map', '1:a:0',
|
|
|
'-shortest',
|
|
|
'-strict', 'experimental',
|
|
|
output_path
|
|
|
]
|
|
|
|
|
|
try:
|
|
|
sp.run(command, check=True)
|
|
|
print(f"Successfully merged audio and video into {output_path}")
|
|
|
return output_path
|
|
|
except sp.CalledProcessError as e:
|
|
|
print(f"Error merging audio and video: {e}")
|
|
|
return None
|
|
|
|
|
|
def load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total):
|
|
|
if audio_path is None:
|
|
|
return torch.zeros((2, int(sample_rate * seconds_total)))
|
|
|
audio_tensor, sr = torchaudio.load(audio_path)
|
|
|
start_index = int(sample_rate * seconds_start)
|
|
|
target_length = int(sample_rate * seconds_total)
|
|
|
end_index = start_index + target_length
|
|
|
audio_tensor = audio_tensor[:, start_index:end_index]
|
|
|
if audio_tensor.shape[1] < target_length:
|
|
|
pad_length = target_length - audio_tensor.shape[1]
|
|
|
audio_tensor = F.pad(audio_tensor, (pad_length, 0))
|
|
|
return audio_tensor |