File size: 11,374 Bytes
469e046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""

FastAPI应用初始化和生命周期管理

"""

import asyncio
import multiprocessing
import os
import sys
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from typing import Callable, Awaitable
from playwright.async_api import Browser as AsyncBrowser, Playwright as AsyncPlaywright

# --- 配置模块导入 ---
from config import *

# --- models模块导入 ---
from models import WebSocketConnectionManager

# --- logging_utils模块导入 ---
from logging_utils import setup_server_logging, restore_original_streams

# --- browser_utils模块导入 ---
from browser_utils import (
    _initialize_page_logic,
    _close_page_logic,
    load_excluded_models,
    _handle_initial_model_state_and_storage
)

import stream
from asyncio import Queue, Lock
from . import auth_utils

# 全局状态变量(这些将在server.py中被引用)
playwright_manager: Optional[AsyncPlaywright] = None
browser_instance: Optional[AsyncBrowser] = None
page_instance = None
is_playwright_ready = False
is_browser_connected = False
is_page_ready = False
is_initializing = False

global_model_list_raw_json = None
parsed_model_list = []
model_list_fetch_event = None

current_ai_studio_model_id = None
model_switching_lock = None

excluded_model_ids = set()

request_queue = None
processing_lock = None
worker_task = None

page_params_cache = {}
params_cache_lock = None

log_ws_manager = None

STREAM_QUEUE = None
STREAM_PROCESS = None

# --- Lifespan Context Manager ---
def _setup_logging():
    import server
    log_level_env = os.environ.get('SERVER_LOG_LEVEL', 'INFO')
    redirect_print_env = os.environ.get('SERVER_REDIRECT_PRINT', 'false')
    server.log_ws_manager = WebSocketConnectionManager()
    return setup_server_logging(
        logger_instance=server.logger,
        log_ws_manager=server.log_ws_manager,
        log_level_name=log_level_env,
        redirect_print_str=redirect_print_env
    )

def _initialize_globals():
    import server
    server.request_queue = Queue()
    server.processing_lock = Lock()
    server.model_switching_lock = Lock()
    server.params_cache_lock = Lock()
    auth_utils.initialize_keys()
    server.logger.info("API keys and global locks initialized.")

def _initialize_proxy_settings():
    import server
    STREAM_PORT = os.environ.get('STREAM_PORT')
    if STREAM_PORT == '0':
        PROXY_SERVER_ENV = os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
    else:
        PROXY_SERVER_ENV = f"http://127.0.0.1:{STREAM_PORT or 3120}/"
    
    if PROXY_SERVER_ENV:
        server.PLAYWRIGHT_PROXY_SETTINGS = {'server': PROXY_SERVER_ENV}
        if NO_PROXY_ENV:
            server.PLAYWRIGHT_PROXY_SETTINGS['bypass'] = NO_PROXY_ENV.replace(',', ';')
        server.logger.info(f"Playwright proxy settings configured: {server.PLAYWRIGHT_PROXY_SETTINGS}")
    else:
        server.logger.info("No proxy configured for Playwright.")

async def _start_stream_proxy():
    import server
    STREAM_PORT = os.environ.get('STREAM_PORT')
    if STREAM_PORT != '0':
        port = int(STREAM_PORT or 3120)
        STREAM_PROXY_SERVER_ENV = os.environ.get('UNIFIED_PROXY_CONFIG') or os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
        server.logger.info(f"Starting STREAM proxy on port {port} with upstream proxy: {STREAM_PROXY_SERVER_ENV}")
        server.STREAM_QUEUE = multiprocessing.Queue()
        server.STREAM_PROCESS = multiprocessing.Process(target=stream.start, args=(server.STREAM_QUEUE, port, STREAM_PROXY_SERVER_ENV))
        server.STREAM_PROCESS.start()
        server.logger.info("STREAM proxy process started.")

async def _initialize_browser_and_page():
    import server
    from playwright.async_api import async_playwright
    
    server.logger.info("Starting Playwright...")
    server.playwright_manager = await async_playwright().start()
    server.is_playwright_ready = True
    server.logger.info("Playwright started.")

    ws_endpoint = os.environ.get('CAMOUFOX_WS_ENDPOINT')
    launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')

    if not ws_endpoint and launch_mode != "direct_debug_no_browser":
        raise ValueError("CAMOUFOX_WS_ENDPOINT environment variable is missing.")

    if ws_endpoint:
        server.logger.info(f"Connecting to browser at: {ws_endpoint}")
        server.browser_instance = await server.playwright_manager.firefox.connect(ws_endpoint, timeout=30000)
        server.is_browser_connected = True
        server.logger.info(f"Connected to browser: {server.browser_instance.version}")
        
        server.page_instance, server.is_page_ready = await _initialize_page_logic(server.browser_instance)
        if server.is_page_ready:
            await _handle_initial_model_state_and_storage(server.page_instance)
            server.logger.info("Page initialized successfully.")
        else:
            server.logger.error("Page initialization failed.")
    
    if not server.model_list_fetch_event.is_set():
        server.model_list_fetch_event.set()

async def _shutdown_resources():
    import server
    logger = server.logger
    logger.info("Shutting down resources...")
    
    if server.STREAM_PROCESS:
        server.STREAM_PROCESS.terminate()
        logger.info("STREAM proxy terminated.")

    if server.worker_task and not server.worker_task.done():
        server.worker_task.cancel()
        try:
            await asyncio.wait_for(server.worker_task, timeout=5.0)
        except (asyncio.TimeoutError, asyncio.CancelledError):
            pass
        logger.info("Worker task stopped.")

    if server.page_instance:
        await _close_page_logic()
    
    if server.browser_instance and server.browser_instance.is_connected():
        await server.browser_instance.close()
        logger.info("Browser connection closed.")
    
    if server.playwright_manager:
        await server.playwright_manager.stop()
        logger.info("Playwright stopped.")

@asynccontextmanager
async def lifespan(app: FastAPI):
    """FastAPI application life cycle management"""
    import server
    from server import queue_worker

    original_streams = sys.stdout, sys.stderr
    initial_stdout, initial_stderr = _setup_logging()
    logger = server.logger

    _initialize_globals()
    _initialize_proxy_settings()
    load_excluded_models(EXCLUDED_MODELS_FILENAME)
    
    server.is_initializing = True
    logger.info("Starting AI Studio Proxy Server...")

    try:
        await _start_stream_proxy()
        await _initialize_browser_and_page()
        
        launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
        if server.is_page_ready or launch_mode == "direct_debug_no_browser":
            server.worker_task = asyncio.create_task(queue_worker())
            logger.info("Request processing worker started.")
        else:
            raise RuntimeError("Failed to initialize browser/page, worker not started.")

        logger.info("Server startup complete.")
        server.is_initializing = False
        yield
    except Exception as e:
        logger.critical(f"Application startup failed: {e}", exc_info=True)
        await _shutdown_resources()
        raise RuntimeError(f"Application startup failed: {e}") from e
    finally:
        logger.info("Shutting down server...")
        await _shutdown_resources()
        restore_original_streams(initial_stdout, initial_stderr)
        restore_original_streams(*original_streams)
        logger.info("Server shutdown complete.")


class APIKeyAuthMiddleware(BaseHTTPMiddleware):
    def __init__(self, app: ASGIApp):
        super().__init__(app)
        self.excluded_paths = [
            "/v1/models",
            "/health",
            "/docs",
            "/openapi.json",
            # FastAPI 自动生成的其他文档路径
            "/redoc",
            "/favicon.ico"
        ]

    async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable]):
        if not auth_utils.API_KEYS:  # 如果 API_KEYS 为空,则不进行验证
            return await call_next(request)

        # 检查是否是需要保护的路径
        if not request.url.path.startswith("/v1/"):
            return await call_next(request)

        # 检查是否是排除的路径
        for excluded_path in self.excluded_paths:
            if request.url.path == excluded_path or request.url.path.startswith(excluded_path + "/"):
                return await call_next(request)

        # 支持多种认证头格式以兼容OpenAI标准
        api_key = None

        # 1. 优先检查标准的 Authorization: Bearer <token> 头
        auth_header = request.headers.get("Authorization")
        if auth_header and auth_header.startswith("Bearer "):
            api_key = auth_header[7:]  # 移除 "Bearer " 前缀

        # 2. 回退到自定义的 X-API-Key 头(向后兼容)
        if not api_key:
            api_key = request.headers.get("X-API-Key")

        if not api_key or not auth_utils.verify_api_key(api_key):
            return JSONResponse(
                status_code=401,
                content={
                    "error": {
                        "message": "Invalid or missing API key. Please provide a valid API key using 'Authorization: Bearer <your_key>' or 'X-API-Key: <your_key>' header.",
                        "type": "invalid_request_error",
                        "param": None,
                        "code": "invalid_api_key"
                    }
                }
            )
        return await call_next(request)

def create_app() -> FastAPI:
    """创建FastAPI应用实例"""
    app = FastAPI(
        title="AI Studio Proxy Server (集成模式)",
        description="通过 Playwright与 AI Studio 交互的代理服务器。",
        version="0.6.0-integrated",
        lifespan=lifespan
    )
    
    # 添加中间件
    app.add_middleware(APIKeyAuthMiddleware)

    # 注册路由
    from .routes import (
        read_index, get_css, get_js, get_api_info,
        health_check, list_models, chat_completions,
        cancel_request, get_queue_status, websocket_log_endpoint,
        get_api_keys, add_api_key, test_api_key, delete_api_key
    )
    from fastapi.responses import FileResponse
    
    app.get("/", response_class=FileResponse)(read_index)
    app.get("/webui.css")(get_css)
    app.get("/webui.js")(get_js)
    app.get("/api/info")(get_api_info)
    app.get("/health")(health_check)
    app.get("/v1/models")(list_models)
    app.post("/v1/chat/completions")(chat_completions)
    app.post("/v1/cancel/{req_id}")(cancel_request)
    app.get("/v1/queue")(get_queue_status)
    app.websocket("/ws/logs")(websocket_log_endpoint)

    # API密钥管理端点
    app.get("/api/keys")(get_api_keys)
    app.post("/api/keys")(add_api_key)
    app.post("/api/keys/test")(test_api_key)
    app.delete("/api/keys")(delete_api_key)

    return app