File size: 1,735 Bytes
57eef5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Block registry for WorldEngine modular pipeline."""
from diffusers.utils import logging
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
from .encoders import WorldEngineTextEncoderStep, WorldEngineControllerEncoderStep
from .before_denoise import WorldEngineBeforeDenoiseStep
from .denoise import WorldEngineDenoiseLoop
from .decoders import WorldEngineDecodeStep
logger = logging.get_logger(__name__)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WorldEngineTextEncoderStep),
("controller_encoder", WorldEngineControllerEncoderStep),
("before_denoise", WorldEngineBeforeDenoiseStep),
("denoise", WorldEngineDenoiseLoop),
("decode", WorldEngineDecodeStep),
]
)
class WorldEngineBlocks(SequentialPipelineBlocks):
"""Sequential pipeline blocks for WorldEngine frame generation."""
block_classes = list(AUTO_BLOCKS.copy().values())
block_names = list(AUTO_BLOCKS.copy().keys())
|