Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException, Request, Depends, status | |
| from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse | |
| from .models import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ModelList | |
| from .gemini import GeminiClient, ResponseWrapper | |
| from .utils import handle_gemini_error, protect_from_abuse, APIKeyManager, test_api_key, format_log_message | |
| import os | |
| import json | |
| import asyncio | |
| from typing import Literal | |
| import random | |
| import requests | |
| from datetime import datetime, timedelta | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| import sys | |
| import logging | |
| logging.getLogger("uvicorn").disabled = True | |
| logging.getLogger("uvicorn.access").disabled = True | |
| # 配置 logger | |
| logger = logging.getLogger("my_logger") | |
| logger.setLevel(logging.DEBUG) | |
| def translate_error(message: str) -> str: | |
| if "quota exceeded" in message.lower(): | |
| return "API 密钥配额已用尽" | |
| if "invalid argument" in message.lower(): | |
| return "无效参数" | |
| if "internal server error" in message.lower(): | |
| return "服务器内部错误" | |
| if "service unavailable" in message.lower(): | |
| return "服务不可用" | |
| return message | |
| def handle_exception(exc_type, exc_value, exc_traceback): | |
| if issubclass(exc_type, KeyboardInterrupt): | |
| sys.excepthook(exc_type, exc_value, exc_traceback) | |
| return | |
| error_message = translate_error(str(exc_value)) | |
| log_msg = format_log_message('ERROR', f"未捕获的异常: %s" % error_message, extra={'status_code': 500, 'error_message': error_message}) | |
| logger.error(log_msg) | |
| sys.excepthook = handle_exception | |
| app = FastAPI() | |
| PASSWORD = os.environ.get("PASSWORD", "123") | |
| MAX_REQUESTS_PER_MINUTE = int(os.environ.get("MAX_REQUESTS_PER_MINUTE", "30")) | |
| MAX_REQUESTS_PER_DAY_PER_IP = int( | |
| os.environ.get("MAX_REQUESTS_PER_DAY_PER_IP", "600")) | |
| # MAX_RETRIES = int(os.environ.get('MaxRetries', '3').strip() or '3') | |
| RETRY_DELAY = 1 | |
| MAX_RETRY_DELAY = 16 | |
| safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": 'HARM_CATEGORY_CIVIC_INTEGRITY', | |
| "threshold": 'BLOCK_NONE' | |
| } | |
| ] | |
| safety_settings_g2 = [ | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "OFF" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "OFF" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "OFF" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "OFF" | |
| }, | |
| { | |
| "category": 'HARM_CATEGORY_CIVIC_INTEGRITY', | |
| "threshold": 'OFF' | |
| } | |
| ] | |
| key_manager = APIKeyManager() # 实例化 APIKeyManager,栈会在 __init__ 中初始化 | |
| current_api_key = key_manager.get_available_key() | |
| def switch_api_key(): | |
| global current_api_key | |
| key = key_manager.get_available_key() # get_available_key 会处理栈的逻辑 | |
| if key: | |
| current_api_key = key | |
| log_msg = format_log_message('INFO', f"API key 替换为 → {current_api_key[:8]}...", extra={'key': current_api_key[:8], 'request_type': 'switch_key'}) | |
| logger.info(log_msg) | |
| else: | |
| log_msg = format_log_message('ERROR', "API key 替换失败,所有API key都已尝试,请重新配置或稍后重试", extra={'key': 'N/A', 'request_type': 'switch_key', 'status_code': 'N/A'}) | |
| logger.error(log_msg) | |
| async def check_keys(): | |
| available_keys = [] | |
| for key in key_manager.api_keys: | |
| is_valid = await test_api_key(key) | |
| status_msg = "有效" if is_valid else "无效" | |
| log_msg = format_log_message('INFO', f"API Key {key[:10]}... {status_msg}.") | |
| logger.info(log_msg) | |
| if is_valid: | |
| available_keys.append(key) | |
| if not available_keys: | |
| log_msg = format_log_message('ERROR', "没有可用的 API 密钥!", extra={'key': 'N/A', 'request_type': 'startup', 'status_code': 'N/A'}) | |
| logger.error(log_msg) | |
| return available_keys | |
| async def startup_event(): | |
| log_msg = format_log_message('INFO', "Starting Gemini API proxy...") | |
| logger.info(log_msg) | |
| available_keys = await check_keys() | |
| if available_keys: | |
| key_manager.api_keys = available_keys | |
| key_manager._reset_key_stack() # 启动时也确保创建随机栈 | |
| key_manager.show_all_keys() | |
| log_msg = format_log_message('INFO', f"可用 API 密钥数量:{len(key_manager.api_keys)}") | |
| logger.info(log_msg) | |
| # MAX_RETRIES = len(key_manager.api_keys) | |
| log_msg = format_log_message('INFO', f"最大重试次数设置为:{len(key_manager.api_keys)}") # 添加日志 | |
| logger.info(log_msg) | |
| if key_manager.api_keys: | |
| all_models = await GeminiClient.list_available_models(key_manager.api_keys[0]) | |
| GeminiClient.AVAILABLE_MODELS = [model.replace( | |
| "models/", "") for model in all_models] | |
| log_msg = format_log_message('INFO', "Available models loaded.") | |
| logger.info(log_msg) | |
| def list_models(): | |
| log_msg = format_log_message('INFO', "Received request to list models", extra={'request_type': 'list_models', 'status_code': 200}) | |
| logger.info(log_msg) | |
| return ModelList(data=[{"id": model, "object": "model", "created": 1678888888, "owned_by": "organization-owner"} for model in GeminiClient.AVAILABLE_MODELS]) | |
| async def verify_password(request: Request): | |
| if PASSWORD: | |
| auth_header = request.headers.get("Authorization") | |
| if not auth_header or not auth_header.startswith("Bearer "): | |
| raise HTTPException( | |
| status_code=401, detail="Unauthorized: Missing or invalid token") | |
| token = auth_header.split(" ")[1] | |
| if token != PASSWORD: | |
| raise HTTPException( | |
| status_code=401, detail="Unauthorized: Invalid token") | |
| async def process_request(chat_request: ChatCompletionRequest, http_request: Request, request_type: Literal['stream', 'non-stream']): | |
| global current_api_key | |
| protect_from_abuse( | |
| http_request, MAX_REQUESTS_PER_MINUTE, MAX_REQUESTS_PER_DAY_PER_IP) | |
| if chat_request.model not in GeminiClient.AVAILABLE_MODELS: | |
| error_msg = "无效的模型" | |
| extra_log = {'request_type': request_type, 'model': chat_request.model, 'status_code': 400, 'error_message': error_msg} | |
| log_msg = format_log_message('ERROR', error_msg, extra=extra_log) | |
| logger.error(log_msg) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg) | |
| key_manager.reset_tried_keys_for_request() # 在每次请求处理开始时重置 tried_keys 集合 | |
| contents, system_instruction = GeminiClient.convert_messages( | |
| GeminiClient, chat_request.messages) | |
| retry_attempts = len(key_manager.api_keys) if key_manager.api_keys else 1 # 重试次数等于密钥数量,至少尝试 1 次 | |
| for attempt in range(1, retry_attempts + 1): | |
| if attempt == 1: | |
| current_api_key = key_manager.get_available_key() # 每次循环开始都获取新的 key, 栈逻辑在 get_available_key 中处理 | |
| if current_api_key is None: # 检查是否获取到 API 密钥 | |
| log_msg_no_key = format_log_message('WARNING', "没有可用的 API 密钥,跳过本次尝试", extra={'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A'}) | |
| logger.warning(log_msg_no_key) | |
| break # 如果没有可用密钥,跳出循环 | |
| extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 'N/A', 'error_message': ''} | |
| log_msg = format_log_message('INFO', f"第 {attempt}/{retry_attempts} 次尝试 ... 使用密钥: {current_api_key[:8]}...", extra=extra_log) | |
| logger.info(log_msg) | |
| gemini_client = GeminiClient(current_api_key) | |
| try: | |
| if chat_request.stream: | |
| async def stream_generator(): | |
| try: | |
| async for chunk in gemini_client.stream_chat(chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction): | |
| formatted_chunk = {"id": "chatcmpl-someid", "object": "chat.completion.chunk", "created": 1234567, | |
| "model": chat_request.model, "choices": [{"delta": {"role": "assistant", "content": chunk}, "index": 0, "finish_reason": None}]} | |
| yield f"data: {json.dumps(formatted_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except asyncio.CancelledError: | |
| extra_log_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端已断开连接'} | |
| log_msg = format_log_message('INFO', "客户端连接已中断", extra=extra_log_cancel) | |
| logger.info(log_msg) | |
| except Exception as e: | |
| error_detail = handle_gemini_error( | |
| e, current_api_key, key_manager) | |
| yield f"data: {json.dumps({'error': {'message': error_detail, 'type': 'gemini_error'}})}\n\n" | |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
| else: | |
| async def run_gemini_completion(): | |
| try: | |
| response_content = await asyncio.to_thread(gemini_client.complete_chat, chat_request, contents, safety_settings_g2 if 'gemini-2.0-flash-exp' in chat_request.model else safety_settings, system_instruction) | |
| return response_content | |
| except asyncio.CancelledError: | |
| extra_log_gemini_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '客户端断开导致API调用取消'} | |
| log_msg = format_log_message('INFO', "API调用因客户端断开而取消", extra=extra_log_gemini_cancel) | |
| logger.info(log_msg) | |
| raise | |
| async def check_client_disconnect(): | |
| while True: | |
| if await http_request.is_disconnected(): | |
| extra_log_client_disconnect = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': '检测到客户端断开连接'} | |
| log_msg = format_log_message('INFO', "客户端连接已中断,正在取消API请求", extra=extra_log_client_disconnect) | |
| logger.info(log_msg) | |
| return True | |
| await asyncio.sleep(0.5) | |
| gemini_task = asyncio.create_task(run_gemini_completion()) | |
| disconnect_task = asyncio.create_task(check_client_disconnect()) | |
| try: | |
| done, pending = await asyncio.wait( | |
| [gemini_task, disconnect_task], | |
| return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| if disconnect_task in done: | |
| gemini_task.cancel() | |
| try: | |
| await gemini_task | |
| except asyncio.CancelledError: | |
| extra_log_gemini_task_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message': 'API任务已终止'} | |
| log_msg = format_log_message('INFO', "API任务已成功取消", extra=extra_log_gemini_task_cancel) | |
| logger.info(log_msg) | |
| # 直接抛出异常中断循环 | |
| raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="客户端连接已中断") | |
| if gemini_task in done: | |
| disconnect_task.cancel() | |
| try: | |
| await disconnect_task | |
| except asyncio.CancelledError: | |
| pass | |
| response_content = gemini_task.result() | |
| if response_content.text == "": | |
| extra_log_empty_response = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 204} | |
| log_msg = format_log_message('INFO', "Gemini API 返回空响应", extra=extra_log_empty_response) | |
| logger.info(log_msg) | |
| # 继续循环 | |
| continue | |
| response = ChatCompletionResponse(id="chatcmpl-someid", object="chat.completion", created=1234567890, model=chat_request.model, | |
| choices=[{"index": 0, "message": {"role": "assistant", "content": response_content.text}, "finish_reason": "stop"}]) | |
| extra_log_success = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'status_code': 200} | |
| log_msg = format_log_message('INFO', "请求处理成功", extra=extra_log_success) | |
| logger.info(log_msg) | |
| return response | |
| except asyncio.CancelledError: | |
| extra_log_request_cancel = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, 'error_message':"请求被取消" } | |
| log_msg = format_log_message('INFO', "请求取消", extra=extra_log_request_cancel) | |
| logger.info(log_msg) | |
| raise | |
| except HTTPException as e: | |
| if e.status_code == status.HTTP_408_REQUEST_TIMEOUT: | |
| extra_log = {'key': current_api_key[:8], 'request_type': request_type, 'model': chat_request.model, | |
| 'status_code': 408, 'error_message': '客户端连接中断'} | |
| log_msg = format_log_message('ERROR', "客户端连接中断,终止后续重试", extra=extra_log) | |
| logger.error(log_msg) | |
| raise | |
| else: | |
| raise | |
| except Exception as e: | |
| handle_gemini_error(e, current_api_key, key_manager) | |
| if attempt < retry_attempts: | |
| switch_api_key() | |
| continue | |
| msg = "所有API密钥均失败,请稍后重试" | |
| extra_log_all_fail = {'key': "ALL", 'request_type': request_type, 'model': chat_request.model, 'status_code': 500, 'error_message': msg} | |
| log_msg = format_log_message('ERROR', msg, extra=extra_log_all_fail) | |
| logger.error(log_msg) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=msg) | |
| async def chat_completions(request: ChatCompletionRequest, http_request: Request, _: None = Depends(verify_password)): | |
| return await process_request(request, http_request, "stream" if request.stream else "non-stream") | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| error_message = translate_error(str(exc)) | |
| extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message} | |
| log_msg = format_log_message('ERROR', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception) | |
| logger.error(log_msg) | |
| return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ErrorResponse(message=str(exc), type="internal_error").dict()) | |
| async def root(): | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Gemini API 代理服务</title> | |
| <style> | |
| body {{ | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; | |
| max-width: 800px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| line-height: 1.6; | |
| }} | |
| h1 {{ | |
| color: #333; | |
| text-align: center; | |
| margin-bottom: 30px; | |
| }} | |
| .info-box {{ | |
| background-color: #f8f9fa; | |
| border: 1px solid #dee2e6; | |
| border-radius: 4px; | |
| padding: 20px; | |
| margin-bottom: 20px; | |
| }} | |
| .status {{ | |
| color: #28a745; | |
| font-weight: bold; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🤖 Gemini API 代理服务</h1> | |
| <div class="info-box"> | |
| <h2>🟢 运行状态</h2> | |
| <p class="status">服务运行中</p> | |
| <p>可用API密钥数量: {len(key_manager.api_keys)}</p> | |
| <p>可用模型数量: {len(GeminiClient.AVAILABLE_MODELS)}</p> | |
| </div> | |
| <div class="info-box"> | |
| <h2>⚙️ 环境配置</h2> | |
| <p>每分钟请求限制: {MAX_REQUESTS_PER_MINUTE}</p> | |
| <p>每IP每日请求限制: {MAX_REQUESTS_PER_DAY_PER_IP}</p> | |
| <p>最大重试次数: {len(key_manager.api_keys)}</p> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return html_content | |