Waypoint-1-Small / vae /ae_model.py
dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
4.74 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""VAE model for WorldEngine frame encoding/decoding."""
from dataclasses import dataclass
from typing import List, Tuple
import torch
from torch import Tensor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .dcae import Encoder, Decoder, bake_weight_norm
@dataclass
class EncoderDecoderConfig:
"""Config object for Encoder/Decoder initialization."""
channels: int
latent_channels: int
ch_0: int
ch_max: int
encoder_blocks_per_stage: List[int]
decoder_blocks_per_stage: List[int]
skip_logvar: bool = False
class WorldEngineVAE(ModelMixin, ConfigMixin):
"""
VAE for encoding/decoding video frames using DCAE architecture.
Encodes RGB uint8 images to latent space and decodes latents back to RGB.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
# Common parameters
sample_size: Tuple[int, int] = (360, 640),
channels: int = 3,
latent_channels: int = 16,
# Encoder parameters
encoder_ch_0: int = 64,
encoder_ch_max: int = 256,
encoder_blocks_per_stage: List[int] = None,
# Decoder parameters
decoder_ch_0: int = 128,
decoder_ch_max: int = 1024,
decoder_blocks_per_stage: List[int] = None,
# Shared parameters
skip_logvar: bool = False,
# Scaling factors
scale_factor: float = 1.0,
shift_factor: float = 0.0,
):
super().__init__()
# Default blocks per stage
if encoder_blocks_per_stage is None:
encoder_blocks_per_stage = [1, 1, 1, 1]
if decoder_blocks_per_stage is None:
decoder_blocks_per_stage = [1, 1, 1, 1]
# Create encoder config
encoder_config = EncoderDecoderConfig(
channels=channels,
latent_channels=latent_channels,
ch_0=encoder_ch_0,
ch_max=encoder_ch_max,
encoder_blocks_per_stage=list(encoder_blocks_per_stage),
decoder_blocks_per_stage=list(decoder_blocks_per_stage),
skip_logvar=skip_logvar,
)
# Create decoder config
decoder_config = EncoderDecoderConfig(
channels=channels,
latent_channels=latent_channels,
ch_0=decoder_ch_0,
ch_max=decoder_ch_max,
encoder_blocks_per_stage=list(encoder_blocks_per_stage),
decoder_blocks_per_stage=list(decoder_blocks_per_stage),
skip_logvar=skip_logvar,
)
self.encoder = Encoder(encoder_config)
self.decoder = Decoder(decoder_config)
def encode(self, img: Tensor):
"""RGB -> RGB+D -> latent"""
assert img.dim() == 3, "Expected [H, W, C] image tensor"
img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
return self.encoder(rgb)
def decode(self, latent: Tensor):
decoded = self.decoder(latent)
decoded = (decoded / 2 + 0.5).clamp(0, 1)
decoded = (decoded * 255).round().to(torch.uint8)
return decoded.squeeze(0).permute(1, 2, 0)[..., :3]
def forward(self, x: Tensor, encode: bool = True) -> Tensor:
"""
Forward pass - encode or decode based on flag.
Args:
x: Input tensor (image for encode, latent for decode)
encode: If True, encode; if False, decode
Returns:
Encoded latent or decoded image
"""
if encode:
return self.encode(x)
else:
return self.decode(x)
def bake_weight_norm(self):
"""Remove weight_norm parametrizations, baking normalized weights into regular tensors.
Call this after loading weights and before torch.compile to avoid
CUDA graph capture errors from in-place weight updates.
"""
bake_weight_norm(self)
return self