|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""VAE model for WorldEngine frame encoding/decoding.""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from .dcae import Encoder, Decoder, bake_weight_norm |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EncoderDecoderConfig: |
|
|
"""Config object for Encoder/Decoder initialization.""" |
|
|
|
|
|
channels: int |
|
|
latent_channels: int |
|
|
ch_0: int |
|
|
ch_max: int |
|
|
encoder_blocks_per_stage: List[int] |
|
|
decoder_blocks_per_stage: List[int] |
|
|
skip_logvar: bool = False |
|
|
|
|
|
|
|
|
class WorldEngineVAE(ModelMixin, ConfigMixin): |
|
|
""" |
|
|
VAE for encoding/decoding video frames using DCAE architecture. |
|
|
|
|
|
Encodes RGB uint8 images to latent space and decodes latents back to RGB. |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = False |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
sample_size: Tuple[int, int] = (360, 640), |
|
|
channels: int = 3, |
|
|
latent_channels: int = 16, |
|
|
|
|
|
encoder_ch_0: int = 64, |
|
|
encoder_ch_max: int = 256, |
|
|
encoder_blocks_per_stage: List[int] = None, |
|
|
|
|
|
decoder_ch_0: int = 128, |
|
|
decoder_ch_max: int = 1024, |
|
|
decoder_blocks_per_stage: List[int] = None, |
|
|
|
|
|
skip_logvar: bool = False, |
|
|
|
|
|
scale_factor: float = 1.0, |
|
|
shift_factor: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if encoder_blocks_per_stage is None: |
|
|
encoder_blocks_per_stage = [1, 1, 1, 1] |
|
|
if decoder_blocks_per_stage is None: |
|
|
decoder_blocks_per_stage = [1, 1, 1, 1] |
|
|
|
|
|
|
|
|
encoder_config = EncoderDecoderConfig( |
|
|
channels=channels, |
|
|
latent_channels=latent_channels, |
|
|
ch_0=encoder_ch_0, |
|
|
ch_max=encoder_ch_max, |
|
|
encoder_blocks_per_stage=list(encoder_blocks_per_stage), |
|
|
decoder_blocks_per_stage=list(decoder_blocks_per_stage), |
|
|
skip_logvar=skip_logvar, |
|
|
) |
|
|
|
|
|
|
|
|
decoder_config = EncoderDecoderConfig( |
|
|
channels=channels, |
|
|
latent_channels=latent_channels, |
|
|
ch_0=decoder_ch_0, |
|
|
ch_max=decoder_ch_max, |
|
|
encoder_blocks_per_stage=list(encoder_blocks_per_stage), |
|
|
decoder_blocks_per_stage=list(decoder_blocks_per_stage), |
|
|
skip_logvar=skip_logvar, |
|
|
) |
|
|
|
|
|
self.encoder = Encoder(encoder_config) |
|
|
self.decoder = Decoder(decoder_config) |
|
|
|
|
|
def encode(self, img: Tensor): |
|
|
"""RGB -> RGB+D -> latent""" |
|
|
assert img.dim() == 3, "Expected [H, W, C] image tensor" |
|
|
img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype) |
|
|
rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1) |
|
|
return self.encoder(rgb) |
|
|
|
|
|
def decode(self, latent: Tensor): |
|
|
decoded = self.decoder(latent) |
|
|
decoded = (decoded / 2 + 0.5).clamp(0, 1) |
|
|
decoded = (decoded * 255).round().to(torch.uint8) |
|
|
return decoded.squeeze(0).permute(1, 2, 0)[..., :3] |
|
|
|
|
|
def forward(self, x: Tensor, encode: bool = True) -> Tensor: |
|
|
""" |
|
|
Forward pass - encode or decode based on flag. |
|
|
|
|
|
Args: |
|
|
x: Input tensor (image for encode, latent for decode) |
|
|
encode: If True, encode; if False, decode |
|
|
|
|
|
Returns: |
|
|
Encoded latent or decoded image |
|
|
""" |
|
|
if encode: |
|
|
return self.encode(x) |
|
|
else: |
|
|
return self.decode(x) |
|
|
|
|
|
def bake_weight_norm(self): |
|
|
"""Remove weight_norm parametrizations, baking normalized weights into regular tensors. |
|
|
|
|
|
Call this after loading weights and before torch.compile to avoid |
|
|
CUDA graph capture errors from in-place weight updates. |
|
|
""" |
|
|
bake_weight_norm(self) |
|
|
return self |
|
|
|