|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Block registry for WorldEngine modular pipeline.""" |
|
|
|
|
|
from diffusers.utils import logging |
|
|
from diffusers.modular_pipelines import SequentialPipelineBlocks |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict |
|
|
|
|
|
from .encoders import WorldEngineTextEncoderStep, WorldEngineControllerEncoderStep |
|
|
from .before_denoise import WorldEngineBeforeDenoiseStep |
|
|
from .denoise import WorldEngineDenoiseLoop |
|
|
from .decoders import WorldEngineDecodeStep |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
AUTO_BLOCKS = InsertableDict( |
|
|
[ |
|
|
("text_encoder", WorldEngineTextEncoderStep), |
|
|
("controller_encoder", WorldEngineControllerEncoderStep), |
|
|
("before_denoise", WorldEngineBeforeDenoiseStep), |
|
|
("denoise", WorldEngineDenoiseLoop), |
|
|
("decode", WorldEngineDecodeStep), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
class WorldEngineBlocks(SequentialPipelineBlocks): |
|
|
"""Sequential pipeline blocks for WorldEngine frame generation.""" |
|
|
|
|
|
block_classes = list(AUTO_BLOCKS.copy().values()) |
|
|
block_names = list(AUTO_BLOCKS.copy().keys()) |
|
|
|