|
|
from unittest import mock |
|
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext |
|
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType |
|
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback |
|
|
|
|
|
|
|
|
class MockExtension(ExtensionBase): |
|
|
"""A mock ExtensionBase subclass for testing purposes.""" |
|
|
|
|
|
def __init__(self, x: int): |
|
|
super().__init__() |
|
|
self._x = x |
|
|
|
|
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) |
|
|
def set_step_index(self, ctx: DenoiseContext): |
|
|
ctx.step_index = self._x |
|
|
|
|
|
|
|
|
def test_extension_base_callback_registration(): |
|
|
"""Test that a callback can be successfully registered with an extension.""" |
|
|
val = 5 |
|
|
mock_extension = MockExtension(val) |
|
|
|
|
|
mock_ctx = mock.MagicMock() |
|
|
|
|
|
callbacks = mock_extension.get_callbacks() |
|
|
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, []) |
|
|
assert len(pre_denoise_loop_cbs) == 1 |
|
|
|
|
|
|
|
|
pre_denoise_loop_cbs[0].function(mock_ctx) |
|
|
|
|
|
|
|
|
assert mock_ctx.step_index == val |
|
|
|
|
|
|
|
|
def test_extension_base_empty_callback_type(): |
|
|
"""Test that an empty list is returned when no callbacks are registered for a given callback type.""" |
|
|
mock_extension = MockExtension(5) |
|
|
|
|
|
|
|
|
callbacks = mock_extension.get_callbacks() |
|
|
|
|
|
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, []) |
|
|
assert len(post_denoise_loop_cbs) == 0 |
|
|
|