Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The Lightricks team and The HuggingFace Team. | |
| # All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from fractions import Fraction | |
| from typing import Optional | |
| import torch | |
| from ...utils import is_av_available | |
| _CAN_USE_AV = is_av_available() | |
| if _CAN_USE_AV: | |
| import av | |
| else: | |
| raise ImportError( | |
| "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`" | |
| ) | |
| def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: | |
| """ | |
| Prepare the audio stream for writing. | |
| """ | |
| audio_stream = container.add_stream("aac", rate=audio_sample_rate) | |
| audio_stream.codec_context.sample_rate = audio_sample_rate | |
| audio_stream.codec_context.layout = "stereo" | |
| audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) | |
| return audio_stream | |
| def _resample_audio( | |
| container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame | |
| ) -> None: | |
| cc = audio_stream.codec_context | |
| # Use the encoder's format/layout/rate as the *target* | |
| target_format = cc.format or "fltp" # AAC → usually fltp | |
| target_layout = cc.layout or "stereo" | |
| target_rate = cc.sample_rate or frame_in.sample_rate | |
| audio_resampler = av.audio.resampler.AudioResampler( | |
| format=target_format, | |
| layout=target_layout, | |
| rate=target_rate, | |
| ) | |
| audio_next_pts = 0 | |
| for rframe in audio_resampler.resample(frame_in): | |
| if rframe.pts is None: | |
| rframe.pts = audio_next_pts | |
| audio_next_pts += rframe.samples | |
| rframe.sample_rate = frame_in.sample_rate | |
| container.mux(audio_stream.encode(rframe)) | |
| # flush audio encoder | |
| for packet in audio_stream.encode(): | |
| container.mux(packet) | |
| def _write_audio( | |
| container: av.container.Container, | |
| audio_stream: av.audio.AudioStream, | |
| samples: torch.Tensor, | |
| audio_sample_rate: int, | |
| ) -> None: | |
| if samples.ndim == 1: | |
| samples = samples[:, None] | |
| if samples.shape[1] != 2 and samples.shape[0] == 2: | |
| samples = samples.T | |
| if samples.shape[1] != 2: | |
| raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") | |
| # Convert to int16 packed for ingestion; resampler converts to encoder fmt. | |
| if samples.dtype != torch.int16: | |
| samples = torch.clip(samples, -1.0, 1.0) | |
| samples = (samples * 32767.0).to(torch.int16) | |
| frame_in = av.AudioFrame.from_ndarray( | |
| samples.contiguous().reshape(1, -1).cpu().numpy(), | |
| format="s16", | |
| layout="stereo", | |
| ) | |
| frame_in.sample_rate = audio_sample_rate | |
| _resample_audio(container, audio_stream, frame_in) | |
| def encode_video( | |
| video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str | |
| ) -> None: | |
| video_np = video.cpu().numpy() | |
| _, height, width, _ = video_np.shape | |
| container = av.open(output_path, mode="w") | |
| stream = container.add_stream("libx264", rate=int(fps)) | |
| stream.width = width | |
| stream.height = height | |
| stream.pix_fmt = "yuv420p" | |
| if audio is not None: | |
| if audio_sample_rate is None: | |
| raise ValueError("audio_sample_rate is required when audio is provided") | |
| audio_stream = _prepare_audio_stream(container, audio_sample_rate) | |
| for frame_array in video_np: | |
| frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") | |
| for packet in stream.encode(frame): | |
| container.mux(packet) | |
| # Flush encoder | |
| for packet in stream.encode(): | |
| container.mux(packet) | |
| if audio is not None: | |
| _write_audio(container, audio_stream, audio, audio_sample_rate) | |
| container.close() | |