# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import numpy as np import PIL import torch from ...configuration_utils import FrozenDict from ...models import AutoencoderKLWan from ...utils import logging from ...video_processor import VideoProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) # pylint: disable=invalid-name class WanVaeDecoderStep(ModularPipelineBlocks): model_name = "wan" @property def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKLWan), ComponentSpec( "video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config", ), ] @property def description(self) -> str: return "Step that decodes the denoised latents into images" @property def inputs(self) -> list[tuple[str, Any]]: return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step", ), InputParam( "output_type", default="np", type_hint=str, description="The output type of the decoded videos" ), ] @property def intermediate_outputs(self) -> list[str]: return [ OutputParam( "videos", type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray], description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", ) ] @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) vae_dtype = components.vae.dtype latents = block_state.latents latents_mean = ( torch.tensor(components.vae.config.latents_mean) .view(1, components.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( 1, components.vae.config.z_dim, 1, 1, 1 ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean latents = latents.to(vae_dtype) block_state.videos = components.vae.decode(latents, return_dict=False)[0] output_type = getattr(block_state, "output_type", "np") block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type=output_type) self.set_block_state(state, block_state) return components, state