""" FastAPI应用初始化和生命周期管理 """ import asyncio import multiprocessing import os import sys from contextlib import asynccontextmanager from typing import Optional from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from typing import Callable, Awaitable from playwright.async_api import Browser as AsyncBrowser, Playwright as AsyncPlaywright # --- 配置模块导入 --- from config import * # --- models模块导入 --- from models import WebSocketConnectionManager # --- logging_utils模块导入 --- from logging_utils import setup_server_logging, restore_original_streams # --- browser_utils模块导入 --- from browser_utils import ( _initialize_page_logic, _close_page_logic, load_excluded_models, _handle_initial_model_state_and_storage ) import stream from asyncio import Queue, Lock from . import auth_utils # 全局状态变量(这些将在server.py中被引用) playwright_manager: Optional[AsyncPlaywright] = None browser_instance: Optional[AsyncBrowser] = None page_instance = None is_playwright_ready = False is_browser_connected = False is_page_ready = False is_initializing = False global_model_list_raw_json = None parsed_model_list = [] model_list_fetch_event = None current_ai_studio_model_id = None model_switching_lock = None excluded_model_ids = set() request_queue = None processing_lock = None worker_task = None page_params_cache = {} params_cache_lock = None log_ws_manager = None STREAM_QUEUE = None STREAM_PROCESS = None # --- Lifespan Context Manager --- def _setup_logging(): import server log_level_env = os.environ.get('SERVER_LOG_LEVEL', 'INFO') redirect_print_env = os.environ.get('SERVER_REDIRECT_PRINT', 'false') server.log_ws_manager = WebSocketConnectionManager() return setup_server_logging( logger_instance=server.logger, log_ws_manager=server.log_ws_manager, log_level_name=log_level_env, redirect_print_str=redirect_print_env ) def _initialize_globals(): import server server.request_queue = Queue() server.processing_lock = Lock() server.model_switching_lock = Lock() server.params_cache_lock = Lock() auth_utils.initialize_keys() server.logger.info("API keys and global locks initialized.") def _initialize_proxy_settings(): import server STREAM_PORT = os.environ.get('STREAM_PORT') if STREAM_PORT == '0': PROXY_SERVER_ENV = os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY') else: PROXY_SERVER_ENV = f"http://127.0.0.1:{STREAM_PORT or 3120}/" if PROXY_SERVER_ENV: server.PLAYWRIGHT_PROXY_SETTINGS = {'server': PROXY_SERVER_ENV} if NO_PROXY_ENV: server.PLAYWRIGHT_PROXY_SETTINGS['bypass'] = NO_PROXY_ENV.replace(',', ';') server.logger.info(f"Playwright proxy settings configured: {server.PLAYWRIGHT_PROXY_SETTINGS}") else: server.logger.info("No proxy configured for Playwright.") async def _start_stream_proxy(): import server STREAM_PORT = os.environ.get('STREAM_PORT') if STREAM_PORT != '0': port = int(STREAM_PORT or 3120) STREAM_PROXY_SERVER_ENV = os.environ.get('UNIFIED_PROXY_CONFIG') or os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY') server.logger.info(f"Starting STREAM proxy on port {port} with upstream proxy: {STREAM_PROXY_SERVER_ENV}") server.STREAM_QUEUE = multiprocessing.Queue() server.STREAM_PROCESS = multiprocessing.Process(target=stream.start, args=(server.STREAM_QUEUE, port, STREAM_PROXY_SERVER_ENV)) server.STREAM_PROCESS.start() server.logger.info("STREAM proxy process started.") async def _initialize_browser_and_page(): import server from playwright.async_api import async_playwright server.logger.info("Starting Playwright...") server.playwright_manager = await async_playwright().start() server.is_playwright_ready = True server.logger.info("Playwright started.") ws_endpoint = os.environ.get('CAMOUFOX_WS_ENDPOINT') launch_mode = os.environ.get('LAUNCH_MODE', 'unknown') if not ws_endpoint and launch_mode != "direct_debug_no_browser": raise ValueError("CAMOUFOX_WS_ENDPOINT environment variable is missing.") if ws_endpoint: server.logger.info(f"Connecting to browser at: {ws_endpoint}") server.browser_instance = await server.playwright_manager.firefox.connect(ws_endpoint, timeout=30000) server.is_browser_connected = True server.logger.info(f"Connected to browser: {server.browser_instance.version}") server.page_instance, server.is_page_ready = await _initialize_page_logic(server.browser_instance) if server.is_page_ready: await _handle_initial_model_state_and_storage(server.page_instance) server.logger.info("Page initialized successfully.") else: server.logger.error("Page initialization failed.") if not server.model_list_fetch_event.is_set(): server.model_list_fetch_event.set() async def _shutdown_resources(): import server logger = server.logger logger.info("Shutting down resources...") if server.STREAM_PROCESS: server.STREAM_PROCESS.terminate() logger.info("STREAM proxy terminated.") if server.worker_task and not server.worker_task.done(): server.worker_task.cancel() try: await asyncio.wait_for(server.worker_task, timeout=5.0) except (asyncio.TimeoutError, asyncio.CancelledError): pass logger.info("Worker task stopped.") if server.page_instance: await _close_page_logic() if server.browser_instance and server.browser_instance.is_connected(): await server.browser_instance.close() logger.info("Browser connection closed.") if server.playwright_manager: await server.playwright_manager.stop() logger.info("Playwright stopped.") @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI application life cycle management""" import server from server import queue_worker original_streams = sys.stdout, sys.stderr initial_stdout, initial_stderr = _setup_logging() logger = server.logger _initialize_globals() _initialize_proxy_settings() load_excluded_models(EXCLUDED_MODELS_FILENAME) server.is_initializing = True logger.info("Starting AI Studio Proxy Server...") try: await _start_stream_proxy() await _initialize_browser_and_page() launch_mode = os.environ.get('LAUNCH_MODE', 'unknown') if server.is_page_ready or launch_mode == "direct_debug_no_browser": server.worker_task = asyncio.create_task(queue_worker()) logger.info("Request processing worker started.") else: raise RuntimeError("Failed to initialize browser/page, worker not started.") logger.info("Server startup complete.") server.is_initializing = False yield except Exception as e: logger.critical(f"Application startup failed: {e}", exc_info=True) await _shutdown_resources() raise RuntimeError(f"Application startup failed: {e}") from e finally: logger.info("Shutting down server...") await _shutdown_resources() restore_original_streams(initial_stdout, initial_stderr) restore_original_streams(*original_streams) logger.info("Server shutdown complete.") class APIKeyAuthMiddleware(BaseHTTPMiddleware): def __init__(self, app: ASGIApp): super().__init__(app) self.excluded_paths = [ "/v1/models", "/health", "/docs", "/openapi.json", # FastAPI 自动生成的其他文档路径 "/redoc", "/favicon.ico" ] async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable]): if not auth_utils.API_KEYS: # 如果 API_KEYS 为空,则不进行验证 return await call_next(request) # 检查是否是需要保护的路径 if not request.url.path.startswith("/v1/"): return await call_next(request) # 检查是否是排除的路径 for excluded_path in self.excluded_paths: if request.url.path == excluded_path or request.url.path.startswith(excluded_path + "/"): return await call_next(request) # 支持多种认证头格式以兼容OpenAI标准 api_key = None # 1. 优先检查标准的 Authorization: Bearer 头 auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): api_key = auth_header[7:] # 移除 "Bearer " 前缀 # 2. 回退到自定义的 X-API-Key 头(向后兼容) if not api_key: api_key = request.headers.get("X-API-Key") if not api_key or not auth_utils.verify_api_key(api_key): return JSONResponse( status_code=401, content={ "error": { "message": "Invalid or missing API key. Please provide a valid API key using 'Authorization: Bearer ' or 'X-API-Key: ' header.", "type": "invalid_request_error", "param": None, "code": "invalid_api_key" } } ) return await call_next(request) def create_app() -> FastAPI: """创建FastAPI应用实例""" app = FastAPI( title="AI Studio Proxy Server (集成模式)", description="通过 Playwright与 AI Studio 交互的代理服务器。", version="0.6.0-integrated", lifespan=lifespan ) # 添加中间件 app.add_middleware(APIKeyAuthMiddleware) # 注册路由 from .routes import ( read_index, get_css, get_js, get_api_info, health_check, list_models, chat_completions, cancel_request, get_queue_status, websocket_log_endpoint, get_api_keys, add_api_key, test_api_key, delete_api_key ) from fastapi.responses import FileResponse app.get("/", response_class=FileResponse)(read_index) app.get("/webui.css")(get_css) app.get("/webui.js")(get_js) app.get("/api/info")(get_api_info) app.get("/health")(health_check) app.get("/v1/models")(list_models) app.post("/v1/chat/completions")(chat_completions) app.post("/v1/cancel/{req_id}")(cancel_request) app.get("/v1/queue")(get_queue_status) app.websocket("/ws/logs")(websocket_log_endpoint) # API密钥管理端点 app.get("/api/keys")(get_api_keys) app.post("/api/keys")(add_api_key) app.post("/api/keys/test")(test_api_key) app.delete("/api/keys")(delete_api_key) return app