| """ |
| 会话 ID 携带方式:任意字符串 → base64 → 零宽字符编码,用特殊零宽标记组包裹。 |
| 从对话内容中通过正则匹配起止标记提取会话 ID,与 session_id 的具体格式无关。 |
| |
| 编码协议: |
| session_id (utf-8) |
| → base64 (A-Za-z0-9+/=,最多 65 个不同符号) |
| → 每个 base64 字符用 3 位 base-5 零宽字符表示(5³=125 ≥ 65) |
| → 有效索引范围 0..64(64 个字符 + padding),故三元组首位最大为 2(3*25=75 > 64) |
| → 因此首位为 ZW[3] 或 ZW[4] 的三元组绝不出现在正文中 |
| → HEAD_MARK/TAIL_MARK 正是利用首位 ≥ 3 的三元组构造,保证不会误中正文 |
| """ |
|
|
| import base64 |
| import re |
| from typing import Any |
|
|
| |
| _ZERO_WIDTH = ( |
| "\u200b", |
| "\u200c", |
| "\u200d", |
| "\ufeff", |
| "\u180e", |
| ) |
| _ZW_SET = frozenset(_ZERO_WIDTH) |
| _ZW_TO_IDX = {c: i for i, c in enumerate(_ZERO_WIDTH)} |
|
|
| |
| _B64_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" |
| _B64_TO_IDX = {c: i for i, c in enumerate(_B64_CHARS)} |
| _PAD_IDX = 64 |
|
|
| |
| _HEAD_MARK = _ZERO_WIDTH[4] * 3 + _ZERO_WIDTH[3] * 3 |
| _TAIL_MARK = _ZERO_WIDTH[3] * 3 + _ZERO_WIDTH[4] * 3 |
|
|
| _ZW_CLASS = r"[\u200b\u200c\u200d\ufeff\u180e]" |
|
|
|
|
| def _encode_b64idx(idx: int) -> str: |
| """将 base64 字符索引 (0-64) 编码为 3 个零宽字符(3 位 base-5)。""" |
| a = idx // 25 |
| r = idx % 25 |
| b = r // 5 |
| c = r % 5 |
| return _ZERO_WIDTH[a] + _ZERO_WIDTH[b] + _ZERO_WIDTH[c] |
|
|
|
|
| def _decode_b64idx(zw3: str) -> int | None: |
| """将 3 个零宽字符解码为 base64 字符索引(0-64),非法返回 None。""" |
| if len(zw3) != 3: |
| return None |
| a = _ZW_TO_IDX.get(zw3[0]) |
| b = _ZW_TO_IDX.get(zw3[1]) |
| c = _ZW_TO_IDX.get(zw3[2]) |
| if a is None or b is None or c is None: |
| return None |
| val = a * 25 + b * 5 + c |
| if val > 64: |
| return None |
| return val |
|
|
|
|
| def encode_session_id(session_id: str) -> str: |
| """ |
| 将任意字符串会话 ID 编码为不可见的零宽序列: |
| HEAD_MARK + zero_width_encoded(base64(utf-8(session_id))) + TAIL_MARK |
| """ |
| b64 = base64.b64encode(session_id.encode()).decode() |
| out: list[str] = [] |
| for ch in b64: |
| if ch == "=": |
| out.append(_encode_b64idx(_PAD_IDX)) |
| else: |
| idx = _B64_TO_IDX.get(ch) |
| if idx is None: |
| return "" |
| out.append(_encode_b64idx(idx)) |
| return _HEAD_MARK + "".join(out) + _TAIL_MARK |
|
|
|
|
| def decode_session_id(text: str) -> str | None: |
| """ |
| 从文本中提取第一个被标记包裹的会话 ID(解码零宽 → base64 → utf-8)。 |
| 若未找到有效标记或解码失败则返回 None。 |
| """ |
| m = re.search( |
| re.escape(_HEAD_MARK) + r"(" + _ZW_CLASS + r"+?)" + re.escape(_TAIL_MARK), |
| text, |
| ) |
| if not m: |
| return None |
| body = m.group(1) |
| if len(body) % 3 != 0: |
| return None |
| b64_chars: list[str] = [] |
| for i in range(0, len(body), 3): |
| idx = _decode_b64idx(body[i : i + 3]) |
| if idx is None: |
| return None |
| b64_chars.append("=" if idx == _PAD_IDX else _B64_CHARS[idx]) |
| try: |
| return base64.b64decode("".join(b64_chars)).decode() |
| except Exception: |
| return None |
|
|
|
|
| def decode_latest_session_id(text: str) -> str | None: |
| """ |
| 从文本中提取最后一个被标记包裹的会话 ID。 |
| 用于客户端保留完整历史时,优先命中最近一次返回的 session_id。 |
| """ |
| matches = list( |
| re.finditer( |
| re.escape(_HEAD_MARK) + r"(" + _ZW_CLASS + r"+?)" + re.escape(_TAIL_MARK), |
| text, |
| ) |
| ) |
| if not matches: |
| return None |
| body = matches[-1].group(1) |
| if len(body) % 3 != 0: |
| return None |
| b64_chars: list[str] = [] |
| for i in range(0, len(body), 3): |
| idx = _decode_b64idx(body[i : i + 3]) |
| if idx is None: |
| return None |
| b64_chars.append("=" if idx == _PAD_IDX else _B64_CHARS[idx]) |
| try: |
| return base64.b64decode("".join(b64_chars)).decode() |
| except Exception: |
| return None |
|
|
|
|
| def extract_session_id_marker(text: str) -> str: |
| """ |
| 从文本中提取完整的零宽会话 ID 标记段(HEAD_MARK + body + TAIL_MARK), |
| 用于在 tool_calls 的 text_content 中携带会话 ID 至下一轮对话。 |
| 若未找到则返回空字符串。 |
| """ |
| m = re.search( |
| re.escape(_HEAD_MARK) + _ZW_CLASS + r"+?" + re.escape(_TAIL_MARK), |
| text, |
| ) |
| return m.group(0) if m else "" |
|
|
|
|
| def session_id_suffix(session_id: str) -> str: |
| """返回响应末尾需附加的不可见标记(含 HEAD/TAIL 包裹的零宽编码会话 ID)。""" |
| return encode_session_id(session_id) |
|
|
|
|
| def strip_session_id_suffix(text: str) -> str: |
| """去掉文本中所有零宽会话 ID 标记段(HEAD_MARK...TAIL_MARK),返回干净正文。""" |
| return re.sub( |
| re.escape(_HEAD_MARK) + _ZW_CLASS + r"+?" + re.escape(_TAIL_MARK), |
| "", |
| text, |
| ) |
|
|
|
|
| def _normalize_content(content: str | list[Any]) -> str: |
| if isinstance(content, str): |
| return content |
| parts: list[str] = [] |
| for p in content: |
| if isinstance(p, dict) and p.get("type") == "text" and "text" in p: |
| parts.append(str(p["text"])) |
| elif isinstance(p, str): |
| parts.append(p) |
| return " ".join(parts) |
|
|
|
|
| def parse_conv_uuid_from_messages(messages: list[dict[str, Any]]) -> str | None: |
| """从 messages 中解析最新会话 ID(从最后一条带标记的消息开始逆序查找)。""" |
| for m in reversed(messages): |
| content = m.get("content") |
| if content is None: |
| continue |
| text = _normalize_content(content) |
| decoded = decode_latest_session_id(text) |
| if decoded is not None: |
| return decoded |
| return None |
|
|