File size: 5,475 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Any, Callable, Union

from invokeai.app.invocations.baseinvocation import (
    BaseInvocation,
    BaseInvocationOutput,
    invocation,
    invocation_output,
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContext


# Define test invocations before importing anything that uses invocations
@invocation_output("test_list_output")
class ListPassThroughInvocationOutput(BaseInvocationOutput):
    collection: list[ImageField] = OutputField(default=[])


@invocation("test_list", version="1.0.0")
class ListPassThroughInvocation(BaseInvocation):
    collection: list[ImageField] = InputField(default=[])

    def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
        return ListPassThroughInvocationOutput(collection=self.collection)


@invocation_output("test_prompt_output")
class PromptTestInvocationOutput(BaseInvocationOutput):
    prompt: str = OutputField(default="")


@invocation("test_prompt", version="1.0.0")
class PromptTestInvocation(BaseInvocation):
    prompt: str = InputField(default="")

    def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
        return PromptTestInvocationOutput(prompt=self.prompt)


@invocation("test_error", version="1.0.0")
class ErrorInvocation(BaseInvocation):
    def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
        raise Exception("This invocation is supposed to fail")


@invocation_output("test_image_output")
class ImageTestInvocationOutput(BaseInvocationOutput):
    image: ImageField = OutputField()


@invocation("test_text_to_image", version="1.0.0")
class TextToImageTestInvocation(BaseInvocation):
    prompt: str = InputField(default="")
    prompt2: str = InputField(default="")

    def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
        return ImageTestInvocationOutput(image=ImageField(image_name=self.id))


@invocation("test_image_to_image", version="1.0.0")
class ImageToImageTestInvocation(BaseInvocation):
    prompt: str = InputField(default="")
    image: Union[ImageField, None] = InputField(default=None)

    def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
        return ImageTestInvocationOutput(image=ImageField(image_name=self.id))


@invocation_output("test_prompt_collection_output")
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
    collection: list[str] = OutputField(default=[])


@invocation("test_prompt_collection", version="1.0.0")
class PromptCollectionTestInvocation(BaseInvocation):
    collection: list[str] = InputField()

    def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
        return PromptCollectionTestInvocationOutput(collection=self.collection.copy())


@invocation_output("test_any_output")
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
    value: Any = OutputField()


@invocation("test_any", version="1.0.0")
class AnyTypeTestInvocation(BaseInvocation):
    value: Any = InputField(default=None)

    def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
        return AnyTypeTestInvocationOutput(value=self.value)


@invocation("test_polymorphic", version="1.0.0")
class PolymorphicStringTestInvocation(BaseInvocation):
    value: Union[str, list[str]] = InputField(default="")

    def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
        if isinstance(self.value, str):
            return PromptCollectionTestInvocationOutput(collection=[self.value])
        return PromptCollectionTestInvocationOutput(collection=self.value)


# Importing these must happen after test invocations are defined or they won't register
from invokeai.app.services.events.events_base import EventServiceBase  # noqa: E402
from invokeai.app.services.shared.graph import Edge, EdgeConnection  # noqa: E402


def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
    return Edge(
        source=EdgeConnection(node_id=from_id, field=from_field),
        destination=EdgeConnection(node_id=to_id, field=to_field),
    )


class TestEvent(EventBase):
    __test__ = False  # not a pytest test case

    __event_name__ = "test_event"


class TestEventService(EventServiceBase):
    __test__ = False  # not a pytest test case

    def __init__(self):
        super().__init__()
        self.events: list[EventBase] = []

    def dispatch(self, event: EventBase) -> None:
        self.events.append(event)
        pass

    def emit_invocation_progress(
        self,
        queue_item: "SessionQueueItem",
        invocation: "BaseInvocation",
        message: str,
        percentage: float | None = None,
        image: "ProgressImage | None" = None,
    ) -> None:
        pass


def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None:
    import time

    start_time = time.time()
    while time.time() - start_time < timeout:
        if condition():
            return
        time.sleep(interval)
    raise TimeoutError("Condition not met")