Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utility functions for the inference libraries.""" | |
| import os | |
| from glob import glob | |
| from typing import Any | |
| import mediapy as media | |
| import numpy as np | |
| import torch | |
| from cosmos_predict1.tokenizer.networks import TokenizerModels | |
| _DTYPE, _DEVICE = torch.bfloat16, "cuda" | |
| _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) | |
| _SPATIAL_ALIGN = 16 | |
| _TEMPORAL_ALIGN = 8 | |
| def load_model( | |
| jit_filepath: str = None, | |
| tokenizer_config: dict[str, Any] = None, | |
| device: str = "cuda", | |
| ) -> torch.nn.Module | torch.jit.ScriptModule: | |
| """Loads a torch.nn.Module from a filepath. | |
| Args: | |
| jit_filepath: The filepath to the JIT-compiled model. | |
| device: The device to load the model onto, default=cuda. | |
| Returns: | |
| The JIT compiled model loaded to device and on eval mode. | |
| """ | |
| if tokenizer_config is None: | |
| return load_jit_model(jit_filepath, device) | |
| full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) | |
| full_model.load_state_dict(ckpts.state_dict(), strict=True) | |
| return full_model.eval().to(device) | |
| def load_encoder_model( | |
| jit_filepath: str = None, | |
| tokenizer_config: dict[str, Any] = None, | |
| device: str = "cuda", | |
| ) -> torch.nn.Module | torch.jit.ScriptModule: | |
| """Loads a torch.nn.Module from a filepath. | |
| Args: | |
| jit_filepath: The filepath to the JIT-compiled model. | |
| device: The device to load the model onto, default=cuda. | |
| Returns: | |
| The JIT compiled model loaded to device and on eval mode. | |
| """ | |
| if tokenizer_config is None: | |
| return load_jit_model(jit_filepath, device) | |
| full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) | |
| encoder_model = full_model.encoder_jit() | |
| encoder_model.load_state_dict(ckpts.state_dict(), strict=True) | |
| return encoder_model.eval().to(device) | |
| def load_decoder_model( | |
| jit_filepath: str = None, | |
| tokenizer_config: dict[str, Any] = None, | |
| device: str = "cuda", | |
| ) -> torch.nn.Module | torch.jit.ScriptModule: | |
| """Loads a torch.nn.Module from a filepath. | |
| Args: | |
| jit_filepath: The filepath to the JIT-compiled model. | |
| device: The device to load the model onto, default=cuda. | |
| Returns: | |
| The JIT compiled model loaded to device and on eval mode. | |
| """ | |
| if tokenizer_config is None: | |
| return load_jit_model(jit_filepath, device) | |
| full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) | |
| decoder_model = full_model.decoder_jit() | |
| decoder_model.load_state_dict(ckpts.state_dict(), strict=True) | |
| return decoder_model.eval().to(device) | |
| def _load_pytorch_model( | |
| jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" | |
| ) -> torch.nn.Module: | |
| """Loads a torch.nn.Module from a filepath. | |
| Args: | |
| jit_filepath: The filepath to the JIT-compiled model. | |
| device: The device to load the model onto, default=cuda. | |
| Returns: | |
| The JIT compiled model loaded to device and on eval mode. | |
| """ | |
| tokenizer_name = tokenizer_config["name"] | |
| model = TokenizerModels[tokenizer_name].value(**tokenizer_config) | |
| ckpts = torch.jit.load(jit_filepath, map_location=device) | |
| return model, ckpts | |
| def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: | |
| """Loads a torch.jit.ScriptModule from a filepath. | |
| Args: | |
| jit_filepath: The filepath to the JIT-compiled model. | |
| device: The device to load the model onto, default=cuda. | |
| Returns: | |
| The JIT compiled model loaded to device and on eval mode. | |
| """ | |
| model = torch.jit.load(jit_filepath, map_location=device) | |
| return model.eval().to(device) | |
| def save_jit_model( | |
| model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, | |
| jit_filepath: str = None, | |
| ) -> None: | |
| """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. | |
| Args: | |
| model: JIT compiled model loaded onto `config.checkpoint.jit.device`. | |
| jit_filepath: The filepath to the JIT-compiled model. | |
| """ | |
| torch.jit.save(model, jit_filepath) | |
| def get_filepaths(input_pattern) -> list[str]: | |
| """Returns a list of filepaths from a pattern.""" | |
| filepaths = sorted(glob(str(input_pattern))) | |
| return list(set(filepaths)) | |
| def get_output_filepath(filepath: str, output_dir: str = None) -> str: | |
| """Returns the output filepath for the given input filepath.""" | |
| output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" | |
| output_filepath = f"{output_dir}/{os.path.basename(filepath)}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| return output_filepath | |
| def read_image(filepath: str) -> np.ndarray: | |
| """Reads an image from a filepath. | |
| Args: | |
| filepath: The filepath to the image. | |
| Returns: | |
| The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. | |
| """ | |
| image = media.read_image(filepath) | |
| # convert the grey scale image to RGB | |
| # since our tokenizers always assume 3-channel RGB image | |
| if image.ndim == 2: | |
| image = np.stack([image] * 3, axis=-1) | |
| # convert RGBA to RGB | |
| if image.shape[-1] == 4: | |
| image = image[..., :3] | |
| return image | |
| def read_video(filepath: str) -> np.ndarray: | |
| """Reads a video from a filepath. | |
| Args: | |
| filepath: The filepath to the video. | |
| Returns: | |
| The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. | |
| """ | |
| video = media.read_video(filepath) | |
| # convert the grey scale frame to RGB | |
| # since our tokenizers always assume 3-channel video | |
| if video.ndim == 3: | |
| video = np.stack([video] * 3, axis=-1) | |
| # convert RGBA to RGB | |
| if video.shape[-1] == 4: | |
| video = video[..., :3] | |
| return video | |
| def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: | |
| """Resizes an image to have the short side of `short_size`. | |
| Args: | |
| image: The image to resize, layout HxWxC, of any range. | |
| short_size: The size of the short side. | |
| Returns: | |
| The resized image. | |
| """ | |
| if short_size is None: | |
| return image | |
| height, width = image.shape[-3:-1] | |
| if height <= width: | |
| height_new, width_new = short_size, int(width * short_size / height + 0.5) | |
| width_new = width_new if width_new % 2 == 0 else width_new + 1 | |
| else: | |
| height_new, width_new = ( | |
| int(height * short_size / width + 0.5), | |
| short_size, | |
| ) | |
| height_new = height_new if height_new % 2 == 0 else height_new + 1 | |
| return media.resize_image(image, shape=(height_new, width_new)) | |
| def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: | |
| """Resizes a video to have the short side of `short_size`. | |
| Args: | |
| video: The video to resize, layout TxHxWxC, of any range. | |
| short_size: The size of the short side. | |
| Returns: | |
| The resized video. | |
| """ | |
| if short_size is None: | |
| return video | |
| height, width = video.shape[-3:-1] | |
| if height <= width: | |
| height_new, width_new = short_size, int(width * short_size / height + 0.5) | |
| width_new = width_new if width_new % 2 == 0 else width_new + 1 | |
| else: | |
| height_new, width_new = ( | |
| int(height * short_size / width + 0.5), | |
| short_size, | |
| ) | |
| height_new = height_new if height_new % 2 == 0 else height_new + 1 | |
| return media.resize_video(video, shape=(height_new, width_new)) | |
| def write_image(filepath: str, image: np.ndarray): | |
| """Writes an image to a filepath.""" | |
| return media.write_image(filepath, image) | |
| def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: | |
| """Writes a video to a filepath.""" | |
| return media.write_video(filepath, video, fps=fps) | |
| def numpy2tensor( | |
| input_image: np.ndarray, | |
| dtype: torch.dtype = _DTYPE, | |
| device: str = _DEVICE, | |
| range_min: int = -1, | |
| ) -> torch.Tensor: | |
| """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. | |
| Args: | |
| input_image: A batch of images in range [0..255], BxHxWx3 layout. | |
| Returns: | |
| A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. | |
| """ | |
| ndim = input_image.ndim | |
| indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] | |
| image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F | |
| if range_min == -1: | |
| image = 2.0 * image - 1.0 | |
| return torch.from_numpy(image).to(dtype).to(device) | |
| def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: | |
| """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. | |
| Args: | |
| input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. | |
| Returns: | |
| A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. | |
| """ | |
| if range_min == -1: | |
| input_tensor = (input_tensor.float() + 1.0) / 2.0 | |
| ndim = input_tensor.ndim | |
| output_image = input_tensor.clamp(0, 1).cpu().numpy() | |
| output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) | |
| return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) | |
| def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]: | |
| """Pads a batch of images to be divisible by `spatial_align`. | |
| Args: | |
| batch: The batch of images to pad, layout BxHxWx3, in any range. | |
| align: The alignment to pad to. | |
| Returns: | |
| The padded batch and the crop region. | |
| """ | |
| height, width = batch.shape[1:3] | |
| align = spatial_align | |
| height_to_pad = (align - height % align) if height % align != 0 else 0 | |
| width_to_pad = (align - width % align) if width % align != 0 else 0 | |
| crop_region = [ | |
| height_to_pad >> 1, | |
| width_to_pad >> 1, | |
| height + (height_to_pad >> 1), | |
| width + (width_to_pad >> 1), | |
| ] | |
| batch = np.pad( | |
| batch, | |
| ( | |
| (0, 0), | |
| (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), | |
| (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), | |
| (0, 0), | |
| ), | |
| mode="constant", | |
| ) | |
| return batch, crop_region | |
| def pad_video_batch( | |
| batch: np.ndarray, | |
| temporal_align: int = _TEMPORAL_ALIGN, | |
| spatial_align: int = _SPATIAL_ALIGN, | |
| ) -> tuple[np.ndarray, list[int]]: | |
| """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. | |
| Zero pad spatially. Reflection pad temporally to handle causality better. | |
| Args: | |
| batch: The batch of videos to pad., layout BxFxHxWx3, in any range. | |
| align: The alignment to pad to. | |
| Returns: | |
| The padded batch and the crop region. | |
| """ | |
| num_frames, height, width = batch.shape[-4:-1] | |
| align = spatial_align | |
| height_to_pad = (align - height % align) if height % align != 0 else 0 | |
| width_to_pad = (align - width % align) if width % align != 0 else 0 | |
| align = temporal_align | |
| frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 | |
| crop_region = [ | |
| frames_to_pad >> 1, | |
| height_to_pad >> 1, | |
| width_to_pad >> 1, | |
| num_frames + (frames_to_pad >> 1), | |
| height + (height_to_pad >> 1), | |
| width + (width_to_pad >> 1), | |
| ] | |
| batch = np.pad( | |
| batch, | |
| ( | |
| (0, 0), | |
| (0, 0), | |
| (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), | |
| (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), | |
| (0, 0), | |
| ), | |
| mode="constant", | |
| ) | |
| batch = np.pad( | |
| batch, | |
| ( | |
| (0, 0), | |
| (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), | |
| (0, 0), | |
| (0, 0), | |
| (0, 0), | |
| ), | |
| mode="edge", | |
| ) | |
| return batch, crop_region | |
| def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: | |
| """Unpads video with `crop_region`. | |
| Args: | |
| batch: A batch of numpy videos, layout BxFxHxWxC. | |
| crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. | |
| Returns: | |
| np.ndarray: Cropped numpy video, layout BxFxHxWxC. | |
| """ | |
| assert len(crop_region) == 6, "crop_region should be len of 6." | |
| f1, y1, x1, f2, y2, x2 = crop_region | |
| return batch[..., f1:f2, y1:y2, x1:x2, :] | |
| def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: | |
| """Unpads image with `crop_region`. | |
| Args: | |
| batch: A batch of numpy images, layout BxHxWxC. | |
| crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. | |
| Returns: | |
| np.ndarray: Cropped numpy image, layout BxHxWxC. | |
| """ | |
| assert len(crop_region) == 4, "crop_region should be len of 4." | |
| y1, x1, y2, x2 = crop_region | |
| return batch[..., y1:y2, x1:x2, :] | |