File size: 5,201 Bytes
6172a47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | from unittest.mock import AsyncMock, MagicMock
import pytest
from messaging.handler import ClaudeMessageHandler
from messaging.trees.data import MessageState
async def _gen_session(events):
for e in events:
yield e
@pytest.fixture
def handler(mock_platform, mock_cli_manager, mock_session_store):
return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store)
@pytest.mark.asyncio
async def test_sibling_replies_fork_from_parent_session_id(
handler, mock_cli_manager, incoming_message_factory
):
# Root node A with a known session_id.
root_incoming = incoming_message_factory(text="A", message_id="A")
tree = await handler.tree_queue.create_tree(
node_id="A", incoming=root_incoming, status_message_id="status_A"
)
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
# Add two sibling replies R1 and R2 under A.
r1_incoming = incoming_message_factory(
text="R1", message_id="R1", reply_to_message_id="A"
)
r2_incoming = incoming_message_factory(
text="R2", message_id="R2", reply_to_message_id="A"
)
_, r1_node = await handler.tree_queue.add_to_tree(
"A", "R1", r1_incoming, "status_R1"
)
_, r2_node = await handler.tree_queue.add_to_tree(
"A", "R2", r2_incoming, "status_R2"
)
# Mock a fresh cli_session per node.
calls = []
async def _get_or_create_session(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
child_sid = f"sess_{prompt}"
async for ev in _gen_session(
[
{"type": "session_info", "session_id": child_sid},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, f"pending_{len(calls) + 1}", True
mock_cli_manager.get_or_create_session = AsyncMock(
side_effect=_get_or_create_session
)
await handler._process_node("R1", r1_node)
await handler._process_node("R2", r2_node)
# Both siblings must resume from the same parent session and fork.
assert calls[0][0] == "R1"
assert calls[0][1] == "sess_A"
assert calls[0][2] is True
assert calls[1][0] == "R2"
assert calls[1][1] == "sess_A"
assert calls[1][2] is True
@pytest.mark.asyncio
async def test_grandchild_reply_forks_from_branch_session(
handler, mock_cli_manager, incoming_message_factory
):
root_incoming = incoming_message_factory(text="A", message_id="A")
tree = await handler.tree_queue.create_tree(
node_id="A", incoming=root_incoming, status_message_id="status_A"
)
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
r1_incoming = incoming_message_factory(
text="R1", message_id="R1", reply_to_message_id="A"
)
_, r1_node = await handler.tree_queue.add_to_tree(
"A", "R1", r1_incoming, "status_R1"
)
calls = []
async def _get_or_create_session(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
# R1 gets its own forked session id.
child_sid = "sess_R1"
async for ev in _gen_session(
[
{"type": "session_info", "session_id": child_sid},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, "pending_R1", True
mock_cli_manager.get_or_create_session = AsyncMock(
side_effect=_get_or_create_session
)
await handler._process_node("R1", r1_node)
assert r1_node.session_id == "sess_R1"
# Grandchild C1 replies to R1 and must fork from sess_R1, not sess_A.
c1_incoming = incoming_message_factory(
text="C1", message_id="C1", reply_to_message_id="R1"
)
_, c1_node = await handler.tree_queue.add_to_tree(
"R1", "C1", c1_incoming, "status_C1"
)
async def _get_or_create_session_c1(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
async for ev in _gen_session(
[
{"type": "session_info", "session_id": "sess_C1"},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, "pending_C1", True
mock_cli_manager.get_or_create_session = AsyncMock(
side_effect=_get_or_create_session_c1
)
await handler._process_node("C1", c1_node)
# The last call should be for C1 and must resume from sess_R1.
assert calls[-1][0] == "C1"
assert calls[-1][1] == "sess_R1"
assert calls[-1][2] is True
|