| from pathlib import Path
|
| from typing import Union
|
|
|
| import torch
|
| from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
|
|
|
|
|
| class VideoJoiner:
|
|
|
| def __init__(self, src_root: Union[str, Path], output_root: Union[str, Path], sample_rate: int,
|
| duration_seconds: float):
|
| self.src_root = Path(src_root)
|
| self.output_root = Path(output_root)
|
| self.sample_rate = sample_rate
|
| self.duration_seconds = duration_seconds
|
|
|
| self.output_root.mkdir(parents=True, exist_ok=True)
|
|
|
| def join(self, video_id: str, output_name: str, audio: torch.Tensor):
|
| video_path = self.src_root / f'{video_id}.mp4'
|
| output_path = self.output_root / f'{output_name}.mp4'
|
| merge_audio_into_video(video_path, output_path, audio, self.sample_rate,
|
| self.duration_seconds)
|
|
|
|
|
| def merge_audio_into_video(video_path: Union[str, Path], output_path: Union[str, Path],
|
| audio: torch.Tensor, sample_rate: int, duration_seconds: float):
|
|
|
|
|
| frame_rate = 24
|
|
|
| reader = StreamingMediaDecoder(video_path)
|
| reader.add_basic_video_stream(
|
| frames_per_chunk=int(frame_rate * duration_seconds),
|
|
|
| format="rgb24",
|
| frame_rate=frame_rate,
|
| )
|
|
|
| reader.fill_buffer()
|
| video_chunk = reader.pop_chunks()[0]
|
| t, _, h, w = video_chunk.shape
|
|
|
| writer = StreamingMediaEncoder(output_path)
|
| writer.add_audio_stream(
|
| sample_rate=sample_rate,
|
| num_channels=audio.shape[-1],
|
| encoder="libmp3lame",
|
| )
|
| writer.add_video_stream(frame_rate=frame_rate,
|
| width=w,
|
| height=h,
|
| format="rgb24",
|
| encoder="libx264",
|
| encoder_format="yuv420p")
|
|
|
| with writer.open():
|
| writer.write_audio_chunk(0, audio.float())
|
| writer.write_video_chunk(1, video_chunk)
|
|
|
|
|
| if __name__ == '__main__':
|
|
|
| import sys
|
| audio = torch.randn(16000 * 4, 1)
|
| merge_audio_into_video(sys.argv[1], sys.argv[2], audio, 16000, 4)
|
|
|