File size: 3,993 Bytes
57eef5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""Decoder blocks for WorldEngine modular pipeline."""

from typing import List, Union

import numpy as np
import PIL.Image
import torch

from diffusers import AutoModel
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import logging
from diffusers.modular_pipelines import (
    ModularPipelineBlocks,
    ModularPipeline,
    PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
    ComponentSpec,
    InputParam,
    OutputParam,
)

logger = logging.get_logger(__name__)


class WorldEngineDecodeStep(ModularPipelineBlocks):
    """Decodes denoised latents back to RGB image using VAE."""

    model_name = "world_engine"

    @property
    def expected_components(self) -> List[ComponentSpec]:
        return [
            ComponentSpec("vae", AutoModel),
            ComponentSpec(
                "image_processor",
                VaeImageProcessor,
                config=FrozenDict(
                    {
                        "vae_scale_factor": 16,
                        "do_normalize": False,
                        "do_convert_rgb": True,
                    }
                ),
                default_creation_method="from_config",
            ),
        ]

    @property
    def description(self) -> str:
        return "Decodes denoised latents to RGB image using the VAE decoder"

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam(
                "latents",
                required=True,
                type_hint=torch.Tensor,
                description="Denoised latent tensor [1, 1, C, H, W]",
            ),
            InputParam(
                "output_type",
                default="pil",
                description="The output format for the generated images (pil, latent, pt, or np)",
            ),
        ]

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "images",
                type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray],
                description="Decoded RGB image in requested output format",
            ),
        ]

    @torch.no_grad()
    def __call__(
        self, components: ModularPipeline, state: PipelineState
    ) -> PipelineState:
        block_state = self.get_block_state(state)
        latents = block_state.latents
        output_type = block_state.output_type or "pil"

        if output_type == "latent":
            block_state.images = latents
        else:
            # Decode to image
            # VAE expects [B, C, H, W] input, squeeze frame dim
            # VAE returns [H, W, 3] uint8 tensor
            image = components.vae.decode(latents.squeeze(1))

            # Postprocess based on output_type
            if output_type == "pt":
                block_state.images = image
            elif output_type == "np":
                block_state.images = image.cpu().numpy()
            else:  # "pil"
                block_state.images = PIL.Image.fromarray(image.cpu().numpy())

        # Clear latents so next frame generates fresh random noise
        block_state.latents = None
        self.set_block_state(state, block_state)
        return components, state