# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """VAE model for WorldEngine frame encoding/decoding.""" from dataclasses import dataclass from typing import List, Tuple import torch from torch import Tensor from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from .dcae import Encoder, Decoder, bake_weight_norm @dataclass class EncoderDecoderConfig: """Config object for Encoder/Decoder initialization.""" channels: int latent_channels: int ch_0: int ch_max: int encoder_blocks_per_stage: List[int] decoder_blocks_per_stage: List[int] skip_logvar: bool = False class WorldEngineVAE(ModelMixin, ConfigMixin): """ VAE for encoding/decoding video frames using DCAE architecture. Encodes RGB uint8 images to latent space and decodes latents back to RGB. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, # Common parameters sample_size: Tuple[int, int] = (360, 640), channels: int = 3, latent_channels: int = 16, # Encoder parameters encoder_ch_0: int = 64, encoder_ch_max: int = 256, encoder_blocks_per_stage: List[int] = None, # Decoder parameters decoder_ch_0: int = 128, decoder_ch_max: int = 1024, decoder_blocks_per_stage: List[int] = None, # Shared parameters skip_logvar: bool = False, # Scaling factors scale_factor: float = 1.0, shift_factor: float = 0.0, ): super().__init__() # Default blocks per stage if encoder_blocks_per_stage is None: encoder_blocks_per_stage = [1, 1, 1, 1] if decoder_blocks_per_stage is None: decoder_blocks_per_stage = [1, 1, 1, 1] # Create encoder config encoder_config = EncoderDecoderConfig( channels=channels, latent_channels=latent_channels, ch_0=encoder_ch_0, ch_max=encoder_ch_max, encoder_blocks_per_stage=list(encoder_blocks_per_stage), decoder_blocks_per_stage=list(decoder_blocks_per_stage), skip_logvar=skip_logvar, ) # Create decoder config decoder_config = EncoderDecoderConfig( channels=channels, latent_channels=latent_channels, ch_0=decoder_ch_0, ch_max=decoder_ch_max, encoder_blocks_per_stage=list(encoder_blocks_per_stage), decoder_blocks_per_stage=list(decoder_blocks_per_stage), skip_logvar=skip_logvar, ) self.encoder = Encoder(encoder_config) self.decoder = Decoder(decoder_config) def encode(self, img: Tensor): """RGB -> RGB+D -> latent""" assert img.dim() == 3, "Expected [H, W, C] image tensor" img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype) rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1) return self.encoder(rgb) def decode(self, latent: Tensor): decoded = self.decoder(latent) decoded = (decoded / 2 + 0.5).clamp(0, 1) decoded = (decoded * 255).round().to(torch.uint8) return decoded.squeeze(0).permute(1, 2, 0)[..., :3] def forward(self, x: Tensor, encode: bool = True) -> Tensor: """ Forward pass - encode or decode based on flag. Args: x: Input tensor (image for encode, latent for decode) encode: If True, encode; if False, decode Returns: Encoded latent or decoded image """ if encode: return self.encode(x) else: return self.decode(x) def bake_weight_norm(self): """Remove weight_norm parametrizations, baking normalized weights into regular tensors. Call this after loading weights and before torch.compile to avoid CUDA graph capture errors from in-place weight updates. """ bake_weight_norm(self) return self