Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A library for Causal Video Tokenizer inference.""" | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from cosmos_predict1.tokenizer.inference.utils import ( | |
| load_decoder_model, | |
| load_encoder_model, | |
| load_model, | |
| numpy2tensor, | |
| pad_video_batch, | |
| tensor2numpy, | |
| unpad_video_batch, | |
| ) | |
| class CausalVideoTokenizer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| checkpoint: str = None, | |
| checkpoint_enc: str = None, | |
| checkpoint_dec: str = None, | |
| tokenizer_config: dict[str, Any] = None, | |
| device: str = "cuda", | |
| dtype: str = "bfloat16", | |
| ) -> None: | |
| super().__init__() | |
| self._device = device | |
| self._dtype = getattr(torch, dtype) | |
| self._full_model = ( | |
| load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None | |
| ) | |
| self._enc_model = ( | |
| load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) | |
| if checkpoint_enc is not None | |
| else None | |
| ) | |
| self._dec_model = ( | |
| load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) | |
| if checkpoint_dec is not None | |
| else None | |
| ) | |
| def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: | |
| """Reconstrcuts a batch of video tensors after embedding into a latent. | |
| Args: | |
| video: The input video Bx3xTxHxW layout, range [-1..1]. | |
| Returns: | |
| The reconstructed video, layout Bx3xTxHxW, range [-1..1]. | |
| """ | |
| if self._full_model is not None: | |
| output_tensor = self._full_model(input_tensor) | |
| output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor | |
| else: | |
| output_latent = self.encode(input_tensor)[0] | |
| output_tensor = self.decode(output_latent) | |
| return output_tensor | |
| def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: | |
| """Encodes a numpy video into a CausalVideo latent or code. | |
| Args: | |
| input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. | |
| Returns: | |
| For causal continuous video (CV) tokenizer, the tuple contains: | |
| - The latent embedding, Bx16x(t)x(h)x(w), where the compression | |
| rate is (T/t x H/h x W/w), and channel dimension of 16. | |
| For causal discrete video (DV) tokenizer, the tuple contains: | |
| 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which | |
| is formed by FSQ levels of (8,8,8,5,5,5). | |
| 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate | |
| is again (T/t x H/h x W/w), and channel dimension of 6. | |
| """ | |
| assert input_tensor.ndim == 5, "input video should be of 5D." | |
| output_latent = self._enc_model(input_tensor) | |
| if isinstance(output_latent, torch.Tensor): | |
| return output_latent | |
| return output_latent[:-1] | |
| def decode(self, input_latent: torch.Tensor) -> torch.Tensor: | |
| """Encodes a numpy video into a CausalVideo latent. | |
| Args: | |
| input_latent: The continuous latent Bx16xtxhxw for CV, | |
| or the discrete indices Bxtxhxw for DV. | |
| Returns: | |
| The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. | |
| """ | |
| assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." | |
| return self._dec_model(input_latent) | |
| def forward( | |
| self, | |
| video: np.ndarray, | |
| temporal_window: int = 17, | |
| ) -> np.ndarray: | |
| """Reconstructs video using a pre-trained CausalTokenizer autoencoder. | |
| Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer | |
| in a sliding manner with a `temporal_window` size. | |
| Args: | |
| video: The input video BxTxHxWx3 layout, range [0..255]. | |
| temporal_window: The length of the temporal window to process, default=25. | |
| Returns: | |
| The reconstructed video in range [0..255], layout BxTxHxWx3. | |
| """ | |
| assert video.ndim == 5, "input video should be of 5D." | |
| num_frames = video.shape[1] # can be of any length. | |
| output_video_list = [] | |
| for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): | |
| # Input video for the current window. | |
| start, end = idx * temporal_window, (idx + 1) * temporal_window | |
| input_video = video[:, start:end, ...] | |
| # Spatio-temporally pad input_video so it's evenly divisible. | |
| padded_input_video, crop_region = pad_video_batch(input_video) | |
| input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) | |
| output_tensor = self.autoencode(input_tensor) | |
| padded_output_video = tensor2numpy(output_tensor) | |
| output_video = unpad_video_batch(padded_output_video, crop_region) | |
| output_video_list.append(output_video) | |
| return np.concatenate(output_video_list, axis=1) | |