|
|
|
|
|
|
|
|
"""
|
|
|
Server Message Data 编解码工具
|
|
|
用于处理 Base64URL 编码的 protobuf 消息
|
|
|
"""
|
|
|
import base64
|
|
|
from typing import Dict, Optional, Tuple, Any
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
try:
|
|
|
from zoneinfo import ZoneInfo
|
|
|
except Exception:
|
|
|
ZoneInfo = None
|
|
|
|
|
|
|
|
|
class MessageCodec:
|
|
|
"""Server Message Data 编解码器"""
|
|
|
|
|
|
@staticmethod
|
|
|
def b64url_decode_padded(s: str) -> bytes:
|
|
|
"""Base64URL解码(带填充)"""
|
|
|
t = s.replace("-", "+").replace("_", "/")
|
|
|
pad = (-len(t)) % 4
|
|
|
if pad:
|
|
|
t += "=" * pad
|
|
|
return base64.b64decode(t)
|
|
|
|
|
|
@staticmethod
|
|
|
def b64url_encode_nopad(b: bytes) -> str:
|
|
|
"""Base64URL编码(无填充)"""
|
|
|
return base64.urlsafe_b64encode(b).decode("ascii").rstrip("=")
|
|
|
|
|
|
@staticmethod
|
|
|
def read_varint(buf: bytes, i: int) -> Tuple[int, int]:
|
|
|
"""读取varint格式的整数"""
|
|
|
shift = 0
|
|
|
val = 0
|
|
|
while i < len(buf):
|
|
|
b = buf[i]
|
|
|
i += 1
|
|
|
val |= (b & 0x7F) << shift
|
|
|
if not (b & 0x80):
|
|
|
return val, i
|
|
|
shift += 7
|
|
|
if shift > 63:
|
|
|
break
|
|
|
raise ValueError("invalid varint")
|
|
|
|
|
|
@staticmethod
|
|
|
def write_varint(v: int) -> bytes:
|
|
|
"""写入varint格式的整数"""
|
|
|
out = bytearray()
|
|
|
vv = int(v)
|
|
|
while True:
|
|
|
to_write = vv & 0x7F
|
|
|
vv >>= 7
|
|
|
if vv:
|
|
|
out.append(to_write | 0x80)
|
|
|
else:
|
|
|
out.append(to_write)
|
|
|
break
|
|
|
return bytes(out)
|
|
|
|
|
|
@classmethod
|
|
|
def make_key(cls, field_no: int, wire_type: int) -> bytes:
|
|
|
"""创建protobuf字段键"""
|
|
|
return cls.write_varint((field_no << 3) | wire_type)
|
|
|
|
|
|
@classmethod
|
|
|
def decode_timestamp(cls, buf: bytes) -> Tuple[Optional[int], Optional[int]]:
|
|
|
"""解码google.protobuf.Timestamp"""
|
|
|
i = 0
|
|
|
seconds: Optional[int] = None
|
|
|
nanos: Optional[int] = None
|
|
|
while i < len(buf):
|
|
|
key, i = cls.read_varint(buf, i)
|
|
|
field_no = key >> 3
|
|
|
wt = key & 0x07
|
|
|
if wt == 0:
|
|
|
val, i = cls.read_varint(buf, i)
|
|
|
if field_no == 1:
|
|
|
seconds = int(val)
|
|
|
elif field_no == 2:
|
|
|
nanos = int(val)
|
|
|
elif wt == 2:
|
|
|
ln, i2 = cls.read_varint(buf, i)
|
|
|
i = i2 + ln
|
|
|
elif wt == 1:
|
|
|
i += 8
|
|
|
elif wt == 5:
|
|
|
i += 4
|
|
|
else:
|
|
|
break
|
|
|
return seconds, nanos
|
|
|
|
|
|
@classmethod
|
|
|
def encode_timestamp(cls, seconds: Optional[int], nanos: Optional[int]) -> bytes:
|
|
|
"""编码google.protobuf.Timestamp"""
|
|
|
parts = bytearray()
|
|
|
if seconds is not None:
|
|
|
parts += cls.make_key(1, 0)
|
|
|
parts += cls.write_varint(int(seconds))
|
|
|
if nanos is not None:
|
|
|
parts += cls.make_key(2, 0)
|
|
|
parts += cls.write_varint(int(nanos))
|
|
|
return bytes(parts)
|
|
|
|
|
|
@classmethod
|
|
|
def decode_server_message_data(cls, b64url: str) -> Dict:
|
|
|
"""解码 Base64URL 的 server_message_data"""
|
|
|
try:
|
|
|
raw = cls.b64url_decode_padded(b64url)
|
|
|
except Exception as e:
|
|
|
return {"error": f"base64url decode failed: {e}", "raw_b64url": b64url}
|
|
|
|
|
|
i = 0
|
|
|
uuid: Optional[str] = None
|
|
|
seconds: Optional[int] = None
|
|
|
nanos: Optional[int] = None
|
|
|
|
|
|
while i < len(raw):
|
|
|
key, i = cls.read_varint(raw, i)
|
|
|
field_no = key >> 3
|
|
|
wt = key & 0x07
|
|
|
if wt == 2:
|
|
|
ln, i2 = cls.read_varint(raw, i)
|
|
|
i = i2
|
|
|
data = raw[i:i+ln]
|
|
|
i += ln
|
|
|
if field_no == 1:
|
|
|
try:
|
|
|
uuid = data.decode("utf-8")
|
|
|
except Exception:
|
|
|
uuid = None
|
|
|
elif field_no == 3:
|
|
|
seconds, nanos = cls.decode_timestamp(data)
|
|
|
elif wt == 0:
|
|
|
_, i = cls.read_varint(raw, i)
|
|
|
elif wt == 1:
|
|
|
i += 8
|
|
|
elif wt == 5:
|
|
|
i += 4
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
out: Dict[str, Any] = {}
|
|
|
if uuid is not None:
|
|
|
out["uuid"] = uuid
|
|
|
if seconds is not None:
|
|
|
out["seconds"] = seconds
|
|
|
if nanos is not None:
|
|
|
out["nanos"] = nanos
|
|
|
return out
|
|
|
|
|
|
@classmethod
|
|
|
def encode_server_message_data(cls, uuid: str = None, seconds: int = None, nanos: int = None) -> str:
|
|
|
"""将 uuid/seconds/nanos 组合编码为 Base64URL 字符串"""
|
|
|
parts = bytearray()
|
|
|
if uuid:
|
|
|
b = uuid.encode("utf-8")
|
|
|
parts += cls.make_key(1, 2)
|
|
|
parts += cls.write_varint(len(b))
|
|
|
parts += b
|
|
|
|
|
|
if seconds is not None or nanos is not None:
|
|
|
ts = cls.encode_timestamp(seconds, nanos)
|
|
|
parts += cls.make_key(3, 2)
|
|
|
parts += cls.write_varint(len(ts))
|
|
|
parts += ts
|
|
|
|
|
|
return cls.b64url_encode_nopad(bytes(parts))
|
|
|
|
|
|
|
|
|
|
|
|
decode_server_message_data = MessageCodec.decode_server_message_data
|
|
|
encode_server_message_data = MessageCodec.encode_server_message_data |