File size: 5,201 Bytes
6172a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from unittest.mock import AsyncMock, MagicMock

import pytest

from messaging.handler import ClaudeMessageHandler
from messaging.trees.data import MessageState


async def _gen_session(events):
    for e in events:
        yield e


@pytest.fixture
def handler(mock_platform, mock_cli_manager, mock_session_store):
    return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store)


@pytest.mark.asyncio
async def test_sibling_replies_fork_from_parent_session_id(
    handler, mock_cli_manager, incoming_message_factory
):
    # Root node A with a known session_id.
    root_incoming = incoming_message_factory(text="A", message_id="A")
    tree = await handler.tree_queue.create_tree(
        node_id="A", incoming=root_incoming, status_message_id="status_A"
    )
    await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")

    # Add two sibling replies R1 and R2 under A.
    r1_incoming = incoming_message_factory(
        text="R1", message_id="R1", reply_to_message_id="A"
    )
    r2_incoming = incoming_message_factory(
        text="R2", message_id="R2", reply_to_message_id="A"
    )
    _, r1_node = await handler.tree_queue.add_to_tree(
        "A", "R1", r1_incoming, "status_R1"
    )
    _, r2_node = await handler.tree_queue.add_to_tree(
        "A", "R2", r2_incoming, "status_R2"
    )

    # Mock a fresh cli_session per node.
    calls = []

    async def _get_or_create_session(session_id=None):
        cli_session = MagicMock()

        async def _start_task(prompt, session_id=None, fork_session=False):
            calls.append((prompt, session_id, fork_session))
            child_sid = f"sess_{prompt}"
            async for ev in _gen_session(
                [
                    {"type": "session_info", "session_id": child_sid},
                    {"type": "exit", "code": 0, "stderr": None},
                ]
            ):
                yield ev

        cli_session.start_task = _start_task
        return cli_session, f"pending_{len(calls) + 1}", True

    mock_cli_manager.get_or_create_session = AsyncMock(
        side_effect=_get_or_create_session
    )

    await handler._process_node("R1", r1_node)
    await handler._process_node("R2", r2_node)

    # Both siblings must resume from the same parent session and fork.
    assert calls[0][0] == "R1"
    assert calls[0][1] == "sess_A"
    assert calls[0][2] is True

    assert calls[1][0] == "R2"
    assert calls[1][1] == "sess_A"
    assert calls[1][2] is True


@pytest.mark.asyncio
async def test_grandchild_reply_forks_from_branch_session(
    handler, mock_cli_manager, incoming_message_factory
):
    root_incoming = incoming_message_factory(text="A", message_id="A")
    tree = await handler.tree_queue.create_tree(
        node_id="A", incoming=root_incoming, status_message_id="status_A"
    )
    await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")

    r1_incoming = incoming_message_factory(
        text="R1", message_id="R1", reply_to_message_id="A"
    )
    _, r1_node = await handler.tree_queue.add_to_tree(
        "A", "R1", r1_incoming, "status_R1"
    )

    calls = []

    async def _get_or_create_session(session_id=None):
        cli_session = MagicMock()

        async def _start_task(prompt, session_id=None, fork_session=False):
            calls.append((prompt, session_id, fork_session))
            # R1 gets its own forked session id.
            child_sid = "sess_R1"
            async for ev in _gen_session(
                [
                    {"type": "session_info", "session_id": child_sid},
                    {"type": "exit", "code": 0, "stderr": None},
                ]
            ):
                yield ev

        cli_session.start_task = _start_task
        return cli_session, "pending_R1", True

    mock_cli_manager.get_or_create_session = AsyncMock(
        side_effect=_get_or_create_session
    )

    await handler._process_node("R1", r1_node)
    assert r1_node.session_id == "sess_R1"

    # Grandchild C1 replies to R1 and must fork from sess_R1, not sess_A.
    c1_incoming = incoming_message_factory(
        text="C1", message_id="C1", reply_to_message_id="R1"
    )
    _, c1_node = await handler.tree_queue.add_to_tree(
        "R1", "C1", c1_incoming, "status_C1"
    )

    async def _get_or_create_session_c1(session_id=None):
        cli_session = MagicMock()

        async def _start_task(prompt, session_id=None, fork_session=False):
            calls.append((prompt, session_id, fork_session))
            async for ev in _gen_session(
                [
                    {"type": "session_info", "session_id": "sess_C1"},
                    {"type": "exit", "code": 0, "stderr": None},
                ]
            ):
                yield ev

        cli_session.start_task = _start_task
        return cli_session, "pending_C1", True

    mock_cli_manager.get_or_create_session = AsyncMock(
        side_effect=_get_or_create_session_c1
    )

    await handler._process_node("C1", c1_node)

    # The last call should be for C1 and must resume from sess_R1.
    assert calls[-1][0] == "C1"
    assert calls[-1][1] == "sess_R1"
    assert calls[-1][2] is True