web2api / core /api /conv_parser.py
ohmyapi's picture
feat: align hosted Space deployment with latest upstream
77169b4
"""
会话 ID 携带方式:任意字符串 → base64 → 零宽字符编码,用特殊零宽标记组包裹。
从对话内容中通过正则匹配起止标记提取会话 ID,与 session_id 的具体格式无关。
编码协议:
session_id (utf-8)
→ base64 (A-Za-z0-9+/=,最多 65 个不同符号)
→ 每个 base64 字符用 3 位 base-5 零宽字符表示(5³=125 ≥ 65)
→ 有效索引范围 0..64(64 个字符 + padding),故三元组首位最大为 2(3*25=75 > 64)
→ 因此首位为 ZW[3] 或 ZW[4] 的三元组绝不出现在正文中
→ HEAD_MARK/TAIL_MARK 正是利用首位 ≥ 3 的三元组构造,保证不会误中正文
"""
import base64
import re
from typing import Any
# 零宽字符集(5 个字符,基数 5,索引 0-4)
_ZERO_WIDTH = (
"\u200b", # 零宽空格 → 0
"\u200c", # 零宽非连接符 → 1
"\u200d", # 零宽连接符 → 2
"\ufeff", # 零宽非断空格 → 3
"\u180e", # 蒙古文元音分隔符 → 4
)
_ZW_SET = frozenset(_ZERO_WIDTH)
_ZW_TO_IDX = {c: i for i, c in enumerate(_ZERO_WIDTH)}
# base64 标准字符集(64 个字符),padding 符 "=" 用索引 64 表示
_B64_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
_B64_TO_IDX = {c: i for i, c in enumerate(_B64_CHARS)}
_PAD_IDX = 64 # "=" 的编码索引
# 起止标记:首位均为 ZW[3] 或 ZW[4],保证不出现在 payload 三元组中
_HEAD_MARK = _ZERO_WIDTH[4] * 3 + _ZERO_WIDTH[3] * 3 # 6 个零宽字符
_TAIL_MARK = _ZERO_WIDTH[3] * 3 + _ZERO_WIDTH[4] * 3 # 6 个零宽字符
_ZW_CLASS = r"[\u200b\u200c\u200d\ufeff\u180e]"
def _encode_b64idx(idx: int) -> str:
"""将 base64 字符索引 (0-64) 编码为 3 个零宽字符(3 位 base-5)。"""
a = idx // 25
r = idx % 25
b = r // 5
c = r % 5
return _ZERO_WIDTH[a] + _ZERO_WIDTH[b] + _ZERO_WIDTH[c]
def _decode_b64idx(zw3: str) -> int | None:
"""将 3 个零宽字符解码为 base64 字符索引(0-64),非法返回 None。"""
if len(zw3) != 3:
return None
a = _ZW_TO_IDX.get(zw3[0])
b = _ZW_TO_IDX.get(zw3[1])
c = _ZW_TO_IDX.get(zw3[2])
if a is None or b is None or c is None:
return None
val = a * 25 + b * 5 + c
if val > 64:
return None
return val
def encode_session_id(session_id: str) -> str:
"""
将任意字符串会话 ID 编码为不可见的零宽序列:
HEAD_MARK + zero_width_encoded(base64(utf-8(session_id))) + TAIL_MARK
"""
b64 = base64.b64encode(session_id.encode()).decode()
out: list[str] = []
for ch in b64:
if ch == "=":
out.append(_encode_b64idx(_PAD_IDX))
else:
idx = _B64_TO_IDX.get(ch)
if idx is None:
return ""
out.append(_encode_b64idx(idx))
return _HEAD_MARK + "".join(out) + _TAIL_MARK
def decode_session_id(text: str) -> str | None:
"""
从文本中提取第一个被标记包裹的会话 ID(解码零宽 → base64 → utf-8)。
若未找到有效标记或解码失败则返回 None。
"""
m = re.search(
re.escape(_HEAD_MARK) + r"(" + _ZW_CLASS + r"+?)" + re.escape(_TAIL_MARK),
text,
)
if not m:
return None
body = m.group(1)
if len(body) % 3 != 0:
return None
b64_chars: list[str] = []
for i in range(0, len(body), 3):
idx = _decode_b64idx(body[i : i + 3])
if idx is None:
return None
b64_chars.append("=" if idx == _PAD_IDX else _B64_CHARS[idx])
try:
return base64.b64decode("".join(b64_chars)).decode()
except Exception:
return None
def decode_latest_session_id(text: str) -> str | None:
"""
从文本中提取最后一个被标记包裹的会话 ID。
用于客户端保留完整历史时,优先命中最近一次返回的 session_id。
"""
matches = list(
re.finditer(
re.escape(_HEAD_MARK) + r"(" + _ZW_CLASS + r"+?)" + re.escape(_TAIL_MARK),
text,
)
)
if not matches:
return None
body = matches[-1].group(1)
if len(body) % 3 != 0:
return None
b64_chars: list[str] = []
for i in range(0, len(body), 3):
idx = _decode_b64idx(body[i : i + 3])
if idx is None:
return None
b64_chars.append("=" if idx == _PAD_IDX else _B64_CHARS[idx])
try:
return base64.b64decode("".join(b64_chars)).decode()
except Exception:
return None
def extract_session_id_marker(text: str) -> str:
"""
从文本中提取完整的零宽会话 ID 标记段(HEAD_MARK + body + TAIL_MARK),
用于在 tool_calls 的 text_content 中携带会话 ID 至下一轮对话。
若未找到则返回空字符串。
"""
m = re.search(
re.escape(_HEAD_MARK) + _ZW_CLASS + r"+?" + re.escape(_TAIL_MARK),
text,
)
return m.group(0) if m else ""
def session_id_suffix(session_id: str) -> str:
"""返回响应末尾需附加的不可见标记(含 HEAD/TAIL 包裹的零宽编码会话 ID)。"""
return encode_session_id(session_id)
def strip_session_id_suffix(text: str) -> str:
"""去掉文本中所有零宽会话 ID 标记段(HEAD_MARK...TAIL_MARK),返回干净正文。"""
return re.sub(
re.escape(_HEAD_MARK) + _ZW_CLASS + r"+?" + re.escape(_TAIL_MARK),
"",
text,
)
def _normalize_content(content: str | list[Any]) -> str:
if isinstance(content, str):
return content
parts: list[str] = []
for p in content:
if isinstance(p, dict) and p.get("type") == "text" and "text" in p:
parts.append(str(p["text"]))
elif isinstance(p, str):
parts.append(p)
return " ".join(parts)
def parse_conv_uuid_from_messages(messages: list[dict[str, Any]]) -> str | None:
"""从 messages 中解析最新会话 ID(从最后一条带标记的消息开始逆序查找)。"""
for m in reversed(messages):
content = m.get("content")
if content is None:
continue
text = _normalize_content(content)
decoded = decode_latest_session_id(text)
if decoded is not None:
return decoded
return None