|
|
|
|
|
from pathlib import Path |
|
|
import uuid |
|
|
from datetime import datetime, timezone |
|
|
import json, os |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
|
|
|
|
|
|
def mk_msg_dir(BASE_MSG_DIR) -> str: |
|
|
m_id = datetime.now().strftime("%Y%m%d-%H%M%S-") + uuid.uuid4().hex[:6] |
|
|
Path(BASE_MSG_DIR, m_id).mkdir(parents=True, exist_ok=True) |
|
|
return m_id |
|
|
|
|
|
def _as_dir(BASE_MSG_DIR, m_id: str) -> str: |
|
|
|
|
|
return Path(BASE_MSG_DIR, m_id) |
|
|
|
|
|
def msg2hist(persona, msg): |
|
|
chat_history = [] |
|
|
if msg != None: |
|
|
if len(msg)>0: |
|
|
chat_history = msg.copy() |
|
|
chat_history[0] = msg[0].copy() |
|
|
chat_history[0]['content'] = chat_history[0]['content'][len(persona):] |
|
|
return chat_history |
|
|
|
|
|
def render(tok, messages: List[Dict[str, str]]) -> str: |
|
|
"""按 chat_template 渲染成最终提示词文本(不分词)。""" |
|
|
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
def _ensure_alternating(messages): |
|
|
if not messages: |
|
|
return |
|
|
if messages[0]["role"] != "user": |
|
|
raise ValueError("messages[0] 必须是 'user'(你的模板要求从 user 开始)") |
|
|
for i, m in enumerate(messages): |
|
|
expect_user = (i % 2 == 0) |
|
|
if (m["role"] == "user") != expect_user: |
|
|
raise ValueError(f"对话必须严格交替 user/assistant,在索引 {i} 处发现 {m['role']}") |
|
|
|
|
|
def trim_by_tokens(tok, messages, prompt_budget): |
|
|
""" |
|
|
只保留 messages[0](persona 的 user)+ 一个“从奇数索引开始的后缀”, |
|
|
用二分法找到能放下的最长后缀。这样可保证交替不被破坏。 |
|
|
""" |
|
|
if not messages: |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(messages) == 1: |
|
|
return messages |
|
|
|
|
|
|
|
|
|
|
|
cand_idx = [k for k in range(1, len(messages)) if k % 2 == 1] |
|
|
|
|
|
|
|
|
best = [messages[0]] |
|
|
|
|
|
|
|
|
lo, hi = 0, len(cand_idx) - 1 |
|
|
while lo <= hi: |
|
|
mid = (lo + hi) // 2 |
|
|
k = cand_idx[mid] |
|
|
candidate = [messages[0]] + messages[k:] |
|
|
toks = len(tok(tok.apply_chat_template(candidate, tokenize=False), |
|
|
add_special_tokens=False).input_ids) |
|
|
if toks <= prompt_budget: |
|
|
best = candidate |
|
|
hi = mid - 1 |
|
|
else: |
|
|
lo = mid + 1 |
|
|
|
|
|
return best |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_json_overwrite(path: Path, data) -> None: |
|
|
with open(path, "w", encoding="utf-8", newline="\n") as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
class MsgStore: |
|
|
def __init__(self, base_dir: str | Path = "./msgs"): |
|
|
self.base = Path(base_dir) |
|
|
self.base.mkdir(parents=True, exist_ok=True) |
|
|
self.archive = self.base / "archive.jsonl" |
|
|
self.trimmed = self.base / "trimmed.json" |
|
|
if not self.archive.exists(): |
|
|
self.archive.write_text("", encoding="utf-8") |
|
|
if not self.trimmed.exists(): |
|
|
self.trimmed.write_text("[]", encoding="utf-8") |
|
|
|
|
|
def load_trimmed(self) -> List[Dict[str, str]]: |
|
|
try: |
|
|
return json.loads(self.trimmed.read_text(encoding="utf-8")) |
|
|
except Exception: |
|
|
return [] |
|
|
|
|
|
def save_trimmed(self, messages: List[Dict[str, str]]) -> None: |
|
|
write_json_overwrite(self.trimmed, messages) |
|
|
|
|
|
def append_archive(self, role: str, content: str, meta: dict | None = None) -> None: |
|
|
rec = {"ts": datetime.now(timezone.utc).isoformat(), "role": role, "content": content} |
|
|
if meta: rec["meta"] = meta |
|
|
with open(self.archive, "a", encoding="utf-8") as f: |
|
|
f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
|
|
f.flush(); os.fsync(f.fileno()) |
|
|
|
|
|
|
|
|
def persist_messages( |
|
|
messages: List[Dict[str, str]], |
|
|
store_dir: str | Path = "./msgs", |
|
|
archive_last_turn: bool = True, |
|
|
) -> None: |
|
|
store = MsgStore(store_dir) |
|
|
|
|
|
|
|
|
|
|
|
store.save_trimmed(messages) |
|
|
|
|
|
|
|
|
if not archive_last_turn: |
|
|
return |
|
|
|
|
|
|
|
|
pair = None |
|
|
for i in range(len(messages) - 2, -1, -1): |
|
|
if ( |
|
|
messages[i]["role"] == "user" |
|
|
and i + 1 < len(messages) |
|
|
and messages[i + 1]["role"] == "assistant" |
|
|
): |
|
|
pair = (messages[i]["content"], messages[i + 1]["content"]) |
|
|
break |
|
|
|
|
|
if pair: |
|
|
u, a = pair |
|
|
store.append_archive("user", u) |
|
|
store.append_archive("assistant", a) |
|
|
|
|
|
|