| from unittest.mock import AsyncMock, MagicMock |
|
|
| import pytest |
|
|
| from messaging.handler import ClaudeMessageHandler |
| from messaging.trees.data import MessageState |
|
|
|
|
| async def _gen_session(events): |
| for e in events: |
| yield e |
|
|
|
|
| @pytest.fixture |
| def handler(mock_platform, mock_cli_manager, mock_session_store): |
| return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_sibling_replies_fork_from_parent_session_id( |
| handler, mock_cli_manager, incoming_message_factory |
| ): |
| |
| root_incoming = incoming_message_factory(text="A", message_id="A") |
| tree = await handler.tree_queue.create_tree( |
| node_id="A", incoming=root_incoming, status_message_id="status_A" |
| ) |
| await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A") |
|
|
| |
| r1_incoming = incoming_message_factory( |
| text="R1", message_id="R1", reply_to_message_id="A" |
| ) |
| r2_incoming = incoming_message_factory( |
| text="R2", message_id="R2", reply_to_message_id="A" |
| ) |
| _, r1_node = await handler.tree_queue.add_to_tree( |
| "A", "R1", r1_incoming, "status_R1" |
| ) |
| _, r2_node = await handler.tree_queue.add_to_tree( |
| "A", "R2", r2_incoming, "status_R2" |
| ) |
|
|
| |
| calls = [] |
|
|
| async def _get_or_create_session(session_id=None): |
| cli_session = MagicMock() |
|
|
| async def _start_task(prompt, session_id=None, fork_session=False): |
| calls.append((prompt, session_id, fork_session)) |
| child_sid = f"sess_{prompt}" |
| async for ev in _gen_session( |
| [ |
| {"type": "session_info", "session_id": child_sid}, |
| {"type": "exit", "code": 0, "stderr": None}, |
| ] |
| ): |
| yield ev |
|
|
| cli_session.start_task = _start_task |
| return cli_session, f"pending_{len(calls) + 1}", True |
|
|
| mock_cli_manager.get_or_create_session = AsyncMock( |
| side_effect=_get_or_create_session |
| ) |
|
|
| await handler._process_node("R1", r1_node) |
| await handler._process_node("R2", r2_node) |
|
|
| |
| assert calls[0][0] == "R1" |
| assert calls[0][1] == "sess_A" |
| assert calls[0][2] is True |
|
|
| assert calls[1][0] == "R2" |
| assert calls[1][1] == "sess_A" |
| assert calls[1][2] is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_grandchild_reply_forks_from_branch_session( |
| handler, mock_cli_manager, incoming_message_factory |
| ): |
| root_incoming = incoming_message_factory(text="A", message_id="A") |
| tree = await handler.tree_queue.create_tree( |
| node_id="A", incoming=root_incoming, status_message_id="status_A" |
| ) |
| await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A") |
|
|
| r1_incoming = incoming_message_factory( |
| text="R1", message_id="R1", reply_to_message_id="A" |
| ) |
| _, r1_node = await handler.tree_queue.add_to_tree( |
| "A", "R1", r1_incoming, "status_R1" |
| ) |
|
|
| calls = [] |
|
|
| async def _get_or_create_session(session_id=None): |
| cli_session = MagicMock() |
|
|
| async def _start_task(prompt, session_id=None, fork_session=False): |
| calls.append((prompt, session_id, fork_session)) |
| |
| child_sid = "sess_R1" |
| async for ev in _gen_session( |
| [ |
| {"type": "session_info", "session_id": child_sid}, |
| {"type": "exit", "code": 0, "stderr": None}, |
| ] |
| ): |
| yield ev |
|
|
| cli_session.start_task = _start_task |
| return cli_session, "pending_R1", True |
|
|
| mock_cli_manager.get_or_create_session = AsyncMock( |
| side_effect=_get_or_create_session |
| ) |
|
|
| await handler._process_node("R1", r1_node) |
| assert r1_node.session_id == "sess_R1" |
|
|
| |
| c1_incoming = incoming_message_factory( |
| text="C1", message_id="C1", reply_to_message_id="R1" |
| ) |
| _, c1_node = await handler.tree_queue.add_to_tree( |
| "R1", "C1", c1_incoming, "status_C1" |
| ) |
|
|
| async def _get_or_create_session_c1(session_id=None): |
| cli_session = MagicMock() |
|
|
| async def _start_task(prompt, session_id=None, fork_session=False): |
| calls.append((prompt, session_id, fork_session)) |
| async for ev in _gen_session( |
| [ |
| {"type": "session_info", "session_id": "sess_C1"}, |
| {"type": "exit", "code": 0, "stderr": None}, |
| ] |
| ): |
| yield ev |
|
|
| cli_session.start_task = _start_task |
| return cli_session, "pending_C1", True |
|
|
| mock_cli_manager.get_or_create_session = AsyncMock( |
| side_effect=_get_or_create_session_c1 |
| ) |
|
|
| await handler._process_node("C1", c1_node) |
|
|
| |
| assert calls[-1][0] == "C1" |
| assert calls[-1][1] == "sess_R1" |
| assert calls[-1][2] is True |
|
|