from dataclasses import dataclass from fractions import Fraction from pathlib import Path from typing import Optional import av import cv2 import numpy as np import torch import os from av import AudioFrame @dataclass class VideoInfo: duration_sec: float fps: Fraction clip_frames: torch.Tensor sync_frames: torch.Tensor all_frames: Optional[list[np.ndarray]] @property def height(self): return self.all_frames[0].shape[0] @property def width(self): return self.all_frames[0].shape[1] @classmethod def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float, fps: Fraction) -> 'VideoInfo': num_frames = int(duration_sec * fps) all_frames = [image_info.original_frame] * num_frames return cls(duration_sec=duration_sec, fps=fps, clip_frames=image_info.clip_frames, sync_frames=image_info.sync_frames, all_frames=all_frames) @dataclass class ImageInfo: clip_frames: torch.Tensor sync_frames: torch.Tensor original_frame: Optional[np.ndarray] @property def height(self): return self.original_frame.shape[0] @property def width(self): return self.original_frame.shape[1] def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float, need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]: cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Could not open {video_path}") fps_val = cap.get(cv2.CAP_PROP_FPS) if not fps_val or fps_val <= 0: cap.release() raise RuntimeError(f"Could not read fps from {video_path}") fps = Fraction(fps_val).limit_denominator() start_frame = int(start_sec * fps_val) end_frame = int(end_sec * fps_val) cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) output_frames = [[] for _ in list_of_fps] next_frame_time_for_each_fps = [start_sec for _ in list_of_fps] time_delta_for_each_fps = [1 / f for f in list_of_fps] all_frames = [] frame_idx = start_frame while frame_idx <= end_frame: ok, frame_bgr = cap.read() if not ok: break frame_idx += 1 frame_time = frame_idx / fps_val # seconds frame_rgb = None if need_all_frames: frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) all_frames.append(frame_rgb) for i, _ in enumerate(list_of_fps): while frame_time >= next_frame_time_for_each_fps[i]: if frame_rgb is None: frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) output_frames[i].append(frame_rgb) next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i] cap.release() output_frames = [np.stack(frames) for frames in output_frames] return output_frames, all_frames, fps def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int): container = av.open(output_path, 'w') output_video_stream = container.add_stream('h264', video_info.fps) output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps output_video_stream.width = video_info.width output_video_stream.height = video_info.height output_video_stream.pix_fmt = 'yuv420p' output_audio_stream = container.add_stream('aac', sampling_rate) # encode video for image in video_info.all_frames: image = av.VideoFrame.from_ndarray(image) packet = output_video_stream.encode(image) container.mux(packet) for packet in output_video_stream.encode(): container.mux(packet) # convert float tensor audio to numpy array audio_np = audio.numpy().astype(np.float32) audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono') audio_frame.sample_rate = sampling_rate for packet in output_audio_stream.encode(audio_frame): container.mux(packet) for packet in output_audio_stream.encode(): container.mux(packet) container.close() import subprocess import tempfile from pathlib import Path import torch def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int): from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: temp_path = Path(f.name) temp_path_str= str(temp_path) import torchaudio torchaudio.save(temp_path_str, audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate) os.makedirs(os.path.dirname(output_path), exist_ok=True) combine_video_with_audio_tracks(video_path, [temp_path_str], output_path ) temp_path.unlink(missing_ok=True) def remux_with_audio_old(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int): """ NOTE: I don't think we can get the exact video duration right without re-encoding so we are not using this but keeping it here for reference """ video = av.open(video_path) output = av.open(output_path, 'w') input_video_stream = video.streams.video[0] output_video_stream = output.add_stream(template=input_video_stream) output_audio_stream = output.add_stream('aac', sampling_rate) duration_sec = audio.shape[-1] / sampling_rate for packet in video.demux(input_video_stream): # We need to skip the "flushing" packets that `demux` generates. if packet.dts is None: continue # We need to assign the packet to the new stream. packet.stream = output_video_stream output.mux(packet) # convert float tensor audio to numpy array audio_np = audio.numpy().astype(np.float32) audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono') audio_frame.sample_rate = sampling_rate for packet in output_audio_stream.encode(audio_frame): output.mux(packet) for packet in output_audio_stream.encode(): output.mux(packet) video.close() output.close() output.close()