Spaces:
Paused
Paused
| import json | |
| from typing import Optional, Union | |
| from fastapi import APIRouter, Body, HTTPException, Path, Query, Request, Depends, status, Header | |
| from fastapi.responses import StreamingResponse | |
| from app.services import GeminiClient | |
| from app.utils import protect_from_abuse,generate_cache_key,openAI_from_text,log | |
| from app.utils.response import openAI_from_Gemini | |
| from app.utils.auth import custom_verify_password | |
| from .stream_handlers import process_stream_request | |
| from .nonstream_handlers import process_request, process_nonstream_with_keepalive_stream | |
| from app.models.schemas import ChatCompletionRequest, ChatCompletionResponse, ModelList, AIRequest, ChatRequestGemini | |
| import app.config.settings as settings | |
| import asyncio | |
| from app.vertex.routes import chat_api, models_api | |
| from app.vertex.models import OpenAIRequest, OpenAIMessage | |
| # 创建路由器 | |
| router = APIRouter() | |
| # 全局变量引用 - 这些将在main.py中初始化并传递给路由 | |
| key_manager = None | |
| response_cache_manager = None | |
| active_requests_manager = None | |
| safety_settings = None | |
| safety_settings_g2 = None | |
| current_api_key = None | |
| FAKE_STREAMING = None | |
| FAKE_STREAMING_INTERVAL = None | |
| PASSWORD = None | |
| MAX_REQUESTS_PER_MINUTE = None | |
| MAX_REQUESTS_PER_DAY_PER_IP = None | |
| # 初始化路由器的函数 | |
| def init_router( | |
| _key_manager, | |
| _response_cache_manager, | |
| _active_requests_manager, | |
| _safety_settings, | |
| _safety_settings_g2, | |
| _current_api_key, | |
| _fake_streaming, | |
| _fake_streaming_interval, | |
| _password, | |
| _max_requests_per_minute, | |
| _max_requests_per_day_per_ip | |
| ): | |
| global key_manager, response_cache_manager, active_requests_manager | |
| global safety_settings, safety_settings_g2, current_api_key | |
| global FAKE_STREAMING, FAKE_STREAMING_INTERVAL | |
| global PASSWORD, MAX_REQUESTS_PER_MINUTE, MAX_REQUESTS_PER_DAY_PER_IP | |
| key_manager = _key_manager | |
| response_cache_manager = _response_cache_manager | |
| active_requests_manager = _active_requests_manager | |
| safety_settings = _safety_settings | |
| safety_settings_g2 = _safety_settings_g2 | |
| current_api_key = _current_api_key | |
| FAKE_STREAMING = _fake_streaming | |
| FAKE_STREAMING_INTERVAL = _fake_streaming_interval | |
| PASSWORD = _password | |
| MAX_REQUESTS_PER_MINUTE = _max_requests_per_minute | |
| MAX_REQUESTS_PER_DAY_PER_IP = _max_requests_per_day_per_ip | |
| async def verify_user_agent(request: Request): | |
| if not settings.WHITELIST_USER_AGENT: | |
| return | |
| if request.headers.get("User-Agent") not in settings.WHITELIST_USER_AGENT: | |
| raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not allowed client") | |
| # todo : 添加 gemini 支持(流式返回) | |
| async def get_cache(cache_key,is_stream: bool,is_gemini=False): | |
| # 检查缓存是否存在,如果存在,返回缓存 | |
| cached_response, cache_hit = await response_cache_manager.get_and_remove(cache_key) | |
| if cache_hit and cached_response: | |
| log('info', f"缓存命中: {cache_key[:8]}...", | |
| extra={'request_type': 'non-stream', 'model': cached_response.model}) | |
| if is_gemini: | |
| if is_stream: | |
| data = f"data: {json.dumps(cached_response.data, ensure_ascii=False)}\n\n" | |
| return StreamingResponse(data, media_type="text/event-stream") | |
| else: | |
| return cached_response.data | |
| if is_stream: | |
| chunk = openAI_from_Gemini(cached_response,stream=True) | |
| return StreamingResponse(chunk, media_type="text/event-stream") | |
| else: | |
| return openAI_from_Gemini(cached_response,stream=False) | |
| return None | |
| async def aistudio_list_models(_ = Depends(custom_verify_password), | |
| _2 = Depends(verify_user_agent)): | |
| if settings.WHITELIST_MODELS: | |
| filtered_models = [model for model in GeminiClient.AVAILABLE_MODELS if model in settings.WHITELIST_MODELS] | |
| else: | |
| filtered_models = [model for model in GeminiClient.AVAILABLE_MODELS if model not in settings.BLOCKED_MODELS] | |
| return ModelList(data=[{"id": model, "object": "model", "created": 1678888888, "owned_by": "organization-owner"} for model in filtered_models]) | |
| async def vertex_list_models(request: Request, | |
| _ = Depends(custom_verify_password), | |
| _2 = Depends(verify_user_agent)): | |
| # 使用vertex/routes/models_api的实现 | |
| return await models_api.list_models(request, current_api_key) | |
| # API路由 | |
| async def list_models(request: Request, | |
| _ = Depends(custom_verify_password), | |
| _2 = Depends(verify_user_agent)): | |
| if settings.ENABLE_VERTEX: | |
| return await vertex_list_models(request, _, _2) | |
| return await aistudio_list_models(_, _2) | |
| async def aistudio_chat_completions( | |
| request: Union[ChatCompletionRequest, AIRequest], | |
| http_request: Request, | |
| _ = Depends(custom_verify_password), | |
| _2 = Depends(verify_user_agent), | |
| ): | |
| format_type = getattr(request, 'format_type', None) | |
| if format_type and (format_type == "gemini"): | |
| is_gemini = True | |
| else: | |
| is_gemini = False | |
| # 生成缓存键 - 用于匹配请求内容对应缓存 | |
| if settings.PRECISE_CACHE: | |
| cache_key = generate_cache_key(request, is_gemini = is_gemini) | |
| else: | |
| cache_key = generate_cache_key(request, last_n_messages = settings.CALCULATE_CACHE_ENTRIES,is_gemini = is_gemini) | |
| # 请求前基本检查 | |
| await protect_from_abuse( | |
| http_request, | |
| settings.MAX_REQUESTS_PER_MINUTE, | |
| settings.MAX_REQUESTS_PER_DAY_PER_IP) | |
| if request.model not in GeminiClient.AVAILABLE_MODELS: | |
| log('error', "无效的模型", | |
| extra={'model': request.model, 'status_code': 400}) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, detail="无效的模型") | |
| # 记录请求缓存键信息 | |
| log('info', f"请求缓存键: {cache_key[:8]}...", | |
| extra={'request_type': 'non-stream', 'model': request.model}) | |
| # 检查缓存是否存在,如果存在,返回缓存 | |
| cached_response = await get_cache(cache_key, is_stream = request.stream,is_gemini=is_gemini) | |
| if cached_response : | |
| return cached_response | |
| if not settings.PUBLIC_MODE: | |
| # 构建包含缓存键的活跃请求池键 | |
| pool_key = f"{cache_key}" | |
| # 查找所有使用相同缓存键的活跃任务 | |
| active_task = active_requests_manager.get(pool_key) | |
| if active_task and not active_task.done(): | |
| log('info', f"发现相同请求的进行中任务", | |
| extra={'request_type': 'stream' if request.stream else "non-stream", 'model': request.model}) | |
| # 等待已有任务完成 | |
| try: | |
| # 设置超时,避免无限等待 | |
| await asyncio.wait_for(active_task, timeout=240) | |
| # 使用任务结果 | |
| if active_task.done() and not active_task.cancelled(): | |
| result = active_task.result() | |
| active_requests_manager.remove(pool_key) | |
| if result: | |
| return result | |
| except (asyncio.TimeoutError, asyncio.CancelledError) as e: | |
| # 任务超时或被取消的情况下,记录日志然后让代码继续执行 | |
| error_type = "超时" if isinstance(e, asyncio.TimeoutError) else "被取消" | |
| log('warning', f"等待已有任务{error_type}: {pool_key}", | |
| extra={'request_type': 'non-stream', 'model': request.model}) | |
| # 从活跃请求池移除该任务 | |
| if active_task.done() or active_task.cancelled(): | |
| active_requests_manager.remove(pool_key) | |
| log('info', f"已从活跃请求池移除{error_type}任务: {pool_key}", | |
| extra={'request_type': 'non-stream'}) | |
| if request.stream: | |
| # 流式请求处理任务 | |
| process_task = asyncio.create_task( | |
| process_stream_request( | |
| chat_request = request, | |
| key_manager=key_manager, | |
| response_cache_manager = response_cache_manager, | |
| safety_settings = safety_settings, | |
| safety_settings_g2 = safety_settings_g2, | |
| cache_key = cache_key | |
| ) | |
| ) | |
| else: | |
| # 检查是否启用非流式保活功能 | |
| if settings.NONSTREAM_KEEPALIVE_ENABLED: | |
| # 使用带保活功能的非流式请求处理 | |
| process_task = asyncio.create_task( | |
| process_nonstream_with_keepalive_stream( | |
| chat_request = request, | |
| key_manager = key_manager, | |
| response_cache_manager = response_cache_manager, | |
| safety_settings = safety_settings, | |
| safety_settings_g2 = safety_settings_g2, | |
| cache_key = cache_key, | |
| is_gemini = is_gemini | |
| ) | |
| ) | |
| else: | |
| # 创建非流式请求处理任务 | |
| process_task = asyncio.create_task( | |
| process_request( | |
| chat_request = request, | |
| key_manager = key_manager, | |
| response_cache_manager = response_cache_manager, | |
| safety_settings = safety_settings, | |
| safety_settings_g2 = safety_settings_g2, | |
| cache_key = cache_key | |
| ) | |
| ) | |
| if not settings.PUBLIC_MODE: | |
| # 将任务添加到活跃请求池 | |
| active_requests_manager.add(pool_key, process_task) | |
| # 等待任务完成 | |
| try: | |
| response = await process_task | |
| if not settings.PUBLIC_MODE: | |
| active_requests_manager.remove(pool_key) | |
| return response | |
| except Exception as e: | |
| if not settings.PUBLIC_MODE: | |
| # 如果任务失败,从活跃请求池中移除 | |
| active_requests_manager.remove(pool_key) | |
| # 检查是否已有缓存的结果(可能是由另一个任务创建的) | |
| cached_response = await get_cache(cache_key, is_stream = request.stream,is_gemini=is_gemini) | |
| if cached_response : | |
| return cached_response | |
| # 发送错误信息给客户端 | |
| raise HTTPException(status_code=500, detail=f" hajimi 服务器内部处理时发生错误\n具体原因:{e}") | |
| async def vertex_chat_completions( | |
| request: ChatCompletionRequest, | |
| http_request: Request, | |
| _dp = Depends(custom_verify_password), | |
| _du = Depends(verify_user_agent), | |
| ): | |
| # 使用vertex/routes/chat_api的实现 | |
| # 转换消息格式 | |
| openai_messages = [] | |
| for message in request.messages: | |
| openai_messages.append(OpenAIMessage( | |
| role=message.get('role', ''), | |
| content=message.get('content', '') | |
| )) | |
| # 转换请求格式 | |
| vertex_request = OpenAIRequest( | |
| model=request.model, | |
| messages=openai_messages, | |
| temperature=request.temperature, | |
| max_tokens=request.max_tokens, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| stream=request.stream, | |
| stop=request.stop, | |
| presence_penalty=request.presence_penalty, | |
| frequency_penalty=request.frequency_penalty, | |
| seed=getattr(request, 'seed', None), | |
| logprobs=getattr(request, 'logprobs', None), | |
| response_logprobs=getattr(request, 'response_logprobs', None), | |
| n=request.n | |
| ) | |
| # 调用vertex/routes/chat_api的实现 | |
| return await chat_api.chat_completions(http_request, vertex_request, current_api_key) | |
| async def chat_completions( | |
| request: ChatCompletionRequest, | |
| http_request: Request, | |
| _dp = Depends(custom_verify_password), | |
| _du = Depends(verify_user_agent), | |
| ): | |
| """处理API请求的主函数,根据需要处理流式或非流式请求""" | |
| if settings.ENABLE_VERTEX: | |
| return await vertex_chat_completions(request, http_request, _dp, _du) | |
| return await aistudio_chat_completions(request, http_request, _dp, _du) | |
| async def gemini_chat_completions( | |
| request: Request, | |
| model_and_responseType: str = Path(...), | |
| key: Optional[str] = Query(None), | |
| alt: Optional[str] = Query(None, description=" sse 或 None"), | |
| payload: ChatRequestGemini = Body(...), | |
| _dp = Depends(custom_verify_password), | |
| _du = Depends(verify_user_agent), | |
| ): | |
| # 提取路径参数 | |
| is_stream = False | |
| try: | |
| model_name, action_type = model_and_responseType.split(":", 1) | |
| if action_type == "streamGenerateContent": | |
| is_stream = True | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="无效的请求路径") | |
| geminiRequest = AIRequest(payload=payload,model=model_name,stream=is_stream,format_type='gemini') | |
| return await aistudio_chat_completions(geminiRequest, request, _dp, _du) | |