# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """Decoder blocks for WorldEngine modular pipeline.""" from typing import List, Union import numpy as np import PIL.Image import torch from diffusers import AutoModel from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor from diffusers.utils import logging from diffusers.modular_pipelines import ( ModularPipelineBlocks, ModularPipeline, PipelineState, ) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) logger = logging.get_logger(__name__) class WorldEngineDecodeStep(ModularPipelineBlocks): """Decodes denoised latents back to RGB image using VAE.""" model_name = "world_engine" @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoModel), ComponentSpec( "image_processor", VaeImageProcessor, config=FrozenDict( { "vae_scale_factor": 16, "do_normalize": False, "do_convert_rgb": True, } ), default_creation_method="from_config", ), ] @property def description(self) -> str: return "Decodes denoised latents to RGB image using the VAE decoder" @property def inputs(self) -> List[InputParam]: return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="Denoised latent tensor [1, 1, C, H, W]", ), InputParam( "output_type", default="pil", description="The output format for the generated images (pil, latent, pt, or np)", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "images", type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray], description="Decoded RGB image in requested output format", ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) latents = block_state.latents output_type = block_state.output_type or "pil" if output_type == "latent": block_state.images = latents else: # Decode to image # VAE expects [B, C, H, W] input, squeeze frame dim # VAE returns [H, W, 3] uint8 tensor image = components.vae.decode(latents.squeeze(1)) # Postprocess based on output_type if output_type == "pt": block_state.images = image elif output_type == "np": block_state.images = image.cpu().numpy() else: # "pil" block_state.images = PIL.Image.fromarray(image.cpu().numpy()) # Clear latents so next frame generates fresh random noise block_state.latents = None self.set_block_state(state, block_state) return components, state