File size: 1,588 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from unittest import mock

from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback


class MockExtension(ExtensionBase):
    """A mock ExtensionBase subclass for testing purposes."""

    def __init__(self, x: int):
        super().__init__()
        self._x = x

    @callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
    def set_step_index(self, ctx: DenoiseContext):
        ctx.step_index = self._x


def test_extension_base_callback_registration():
    """Test that a callback can be successfully registered with an extension."""
    val = 5
    mock_extension = MockExtension(val)

    mock_ctx = mock.MagicMock()

    callbacks = mock_extension.get_callbacks()
    pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
    assert len(pre_denoise_loop_cbs) == 1

    # Call the mock callback.
    pre_denoise_loop_cbs[0].function(mock_ctx)

    # Confirm that the callback ran.
    assert mock_ctx.step_index == val


def test_extension_base_empty_callback_type():
    """Test that an empty list is returned when no callbacks are registered for a given callback type."""
    mock_extension = MockExtension(5)

    # There should be no callbacks registered for POST_DENOISE_LOOP.
    callbacks = mock_extension.get_callbacks()

    post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
    assert len(post_denoise_loop_cbs) == 0