File size: 5,946 Bytes
621645b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

Server Message Data 编解码工具

用于处理 Base64URL 编码的 protobuf 消息

"""
import base64
from typing import Dict, Optional, Tuple, Any
from datetime import datetime, timezone

try:
    from zoneinfo import ZoneInfo  # Python 3.9+
except Exception:
    ZoneInfo = None  # type: ignore


class MessageCodec:
    """Server Message Data 编解码器"""
    
    @staticmethod
    def b64url_decode_padded(s: str) -> bytes:
        """Base64URL解码(带填充)"""
        t = s.replace("-", "+").replace("_", "/")
        pad = (-len(t)) % 4
        if pad:
            t += "=" * pad
        return base64.b64decode(t)
    
    @staticmethod
    def b64url_encode_nopad(b: bytes) -> str:
        """Base64URL编码(无填充)"""
        return base64.urlsafe_b64encode(b).decode("ascii").rstrip("=")
    
    @staticmethod
    def read_varint(buf: bytes, i: int) -> Tuple[int, int]:
        """读取varint格式的整数"""
        shift = 0
        val = 0
        while i < len(buf):
            b = buf[i]
            i += 1
            val |= (b & 0x7F) << shift
            if not (b & 0x80):
                return val, i
            shift += 7
            if shift > 63:
                break
        raise ValueError("invalid varint")
    
    @staticmethod
    def write_varint(v: int) -> bytes:
        """写入varint格式的整数"""
        out = bytearray()
        vv = int(v)
        while True:
            to_write = vv & 0x7F
            vv >>= 7
            if vv:
                out.append(to_write | 0x80)
            else:
                out.append(to_write)
                break
        return bytes(out)
    
    @classmethod
    def make_key(cls, field_no: int, wire_type: int) -> bytes:
        """创建protobuf字段键"""
        return cls.write_varint((field_no << 3) | wire_type)
    
    @classmethod
    def decode_timestamp(cls, buf: bytes) -> Tuple[Optional[int], Optional[int]]:
        """解码google.protobuf.Timestamp"""
        i = 0
        seconds: Optional[int] = None
        nanos: Optional[int] = None
        while i < len(buf):
            key, i = cls.read_varint(buf, i)
            field_no = key >> 3
            wt = key & 0x07
            if wt == 0:  # varint
                val, i = cls.read_varint(buf, i)
                if field_no == 1:
                    seconds = int(val)
                elif field_no == 2:
                    nanos = int(val)
            elif wt == 2:  # length-delimited
                ln, i2 = cls.read_varint(buf, i)
                i = i2 + ln
            elif wt == 1:
                i += 8
            elif wt == 5:
                i += 4
            else:
                break
        return seconds, nanos
    
    @classmethod
    def encode_timestamp(cls, seconds: Optional[int], nanos: Optional[int]) -> bytes:
        """编码google.protobuf.Timestamp"""
        parts = bytearray()
        if seconds is not None:
            parts += cls.make_key(1, 0)  # field 1, varint
            parts += cls.write_varint(int(seconds))
        if nanos is not None:
            parts += cls.make_key(2, 0)  # field 2, varint
            parts += cls.write_varint(int(nanos))
        return bytes(parts)
    
    @classmethod
    def decode_server_message_data(cls, b64url: str) -> Dict:
        """解码 Base64URL 的 server_message_data"""
        try:
            raw = cls.b64url_decode_padded(b64url)
        except Exception as e:
            return {"error": f"base64url decode failed: {e}", "raw_b64url": b64url}
        
        i = 0
        uuid: Optional[str] = None
        seconds: Optional[int] = None
        nanos: Optional[int] = None
        
        while i < len(raw):
            key, i = cls.read_varint(raw, i)
            field_no = key >> 3
            wt = key & 0x07
            if wt == 2:  # length-delimited
                ln, i2 = cls.read_varint(raw, i)
                i = i2
                data = raw[i:i+ln]
                i += ln
                if field_no == 1:  # uuid string
                    try:
                        uuid = data.decode("utf-8")
                    except Exception:
                        uuid = None
                elif field_no == 3:  # google.protobuf.Timestamp
                    seconds, nanos = cls.decode_timestamp(data)
            elif wt == 0:  # varint
                _, i = cls.read_varint(raw, i)
            elif wt == 1:
                i += 8
            elif wt == 5:
                i += 4
            else:
                break
        
        out: Dict[str, Any] = {}
        if uuid is not None:
            out["uuid"] = uuid
        if seconds is not None:
            out["seconds"] = seconds
        if nanos is not None:
            out["nanos"] = nanos
        return out
    
    @classmethod
    def encode_server_message_data(cls, uuid: str = None, seconds: int = None, nanos: int = None) -> str:
        """将 uuid/seconds/nanos 组合编码为 Base64URL 字符串"""
        parts = bytearray()
        if uuid:
            b = uuid.encode("utf-8")
            parts += cls.make_key(1, 2)  # field 1, length-delimited
            parts += cls.write_varint(len(b))
            parts += b
        
        if seconds is not None or nanos is not None:
            ts = cls.encode_timestamp(seconds, nanos)
            parts += cls.make_key(3, 2)  # field 3, length-delimited
            parts += cls.write_varint(len(ts))
            parts += ts
        
        return cls.b64url_encode_nopad(bytes(parts))


# 为了向后兼容,提供简单的函数接口
decode_server_message_data = MessageCodec.decode_server_message_data
encode_server_message_data = MessageCodec.encode_server_message_data