File size: 4,949 Bytes
7d4338a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
from datetime import datetime

from helpers.api import ApiHandler, Input, Output, Request, Response
from helpers.persist_chat import (
    _serialize_context,
    _deserialize_context,
    save_tmp_chat,
)
from agent import AgentContext


def _trim_history_json(history_json: str, kept_ids: set[str], after_cut_ids: set[str]) -> str:
    """Trim a serialized history JSON to keep only messages that appear
    before the branch cut point.

    *kept_ids*: IDs of log items up to and including the cut point.
    *after_cut_ids*: IDs of log items after the cut point.

    Walk messages in order.  A message is kept while the running state
    is "keep".  The state flips to "drop" the first time we encounter
    a message whose id is in *after_cut_ids*.  Messages whose id does
    not appear in any log set (unpaired) inherit the current state.
    Summarized topics/bulks are always preserved.
    """
    if not history_json:
        return history_json

    hist = json.loads(history_json)
    keep = True  # running state

    def filter_messages(messages: list[dict]) -> list[dict]:
        nonlocal keep
        result = []
        for msg in messages:
            mid = msg.get("id", "")
            if mid and mid in after_cut_ids:
                keep = False
            elif mid and mid in kept_ids:
                keep = True
            # else: unpaired – inherit current state
            if keep:
                result.append(msg)
        return result

    # Bulks are already summarized old history – always keep
    # Topics: keep summarized ones; filter unsummarized
    trimmed_topics = []
    for topic in hist.get("topics", []):
        if topic.get("summary"):
            trimmed_topics.append(topic)
            continue
        msgs = filter_messages(topic.get("messages", []))
        if msgs:
            topic["messages"] = msgs
            trimmed_topics.append(topic)
        if not keep:
            break
    hist["topics"] = trimmed_topics

    # Current topic
    current = hist.get("current", {})
    if not current.get("summary") and keep:
        current["messages"] = filter_messages(current.get("messages", []))
    elif not keep:
        current["messages"] = []
    hist["current"] = current

    # Recount
    total = sum(
        len(t.get("messages", [])) for t in hist["topics"] if not t.get("summary")
    ) + len(hist.get("current", {}).get("messages", []))
    hist["counter"] = total

    return json.dumps(hist, ensure_ascii=False)


class BranchChat(ApiHandler):
    """Create a new chat branched from an existing chat at a specific log message."""

    async def process(self, input: Input, request: Request) -> Output:
        ctxid = input.get("context", "")
        log_no = input.get("log_no")  # LogItem.no from frontend

        if not ctxid:
            return Response("Missing context id", 400)
        if log_no is None:
            return Response("Missing log_no", 400)

        context = AgentContext.get(ctxid)
        if not context:
            return Response("Context not found", 404)

        # Serialize the source context
        data = _serialize_context(context)

        # Remove id so _deserialize_context generates a new one
        del data["id"]

        # Trim log entries: keep only items up to and including log_no.
        src_logs = data["log"]["logs"]
        cut_idx = None
        for i, item in enumerate(src_logs):
            if item["no"] == log_no:
                cut_idx = i
                break

        if cut_idx is None:
            if 0 <= log_no < len(src_logs):
                cut_idx = log_no
            else:
                return Response("log_no not found in chat log", 400)

        kept_logs = src_logs[: cut_idx + 1]
        after_logs = src_logs[cut_idx + 1 :]
        data["log"]["logs"] = kept_logs

        # Build ID sets for history trimming
        kept_ids = {item["id"] for item in kept_logs if item.get("id")}
        after_cut_ids = {item["id"] for item in after_logs if item.get("id")}

        # Trim each agent's history using ID matching
        for ag in data.get("agents", []):
            ag["history"] = _trim_history_json(
                ag.get("history", ""), kept_ids, after_cut_ids
            )

        # Give the branch a distinguishable name
        src_name = data.get("name") or "Chat"
        data["name"] = f"{src_name} (branch)"
        data["created_at"] = datetime.now().isoformat()

        # Deserialize into a brand-new context (new id, fresh agent config)
        new_context = _deserialize_context(data)

        # Persist immediately
        save_tmp_chat(new_context)

        # Notify all tabs
        from helpers.state_monitor_integration import mark_dirty_all
        mark_dirty_all(reason="plugins.chat_branching.BranchChat")

        return {
            "ok": True,
            "ctxid": new_context.id,
            "message": "Chat branched successfully.",
        }