File size: 26,238 Bytes
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
"""Tree-based message queue: index, async node processor, and public manager API."""

import asyncio
from collections.abc import Awaitable, Callable

from loguru import logger

from config.settings import get_settings
from core.anthropic import get_user_facing_error_message

from ..models import IncomingMessage
from ..safe_diagnostics import format_exception_for_log
from .data import MessageNode, MessageState, MessageTree


class TreeRepository:
    """
    In-memory index of trees and node-to-root mappings.

    Used only by :class:`TreeQueueManager`; kept as a named type for tests.
    """

    def __init__(self) -> None:
        self._trees: dict[str, MessageTree] = {}  # root_id -> tree
        self._node_to_tree: dict[str, str] = {}  # node_id -> root_id

    def get_tree(self, root_id: str) -> MessageTree | None:
        """Get a tree by its root ID."""
        return self._trees.get(root_id)

    def get_tree_for_node(self, node_id: str) -> MessageTree | None:
        """Get the tree containing a given node."""
        root_id = self._node_to_tree.get(node_id)
        if not root_id:
            return None
        return self._trees.get(root_id)

    def get_node(self, node_id: str) -> MessageNode | None:
        """Get a node from any tree."""
        tree = self.get_tree_for_node(node_id)
        return tree.get_node(node_id) if tree else None

    def add_tree(self, root_id: str, tree: MessageTree) -> None:
        """Add a new tree to the repository."""
        self._trees[root_id] = tree
        self._node_to_tree[root_id] = root_id
        logger.debug("TREE_REPO: add_tree root_id={}", root_id)

    def register_node(self, node_id: str, root_id: str) -> None:
        """Register a node ID to a tree."""
        self._node_to_tree[node_id] = root_id
        logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)

    def has_node(self, node_id: str) -> bool:
        """Check if a node is registered in any tree."""
        return node_id in self._node_to_tree

    def tree_count(self) -> int:
        """Get the number of trees in the repository."""
        return len(self._trees)

    def is_tree_busy(self, root_id: str) -> bool:
        """Check if a tree is currently processing."""
        tree = self._trees.get(root_id)
        return tree.is_processing if tree else False

    def is_node_tree_busy(self, node_id: str) -> bool:
        """Check if the tree containing a node is busy."""
        tree = self.get_tree_for_node(node_id)
        return tree.is_processing if tree else False

    def get_queue_size(self, node_id: str) -> int:
        """Get queue size for the tree containing a node."""
        tree = self.get_tree_for_node(node_id)
        return tree.get_queue_size() if tree else 0

    def resolve_parent_node_id(self, msg_id: str) -> str | None:
        """
        Resolve a message ID to the actual parent node ID.

        Handles the case where msg_id is a status message ID
        (which maps to the tree but isn't an actual node).

        Returns:
            The node_id to use as parent, or None if not found
        """
        tree = self.get_tree_for_node(msg_id)
        if not tree:
            return None

        if tree.has_node(msg_id):
            return msg_id

        node = tree.find_node_by_status_message(msg_id)
        if node:
            return node.node_id

        return None

    def get_pending_children(self, node_id: str) -> list[MessageNode]:
        """
        Get all pending child nodes (recursively) of a given node.

        Used for error propagation - when a node fails, its pending
        children should also be marked as failed.
        """
        tree = self.get_tree_for_node(node_id)
        if not tree:
            return []

        pending: list[MessageNode] = []
        stack = [node_id]

        while stack:
            current_id = stack.pop()
            node = tree.get_node(current_id)
            if not node:
                continue
            for child_id in node.children_ids:
                child = tree.get_node(child_id)
                if child and child.state == MessageState.PENDING:
                    pending.append(child)
                    stack.append(child_id)

        return pending

    def all_trees(self) -> list[MessageTree]:
        """Get all trees in the repository."""
        return list(self._trees.values())

    def tree_ids(self) -> list[str]:
        """Get all tree root IDs."""
        return list(self._trees.keys())

    def unregister_nodes(self, node_ids: list[str]) -> None:
        """Remove node IDs from the node-to-tree mapping."""
        for nid in node_ids:
            self._node_to_tree.pop(nid, None)

    def remove_tree(self, root_id: str) -> MessageTree | None:
        """
        Remove a tree and all its node mappings from the repository.

        Returns:
            The removed tree, or None if not found.
        """
        tree = self._trees.pop(root_id, None)
        if not tree:
            return None
        for node in tree.all_nodes():
            self._node_to_tree.pop(node.node_id, None)
        logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
        return tree

    def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
        """Get all message IDs (incoming + status) for a given platform/chat."""
        msg_ids: set[str] = set()
        for tree in self._trees.values():
            for node in tree.all_nodes():
                if str(node.incoming.platform) == str(platform) and str(
                    node.incoming.chat_id
                ) == str(chat_id):
                    if node.incoming.message_id is not None:
                        msg_ids.add(str(node.incoming.message_id))
                    if node.status_message_id:
                        msg_ids.add(str(node.status_message_id))
        return msg_ids

    def to_dict(self) -> dict:
        """Serialize all trees."""
        return {
            "trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
            "node_to_tree": self._node_to_tree.copy(),
        }

    @classmethod
    def from_dict(cls, data: dict) -> TreeRepository:
        """Deserialize from dictionary."""
        repo = cls()
        for root_id, tree_data in data.get("trees", {}).items():
            repo._trees[root_id] = MessageTree.from_dict(tree_data)
        repo._node_to_tree = data.get("node_to_tree", {})
        return repo


class TreeQueueProcessor:
    """
    Per-tree async queue processing (one manager owns one processor instance).
    """

    def __init__(
        self,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
        | None = None,
    ) -> None:
        self._queue_update_callback = queue_update_callback
        self._node_started_callback = node_started_callback

    def set_queue_update_callback(
        self,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
    ) -> None:
        """Update the callback used to refresh queue positions."""
        self._queue_update_callback = queue_update_callback

    def set_node_started_callback(
        self,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
    ) -> None:
        """Update the callback used when a queued node starts processing."""
        self._node_started_callback = node_started_callback

    async def _notify_queue_updated(self, tree: MessageTree) -> None:
        """Invoke queue update callback if set."""
        if not self._queue_update_callback:
            return
        try:
            await self._queue_update_callback(tree)
        except Exception as e:
            d = get_settings().log_messaging_error_details
            logger.warning(
                "Queue update callback failed: {}",
                format_exception_for_log(e, log_full_message=d),
            )

    async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
        """Invoke node started callback if set."""
        if not self._node_started_callback:
            return
        try:
            await self._node_started_callback(tree, node_id)
        except Exception as e:
            d = get_settings().log_messaging_error_details
            logger.warning(
                "Node started callback failed: {}",
                format_exception_for_log(e, log_full_message=d),
            )

    async def process_node(
        self,
        tree: MessageTree,
        node: MessageNode,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> None:
        """Process a single node and then check the queue."""
        if node.state == MessageState.ERROR:
            logger.info(
                f"Skipping node {node.node_id} as it is already in state {node.state}"
            )
            await self._process_next(tree, processor)
            return

        try:
            await processor(node.node_id, node)
        except asyncio.CancelledError:
            logger.info(f"Task for node {node.node_id} was cancelled")
            raise
        except Exception as e:
            d = get_settings().log_messaging_error_details
            logger.error(
                "Error processing node {}: {}",
                node.node_id,
                format_exception_for_log(e, log_full_message=d),
            )
            await tree.update_state(
                node.node_id,
                MessageState.ERROR,
                error_message=get_user_facing_error_message(e),
            )
        finally:
            async with tree.with_lock():
                tree.clear_current_node()
            await self._process_next(tree, processor)

    async def _process_next(
        self,
        tree: MessageTree,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> None:
        """Process the next message in queue, if any."""
        next_node_id = None
        async with tree.with_lock():
            next_node_id = await tree.dequeue()

            if not next_node_id:
                tree.set_processing_state(None, False)
                logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
                return

            tree.set_processing_state(next_node_id, True)
            logger.info(f"Processing next queued node {next_node_id}")

            node = tree.get_node(next_node_id)
            if node:
                tree.set_current_task(
                    asyncio.create_task(self.process_node(tree, node, processor))
                )

        if next_node_id:
            await self._notify_node_started(tree, next_node_id)
            await self._notify_queue_updated(tree)

    async def enqueue_and_start(
        self,
        tree: MessageTree,
        node_id: str,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> bool:
        """
        Enqueue a node or start processing immediately.

        Returns:
            True if queued, False if processing immediately
        """
        async with tree.with_lock():
            if tree.is_processing:
                tree.put_queue_unlocked(node_id)
                queue_size = tree.get_queue_size()
                logger.info(f"Queued node {node_id}, position {queue_size}")
                return True
            else:
                tree.set_processing_state(node_id, True)

                node = tree.get_node(node_id)
                if node:
                    tree.set_current_task(
                        asyncio.create_task(self.process_node(tree, node, processor))
                    )
                return False

    def cancel_current(self, tree: MessageTree) -> bool:
        """Cancel the currently running task in a tree."""
        return tree.cancel_current_task()


class TreeQueueManager:
    """
    Manages multiple message trees: index + async processing.

    Each new conversation creates a new tree.
    Replies to existing messages add nodes to existing trees.
    """

    def __init__(
        self,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
        | None = None,
        _repository: TreeRepository | None = None,
    ) -> None:
        self._repository = _repository or TreeRepository()
        self._processor = TreeQueueProcessor(
            queue_update_callback=queue_update_callback,
            node_started_callback=node_started_callback,
        )
        self._lock = asyncio.Lock()

        logger.info("TreeQueueManager initialized")

    async def create_tree(
        self,
        node_id: str,
        incoming: IncomingMessage,
        status_message_id: str,
    ) -> MessageTree:
        """
        Create a new tree with a root node.

        Args:
            node_id: ID for the root node
            incoming: The incoming message
            status_message_id: Bot's status message ID

        Returns:
            The created MessageTree
        """
        async with self._lock:
            root_node = MessageNode(
                node_id=node_id,
                incoming=incoming,
                status_message_id=status_message_id,
                state=MessageState.PENDING,
            )

            tree = MessageTree(root_node)
            self._repository.add_tree(node_id, tree)

            logger.info(f"Created new tree with root {node_id}")
            return tree

    async def add_to_tree(
        self,
        parent_node_id: str,
        node_id: str,
        incoming: IncomingMessage,
        status_message_id: str,
    ) -> tuple[MessageTree, MessageNode]:
        """
        Add a reply as a child node to an existing tree.

        Args:
            parent_node_id: ID of the parent message
            node_id: ID for the new node
            incoming: The incoming reply message
            status_message_id: Bot's status message ID

        Returns:
            Tuple of (tree, new_node)
        """
        async with self._lock:
            if not self._repository.has_node(parent_node_id):
                raise ValueError(f"Parent node {parent_node_id} not found in any tree")

            tree = self._repository.get_tree_for_node(parent_node_id)
            if not tree:
                raise ValueError(f"Parent node {parent_node_id} not found in any tree")

        node = await tree.add_node(
            node_id=node_id,
            incoming=incoming,
            status_message_id=status_message_id,
            parent_id=parent_node_id,
        )

        async with self._lock:
            self._repository.register_node(node_id, tree.root_id)

        logger.info(f"Added node {node_id} to tree {tree.root_id}")
        return tree, node

    def get_tree(self, root_id: str) -> MessageTree | None:
        """Get a tree by its root ID."""
        return self._repository.get_tree(root_id)

    def get_tree_for_node(self, node_id: str) -> MessageTree | None:
        """Get the tree containing a given node."""
        return self._repository.get_tree_for_node(node_id)

    def get_node(self, node_id: str) -> MessageNode | None:
        """Get a node from any tree."""
        return self._repository.get_node(node_id)

    def resolve_parent_node_id(self, msg_id: str) -> str | None:
        """Resolve a message ID to the actual parent node ID."""
        return self._repository.resolve_parent_node_id(msg_id)

    def is_tree_busy(self, root_id: str) -> bool:
        """Check if a tree is currently processing."""
        return self._repository.is_tree_busy(root_id)

    def is_node_tree_busy(self, node_id: str) -> bool:
        """Check if the tree containing a node is busy."""
        return self._repository.is_node_tree_busy(node_id)

    async def enqueue(
        self,
        node_id: str,
        processor: Callable[[str, MessageNode], Awaitable[None]],
    ) -> bool:
        """
        Enqueue a node for processing.

        If the tree is not busy, processing starts immediately.
        If busy, the message is queued.

        Args:
            node_id: Node to process
            processor: Async function to process the node

        Returns:
            True if queued, False if processing immediately
        """
        tree = self._repository.get_tree_for_node(node_id)
        if not tree:
            logger.error(f"No tree found for node {node_id}")
            return False

        return await self._processor.enqueue_and_start(tree, node_id, processor)

    def get_queue_size(self, node_id: str) -> int:
        """Get queue size for the tree containing a node."""
        return self._repository.get_queue_size(node_id)

    def get_pending_children(self, node_id: str) -> list[MessageNode]:
        """Get all pending child nodes (recursively) of a given node."""
        return self._repository.get_pending_children(node_id)

    async def mark_node_error(
        self,
        node_id: str,
        error_message: str,
        propagate_to_children: bool = True,
    ) -> list[MessageNode]:
        """
        Mark a node as ERROR and optionally propagate to pending children.

        Args:
            node_id: The node to mark as error
            error_message: Error description
            propagate_to_children: If True, also mark pending children as error

        Returns:
            List of all nodes marked as error (including children)
        """
        tree = self._repository.get_tree_for_node(node_id)
        if not tree:
            return []

        affected = []
        node = tree.get_node(node_id)
        if node:
            await tree.update_state(
                node_id, MessageState.ERROR, error_message=error_message
            )
            affected.append(node)

        if propagate_to_children:
            pending_children = self._repository.get_pending_children(node_id)
            for child in pending_children:
                await tree.update_state(
                    child.node_id,
                    MessageState.ERROR,
                    error_message=f"Parent failed: {error_message}",
                )
                affected.append(child)

        return affected

    async def cancel_tree(self, root_id: str) -> list[MessageNode]:
        """
        Cancel all queued and in-progress messages in a tree.

        Updates node states to ERROR and returns list of affected nodes
        that were actually active or in the current processing queue.
        """
        tree = self._repository.get_tree(root_id)
        if not tree:
            return []

        cancelled_nodes = []

        cleanup_count = 0
        async with tree.with_lock():
            if tree.cancel_current_task():
                current_id = tree.current_node_id
                if current_id:
                    node = tree.get_node(current_id)
                    if node and node.state not in (
                        MessageState.COMPLETED,
                        MessageState.ERROR,
                    ):
                        tree.set_node_error_sync(node, "Cancelled by user")
                        cancelled_nodes.append(node)

            queue_nodes = tree.drain_queue_and_mark_cancelled()
            cancelled_nodes.extend(queue_nodes)
            cancelled_ids = {n.node_id for n in cancelled_nodes}

            for node in tree.all_nodes():
                if (
                    node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
                    and node.node_id not in cancelled_ids
                ):
                    tree.set_node_error_sync(node, "Stale task cleaned up")
                    cleanup_count += 1

            tree.reset_processing_state()

        if cancelled_nodes:
            logger.info(
                f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
            )
        if cleanup_count:
            logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")

        return cancelled_nodes

    async def cancel_node(self, node_id: str) -> list[MessageNode]:
        """
        Cancel a single node (queued or in-progress) without affecting other nodes.

        Returns:
            List containing the cancelled node if it was cancellable, else empty list.
        """
        tree = self._repository.get_tree_for_node(node_id)
        if not tree:
            return []

        async with tree.with_lock():
            node = tree.get_node(node_id)
            if not node:
                return []

            if node.state in (MessageState.COMPLETED, MessageState.ERROR):
                return []

            if tree.is_current_node(node_id):
                self._processor.cancel_current(tree)

            try:
                tree.remove_from_queue(node_id)
            except Exception:
                logger.debug(
                    "Failed to remove node from queue; will rely on state=ERROR"
                )

            tree.set_node_error_sync(node, "Cancelled by user")

            return [node]

    async def cancel_all(self) -> list[MessageNode]:
        """Cancel all messages in all trees."""
        async with self._lock:
            root_ids = list(self._repository.tree_ids())
            all_cancelled: list[MessageNode] = []
            for root_id in root_ids:
                all_cancelled.extend(await self.cancel_tree(root_id))
            return all_cancelled

    def cleanup_stale_nodes(self) -> int:
        """
        Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
        Used on startup to reconcile restored state.
        """
        count = 0
        for tree in self._repository.all_trees():
            for node in tree.all_nodes():
                if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
                    tree.set_node_error_sync(node, "Lost during server restart")
                    count += 1
        if count:
            logger.info(f"Cleaned up {count} stale nodes during startup")
        return count

    def get_tree_count(self) -> int:
        """Get the number of active message trees."""
        return self._repository.tree_count()

    def set_queue_update_callback(
        self,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
    ) -> None:
        """Set callback for queue position updates."""
        self._processor.set_queue_update_callback(queue_update_callback)

    def set_node_started_callback(
        self,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
    ) -> None:
        """Set callback for when a queued node starts processing."""
        self._processor.set_node_started_callback(node_started_callback)

    def register_node(self, node_id: str, root_id: str) -> None:
        """Register a node ID to a tree (for external mapping)."""
        self._repository.register_node(node_id, root_id)

    async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]:
        """
        Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
        """
        tree = self._repository.get_tree_for_node(branch_root_id)
        if not tree:
            return []

        branch_ids = set(tree.get_descendants(branch_root_id))
        cancelled: list[MessageNode] = []

        async with tree.with_lock():
            for nid in branch_ids:
                node = tree.get_node(nid)
                if not node or node.state in (
                    MessageState.COMPLETED,
                    MessageState.ERROR,
                ):
                    continue

                if tree.is_current_node(nid):
                    self._processor.cancel_current(tree)
                    tree.set_node_error_sync(node, "Cancelled by user")
                    cancelled.append(node)
                else:
                    tree.remove_from_queue(nid)
                    tree.set_node_error_sync(node, "Cancelled by user")
                    cancelled.append(node)

        if cancelled:
            logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
        return cancelled

    async def remove_branch(
        self, branch_root_id: str
    ) -> tuple[list[MessageNode], str, bool]:
        """
        Remove a branch (subtree) from the tree.

        If branch_root is the tree root, removes the entire tree.

        Returns:
            (removed_nodes, root_id, removed_entire_tree)
        """
        tree = self._repository.get_tree_for_node(branch_root_id)
        if not tree:
            return ([], "", False)

        root_id = tree.root_id

        if branch_root_id == root_id:
            cancelled = await self.cancel_tree(root_id)
            removed_tree = self._repository.remove_tree(root_id)
            if removed_tree:
                return (removed_tree.all_nodes(), root_id, True)
            return (cancelled, root_id, True)

        async with tree.with_lock():
            removed = tree.remove_branch(branch_root_id)

        self._repository.unregister_nodes([n.node_id for n in removed])
        return (removed, root_id, False)

    def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
        """Get all message IDs for a given platform/chat."""
        return self._repository.get_message_ids_for_chat(platform, chat_id)

    def to_dict(self) -> dict:
        """Serialize all trees."""
        return self._repository.to_dict()

    @classmethod
    def from_dict(
        cls,
        data: dict,
        queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
        node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
        | None = None,
    ) -> TreeQueueManager:
        """Deserialize from dictionary."""
        return cls(
            queue_update_callback=queue_update_callback,
            node_started_callback=node_started_callback,
            _repository=TreeRepository.from_dict(data),
        )


__all__ = [
    "TreeQueueManager",
    "TreeQueueProcessor",
    "TreeRepository",
]