from __future__ import annotations from typing import Optional """Async queue processor for message trees. Handles the async processing lifecycle of tree nodes. """ import asyncio from collections.abc import Awaitable, Callable from loguru import logger from providers.common import get_user_facing_error_message from .data import MessageNode, MessageState, MessageTree class TreeQueueProcessor: """ Handles async queue processing for a single tree. Separates the async processing logic from the data management. """ def __init__( self, queue_update_callback: Callable[[MessageTree], Optional[Awaitable[None]]] = None, node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None = None, ): self._queue_update_callback = queue_update_callback self._node_started_callback = node_started_callback def set_queue_update_callback( self, queue_update_callback: Callable[[MessageTree], Optional[Awaitable[None]]], ) -> None: """Update the callback used to refresh queue positions.""" self._queue_update_callback = queue_update_callback def set_node_started_callback( self, node_started_callback: Callable[[MessageTree, str], Optional[Awaitable[None]]], ) -> None: """Update the callback used when a queued node starts processing.""" self._node_started_callback = node_started_callback async def _notify_queue_updated(self, tree: MessageTree) -> None: """Invoke queue update callback if set.""" if not self._queue_update_callback: return try: await self._queue_update_callback(tree) except Exception as e: logger.warning(f"Queue update callback failed: {e}") async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None: """Invoke node started callback if set.""" if not self._node_started_callback: return try: await self._node_started_callback(tree, node_id) except Exception as e: logger.warning(f"Node started callback failed: {e}") async def process_node( self, tree: MessageTree, node: MessageNode, processor: Callable[[str, MessageNode], Awaitable[None]], ) -> None: """Process a single node and then check the queue.""" # Skip if already in terminal state (e.g. from error propagation) if node.state == MessageState.ERROR: logger.info( f"Skipping node {node.node_id} as it is already in state {node.state}" ) # Still need to check for next messages await self._process_next(tree, processor) return try: await processor(node.node_id, node) except asyncio.CancelledError: logger.info(f"Task for node {node.node_id} was cancelled") raise except Exception as e: logger.error(f"Error processing node {node.node_id}: {e}") await tree.update_state( node.node_id, MessageState.ERROR, error_message=get_user_facing_error_message(e), ) finally: async with tree.with_lock(): tree.clear_current_node() # Check if there are more messages in the queue await self._process_next(tree, processor) async def _process_next( self, tree: MessageTree, processor: Callable[[str, MessageNode], Awaitable[None]], ) -> None: """Process the next message in queue, if any.""" next_node_id = None node = None async with tree.with_lock(): next_node_id = await tree.dequeue() if not next_node_id: tree.set_processing_state(None, False) logger.debug(f"Tree {tree.root_id} queue empty, marking as free") return tree.set_processing_state(next_node_id, True) logger.info(f"Processing next queued node {next_node_id}") # Process next node (outside lock) node = tree.get_node(next_node_id) if node: tree.set_current_task( asyncio.create_task(self.process_node(tree, node, processor)) ) # Notify that this node has started processing and refresh queue positions. if next_node_id: await self._notify_node_started(tree, next_node_id) await self._notify_queue_updated(tree) async def enqueue_and_start( self, tree: MessageTree, node_id: str, processor: Callable[[str, MessageNode], Awaitable[None]], ) -> bool: """ Enqueue a node or start processing immediately. Args: tree: The message tree node_id: Node to process processor: Async function to process the node Returns: True if queued, False if processing immediately """ async with tree.with_lock(): if tree.is_processing: tree.put_queue_unlocked(node_id) queue_size = tree.get_queue_size() logger.info(f"Queued node {node_id}, position {queue_size}") return True else: tree.set_processing_state(node_id, True) # Process outside the lock node = tree.get_node(node_id) if node: tree.set_current_task( asyncio.create_task(self.process_node(tree, node, processor)) ) return False def cancel_current(self, tree: MessageTree) -> bool: """Cancel the currently running task in a tree.""" return tree.cancel_current_task()