File size: 4,760 Bytes
57eef5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""VAE model for WorldEngine frame encoding/decoding."""

from dataclasses import dataclass
from typing import List, Tuple

import torch
from torch import Tensor

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .dcae import Encoder, Decoder


@dataclass
class EncoderDecoderConfig:
    """Config object for Encoder/Decoder initialization."""

    sample_size: Tuple[int, int]
    channels: int
    latent_channels: int
    ch_0: int
    ch_max: int
    encoder_blocks_per_stage: List[int]
    decoder_blocks_per_stage: List[int]
    use_middle_block: bool
    skip_logvar: bool = False
    skip_residuals: bool = False
    normalize_mu: bool = False


class WorldEngineVAE(ModelMixin, ConfigMixin):
    """
    VAE for encoding/decoding video frames using DCAE architecture.

    Encodes RGB uint8 images to latent space and decodes latents back to RGB.
    """

    _supports_gradient_checkpointing = False

    @register_to_config
    def __init__(
        self,
        # Common parameters
        sample_size: Tuple[int, int] = (360, 640),
        channels: int = 3,
        latent_channels: int = 16,
        # Encoder parameters
        encoder_ch_0: int = 64,
        encoder_ch_max: int = 256,
        encoder_blocks_per_stage: List[int] = None,
        # Decoder parameters
        decoder_ch_0: int = 128,
        decoder_ch_max: int = 1024,
        decoder_blocks_per_stage: List[int] = None,
        # Shared parameters
        use_middle_block: bool = False,
        skip_logvar: bool = False,
        # Scaling factors
        scale_factor: float = 1.0,
        shift_factor: float = 0.0,
    ):
        super().__init__()

        # Default blocks per stage
        if encoder_blocks_per_stage is None:
            encoder_blocks_per_stage = [1, 1, 1, 1]
        if decoder_blocks_per_stage is None:
            decoder_blocks_per_stage = [1, 1, 1, 1]

        # Create encoder config
        encoder_config = EncoderDecoderConfig(
            sample_size=tuple(sample_size),
            channels=channels,
            latent_channels=latent_channels,
            ch_0=encoder_ch_0,
            ch_max=encoder_ch_max,
            encoder_blocks_per_stage=list(encoder_blocks_per_stage),
            decoder_blocks_per_stage=list(decoder_blocks_per_stage),
            use_middle_block=use_middle_block,
            skip_logvar=skip_logvar,
        )

        # Create decoder config
        decoder_config = EncoderDecoderConfig(
            sample_size=tuple(sample_size),
            channels=channels,
            latent_channels=latent_channels,
            ch_0=decoder_ch_0,
            ch_max=decoder_ch_max,
            encoder_blocks_per_stage=list(encoder_blocks_per_stage),
            decoder_blocks_per_stage=list(decoder_blocks_per_stage),
            use_middle_block=use_middle_block,
            skip_logvar=skip_logvar,
        )

        self.encoder = Encoder(encoder_config)
        self.decoder = Decoder(decoder_config)

    def encode(self, img: Tensor):
        """RGB -> RGB+D -> latent"""
        assert img.dim() == 3, "Expected [H, W, C] image tensor"
        img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
        rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
        return self.encoder(rgb)

    @torch.compile
    def decode(self, latent: Tensor):
        decoded = self.decoder(latent)
        decoded = (decoded / 2 + 0.5).clamp(0, 1)
        decoded = (decoded * 255).round().to(torch.uint8)
        return decoded.squeeze(0).permute(1, 2, 0)[..., :3]

    def forward(self, x: Tensor, encode: bool = True) -> Tensor:
        """
        Forward pass - encode or decode based on flag.

        Args:
            x: Input tensor (image for encode, latent for decode)
            encode: If True, encode; if False, decode

        Returns:
            Encoded latent or decoded image
        """
        if encode:
            return self.encode(x)
        else:
            return self.decode(x)