File size: 5,792 Bytes
95b305d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | # from __future__ import annotations
from pathlib import Path
import uuid
from datetime import datetime, timezone
import json, os
from typing import List, Dict, Tuple, Optional
# ============ 工具函数 ============
def mk_msg_dir(BASE_MSG_DIR) -> str:
m_id = datetime.now().strftime("%Y%m%d-%H%M%S-") + uuid.uuid4().hex[:6]
Path(BASE_MSG_DIR, m_id).mkdir(parents=True, exist_ok=True)
return m_id # 只返回 ID
def _as_dir(BASE_MSG_DIR, m_id: str) -> str:
# 统一把传入值规整为 ./msgs/<ID>
return Path(BASE_MSG_DIR, m_id)
def msg2hist(persona, msg):
chat_history = []
if msg != None:
if len(msg)>0:
chat_history = msg.copy() # 外层列表浅拷
chat_history[0] = msg[0].copy() # 这个字典单独拷
chat_history[0]['content'] = chat_history[0]['content'][len(persona):]
return chat_history
def render(tok, messages: List[Dict[str, str]]) -> str:
"""按 chat_template 渲染成最终提示词文本(不分词)。"""
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def _ensure_alternating(messages):
if not messages:
return
if messages[0]["role"] != "user":
raise ValueError("messages[0] 必须是 'user'(你的模板要求从 user 开始)")
for i, m in enumerate(messages):
expect_user = (i % 2 == 0)
if (m["role"] == "user") != expect_user:
raise ValueError(f"对话必须严格交替 user/assistant,在索引 {i} 处发现 {m['role']}")
def trim_by_tokens(tok, messages, prompt_budget):
"""
只保留 messages[0](persona 的 user)+ 一个“从奇数索引开始的后缀”,
用二分法找到能放下的最长后缀。这样可保证交替不被破坏。
"""
if not messages:
return []
# _ensure_alternating(messages)
# 只有 persona 这一条时,直接返回
if len(messages) == 1:
return messages
# 允许的后缀起点:奇数索引(index 1,3,5,... 都是 assistant),
# 这样拼接到 index0(user) 后才能保持交替。
cand_idx = [k for k in range(1, len(messages)) if k % 2 == 1]
# 如果任何也放不下,就只留 persona
best = [messages[0]]
# 二分:起点越靠前 → 保留消息越多 → token 越大(单调)
lo, hi = 0, len(cand_idx) - 1
while lo <= hi:
mid = (lo + hi) // 2
k = cand_idx[mid]
candidate = [messages[0]] + messages[k:]
toks = len(tok(tok.apply_chat_template(candidate, tokenize=False),
add_special_tokens=False).input_ids)
if toks <= prompt_budget:
best = candidate # 能放下:尝试保留更多(向左走)
hi = mid - 1
else:
lo = mid + 1 # 放不下:丢更多旧消息(向右走)
return best
# ============ 原子写 可能会和onedrive同步冲突============
# def atomic_write_json(path: Path, data) -> None:
# tmp = path.with_suffix(path.suffix + ".tmp")
# with open(tmp, "w", encoding="utf-8") as f:
# json.dump(data, f, ensure_ascii=False, indent=2)
# f.flush()
# os.fsync(f.fileno())
# os.replace(tmp, path) # 同目录原子替换
# 直接覆盖
def write_json_overwrite(path: Path, data) -> None:
with open(path, "w", encoding="utf-8", newline="\n") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
# ============ 存储层 ============
class MsgStore:
def __init__(self, base_dir: str | Path = "./msgs"):
self.base = Path(base_dir)
self.base.mkdir(parents=True, exist_ok=True)
self.archive = self.base / "archive.jsonl" # 只追加
self.trimmed = self.base / "trimmed.json" # 当前上下文
if not self.archive.exists():
self.archive.write_text("", encoding="utf-8")
if not self.trimmed.exists():
self.trimmed.write_text("[]", encoding="utf-8")
def load_trimmed(self) -> List[Dict[str, str]]:
try:
return json.loads(self.trimmed.read_text(encoding="utf-8"))
except Exception:
return []
def save_trimmed(self, messages: List[Dict[str, str]]) -> None:
write_json_overwrite(self.trimmed, messages)
def append_archive(self, role: str, content: str, meta: dict | None = None) -> None:
rec = {"ts": datetime.now(timezone.utc).isoformat(), "role": role, "content": content}
if meta: rec["meta"] = meta
with open(self.archive, "a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
f.flush(); os.fsync(f.fileno())
# ============ 显式保存(手动调用才落盘) ============
def persist_messages(
messages: List[Dict[str, str]],
store_dir: str | Path = "./msgs",
archive_last_turn: bool = True,
) -> None:
store = MsgStore(store_dir)
# _ensure_alternating(messages)
# 1) 覆写 trimmed.json(原子)
store.save_trimmed(messages)
# 2) 追加最近一轮到 archive.jsonl(可选)
if not archive_last_turn:
return
# 从尾部向前找最近的一对 (user, assistant)
pair = None
for i in range(len(messages) - 2, -1, -1):
if (
messages[i]["role"] == "user"
and i + 1 < len(messages)
and messages[i + 1]["role"] == "assistant"
):
pair = (messages[i]["content"], messages[i + 1]["content"])
break
if pair:
u, a = pair
store.append_archive("user", u)
store.append_archive("assistant", a)
# 若没有找到成对(比如你在生成前就调用了 persist),就只写 trimmed,不归档
|