"""Tree data structures for message queue. Contains MessageState, MessageNode, and MessageTree classes. """ import asyncio from collections import deque from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import UTC, datetime from enum import Enum from typing import Any from loguru import logger from ..models import IncomingMessage class _SnapshotQueue: """Queue with snapshot/remove helpers, backed by a deque and a set index.""" def __init__(self) -> None: self._deque: deque[str] = deque() self._set: set[str] = set() async def put(self, item: str) -> None: self._deque.append(item) self._set.add(item) def put_nowait(self, item: str) -> None: self._deque.append(item) self._set.add(item) def get_nowait(self) -> str: if not self._deque: raise asyncio.QueueEmpty() item = self._deque.popleft() self._set.discard(item) return item def qsize(self) -> int: return len(self._deque) def get_snapshot(self) -> list[str]: """Return current queue contents in FIFO order (read-only copy).""" return list(self._deque) def remove_if_present(self, item: str) -> bool: """Remove item from queue if present (O(1) membership check). Returns True if removed.""" if item not in self._set: return False self._set.discard(item) self._deque = deque(x for x in self._deque if x != item) return True class MessageState(Enum): """State of a message node in the tree.""" PENDING = "pending" # Queued, waiting to be processed IN_PROGRESS = "in_progress" # Currently being processed by Claude COMPLETED = "completed" # Processing finished successfully ERROR = "error" # Processing failed @dataclass class MessageNode: """ A node in the message tree. Each node represents a single message and tracks: - Its relationship to parent/children - Its processing state - Claude session information """ node_id: str # Unique ID (typically message_id) incoming: IncomingMessage # The original message status_message_id: str # Bot's status message ID state: MessageState = MessageState.PENDING parent_id: str | None = None # Parent node ID (None for root) session_id: str | None = None # Claude session ID (forked from parent) children_ids: list[str] = field(default_factory=list) created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) completed_at: datetime | None = None error_message: str | None = None context: Any = None # Additional context if needed def set_context(self, context: Any) -> None: self.context = context def to_dict(self) -> dict: """Convert to dictionary for JSON serialization.""" return { "node_id": self.node_id, "incoming": { "text": self.incoming.text, "chat_id": self.incoming.chat_id, "user_id": self.incoming.user_id, "message_id": self.incoming.message_id, "platform": self.incoming.platform, "reply_to_message_id": self.incoming.reply_to_message_id, "message_thread_id": self.incoming.message_thread_id, "username": self.incoming.username, }, "status_message_id": self.status_message_id, "state": self.state.value, "parent_id": self.parent_id, "session_id": self.session_id, "children_ids": self.children_ids, "created_at": self.created_at.isoformat(), "completed_at": self.completed_at.isoformat() if self.completed_at else None, "error_message": self.error_message, } @classmethod def from_dict(cls, data: dict) -> MessageNode: """Create from dictionary (JSON deserialization).""" incoming_data = data["incoming"] incoming = IncomingMessage( text=incoming_data["text"], chat_id=incoming_data["chat_id"], user_id=incoming_data["user_id"], message_id=incoming_data["message_id"], platform=incoming_data["platform"], reply_to_message_id=incoming_data.get("reply_to_message_id"), message_thread_id=incoming_data.get("message_thread_id"), username=incoming_data.get("username"), ) return cls( node_id=data["node_id"], incoming=incoming, status_message_id=data["status_message_id"], state=MessageState(data["state"]), parent_id=data.get("parent_id"), session_id=data.get("session_id"), children_ids=data.get("children_ids", []), created_at=datetime.fromisoformat(data["created_at"]), completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None, error_message=data.get("error_message"), ) class MessageTree: """ A tree of message nodes with queue functionality. Provides: - O(1) node lookup via hashmap - Per-tree message queue - Thread-safe operations via asyncio.Lock """ def __init__(self, root_node: MessageNode): """ Initialize tree with a root node. Args: root_node: The root message node """ self.root_id = root_node.node_id self._nodes: dict[str, MessageNode] = {root_node.node_id: root_node} self._status_to_node: dict[str, str] = { root_node.status_message_id: root_node.node_id } self._queue: _SnapshotQueue = _SnapshotQueue() self._lock = asyncio.Lock() self._is_processing = False self._current_node_id: str | None = None self._current_task: asyncio.Task | None = None logger.debug(f"Created MessageTree with root {self.root_id}") def set_current_task(self, task: asyncio.Task | None) -> None: """Set the current processing task. Caller must hold lock.""" self._current_task = task @property def is_processing(self) -> bool: """Check if tree is currently processing a message.""" return self._is_processing async def add_node( self, node_id: str, incoming: IncomingMessage, status_message_id: str, parent_id: str, ) -> MessageNode: """ Add a child node to the tree. Args: node_id: Unique ID for the new node incoming: The incoming message status_message_id: Bot's status message ID parent_id: Parent node ID Returns: The created MessageNode """ async with self._lock: if parent_id not in self._nodes: raise ValueError(f"Parent node {parent_id} not found in tree") node = MessageNode( node_id=node_id, incoming=incoming, status_message_id=status_message_id, parent_id=parent_id, state=MessageState.PENDING, ) self._nodes[node_id] = node self._status_to_node[status_message_id] = node_id self._nodes[parent_id].children_ids.append(node_id) logger.debug(f"Added node {node_id} as child of {parent_id}") return node def get_node(self, node_id: str) -> MessageNode | None: """Get a node by ID (O(1) lookup).""" return self._nodes.get(node_id) def get_root(self) -> MessageNode: """Get the root node.""" return self._nodes[self.root_id] def get_children(self, node_id: str) -> list[MessageNode]: """Get all child nodes of a given node.""" node = self._nodes.get(node_id) if not node: return [] return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes] def get_parent(self, node_id: str) -> MessageNode | None: """Get the parent node.""" node = self._nodes.get(node_id) if not node or not node.parent_id: return None return self._nodes.get(node.parent_id) def get_parent_session_id(self, node_id: str) -> str | None: """ Get the parent's session ID for forking. Returns None for root nodes. """ parent = self.get_parent(node_id) return parent.session_id if parent else None async def update_state( self, node_id: str, state: MessageState, session_id: str | None = None, error_message: str | None = None, ) -> None: """Update a node's state.""" async with self._lock: node = self._nodes.get(node_id) if not node: logger.warning(f"Node {node_id} not found for state update") return node.state = state if session_id: node.session_id = session_id if error_message: node.error_message = error_message if state in (MessageState.COMPLETED, MessageState.ERROR): node.completed_at = datetime.now(UTC) logger.debug(f"Node {node_id} state -> {state.value}") async def enqueue(self, node_id: str) -> int: """ Add a node to the processing queue. Returns: Queue position (1-indexed) """ async with self._lock: await self._queue.put(node_id) position = self._queue.qsize() logger.debug(f"Enqueued node {node_id}, position {position}") return position async def dequeue(self) -> str | None: """ Get the next node ID from the queue. Returns None if queue is empty. """ try: return self._queue.get_nowait() except asyncio.QueueEmpty: return None async def get_queue_snapshot(self) -> list[str]: """ Get a snapshot of the current queue order. Returns: List of node IDs in FIFO order. """ async with self._lock: return self._queue.get_snapshot() def get_queue_size(self) -> int: """Get number of messages waiting in queue.""" return self._queue.qsize() def remove_from_queue(self, node_id: str) -> bool: """ Remove node_id from the internal queue if present. Caller must hold the tree lock (e.g. via with_lock). Returns True if node was removed, False if not in queue. """ return self._queue.remove_if_present(node_id) @asynccontextmanager async def with_lock(self): """Async context manager for tree lock. Use when multiple operations need atomicity.""" async with self._lock: yield def set_processing_state(self, node_id: str | None, is_processing: bool) -> None: """Set processing state. Caller must hold lock for consistency with queue operations.""" self._is_processing = is_processing self._current_node_id = node_id if is_processing else None def clear_current_node(self) -> None: """Clear the currently processing node ID. Caller must hold lock.""" self._current_node_id = None def is_current_node(self, node_id: str) -> bool: """Check if node_id is the currently processing node.""" return self._current_node_id == node_id def put_queue_unlocked(self, node_id: str) -> None: """Add node to queue. Caller must hold lock (e.g. via with_lock).""" self._queue.put_nowait(node_id) def cancel_current_task(self) -> bool: """Cancel the currently running task. Returns True if a task was cancelled.""" if self._current_task and not self._current_task.done(): self._current_task.cancel() return True return False def set_node_error_sync(self, node: MessageNode, error_message: str) -> None: """Synchronously mark a node as ERROR. Caller must ensure no concurrent access.""" node.state = MessageState.ERROR node.error_message = error_message node.completed_at = datetime.now(UTC) def drain_queue_and_mark_cancelled( self, error_message: str = "Cancelled by user" ) -> list[MessageNode]: """ Drain the queue, mark each node as ERROR, and return affected nodes. Does not acquire lock; caller must ensure no concurrent queue access. """ nodes: list[MessageNode] = [] while True: try: node_id = self._queue.get_nowait() except asyncio.QueueEmpty: break node = self._nodes.get(node_id) if node: self.set_node_error_sync(node, error_message) nodes.append(node) return nodes def reset_processing_state(self) -> None: """Reset processing flags after cancel/cleanup.""" self._is_processing = False self._current_node_id = None @property def current_node_id(self) -> str | None: """Get the ID of the node currently being processed.""" return self._current_node_id def to_dict(self) -> dict: """Serialize tree to dictionary.""" return { "root_id": self.root_id, "nodes": {nid: node.to_dict() for nid, node in self._nodes.items()}, } def _add_node_from_dict(self, node: MessageNode) -> None: """Register a deserialized node into the tree's internal indices.""" self._nodes[node.node_id] = node self._status_to_node[node.status_message_id] = node.node_id @classmethod def from_dict(cls, data: dict) -> MessageTree: """Deserialize tree from dictionary.""" root_id = data["root_id"] nodes_data = data["nodes"] # Create root node first root_node = MessageNode.from_dict(nodes_data[root_id]) tree = cls(root_node) # Add remaining nodes and build status->node index for node_id, node_data in nodes_data.items(): if node_id != root_id: node = MessageNode.from_dict(node_data) tree._add_node_from_dict(node) return tree def all_nodes(self) -> list[MessageNode]: """Get all nodes in the tree.""" return list(self._nodes.values()) def has_node(self, node_id: str) -> bool: """Check if a node exists in this tree.""" return node_id in self._nodes def find_node_by_status_message(self, status_msg_id: str) -> MessageNode | None: """Find the node that has this status message ID (O(1) lookup).""" node_id = self._status_to_node.get(status_msg_id) return self._nodes.get(node_id) if node_id else None def get_descendants(self, node_id: str) -> list[str]: """ Get node_id and all descendant IDs (subtree). Returns: List of node IDs including the given node. """ if node_id not in self._nodes: return [] result: list[str] = [] stack = [node_id] while stack: nid = stack.pop() result.append(nid) node = self._nodes.get(nid) if node: stack.extend(node.children_ids) return result def remove_branch(self, branch_root_id: str) -> list[MessageNode]: """ Remove a subtree (branch_root and all descendants) from the tree. Updates parent's children_ids. Caller must hold lock for consistency. Does not acquire lock internally. Returns: List of removed nodes. """ if branch_root_id not in self._nodes: return [] parent = self.get_parent(branch_root_id) removed = [] for nid in self.get_descendants(branch_root_id): node = self._nodes.get(nid) if node: removed.append(node) del self._nodes[nid] del self._status_to_node[node.status_message_id] if parent and branch_root_id in parent.children_ids: parent.children_ids = [ c for c in parent.children_ids if c != branch_root_id ] logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)") return removed