File size: 5,828 Bytes
6172a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Async queue processor for message trees.

Handles the async processing lifecycle of tree nodes.
"""

import asyncio
from collections.abc import Awaitable, Callable

from loguru import logger

from providers.common import get_user_facing_error_message

from .data import MessageNode, MessageState, MessageTree


class TreeQueueProcessor:
    """
    Handles async queue processing for a single tree.

    Separates the async processing logic from the data management.
    """

    def __init__(
        self,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
        | None = None,
    ):
        self._queue_update_callback = queue_update_callback
        self._node_started_callback = node_started_callback

    def set_queue_update_callback(
        self,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
    ) -> None:
        """Update the callback used to refresh queue positions."""
        self._queue_update_callback = queue_update_callback

    def set_node_started_callback(
        self,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
    ) -> None:
        """Update the callback used when a queued node starts processing."""
        self._node_started_callback = node_started_callback

    async def _notify_queue_updated(self, tree: MessageTree) -> None:
        """Invoke queue update callback if set."""
        if not self._queue_update_callback:
            return
        try:
            await self._queue_update_callback(tree)
        except Exception as e:
            logger.warning(f"Queue update callback failed: {e}")

    async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
        """Invoke node started callback if set."""
        if not self._node_started_callback:
            return
        try:
            await self._node_started_callback(tree, node_id)
        except Exception as e:
            logger.warning(f"Node started callback failed: {e}")

    async def process_node(
        self,
        tree: MessageTree,
        node: MessageNode,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> None:
        """Process a single node and then check the queue."""
        # Skip if already in terminal state (e.g. from error propagation)
        if node.state == MessageState.ERROR:
            logger.info(
                f"Skipping node {node.node_id} as it is already in state {node.state}"
            )
            # Still need to check for next messages
            await self._process_next(tree, processor)
            return

        try:
            await processor(node.node_id, node)
        except asyncio.CancelledError:
            logger.info(f"Task for node {node.node_id} was cancelled")
            raise
        except Exception as e:
            logger.error(f"Error processing node {node.node_id}: {e}")
            await tree.update_state(
                node.node_id,
                MessageState.ERROR,
                error_message=get_user_facing_error_message(e),
            )
        finally:
            async with tree.with_lock():
                tree.clear_current_node()
            # Check if there are more messages in the queue
            await self._process_next(tree, processor)

    async def _process_next(
        self,
        tree: MessageTree,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> None:
        """Process the next message in queue, if any."""
        next_node_id = None
        node = None
        async with tree.with_lock():
            next_node_id = await tree.dequeue()

            if not next_node_id:
                tree.set_processing_state(None, False)
                logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
                return

            tree.set_processing_state(next_node_id, True)
            logger.info(f"Processing next queued node {next_node_id}")

            # Process next node (outside lock)
            node = tree.get_node(next_node_id)
            if node:
                tree.set_current_task(
                    asyncio.create_task(self.process_node(tree, node, processor))
                )

        # Notify that this node has started processing and refresh queue positions.
        if next_node_id:
            await self._notify_node_started(tree, next_node_id)
            await self._notify_queue_updated(tree)

    async def enqueue_and_start(
        self,
        tree: MessageTree,
        node_id: str,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> bool:
        """
        Enqueue a node or start processing immediately.

        Args:
            tree: The message tree
            node_id: Node to process
            processor: Async function to process the node

        Returns:
            True if queued, False if processing immediately
        """
        async with tree.with_lock():
            if tree.is_processing:
                tree.put_queue_unlocked(node_id)
                queue_size = tree.get_queue_size()
                logger.info(f"Queued node {node_id}, position {queue_size}")
                return True
            else:
                tree.set_processing_state(node_id, True)

                # Process outside the lock
                node = tree.get_node(node_id)
                if node:
                    tree.set_current_task(
                        asyncio.create_task(self.process_node(tree, node, processor))
                    )
                return False

    def cancel_current(self, tree: MessageTree) -> bool:
        """Cancel the currently running task in a tree."""
        return tree.cancel_current_task()