Spaces:
Paused
Paused
File size: 4,949 Bytes
7d4338a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import json
from datetime import datetime
from helpers.api import ApiHandler, Input, Output, Request, Response
from helpers.persist_chat import (
_serialize_context,
_deserialize_context,
save_tmp_chat,
)
from agent import AgentContext
def _trim_history_json(history_json: str, kept_ids: set[str], after_cut_ids: set[str]) -> str:
"""Trim a serialized history JSON to keep only messages that appear
before the branch cut point.
*kept_ids*: IDs of log items up to and including the cut point.
*after_cut_ids*: IDs of log items after the cut point.
Walk messages in order. A message is kept while the running state
is "keep". The state flips to "drop" the first time we encounter
a message whose id is in *after_cut_ids*. Messages whose id does
not appear in any log set (unpaired) inherit the current state.
Summarized topics/bulks are always preserved.
"""
if not history_json:
return history_json
hist = json.loads(history_json)
keep = True # running state
def filter_messages(messages: list[dict]) -> list[dict]:
nonlocal keep
result = []
for msg in messages:
mid = msg.get("id", "")
if mid and mid in after_cut_ids:
keep = False
elif mid and mid in kept_ids:
keep = True
# else: unpaired – inherit current state
if keep:
result.append(msg)
return result
# Bulks are already summarized old history – always keep
# Topics: keep summarized ones; filter unsummarized
trimmed_topics = []
for topic in hist.get("topics", []):
if topic.get("summary"):
trimmed_topics.append(topic)
continue
msgs = filter_messages(topic.get("messages", []))
if msgs:
topic["messages"] = msgs
trimmed_topics.append(topic)
if not keep:
break
hist["topics"] = trimmed_topics
# Current topic
current = hist.get("current", {})
if not current.get("summary") and keep:
current["messages"] = filter_messages(current.get("messages", []))
elif not keep:
current["messages"] = []
hist["current"] = current
# Recount
total = sum(
len(t.get("messages", [])) for t in hist["topics"] if not t.get("summary")
) + len(hist.get("current", {}).get("messages", []))
hist["counter"] = total
return json.dumps(hist, ensure_ascii=False)
class BranchChat(ApiHandler):
"""Create a new chat branched from an existing chat at a specific log message."""
async def process(self, input: Input, request: Request) -> Output:
ctxid = input.get("context", "")
log_no = input.get("log_no") # LogItem.no from frontend
if not ctxid:
return Response("Missing context id", 400)
if log_no is None:
return Response("Missing log_no", 400)
context = AgentContext.get(ctxid)
if not context:
return Response("Context not found", 404)
# Serialize the source context
data = _serialize_context(context)
# Remove id so _deserialize_context generates a new one
del data["id"]
# Trim log entries: keep only items up to and including log_no.
src_logs = data["log"]["logs"]
cut_idx = None
for i, item in enumerate(src_logs):
if item["no"] == log_no:
cut_idx = i
break
if cut_idx is None:
if 0 <= log_no < len(src_logs):
cut_idx = log_no
else:
return Response("log_no not found in chat log", 400)
kept_logs = src_logs[: cut_idx + 1]
after_logs = src_logs[cut_idx + 1 :]
data["log"]["logs"] = kept_logs
# Build ID sets for history trimming
kept_ids = {item["id"] for item in kept_logs if item.get("id")}
after_cut_ids = {item["id"] for item in after_logs if item.get("id")}
# Trim each agent's history using ID matching
for ag in data.get("agents", []):
ag["history"] = _trim_history_json(
ag.get("history", ""), kept_ids, after_cut_ids
)
# Give the branch a distinguishable name
src_name = data.get("name") or "Chat"
data["name"] = f"{src_name} (branch)"
data["created_at"] = datetime.now().isoformat()
# Deserialize into a brand-new context (new id, fresh agent config)
new_context = _deserialize_context(data)
# Persist immediately
save_tmp_chat(new_context)
# Notify all tabs
from helpers.state_monitor_integration import mark_dirty_all
mark_dirty_all(reason="plugins.chat_branching.BranchChat")
return {
"ok": True,
"ctxid": new_context.id,
"message": "Chat branched successfully.",
} |