# Copyright 2025 The Lightricks team and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from fractions import Fraction from typing import Optional import torch from ...utils import is_av_available _CAN_USE_AV = is_av_available() if _CAN_USE_AV: import av else: raise ImportError( "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`" ) def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: """ Prepare the audio stream for writing. """ audio_stream = container.add_stream("aac", rate=audio_sample_rate) audio_stream.codec_context.sample_rate = audio_sample_rate audio_stream.codec_context.layout = "stereo" audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) return audio_stream def _resample_audio( container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame ) -> None: cc = audio_stream.codec_context # Use the encoder's format/layout/rate as the *target* target_format = cc.format or "fltp" # AAC → usually fltp target_layout = cc.layout or "stereo" target_rate = cc.sample_rate or frame_in.sample_rate audio_resampler = av.audio.resampler.AudioResampler( format=target_format, layout=target_layout, rate=target_rate, ) audio_next_pts = 0 for rframe in audio_resampler.resample(frame_in): if rframe.pts is None: rframe.pts = audio_next_pts audio_next_pts += rframe.samples rframe.sample_rate = frame_in.sample_rate container.mux(audio_stream.encode(rframe)) # flush audio encoder for packet in audio_stream.encode(): container.mux(packet) def _write_audio( container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int, ) -> None: if samples.ndim == 1: samples = samples[:, None] if samples.shape[1] != 2 and samples.shape[0] == 2: samples = samples.T if samples.shape[1] != 2: raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") # Convert to int16 packed for ingestion; resampler converts to encoder fmt. if samples.dtype != torch.int16: samples = torch.clip(samples, -1.0, 1.0) samples = (samples * 32767.0).to(torch.int16) frame_in = av.AudioFrame.from_ndarray( samples.contiguous().reshape(1, -1).cpu().numpy(), format="s16", layout="stereo", ) frame_in.sample_rate = audio_sample_rate _resample_audio(container, audio_stream, frame_in) def encode_video( video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str ) -> None: video_np = video.cpu().numpy() _, height, width, _ = video_np.shape container = av.open(output_path, mode="w") stream = container.add_stream("libx264", rate=int(fps)) stream.width = width stream.height = height stream.pix_fmt = "yuv420p" if audio is not None: if audio_sample_rate is None: raise ValueError("audio_sample_rate is required when audio is provided") audio_stream = _prepare_audio_stream(container, audio_sample_rate) for frame_array in video_np: frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") for packet in stream.encode(frame): container.mux(packet) # Flush encoder for packet in stream.encode(): container.mux(packet) if audio is not None: _write_audio(container, audio_stream, audio, audio_sample_rate) container.close()