File size: 20,315 Bytes
c1ae554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

Warp Protobuf编解码服务器启动文件



纯protobuf编解码服务器,提供JSON<->Protobuf转换、WebSocket监控和静态文件服务。

"""

from typing import Dict, Optional, Tuple
import base64
from pathlib import Path
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from fastapi import Query, HTTPException
from fastapi.responses import Response

# 新增:类型导入
from typing import Any

from warp2protobuf.api.protobuf_routes import app as protobuf_app
from warp2protobuf.core.logging import logger, set_log_file
from warp2protobuf.api.protobuf_routes import EncodeRequest, _encode_smd_inplace
from warp2protobuf.core.protobuf_utils import dict_to_protobuf_bytes
from warp2protobuf.core.schema_sanitizer import sanitize_mcp_input_schema_in_packet
from warp2protobuf.core.auth import acquire_anonymous_access_token
from warp2protobuf.config.models import get_all_unique_models


# ============= 工具:input_schema 清理与校验 =============


def _is_empty_value(value: Any) -> bool:
    if value is None:
        return True
    if isinstance(value, str) and value.strip() == "":
        return True
    if isinstance(value, (list, dict)) and len(value) == 0:
        return True
    return False


def _deep_clean(value: Any) -> Any:
    if isinstance(value, dict):
        cleaned: Dict[str, Any] = {}
        for k, v in value.items():
            vv = _deep_clean(v)
            if _is_empty_value(vv):
                continue
            cleaned[k] = vv
        return cleaned
    if isinstance(value, list):
        cleaned_list = []
        for item in value:
            ii = _deep_clean(item)
            if _is_empty_value(ii):
                continue
            cleaned_list.append(ii)
        return cleaned_list
    if isinstance(value, str):
        return value.strip()
    return value


def _infer_type_for_property(prop_name: str) -> str:
    name = prop_name.lower()
    if name in ("url", "uri", "href", "link"):
        return "string"
    if name in ("headers", "options", "params", "payload", "data"):
        return "object"
    return "string"


def _ensure_property_schema(name: str, schema: Dict[str, Any]) -> Dict[str, Any]:
    prop = dict(schema) if isinstance(schema, dict) else {}
    prop = _deep_clean(prop)

    # 必填:type & description
    if (
        "type" not in prop
        or not isinstance(prop.get("type"), str)
        or not prop["type"].strip()
    ):
        prop["type"] = _infer_type_for_property(name)
    if (
        "description" not in prop
        or not isinstance(prop.get("description"), str)
        or not prop["description"].strip()
    ):
        prop["description"] = f"{name} parameter"

    # 特殊处理 headers:必须是对象,且其 properties 不能是空
    if name.lower() == "headers":
        prop["type"] = "object"
        headers_props = prop.get("properties")
        if not isinstance(headers_props, dict):
            headers_props = {}
        headers_props = _deep_clean(headers_props)
        if not headers_props:
            headers_props = {
                "user-agent": {
                    "type": "string",
                    "description": "User-Agent header for the request",
                }
            }
        else:
            # 清理并保证每个 header 的子属性都具备 type/description
            fixed_headers: Dict[str, Any] = {}
            for hk, hv in headers_props.items():
                sub = _deep_clean(hv if isinstance(hv, dict) else {})
                if (
                    "type" not in sub
                    or not isinstance(sub.get("type"), str)
                    or not sub["type"].strip()
                ):
                    sub["type"] = "string"
                if (
                    "description" not in sub
                    or not isinstance(sub.get("description"), str)
                    or not sub["description"].strip()
                ):
                    sub["description"] = f"{hk} header"
                fixed_headers[hk] = sub
            headers_props = fixed_headers
        prop["properties"] = headers_props
        # 处理 required 空数组
        if isinstance(prop.get("required"), list):
            req = [
                r for r in prop["required"] if isinstance(r, str) and r in headers_props
            ]
            if req:
                prop["required"] = req
            else:
                prop.pop("required", None)
        # additionalProperties 若为空 dict,删除;保留显式 True/False
        if (
            isinstance(prop.get("additionalProperties"), dict)
            and len(prop["additionalProperties"]) == 0
        ):
            prop.pop("additionalProperties", None)

    return prop


def _sanitize_json_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
    s = _deep_clean(schema if isinstance(schema, dict) else {})

    # 如果存在 properties,则顶层应为 object
    if "properties" in s and not isinstance(s.get("type"), str):
        s["type"] = "object"

    # 修正 $schema
    if "$schema" in s and not isinstance(s["$schema"], str):
        s.pop("$schema", None)
    if "$schema" not in s:
        s["$schema"] = "http://json-schema.org/draft-07/schema#"

    properties = s.get("properties")
    if isinstance(properties, dict):
        fixed_props: Dict[str, Any] = {}
        for name, subschema in properties.items():
            fixed_props[name] = _ensure_property_schema(
                name, subschema if isinstance(subschema, dict) else {}
            )
        s["properties"] = fixed_props

    # required:去掉不存在的属性,且不允许为空列表
    if isinstance(s.get("required"), list):
        if isinstance(properties, dict):
            req = [r for r in s["required"] if isinstance(r, str) and r in properties]
        else:
            req = []
        if req:
            s["required"] = req
        else:
            s.pop("required", None)

    # additionalProperties:空 dict 视为无效,删除
    if (
        isinstance(s.get("additionalProperties"), dict)
        and len(s["additionalProperties"]) == 0
    ):
        s.pop("additionalProperties", None)

    return s


class _InputSchemaSanitizerMiddleware:  # deprecated; use sanitize_mcp_input_schema_in_packet in handlers
    pass


# ============= 应用创建 =============


@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用生命周期管理"""
    # 启动时执行
    await startup_tasks()
    yield
    # 关闭时执行(如果需要的话)
    pass


def create_app() -> FastAPI:
    """创建FastAPI应用"""
    # 将服务器日志重定向到专用文件
    try:
        set_log_file("warp_server.log")
    except Exception:
        pass

    # 使用protobuf路由的应用作为主应用,并添加lifespan处理器
    app = FastAPI(lifespan=lifespan)

    # 将protobuf路由包含到主应用中
    app.mount("/", protobuf_app)

    # 挂载输入 schema 清理中间件(覆盖 Warp 相关端点)

    # 检查静态文件目录
    static_dir = Path("static")
    if static_dir.exists():
        # 挂载静态文件服务
        app.mount("/static", StaticFiles(directory="static"), name="static")
        logger.info("✅ 静态文件服务已启用: /static")

        # 添加根路径重定向到前端界面
        @app.get("/gui", response_class=HTMLResponse)
        async def serve_gui():
            """提供前端GUI界面"""
            index_file = static_dir / "index.html"
            if index_file.exists():
                return HTMLResponse(content=index_file.read_text(encoding="utf-8"))
            else:
                return HTMLResponse(
                    content="""

                <html>

                    <body>

                        <h1>前端界面文件未找到</h1>

                        <p>请确保 static/index.html 文件存在</p>

                    </body>

                </html>

                """
                )
    else:
        logger.warning("静态文件目录不存在,GUI界面将不可用")

        @app.get("/gui", response_class=HTMLResponse)
        async def no_gui():
            return HTMLResponse(
                content="""

            <html>

                <body>

                    <h1>GUI界面未安装</h1>

                    <p>静态文件目录 'static' 不存在</p>

                    <p>请创建前端界面文件</p>

                </body>

            </html>

            """
            )

    # ============= 新增接口:返回protobuf编码后的AI请求字节 =============
    @app.post("/api/warp/encode_raw")
    async def encode_ai_request_raw(

        request: EncodeRequest,

        output: str = Query(

            "raw",

            description="输出格式:raw(默认,返回application/x-protobuf字节) 或 base64",

            regex=r"^(raw|base64)$",

        ),

    ):
        try:
            # 获取实际数据并验证
            actual_data = request.get_data()
            if not actual_data:
                raise HTTPException(400, "数据包不能为空")

            # 在 encode 之前,对 mcp_context.tools[*].input_schema 做一次安全清理
            if isinstance(actual_data, dict):
                wrapped = {"json_data": actual_data}
                wrapped = sanitize_mcp_input_schema_in_packet(wrapped)
                actual_data = wrapped.get("json_data", actual_data)

            # 将 server_message_data 对象(如有)编码为 Base64URL 字符串
            actual_data = _encode_smd_inplace(actual_data)

            # 编码为protobuf字节
            protobuf_bytes = dict_to_protobuf_bytes(actual_data, request.message_type)
            logger.info(f"✅ AI请求编码为protobuf成功: {len(protobuf_bytes)} 字节")

            if output == "raw":
                # 直接返回二进制 protobuf 内容
                return Response(
                    content=protobuf_bytes,
                    media_type="application/x-protobuf",
                    headers={"Content-Length": str(len(protobuf_bytes))},
                )
            else:
                # 返回base64文本,便于在JSON中传输/调试
                import base64

                return {
                    "protobuf_base64": base64.b64encode(protobuf_bytes).decode("utf-8"),
                    "size": len(protobuf_bytes),
                    "message_type": request.message_type,
                }
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"❌ AI请求编码失败: {e}")
            raise HTTPException(500, f"编码失败: {str(e)}")

    # ============= OpenAI 兼容:模型列表接口 =============
    @app.get("/v1/models")
    async def list_models():
        """OpenAI-compatible endpoint that lists available models."""
        try:
            models = get_all_unique_models()
            return {"object": "list", "data": models}
        except Exception as e:
            logger.error(f"❌ 获取模型列表失败: {e}")
            raise HTTPException(500, f"获取模型列表失败: {str(e)}")

    return app


############################################################
# server_message_data 深度编解码工具
############################################################

# 说明:
# 根据抓包与分析,server_message_data 是 Base64URL 编码的 proto3 消息:
#   - 字段 1:string(通常为 36 字节 UUID)
#   - 字段 3:google.protobuf.Timestamp(字段1=seconds,字段2=nanos)
# 可能出现:仅 Timestamp、仅 UUID、或 UUID + Timestamp。

try:
    from zoneinfo import ZoneInfo  # Python 3.9+
except Exception:
    ZoneInfo = None  # type: ignore


def _b64url_decode_padded(s: str) -> bytes:
    t = s.replace("-", "+").replace("_", "/")
    pad = (-len(t)) % 4
    if pad:
        t += "=" * pad
    return base64.b64decode(t)


def _b64url_encode_nopad(b: bytes) -> str:
    return base64.urlsafe_b64encode(b).decode("ascii").rstrip("=")


def _read_varint(buf: bytes, i: int) -> Tuple[int, int]:
    shift = 0
    val = 0
    while i < len(buf):
        b = buf[i]
        i += 1
        val |= (b & 0x7F) << shift
        if not (b & 0x80):
            return val, i
        shift += 7
        if shift > 63:
            break
    raise ValueError("invalid varint")


def _write_varint(v: int) -> bytes:
    out = bytearray()
    vv = int(v)
    while True:
        to_write = vv & 0x7F
        vv >>= 7
        if vv:
            out.append(to_write | 0x80)
        else:
            out.append(to_write)
            break
    return bytes(out)


def _make_key(field_no: int, wire_type: int) -> bytes:
    return _write_varint((field_no << 3) | wire_type)


def _decode_timestamp(buf: bytes) -> Tuple[Optional[int], Optional[int]]:
    # google.protobuf.Timestamp: field 1 = seconds (int64 varint), field 2 = nanos (int32 varint)
    i = 0
    seconds: Optional[int] = None
    nanos: Optional[int] = None
    while i < len(buf):
        key, i = _read_varint(buf, i)
        field_no = key >> 3
        wt = key & 0x07
        if wt == 0:  # varint
            val, i = _read_varint(buf, i)
            if field_no == 1:
                seconds = int(val)
            elif field_no == 2:
                nanos = int(val)
        elif wt == 2:  # length-delimited (not expected inside Timestamp)
            ln, i2 = _read_varint(buf, i)
            i = i2 + ln
        elif wt == 1:
            i += 8
        elif wt == 5:
            i += 4
        else:
            break
    return seconds, nanos


def _encode_timestamp(seconds: Optional[int], nanos: Optional[int]) -> bytes:
    parts = bytearray()
    if seconds is not None:
        parts += _make_key(1, 0)  # field 1, varint
        parts += _write_varint(int(seconds))
    if nanos is not None:
        parts += _make_key(2, 0)  # field 2, varint
        parts += _write_varint(int(nanos))
    return bytes(parts)


def decode_server_message_data(b64url: str) -> Dict:
    """解码 Base64URL 的 server_message_data,返回结构化信息。"""
    try:
        raw = _b64url_decode_padded(b64url)
    except Exception as e:
        return {"error": f"base64url decode failed: {e}", "raw_b64url": b64url}

    i = 0
    uuid: Optional[str] = None
    seconds: Optional[int] = None
    nanos: Optional[int] = None

    while i < len(raw):
        key, i = _read_varint(raw, i)
        field_no = key >> 3
        wt = key & 0x07
        if wt == 2:  # length-delimited
            ln, i2 = _read_varint(raw, i)
            i = i2
            data = raw[i : i + ln]
            i += ln
            if field_no == 1:  # uuid string
                try:
                    uuid = data.decode("utf-8")
                except Exception:
                    uuid = None
            elif field_no == 3:  # google.protobuf.Timestamp
                seconds, nanos = _decode_timestamp(data)
        elif wt == 0:  # varint -> not expected, skip
            _, i = _read_varint(raw, i)
        elif wt == 1:
            i += 8
        elif wt == 5:
            i += 4
        else:
            break

    out: Dict[str, Any] = {}
    if uuid is not None:
        out["uuid"] = uuid
    if seconds is not None:
        out["seconds"] = seconds
    if nanos is not None:
        out["nanos"] = nanos
    return out


def encode_server_message_data(

    uuid: Optional[str] = None,

    seconds: Optional[int] = None,

    nanos: Optional[int] = None,

) -> str:
    """将 uuid/seconds/nanos 组合编码为 Base64URL 字符串。"""
    parts = bytearray()
    if uuid:
        b = uuid.encode("utf-8")
        parts += _make_key(1, 2)  # field 1, length-delimited
        parts += _write_varint(len(b))
        parts += b

    if seconds is not None or nanos is not None:
        ts = _encode_timestamp(seconds, nanos)
        parts += _make_key(3, 2)  # field 3, length-delimited
        parts += _write_varint(len(ts))
        parts += ts

    return _b64url_encode_nopad(bytes(parts))


async def startup_tasks():
    """启动时执行的任务"""
    logger.info("=" * 60)
    logger.info("Warp Protobuf编解码服务器启动")
    logger.info("=" * 60)

    # 检查protobuf运行时
    try:
        from warp2protobuf.core.protobuf import ensure_proto_runtime

        ensure_proto_runtime()
        logger.info("✅ Protobuf运行时初始化成功")
    except Exception as e:
        logger.error(f"❌ Protobuf运行时初始化失败: {e}")
        raise

    # 检查JWT token
    try:
        from warp2protobuf.core.auth import get_jwt_token, is_token_expired

        token = get_jwt_token()
        if token and not is_token_expired(token):
            logger.info("✅ JWT token有效")
        elif not token:
            logger.warning("⚠️ 未找到JWT token,尝试申请匿名访问token用于额度初始化…")
            try:
                new_token = await acquire_anonymous_access_token()
                if new_token:
                    logger.info("✅ 匿名访问token申请成功")
                else:
                    logger.warning("⚠️ 匿名访问token申请失败")
            except Exception as e2:
                logger.warning(f"⚠️ 匿名访问token申请异常: {e2}")
        else:
            logger.warning("⚠️ JWT token无效或已过期,建议运行: uv run refresh_jwt.py")
    except Exception as e:
        logger.warning(f"⚠️ JWT检查失败: {e}")

    # 如需 OpenAI 兼容层,请单独运行 src/openai_compat_server.py

    # 显示可用端点
    logger.info("-" * 40)
    logger.info("可用的API端点:")
    logger.info("  GET  /                   - 服务信息")
    logger.info("  GET  /healthz            - 健康检查")
    logger.info("  GET  /gui                - Web GUI界面")
    logger.info("  POST /api/encode         - JSON -> Protobuf编码")
    logger.info("  POST /api/decode         - Protobuf -> JSON解码")
    logger.info("  POST /api/stream-decode  - 流式protobuf解码")
    logger.info("  POST /api/warp/send      - JSON -> Protobuf -> Warp API转发")
    logger.info(
        "  POST /api/warp/send_stream - JSON -> Protobuf -> Warp API转发(返回解析事件)"
    )
    logger.info(
        "  POST /api/warp/send_stream_sse - JSON -> Protobuf -> Warp API转发(实时SSE,事件已解析)"
    )
    logger.info("  POST /api/warp/graphql/* - GraphQL请求转发到Warp API(带鉴权)")
    logger.info("  GET  /api/schemas        - Protobuf schema信息")
    logger.info("  GET  /api/auth/status    - JWT认证状态")
    logger.info("  POST /api/auth/refresh   - 刷新JWT token")
    logger.info("  GET  /api/auth/user_id   - 获取当前用户ID")
    logger.info("  GET  /api/packets/history - 数据包历史记录")
    logger.info("  WS   /ws                 - WebSocket实时监控")
    logger.info("-" * 40)
    logger.info("测试命令:")
    logger.info("  uv run main.py --test basic    - 运行基础测试")
    logger.info("  uv run main.py --list          - 查看所有测试场景")
    logger.info("=" * 60)


def main():
    """主函数"""
    # 创建应用
    app = create_app()

    # 启动服务器
    try:
        uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info", access_log=True)
    except KeyboardInterrupt:
        logger.info("服务器被用户停止")
    except Exception as e:
        logger.error(f"服务器启动失败: {e}")
        raise


if __name__ == "__main__":
    main()