from abc import ABC, abstractmethod from threading import Event from typing import Optional, Protocol from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.util.profiler import Profiler class SessionRunnerBase(ABC): """ Base class for session runner. """ @abstractmethod def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: """Starts the session runner. Args: services: The invocation services. cancel_event: The cancel event. profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session stats will be still be recorded and logged when profiling is disabled. """ pass @abstractmethod def run(self, queue_item: SessionQueueItem) -> None: """Runs a session. Args: queue_item: The session to run. """ pass @abstractmethod def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: """Run a single node in the graph. Args: invocation: The invocation to run. queue_item: The session queue item. """ pass class SessionProcessorBase(ABC): """ Base class for session processor. The session processor is responsible for executing sessions. It runs a simple polling loop, checking the session queue for new sessions to execute. It must coordinate with the invocation queue to ensure only one session is executing at a time. """ @abstractmethod def resume(self) -> SessionProcessorStatus: """Starts or resumes the session processor""" pass @abstractmethod def pause(self) -> SessionProcessorStatus: """Pauses the session processor""" pass @abstractmethod def get_status(self) -> SessionProcessorStatus: """Gets the status of the session processor""" pass class OnBeforeRunNode(Protocol): def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: """Callback to run before executing a node. Args: invocation: The invocation that will be executed. queue_item: The session queue item. """ ... class OnAfterRunNode(Protocol): def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None: """Callback to run before executing a node. Args: invocation: The invocation that was executed. queue_item: The session queue item. """ ... class OnNodeError(Protocol): def __call__( self, invocation: BaseInvocation, queue_item: SessionQueueItem, error_type: str, error_message: str, error_traceback: str, ) -> None: """Callback to run when a node has an error. Args: invocation: The invocation that errored. queue_item: The session queue item. error_type: The type of error, e.g. "ValueError". error_message: The error message, e.g. "Invalid value". error_traceback: The stringified error traceback. """ ... class OnBeforeRunSession(Protocol): def __call__(self, queue_item: SessionQueueItem) -> None: """Callback to run before executing a session. Args: queue_item: The session queue item. """ ... class OnAfterRunSession(Protocol): def __call__(self, queue_item: SessionQueueItem) -> None: """Callback to run after executing a session. Args: queue_item: The session queue item. """ ... class OnNonFatalProcessorError(Protocol): def __call__( self, queue_item: Optional[SessionQueueItem], error_type: str, error_message: str, error_traceback: str, ) -> None: """Callback to run when a non-fatal error occurs in the processor. Args: queue_item: The session queue item, if one was being executed when the error occurred. error_type: The type of error, e.g. "ValueError". error_message: The error message, e.g. "Invalid value". error_traceback: The stringified error traceback. """ ...