|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Decoder blocks for WorldEngine modular pipeline.""" |
|
|
|
|
|
from typing import List, Union |
|
|
|
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import torch |
|
|
|
|
|
from diffusers import AutoModel |
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from diffusers.image_processor import VaeImageProcessor |
|
|
from diffusers.utils import logging |
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipelineBlocks, |
|
|
ModularPipeline, |
|
|
PipelineState, |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class WorldEngineDecodeStep(ModularPipelineBlocks): |
|
|
"""Decodes denoised latents back to RGB image using VAE.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("vae", AutoModel), |
|
|
ComponentSpec( |
|
|
"image_processor", |
|
|
VaeImageProcessor, |
|
|
config=FrozenDict( |
|
|
{ |
|
|
"vae_scale_factor": 16, |
|
|
"do_normalize": False, |
|
|
"do_convert_rgb": True, |
|
|
} |
|
|
), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Decodes denoised latents to RGB image using the VAE decoder" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Denoised latent tensor [1, 1, C, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"output_type", |
|
|
default="pil", |
|
|
description="The output format for the generated images (pil, latent, pt, or np)", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"images", |
|
|
type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray], |
|
|
description="Decoded RGB image in requested output format", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
latents = block_state.latents |
|
|
output_type = block_state.output_type or "pil" |
|
|
|
|
|
if output_type == "latent": |
|
|
block_state.images = latents |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
image = components.vae.decode(latents.squeeze(1)) |
|
|
|
|
|
|
|
|
if output_type == "pt": |
|
|
block_state.images = image |
|
|
elif output_type == "np": |
|
|
block_state.images = image.cpu().numpy() |
|
|
else: |
|
|
block_state.images = PIL.Image.fromarray(image.cpu().numpy()) |
|
|
|
|
|
|
|
|
block_state.latents = None |
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|