Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from abc import ABC, abstractmethod | |
| import torch | |
| import torch.nn.functional as F | |
| from cosmos_predict1.utils.distributed import rank0_first | |
| from cosmos_predict1.utils.misc import load_from_s3_with_cache | |
| class BaseVAE(torch.nn.Module, ABC): | |
| """ | |
| Abstract base class for a Variational Autoencoder (VAE). | |
| All subclasses should implement the methods to define the behavior for encoding | |
| and decoding, along with specifying the latent channel size. | |
| """ | |
| def __init__(self, channel: int = 3, name: str = "vae"): | |
| super().__init__() | |
| self.channel = channel | |
| self.name = name | |
| def latent_ch(self) -> int: | |
| """ | |
| Returns the number of latent channels in the VAE. | |
| """ | |
| return self.channel | |
| def encode(self, state: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Encodes the input tensor into a latent representation. | |
| Args: | |
| - state (torch.Tensor): The input tensor to encode. | |
| Returns: | |
| - torch.Tensor: The encoded latent tensor. | |
| """ | |
| pass | |
| def decode(self, latent: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decodes the latent representation back to the original space. | |
| Args: | |
| - latent (torch.Tensor): The latent tensor to decode. | |
| Returns: | |
| - torch.Tensor: The decoded tensor. | |
| """ | |
| pass | |
| def spatial_compression_factor(self) -> int: | |
| """ | |
| Returns the spatial reduction factor for the VAE. | |
| """ | |
| raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") | |
| class BasePretrainedImageVAE(BaseVAE): | |
| """ | |
| A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values | |
| from a remote store, handles data type conversions, and normalization | |
| using provided mean and standard deviation values for latent space representation. | |
| Derived classes should load pre-trained encoder and decoder components from a remote store | |
| Attributes: | |
| latent_mean (Tensor): The mean used for normalizing the latent representation. | |
| latent_std (Tensor): The standard deviation used for normalizing the latent representation. | |
| dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. | |
| Args: | |
| mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. | |
| latent_ch (int, optional): Number of latent channels (default is 16). | |
| is_image (bool, optional): Flag to indicate whether the output is an image (default is True). | |
| is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). | |
| """ | |
| def __init__( | |
| self, | |
| name: str, | |
| mean_std_fp: str, | |
| latent_ch: int = 16, | |
| is_image: bool = True, | |
| is_bf16: bool = True, | |
| ) -> None: | |
| super().__init__(latent_ch, name) | |
| dtype = torch.bfloat16 if is_bf16 else torch.float32 | |
| self.dtype = dtype | |
| self.is_image = is_image | |
| self.mean_std_fp = mean_std_fp | |
| self.name = name | |
| self.backend_args = None | |
| self.register_mean_std(mean_std_fp) | |
| def register_mean_std(self, mean_std_fp: str) -> None: | |
| latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) | |
| target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] | |
| self.register_buffer( | |
| "latent_mean", | |
| latent_mean.to(self.dtype).reshape(*target_shape), | |
| persistent=False, | |
| ) | |
| self.register_buffer( | |
| "latent_std", | |
| latent_std.to(self.dtype).reshape(*target_shape), | |
| persistent=False, | |
| ) | |
| def encode(self, state: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Encode the input state to latent space; also handle the dtype conversion, mean and std scaling | |
| """ | |
| in_dtype = state.dtype | |
| latent_mean = self.latent_mean.to(in_dtype) | |
| latent_std = self.latent_std.to(in_dtype) | |
| encoded_state = self.encoder(state.to(self.dtype)) | |
| if isinstance(encoded_state, torch.Tensor): | |
| pass | |
| elif isinstance(encoded_state, tuple): | |
| assert isinstance(encoded_state[0], torch.Tensor) | |
| encoded_state = encoded_state[0] | |
| else: | |
| raise ValueError("Invalid type of encoded state") | |
| return (encoded_state.to(in_dtype) - latent_mean) / latent_std | |
| def decode(self, latent: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decode the input latent to state; also handle the dtype conversion, mean and std scaling | |
| """ | |
| in_dtype = latent.dtype | |
| latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) | |
| return self.decoder(latent.to(self.dtype)).to(in_dtype) | |
| def reset_dtype(self, *args, **kwargs): | |
| """ | |
| Resets the data type of the encoder and decoder to the model's default data type. | |
| Args: | |
| *args, **kwargs: Unused, present to allow flexibility in method calls. | |
| """ | |
| del args, kwargs | |
| self.decoder.to(self.dtype) | |
| self.encoder.to(self.dtype) | |
| class JITVAE(BasePretrainedImageVAE): | |
| """ | |
| A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder | |
| and decoder components from a remote store, handles data type conversions, and normalization | |
| using provided mean and standard deviation values for latent space representation. | |
| Attributes: | |
| encoder (Module): The JIT compiled encoder loaded from storage. | |
| decoder (Module): The JIT compiled decoder loaded from storage. | |
| latent_mean (Tensor): The mean used for normalizing the latent representation. | |
| latent_std (Tensor): The standard deviation used for normalizing the latent representation. | |
| dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. | |
| Args: | |
| enc_fp (str): File path to the encoder's JIT file on the remote store. | |
| dec_fp (str): File path to the decoder's JIT file on the remote store. | |
| name (str): Name of the model, used for differentiating cache file paths. | |
| mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. | |
| latent_ch (int, optional): Number of latent channels (default is 16). | |
| is_image (bool, optional): Flag to indicate whether the output is an image (default is True). | |
| is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). | |
| """ | |
| def __init__( | |
| self, | |
| enc_fp: str, | |
| dec_fp: str, | |
| name: str, | |
| mean_std_fp: str, | |
| latent_ch: int = 16, | |
| is_image: bool = True, | |
| is_bf16: bool = True, | |
| ): | |
| super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) | |
| self.load_encoder(enc_fp) | |
| self.load_decoder(dec_fp) | |
| def load_encoder(self, enc_fp: str) -> None: | |
| """ | |
| Load the encoder from the remote store. | |
| Args: | |
| - enc_fp (str): File path to the encoder's JIT file on the remote store. | |
| """ | |
| self.encoder = torch.jit.load(enc_fp, map_location="cuda") | |
| self.encoder.eval() | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| self.encoder.to(self.dtype) | |
| def load_decoder(self, dec_fp: str) -> None: | |
| """ | |
| Load the decoder from the remote store. | |
| Args: | |
| - dec_fp (str): File path to the decoder's JIT file on the remote store. | |
| """ | |
| self.decoder = torch.jit.load(dec_fp, map_location="cuda") | |
| self.decoder.eval() | |
| for param in self.decoder.parameters(): | |
| param.requires_grad = False | |
| self.decoder.to(self.dtype) | |
| class StateDictVAE(BasePretrainedImageVAE): | |
| """ | |
| A Variational Autoencoder (VAE) that loads pre-trained weights into | |
| provided encoder and decoder components from a remote store, handles data type conversions, | |
| and normalization using provided mean and standard deviation values for latent space representation. | |
| Attributes: | |
| encoder (Module): The encoder with weights loaded from storage. | |
| decoder (Module): The decoder with weights loaded from storage. | |
| latent_mean (Tensor): The mean used for normalizing the latent representation. | |
| latent_std (Tensor): The standard deviation used for normalizing the latent representation. | |
| dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. | |
| Args: | |
| enc_fp (str): File path to the encoder's JIT file on the remote store. | |
| dec_fp (str): File path to the decoder's JIT file on the remote store. | |
| vae (Module): Instance of VAE with not loaded weights | |
| name (str): Name of the model, used for differentiating cache file paths. | |
| mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. | |
| latent_ch (int, optional): Number of latent channels (default is 16). | |
| is_image (bool, optional): Flag to indicate whether the output is an image (default is True). | |
| is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). | |
| """ | |
| def __init__( | |
| self, | |
| enc_fp: str, | |
| dec_fp: str, | |
| vae: torch.nn.Module, | |
| name: str, | |
| mean_std_fp: str, | |
| latent_ch: int = 16, | |
| is_image: bool = True, | |
| is_bf16: bool = True, | |
| ): | |
| super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) | |
| self.load_encoder_and_decoder(enc_fp, dec_fp, vae) | |
| def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None: | |
| """ | |
| Load the encoder from the remote store. | |
| Args: | |
| - vae_fp (str): File path to the vae's state dict file on the remote store. | |
| - vae (str): VAE module into which weights will be loaded. | |
| """ | |
| state_dict_enc = load_from_s3_with_cache( | |
| enc_fp, | |
| f"vae/{self.name}_enc.jit", | |
| easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, | |
| backend_args=self.backend_args, | |
| ) | |
| state_dict_dec = load_from_s3_with_cache( | |
| dec_fp, | |
| f"vae/{self.name}_dec.jit", | |
| easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, | |
| backend_args=self.backend_args, | |
| ) | |
| jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict() | |
| jit_weights_state_dict = { | |
| k: v | |
| for k, v in jit_weights_state_dict.items() | |
| # Global variables captured by JIT | |
| if k | |
| not in ( | |
| "encoder.patcher.wavelets", | |
| "encoder.patcher._arange", | |
| "decoder.unpatcher.wavelets", | |
| "decoder.unpatcher._arange", | |
| ) | |
| } | |
| vae.load_state_dict(jit_weights_state_dict) | |
| vae.eval() | |
| for param in vae.parameters(): | |
| param.requires_grad = False | |
| vae.to(self.dtype) | |
| self.vae = vae | |
| self.encoder = self.vae.encode | |
| self.decoder = self.vae.decode | |
| def reset_dtype(self, *args, **kwargs): | |
| """ | |
| Resets the data type of the encoder and decoder to the model's default data type. | |
| Args: | |
| *args, **kwargs: Unused, present to allow flexibility in method calls. | |
| """ | |
| del args, kwargs | |
| self.vae.to(self.dtype) | |
| class SDVAE(BaseVAE): | |
| def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: | |
| super().__init__(channel=4, name="sd_vae") | |
| self.dtype = torch.bfloat16 | |
| self.register_buffer( | |
| "scale", | |
| torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), | |
| persistent=False, | |
| ) | |
| self.register_buffer( | |
| "bias", | |
| -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, | |
| persistent=False, | |
| ) | |
| self.batch_size = batch_size | |
| self.count_std = count_std | |
| self.is_downsample = is_downsample | |
| self.load_vae() | |
| self.reset_dtype() | |
| def reset_dtype(self, *args, **kwargs): | |
| del args, kwargs | |
| self.vae.to(self.dtype) | |
| def load_vae(self) -> None: | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" | |
| import diffusers | |
| vae_name = "stabilityai/sd-vae-ft-mse" | |
| try: | |
| vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) | |
| except: # noqa: E722 | |
| # Could not load the model from cache; try without local_files_only. | |
| vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) | |
| self.vae = vae.eval().requires_grad_(False) | |
| def encode(self, state: torch.Tensor) -> torch.Tensor: | |
| """ | |
| state : pixel range [-1, 1] | |
| """ | |
| if self.is_downsample: | |
| _h, _w = state.shape[-2:] | |
| state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) | |
| in_dtype = state.dtype | |
| state = state.to(self.dtype) | |
| state = (state + 1.0) / 2.0 | |
| latent_dist = self.vae.encode(state)["latent_dist"] | |
| mean, std = latent_dist.mean, latent_dist.std | |
| if self.count_std: | |
| latent = mean + torch.randn_like(mean) * std | |
| else: | |
| latent = mean | |
| latent = latent * self.scale | |
| latent = latent + self.bias | |
| return latent.to(in_dtype) | |
| def decode(self, latent: torch.Tensor) -> torch.Tensor: | |
| in_dtype = latent.dtype | |
| latent = latent.to(self.dtype) | |
| latent = latent - self.bias | |
| latent = latent / self.scale | |
| latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) | |
| if self.is_downsample: | |
| _h, _w = latent.shape[-2:] | |
| latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) | |
| return latent.to(in_dtype) * 2 - 1.0 | |
| def spatial_compression_factor(self) -> int: | |
| return 8 | |