hins111's picture
Upload 9 files
469e046 verified
"""
FastAPI应用初始化和生命周期管理
"""
import asyncio
import multiprocessing
import os
import sys
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from typing import Callable, Awaitable
from playwright.async_api import Browser as AsyncBrowser, Playwright as AsyncPlaywright
# --- 配置模块导入 ---
from config import *
# --- models模块导入 ---
from models import WebSocketConnectionManager
# --- logging_utils模块导入 ---
from logging_utils import setup_server_logging, restore_original_streams
# --- browser_utils模块导入 ---
from browser_utils import (
_initialize_page_logic,
_close_page_logic,
load_excluded_models,
_handle_initial_model_state_and_storage
)
import stream
from asyncio import Queue, Lock
from . import auth_utils
# 全局状态变量(这些将在server.py中被引用)
playwright_manager: Optional[AsyncPlaywright] = None
browser_instance: Optional[AsyncBrowser] = None
page_instance = None
is_playwright_ready = False
is_browser_connected = False
is_page_ready = False
is_initializing = False
global_model_list_raw_json = None
parsed_model_list = []
model_list_fetch_event = None
current_ai_studio_model_id = None
model_switching_lock = None
excluded_model_ids = set()
request_queue = None
processing_lock = None
worker_task = None
page_params_cache = {}
params_cache_lock = None
log_ws_manager = None
STREAM_QUEUE = None
STREAM_PROCESS = None
# --- Lifespan Context Manager ---
def _setup_logging():
import server
log_level_env = os.environ.get('SERVER_LOG_LEVEL', 'INFO')
redirect_print_env = os.environ.get('SERVER_REDIRECT_PRINT', 'false')
server.log_ws_manager = WebSocketConnectionManager()
return setup_server_logging(
logger_instance=server.logger,
log_ws_manager=server.log_ws_manager,
log_level_name=log_level_env,
redirect_print_str=redirect_print_env
)
def _initialize_globals():
import server
server.request_queue = Queue()
server.processing_lock = Lock()
server.model_switching_lock = Lock()
server.params_cache_lock = Lock()
auth_utils.initialize_keys()
server.logger.info("API keys and global locks initialized.")
def _initialize_proxy_settings():
import server
STREAM_PORT = os.environ.get('STREAM_PORT')
if STREAM_PORT == '0':
PROXY_SERVER_ENV = os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
else:
PROXY_SERVER_ENV = f"http://127.0.0.1:{STREAM_PORT or 3120}/"
if PROXY_SERVER_ENV:
server.PLAYWRIGHT_PROXY_SETTINGS = {'server': PROXY_SERVER_ENV}
if NO_PROXY_ENV:
server.PLAYWRIGHT_PROXY_SETTINGS['bypass'] = NO_PROXY_ENV.replace(',', ';')
server.logger.info(f"Playwright proxy settings configured: {server.PLAYWRIGHT_PROXY_SETTINGS}")
else:
server.logger.info("No proxy configured for Playwright.")
async def _start_stream_proxy():
import server
STREAM_PORT = os.environ.get('STREAM_PORT')
if STREAM_PORT != '0':
port = int(STREAM_PORT or 3120)
STREAM_PROXY_SERVER_ENV = os.environ.get('UNIFIED_PROXY_CONFIG') or os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
server.logger.info(f"Starting STREAM proxy on port {port} with upstream proxy: {STREAM_PROXY_SERVER_ENV}")
server.STREAM_QUEUE = multiprocessing.Queue()
server.STREAM_PROCESS = multiprocessing.Process(target=stream.start, args=(server.STREAM_QUEUE, port, STREAM_PROXY_SERVER_ENV))
server.STREAM_PROCESS.start()
server.logger.info("STREAM proxy process started.")
async def _initialize_browser_and_page():
import server
from playwright.async_api import async_playwright
server.logger.info("Starting Playwright...")
server.playwright_manager = await async_playwright().start()
server.is_playwright_ready = True
server.logger.info("Playwright started.")
ws_endpoint = os.environ.get('CAMOUFOX_WS_ENDPOINT')
launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
if not ws_endpoint and launch_mode != "direct_debug_no_browser":
raise ValueError("CAMOUFOX_WS_ENDPOINT environment variable is missing.")
if ws_endpoint:
server.logger.info(f"Connecting to browser at: {ws_endpoint}")
server.browser_instance = await server.playwright_manager.firefox.connect(ws_endpoint, timeout=30000)
server.is_browser_connected = True
server.logger.info(f"Connected to browser: {server.browser_instance.version}")
server.page_instance, server.is_page_ready = await _initialize_page_logic(server.browser_instance)
if server.is_page_ready:
await _handle_initial_model_state_and_storage(server.page_instance)
server.logger.info("Page initialized successfully.")
else:
server.logger.error("Page initialization failed.")
if not server.model_list_fetch_event.is_set():
server.model_list_fetch_event.set()
async def _shutdown_resources():
import server
logger = server.logger
logger.info("Shutting down resources...")
if server.STREAM_PROCESS:
server.STREAM_PROCESS.terminate()
logger.info("STREAM proxy terminated.")
if server.worker_task and not server.worker_task.done():
server.worker_task.cancel()
try:
await asyncio.wait_for(server.worker_task, timeout=5.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
logger.info("Worker task stopped.")
if server.page_instance:
await _close_page_logic()
if server.browser_instance and server.browser_instance.is_connected():
await server.browser_instance.close()
logger.info("Browser connection closed.")
if server.playwright_manager:
await server.playwright_manager.stop()
logger.info("Playwright stopped.")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""FastAPI application life cycle management"""
import server
from server import queue_worker
original_streams = sys.stdout, sys.stderr
initial_stdout, initial_stderr = _setup_logging()
logger = server.logger
_initialize_globals()
_initialize_proxy_settings()
load_excluded_models(EXCLUDED_MODELS_FILENAME)
server.is_initializing = True
logger.info("Starting AI Studio Proxy Server...")
try:
await _start_stream_proxy()
await _initialize_browser_and_page()
launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
if server.is_page_ready or launch_mode == "direct_debug_no_browser":
server.worker_task = asyncio.create_task(queue_worker())
logger.info("Request processing worker started.")
else:
raise RuntimeError("Failed to initialize browser/page, worker not started.")
logger.info("Server startup complete.")
server.is_initializing = False
yield
except Exception as e:
logger.critical(f"Application startup failed: {e}", exc_info=True)
await _shutdown_resources()
raise RuntimeError(f"Application startup failed: {e}") from e
finally:
logger.info("Shutting down server...")
await _shutdown_resources()
restore_original_streams(initial_stdout, initial_stderr)
restore_original_streams(*original_streams)
logger.info("Server shutdown complete.")
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp):
super().__init__(app)
self.excluded_paths = [
"/v1/models",
"/health",
"/docs",
"/openapi.json",
# FastAPI 自动生成的其他文档路径
"/redoc",
"/favicon.ico"
]
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable]):
if not auth_utils.API_KEYS: # 如果 API_KEYS 为空,则不进行验证
return await call_next(request)
# 检查是否是需要保护的路径
if not request.url.path.startswith("/v1/"):
return await call_next(request)
# 检查是否是排除的路径
for excluded_path in self.excluded_paths:
if request.url.path == excluded_path or request.url.path.startswith(excluded_path + "/"):
return await call_next(request)
# 支持多种认证头格式以兼容OpenAI标准
api_key = None
# 1. 优先检查标准的 Authorization: Bearer <token> 头
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:] # 移除 "Bearer " 前缀
# 2. 回退到自定义的 X-API-Key 头(向后兼容)
if not api_key:
api_key = request.headers.get("X-API-Key")
if not api_key or not auth_utils.verify_api_key(api_key):
return JSONResponse(
status_code=401,
content={
"error": {
"message": "Invalid or missing API key. Please provide a valid API key using 'Authorization: Bearer <your_key>' or 'X-API-Key: <your_key>' header.",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key"
}
}
)
return await call_next(request)
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="AI Studio Proxy Server (集成模式)",
description="通过 Playwright与 AI Studio 交互的代理服务器。",
version="0.6.0-integrated",
lifespan=lifespan
)
# 添加中间件
app.add_middleware(APIKeyAuthMiddleware)
# 注册路由
from .routes import (
read_index, get_css, get_js, get_api_info,
health_check, list_models, chat_completions,
cancel_request, get_queue_status, websocket_log_endpoint,
get_api_keys, add_api_key, test_api_key, delete_api_key
)
from fastapi.responses import FileResponse
app.get("/", response_class=FileResponse)(read_index)
app.get("/webui.css")(get_css)
app.get("/webui.js")(get_js)
app.get("/api/info")(get_api_info)
app.get("/health")(health_check)
app.get("/v1/models")(list_models)
app.post("/v1/chat/completions")(chat_completions)
app.post("/v1/cancel/{req_id}")(cancel_request)
app.get("/v1/queue")(get_queue_status)
app.websocket("/ws/logs")(websocket_log_endpoint)
# API密钥管理端点
app.get("/api/keys")(get_api_keys)
app.post("/api/keys")(add_api_key)
app.post("/api/keys/test")(test_api_key)
app.delete("/api/keys")(delete_api_key)
return app