File size: 3,650 Bytes
b4bbb92
 
 
 
 
 
 
 
2ba71a4
 
b4bbb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
import os
from typing import Callable
import tempfile

import numpy as np
import soundfile as sf
from moviepy import VideoFileClip, AudioFileClip
# from moviepy.audio.AudioClip import AudioArrayClip
# from moviepy.audio.fx import AudioLoop
import torch
import torchvision


def merge_audio_video(
    audio: str | Path | np.ndarray,
    video_path: str | Path,
    target_path: str | Path,
    backend: str = "moviepy",
    logging: bool = False,
    audio_fps: int | None = None
):
    """
    Merge audio and video into a single file.

    Args:
        audio_path (str | Path): Path to the audio file.
        video_path (str | Path): Path to the video file.
        target_path (str | Path): Path to the target file.
        backend (str, optional): The backend to use for merging. Defaults to "moviepy".
    """
    assert backend in [
        "moviepy", "ffmpeg"
    ], "Backend should be moviepy or ffmpeg"
    if backend == "moviepy":
        video = VideoFileClip(video_path.__str__())
        video = video.without_audio()
        if isinstance(audio, np.ndarray):
            assert audio_fps is not None
            # write to a temp file, then use AudioFileClip to load
            with tempfile.NamedTemporaryFile(
                suffix=".wav", delete=False
            ) as tmp_wav:
                sf.write(tmp_wav.name, audio, samplerate=audio_fps)
            audio = AudioFileClip(tmp_wav.name)
        else:
            audio = AudioFileClip(audio.__str__())
            tmp_wav = None

        video = video.with_audio(audio)

        target_path = Path(target_path)
        video.write_videofile(
            target_path,
            logger=None if not logging else "bar",
            threads=8,
            preset="ultrafast",
            ffmpeg_params=["-crf", "23"]
        )
        if tmp_wav:
            os.remove(tmp_wav.name)
    else:
        logging_arg = "" if logging else "-loglevel quiet"
        command = f"ffmpeg {logging_arg} -i '{video_path.__str__()}' -i '{audio.__str__()}' -c:v copy " \
                  f"-c:a copy -map 0:v:0 -map 1:a:0 '{target_path.__str__()}'"
        os.system(command)


def read_video_frames(
    video_path: str,
    duration: float | None = 10.0,
    fps: int = 10,
    video_size: tuple[int] = (256, 256),
    resize_transform: Callable | None = None,
):
    try:
        video, _, meta = torchvision.io.read_video(
            str(video_path), start_pts=0, end_pts=duration, pts_unit='sec'
        )
        video_duration = video.shape[0] / meta["video_fps"]

        if duration and video_duration < duration:
            num_frames, height, width, channels = video.shape
            padding_length = int(duration * meta["video_fps"]) - num_frames
            padding = torch.zeros((padding_length, height, width, channels),
                                  dtype=video.dtype)
            video = torch.cat([video, padding], dim=0)
            target_length = int(duration * fps)
        else:
            target_length = int(video_duration * fps)

        indices = torch.linspace(0, video.shape[0] - 1,
                                 steps=target_length).long()
        video = video[indices]
        video = video.permute(0, 3, 1, 2)  # [T, C, H, W]
        if resize_transform is None:
            resize_transform = torchvision.transforms.Resize(video_size)
        video = resize_transform(video)
        return video
    except Exception as e:
        print(f"error reading video {video_path}: {e}")
        assert duration is not None
        target_length = int(duration * fps)
        return torch.zeros(target_length, 3, *video_size)