| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| | import tempfile |
| | from collections import deque |
| | from typing import List |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from diffusers import FluxTransformer2DModel |
| | from diffusers.modular_pipelines import ( |
| | ComponentSpec, |
| | InputParam, |
| | ModularPipelineBlocks, |
| | OutputParam, |
| | PipelineState, |
| | WanModularPipeline, |
| | ) |
| |
|
| | from ..testing_utils import nightly, require_torch, slow |
| |
|
| |
|
| | class DummyCustomBlockSimple(ModularPipelineBlocks): |
| | def __init__(self, use_dummy_model_component=False): |
| | self.use_dummy_model_component = use_dummy_model_component |
| | super().__init__() |
| |
|
| | @property |
| | def expected_components(self): |
| | if self.use_dummy_model_component: |
| | return [ComponentSpec("transformer", FluxTransformer2DModel)] |
| | else: |
| | return [] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] |
| |
|
| | @property |
| | def intermediate_inputs(self) -> List[InputParam]: |
| | return [] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "output_prompt", |
| | type_hint=str, |
| | description="Modified prompt", |
| | ) |
| | ] |
| |
|
| | def __call__(self, components, state: PipelineState) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| |
|
| | old_prompt = block_state.prompt |
| | block_state.output_prompt = "Modular diffusers + " + old_prompt |
| | self.set_block_state(state, block_state) |
| |
|
| | return components, state |
| |
|
| |
|
| | CODE_STR = """ |
| | from diffusers.modular_pipelines import ( |
| | ComponentSpec, |
| | InputParam, |
| | ModularPipelineBlocks, |
| | OutputParam, |
| | PipelineState, |
| | WanModularPipeline, |
| | ) |
| | from typing import List |
| | |
| | class DummyCustomBlockSimple(ModularPipelineBlocks): |
| | def __init__(self, use_dummy_model_component=False): |
| | self.use_dummy_model_component = use_dummy_model_component |
| | super().__init__() |
| | |
| | @property |
| | def expected_components(self): |
| | if self.use_dummy_model_component: |
| | return [ComponentSpec("transformer", FluxTransformer2DModel)] |
| | else: |
| | return [] |
| | |
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] |
| | |
| | @property |
| | def intermediate_inputs(self) -> List[InputParam]: |
| | return [] |
| | |
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "output_prompt", |
| | type_hint=str, |
| | description="Modified prompt", |
| | ) |
| | ] |
| | |
| | def __call__(self, components, state: PipelineState) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | |
| | old_prompt = block_state.prompt |
| | block_state.output_prompt = "Modular diffusers + " + old_prompt |
| | self.set_block_state(state, block_state) |
| | |
| | return components, state |
| | """ |
| |
|
| |
|
| | class TestModularCustomBlocks: |
| | def _test_block_properties(self, block): |
| | assert not block.expected_components |
| | assert not block.intermediate_inputs |
| |
|
| | actual_inputs = [inp.name for inp in block.inputs] |
| | actual_intermediate_outputs = [out.name for out in block.intermediate_outputs] |
| | assert actual_inputs == ["prompt"] |
| | assert actual_intermediate_outputs == ["output_prompt"] |
| |
|
| | def test_custom_block_properties(self): |
| | custom_block = DummyCustomBlockSimple() |
| | self._test_block_properties(custom_block) |
| |
|
| | def test_custom_block_output(self): |
| | custom_block = DummyCustomBlockSimple() |
| | pipe = custom_block.init_pipeline() |
| | prompt = "Diffusers is nice" |
| | output = pipe(prompt=prompt) |
| |
|
| | actual_inputs = [inp.name for inp in custom_block.inputs] |
| | actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] |
| | assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) |
| |
|
| | output_prompt = output.values["output_prompt"] |
| | assert output_prompt.startswith("Modular diffusers + ") |
| |
|
| | def test_custom_block_saving_loading(self): |
| | custom_block = DummyCustomBlockSimple() |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | custom_block.save_pretrained(tmpdir) |
| | assert any("modular_config.json" in k for k in os.listdir(tmpdir)) |
| |
|
| | with open(os.path.join(tmpdir, "modular_config.json"), "r") as f: |
| | config = json.load(f) |
| | auto_map = config["auto_map"] |
| | assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"} |
| |
|
| | |
| | |
| | code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py") |
| | with open(code_path, "w") as f: |
| | f.write(CODE_STR) |
| |
|
| | loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True) |
| |
|
| | pipe = loaded_custom_block.init_pipeline() |
| | prompt = "Diffusers is nice" |
| | output = pipe(prompt=prompt) |
| |
|
| | actual_inputs = [inp.name for inp in loaded_custom_block.inputs] |
| | actual_intermediate_outputs = [out.name for out in loaded_custom_block.intermediate_outputs] |
| | assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) |
| |
|
| | output_prompt = output.values["output_prompt"] |
| | assert output_prompt.startswith("Modular diffusers + ") |
| |
|
| | def test_custom_block_supported_components(self): |
| | custom_block = DummyCustomBlockSimple(use_dummy_model_component=True) |
| | pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe") |
| | pipe.load_components() |
| |
|
| | assert len(pipe.components) == 1 |
| | assert pipe.component_names[0] == "transformer" |
| |
|
| | def test_custom_block_loads_from_hub(self): |
| | repo_id = "hf-internal-testing/tiny-modular-diffusers-block" |
| | block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) |
| | self._test_block_properties(block) |
| |
|
| | pipe = block.init_pipeline() |
| |
|
| | prompt = "Diffusers is nice" |
| | output = pipe(prompt=prompt) |
| | output_prompt = output.values["output_prompt"] |
| | assert output_prompt.startswith("Modular diffusers + ") |
| |
|
| |
|
| | @slow |
| | @nightly |
| | @require_torch |
| | class TestKreaCustomBlocksIntegration: |
| | repo_id = "krea/krea-realtime-video" |
| |
|
| | def test_loading_from_hub(self): |
| | blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) |
| | block_names = sorted(blocks.sub_blocks) |
| |
|
| | assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"]) |
| |
|
| | pipe = WanModularPipeline(blocks, self.repo_id) |
| | pipe.load_components( |
| | trust_remote_code=True, |
| | device_map="cuda", |
| | torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, |
| | ) |
| | assert len(pipe.components) == 7 |
| | assert sorted(pipe.components) == sorted( |
| | ["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"] |
| | ) |
| |
|
| | def test_forward(self): |
| | blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) |
| | pipe = WanModularPipeline(blocks, self.repo_id) |
| | pipe.load_components( |
| | trust_remote_code=True, |
| | device_map="cuda", |
| | torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, |
| | ) |
| |
|
| | num_frames_per_block = 2 |
| | num_blocks = 2 |
| |
|
| | state = PipelineState() |
| | state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len)) |
| |
|
| | prompt = ["a cat sitting on a boat"] |
| |
|
| | for block in pipe.transformer.blocks: |
| | block.self_attn.fuse_projections() |
| |
|
| | for block_idx in range(num_blocks): |
| | state = pipe( |
| | state, |
| | prompt=prompt, |
| | num_inference_steps=2, |
| | num_blocks=num_blocks, |
| | num_frames_per_block=num_frames_per_block, |
| | block_idx=block_idx, |
| | generator=torch.manual_seed(42), |
| | ) |
| | current_frames = np.array(state.values["videos"][0]) |
| | current_frames_flat = current_frames.flatten() |
| | actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist() |
| |
|
| | if block_idx == 0: |
| | assert current_frames.shape == (5, 480, 832, 3) |
| | expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193]) |
| | else: |
| | assert current_frames.shape == (8, 480, 832, 3) |
| | expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191]) |
| |
|
| | assert np.allclose(actual_slices, expected_slices) |
| |
|