File size: 4,726 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler


class SessionRunnerBase(ABC):
    """
    Base class for session runner.
    """

    @abstractmethod
    def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
        """Starts the session runner.

        Args:
            services: The invocation services.
            cancel_event: The cancel event.
            profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
                stats will be still be recorded and logged when profiling is disabled.
        """
        pass

    @abstractmethod
    def run(self, queue_item: SessionQueueItem) -> None:
        """Runs a session.

        Args:
            queue_item: The session to run.
        """
        pass

    @abstractmethod
    def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
        """Run a single node in the graph.

        Args:
            invocation: The invocation to run.
            queue_item: The session queue item.
        """
        pass


class SessionProcessorBase(ABC):
    """
    Base class for session processor.

    The session processor is responsible for executing sessions. It runs a simple polling loop,
    checking the session queue for new sessions to execute. It must coordinate with the
    invocation queue to ensure only one session is executing at a time.
    """

    @abstractmethod
    def resume(self) -> SessionProcessorStatus:
        """Starts or resumes the session processor"""
        pass

    @abstractmethod
    def pause(self) -> SessionProcessorStatus:
        """Pauses the session processor"""
        pass

    @abstractmethod
    def get_status(self) -> SessionProcessorStatus:
        """Gets the status of the session processor"""
        pass


class OnBeforeRunNode(Protocol):
    def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
        """Callback to run before executing a node.

        Args:
            invocation: The invocation that will be executed.
            queue_item: The session queue item.
        """
        ...


class OnAfterRunNode(Protocol):
    def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
        """Callback to run before executing a node.

        Args:
            invocation: The invocation that was executed.
            queue_item: The session queue item.
        """
        ...


class OnNodeError(Protocol):
    def __call__(
        self,
        invocation: BaseInvocation,
        queue_item: SessionQueueItem,
        error_type: str,
        error_message: str,
        error_traceback: str,
    ) -> None:
        """Callback to run when a node has an error.

        Args:
            invocation: The invocation that errored.
            queue_item: The session queue item.
            error_type: The type of error, e.g. "ValueError".
            error_message: The error message, e.g. "Invalid value".
            error_traceback: The stringified error traceback.
        """
        ...


class OnBeforeRunSession(Protocol):
    def __call__(self, queue_item: SessionQueueItem) -> None:
        """Callback to run before executing a session.

        Args:
            queue_item: The session queue item.
        """
        ...


class OnAfterRunSession(Protocol):
    def __call__(self, queue_item: SessionQueueItem) -> None:
        """Callback to run after executing a session.

        Args:
            queue_item: The session queue item.
        """
        ...


class OnNonFatalProcessorError(Protocol):
    def __call__(
        self,
        queue_item: Optional[SessionQueueItem],
        error_type: str,
        error_message: str,
        error_traceback: str,
    ) -> None:
        """Callback to run when a non-fatal error occurs in the processor.

        Args:
            queue_item: The session queue item, if one was being executed when the error occurred.
            error_type: The type of error, e.g. "ValueError".
            error_message: The error message, e.g. "Invalid value".
            error_traceback: The stringified error traceback.
        """
        ...