| | import traceback |
| | from contextlib import suppress |
| | from threading import BoundedSemaphore, Thread |
| | from threading import Event as ThreadEvent |
| | from typing import Optional |
| |
|
| | from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput |
| | from invokeai.app.services.events.events_common import ( |
| | BatchEnqueuedEvent, |
| | FastAPIEvent, |
| | QueueClearedEvent, |
| | QueueItemStatusChangedEvent, |
| | register_events, |
| | ) |
| | from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError |
| | from invokeai.app.services.invoker import Invoker |
| | from invokeai.app.services.session_processor.session_processor_base import ( |
| | InvocationServices, |
| | OnAfterRunNode, |
| | OnAfterRunSession, |
| | OnBeforeRunNode, |
| | OnBeforeRunSession, |
| | OnNodeError, |
| | OnNonFatalProcessorError, |
| | SessionProcessorBase, |
| | SessionRunnerBase, |
| | ) |
| | from invokeai.app.services.session_processor.session_processor_common import CanceledException, SessionProcessorStatus |
| | from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError |
| | from invokeai.app.services.shared.graph import NodeInputError |
| | from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context |
| | from invokeai.app.util.profiler import Profiler |
| |
|
| |
|
| | class DefaultSessionRunner(SessionRunnerBase): |
| | """Processes a single session's invocations.""" |
| |
|
| | def __init__( |
| | self, |
| | on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None, |
| | on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None, |
| | on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None, |
| | on_node_error_callbacks: Optional[list[OnNodeError]] = None, |
| | on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, |
| | ): |
| | """ |
| | Args: |
| | on_before_run_session_callbacks: Callbacks to run before the session starts. |
| | on_before_run_node_callbacks: Callbacks to run before each node starts. |
| | on_after_run_node_callbacks: Callbacks to run after each node completes. |
| | on_node_error_callbacks: Callbacks to run when a node errors. |
| | on_after_run_session_callbacks: Callbacks to run after the session completes. |
| | """ |
| |
|
| | self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] |
| | self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] |
| | self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] |
| | self._on_node_error_callbacks = on_node_error_callbacks or [] |
| | self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] |
| |
|
| | def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): |
| | self._services = services |
| | self._cancel_event = cancel_event |
| | self._profiler = profiler |
| |
|
| | def _is_canceled(self) -> bool: |
| | """Check if the cancel event is set. This is also passed to the invocation context builder and called during |
| | denoising to check if the session has been canceled.""" |
| | return self._cancel_event.is_set() |
| |
|
| | def run(self, queue_item: SessionQueueItem): |
| | |
| |
|
| | self._on_before_run_session(queue_item=queue_item) |
| |
|
| | |
| | while True: |
| | try: |
| | invocation = queue_item.session.next() |
| | |
| | except NodeInputError as e: |
| | error_type = e.__class__.__name__ |
| | error_message = str(e) |
| | error_traceback = traceback.format_exc() |
| | self._on_node_error( |
| | invocation=e.node, |
| | queue_item=queue_item, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| | break |
| |
|
| | if invocation is None or self._is_canceled(): |
| | break |
| |
|
| | self.run_node(invocation, queue_item) |
| |
|
| | |
| | |
| | |
| | if ( |
| | queue_item.session.is_complete() |
| | or self._is_canceled() |
| | or queue_item.status in ["failed", "canceled", "completed"] |
| | ): |
| | break |
| |
|
| | self._on_after_run_session(queue_item=queue_item) |
| |
|
| | def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): |
| | try: |
| | |
| | with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): |
| | self._on_before_run_node(invocation, queue_item) |
| |
|
| | data = InvocationContextData( |
| | invocation=invocation, |
| | source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], |
| | queue_item=queue_item, |
| | ) |
| | context = build_invocation_context( |
| | data=data, |
| | services=self._services, |
| | is_canceled=self._is_canceled, |
| | ) |
| |
|
| | |
| | output = invocation.invoke_internal(context=context, services=self._services) |
| | |
| | queue_item.session.complete(invocation.id, output) |
| |
|
| | self._on_after_run_node(invocation, queue_item, output) |
| |
|
| | except KeyboardInterrupt: |
| | |
| | pass |
| | except CanceledException: |
| | |
| | |
| | |
| | |
| | |
| | |
| | pass |
| | except Exception as e: |
| | error_type = e.__class__.__name__ |
| | error_message = str(e) |
| | error_traceback = traceback.format_exc() |
| | self._on_node_error( |
| | invocation=invocation, |
| | queue_item=queue_item, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| |
|
| | def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: |
| | """Called before a session is run. |
| | |
| | - Start the profiler if profiling is enabled. |
| | - Run any callbacks registered for this event. |
| | """ |
| |
|
| | self._services.logger.debug( |
| | f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}" |
| | ) |
| |
|
| | |
| | if self._profiler is not None: |
| | self._profiler.start(profile_id=queue_item.session_id) |
| |
|
| | for callback in self._on_before_run_session_callbacks: |
| | callback(queue_item=queue_item) |
| |
|
| | def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: |
| | """Called after a session is run. |
| | |
| | - Stop the profiler if profiling is enabled. |
| | - Update the queue item's session object in the database. |
| | - If not already canceled or failed, complete the queue item. |
| | - Log and reset performance statistics. |
| | - Run any callbacks registered for this event. |
| | """ |
| |
|
| | self._services.logger.debug( |
| | f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}" |
| | ) |
| |
|
| | |
| | if self._profiler is not None: |
| | profile_path = self._profiler.stop() |
| | stats_path = profile_path.with_suffix(".json") |
| | self._services.performance_statistics.dump_stats( |
| | graph_execution_state_id=queue_item.session.id, output_path=stats_path |
| | ) |
| |
|
| | try: |
| | |
| | |
| | |
| | queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) |
| |
|
| | |
| | |
| | if queue_item.status not in ["canceled", "failed"]: |
| | queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id) |
| |
|
| | |
| | |
| | with suppress(GESStatsNotFoundError): |
| | self._services.performance_statistics.log_stats(queue_item.session.id) |
| | self._services.performance_statistics.reset_stats() |
| |
|
| | for callback in self._on_after_run_session_callbacks: |
| | callback(queue_item=queue_item) |
| | except SessionQueueItemNotFoundError: |
| | pass |
| |
|
| | def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): |
| | """Called before a node is run. |
| | |
| | - Emits an invocation started event. |
| | - Run any callbacks registered for this event. |
| | """ |
| |
|
| | self._services.logger.debug( |
| | f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" |
| | ) |
| |
|
| | |
| | self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation) |
| |
|
| | for callback in self._on_before_run_node_callbacks: |
| | callback(invocation=invocation, queue_item=queue_item) |
| |
|
| | def _on_after_run_node( |
| | self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput |
| | ): |
| | """Called after a node is run. |
| | |
| | - Emits an invocation complete event. |
| | - Run any callbacks registered for this event. |
| | """ |
| |
|
| | self._services.logger.debug( |
| | f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" |
| | ) |
| |
|
| | |
| | self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output) |
| |
|
| | for callback in self._on_after_run_node_callbacks: |
| | callback(invocation=invocation, queue_item=queue_item, output=output) |
| |
|
| | def _on_node_error( |
| | self, |
| | invocation: BaseInvocation, |
| | queue_item: SessionQueueItem, |
| | error_type: str, |
| | error_message: str, |
| | error_traceback: str, |
| | ): |
| | """Called when a node errors. Node errors may occur when running or preparing the node.. |
| | |
| | - Set the node error on the session object. |
| | - Log the error. |
| | - Fail the queue item. |
| | - Emits an invocation error event. |
| | - Run any callbacks registered for this event. |
| | """ |
| |
|
| | self._services.logger.debug( |
| | f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" |
| | ) |
| |
|
| | |
| | node_error = f"{error_type}: {error_message}" |
| | queue_item.session.set_node_error(invocation.id, node_error) |
| | self._services.logger.error( |
| | f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}" |
| | ) |
| | self._services.logger.error(error_traceback) |
| |
|
| | |
| | queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) |
| | queue_item = self._services.session_queue.fail_queue_item( |
| | queue_item.item_id, error_type, error_message, error_traceback |
| | ) |
| |
|
| | |
| | self._services.events.emit_invocation_error( |
| | queue_item=queue_item, |
| | invocation=invocation, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| |
|
| | for callback in self._on_node_error_callbacks: |
| | callback( |
| | invocation=invocation, |
| | queue_item=queue_item, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| |
|
| |
|
| | class DefaultSessionProcessor(SessionProcessorBase): |
| | def __init__( |
| | self, |
| | session_runner: Optional[SessionRunnerBase] = None, |
| | on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None, |
| | thread_limit: int = 1, |
| | polling_interval: int = 1, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.session_runner = session_runner if session_runner else DefaultSessionRunner() |
| | self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] |
| | self._thread_limit = thread_limit |
| | self._polling_interval = polling_interval |
| |
|
| | def start(self, invoker: Invoker) -> None: |
| | self._invoker: Invoker = invoker |
| | self._queue_item: Optional[SessionQueueItem] = None |
| | self._invocation: Optional[BaseInvocation] = None |
| |
|
| | self._resume_event = ThreadEvent() |
| | self._stop_event = ThreadEvent() |
| | self._poll_now_event = ThreadEvent() |
| | self._cancel_event = ThreadEvent() |
| |
|
| | register_events(QueueClearedEvent, self._on_queue_cleared) |
| | register_events(BatchEnqueuedEvent, self._on_batch_enqueued) |
| | register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) |
| |
|
| | self._thread_semaphore = BoundedSemaphore(self._thread_limit) |
| |
|
| | |
| | |
| | self._profiler = ( |
| | Profiler( |
| | logger=self._invoker.services.logger, |
| | output_dir=self._invoker.services.configuration.profiles_path, |
| | prefix=self._invoker.services.configuration.profile_prefix, |
| | ) |
| | if self._invoker.services.configuration.profile_graphs |
| | else None |
| | ) |
| |
|
| | self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) |
| | self._thread = Thread( |
| | name="session_processor", |
| | target=self._process, |
| | kwargs={ |
| | "stop_event": self._stop_event, |
| | "poll_now_event": self._poll_now_event, |
| | "resume_event": self._resume_event, |
| | "cancel_event": self._cancel_event, |
| | }, |
| | ) |
| | self._thread.start() |
| |
|
| | def stop(self, *args, **kwargs) -> None: |
| | self._stop_event.set() |
| |
|
| | def _poll_now(self) -> None: |
| | self._poll_now_event.set() |
| |
|
| | async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: |
| | if self._queue_item and self._queue_item.queue_id == event[1].queue_id: |
| | self._cancel_event.set() |
| | self._poll_now() |
| |
|
| | async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: |
| | self._poll_now() |
| |
|
| | async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: |
| | if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if event[1].status == "canceled": |
| | self._cancel_event.set() |
| | self._poll_now() |
| |
|
| | def resume(self) -> SessionProcessorStatus: |
| | if not self._resume_event.is_set(): |
| | self._resume_event.set() |
| | return self.get_status() |
| |
|
| | def pause(self) -> SessionProcessorStatus: |
| | if self._resume_event.is_set(): |
| | self._resume_event.clear() |
| | return self.get_status() |
| |
|
| | def get_status(self) -> SessionProcessorStatus: |
| | return SessionProcessorStatus( |
| | is_started=self._resume_event.is_set(), |
| | is_processing=self._queue_item is not None, |
| | ) |
| |
|
| | def _process( |
| | self, |
| | stop_event: ThreadEvent, |
| | poll_now_event: ThreadEvent, |
| | resume_event: ThreadEvent, |
| | cancel_event: ThreadEvent, |
| | ): |
| | try: |
| | |
| | self._thread_semaphore.acquire() |
| | stop_event.clear() |
| | resume_event.set() |
| | cancel_event.clear() |
| |
|
| | while not stop_event.is_set(): |
| | poll_now_event.clear() |
| | try: |
| | |
| | |
| | resume_event.wait() |
| |
|
| | |
| | self._queue_item = self._invoker.services.session_queue.dequeue() |
| |
|
| | if self._queue_item is None: |
| | |
| | self._invoker.services.logger.debug("Waiting for next polling interval or event") |
| | poll_now_event.wait(self._polling_interval) |
| | continue |
| |
|
| | self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") |
| | cancel_event.clear() |
| |
|
| | |
| | self.session_runner.run(queue_item=self._queue_item) |
| |
|
| | except Exception as e: |
| | error_type = e.__class__.__name__ |
| | error_message = str(e) |
| | error_traceback = traceback.format_exc() |
| | self._on_non_fatal_processor_error( |
| | queue_item=self._queue_item, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| | |
| | poll_now_event.wait(self._polling_interval) |
| | continue |
| | except Exception as e: |
| | |
| | error_type = e.__class__.__name__ |
| | error_message = str(e) |
| | error_traceback = traceback.format_exc() |
| | self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}") |
| | self._invoker.services.logger.error(error_traceback) |
| | pass |
| | finally: |
| | stop_event.clear() |
| | poll_now_event.clear() |
| | self._queue_item = None |
| | self._thread_semaphore.release() |
| |
|
| | def _on_non_fatal_processor_error( |
| | self, |
| | queue_item: Optional[SessionQueueItem], |
| | error_type: str, |
| | error_message: str, |
| | error_traceback: str, |
| | ) -> None: |
| | """Called when a non-fatal error occurs in the processor. |
| | |
| | - Log the error. |
| | - If a queue item is provided, update the queue item with the completed session & fail it. |
| | - Run any callbacks registered for this event. |
| | """ |
| |
|
| | self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}") |
| | self._invoker.services.logger.error(error_traceback) |
| |
|
| | if queue_item is not None: |
| | |
| | queue_item = self._invoker.services.session_queue.set_queue_item_session( |
| | queue_item.item_id, queue_item.session |
| | ) |
| | queue_item = self._invoker.services.session_queue.fail_queue_item( |
| | item_id=queue_item.item_id, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| |
|
| | for callback in self._on_non_fatal_processor_error_callbacks: |
| | callback( |
| | queue_item=queue_item, |
| | error_type=error_type, |
| | error_message=error_message, |
| | error_traceback=error_traceback, |
| | ) |
| |
|