Spaces:
Paused
Paused
| import json | |
| import os | |
| import asyncio | |
| from app.models.schemas import ChatCompletionRequest | |
| from dataclasses import dataclass | |
| from typing import Optional, Dict, Any, List | |
| import httpx | |
| import logging | |
| import secrets | |
| import string | |
| from app.utils import format_log_message | |
| import app.config.settings as settings | |
| from app.utils.logging import log | |
| def generate_secure_random_string(length): | |
| all_characters = string.ascii_letters + string.digits | |
| secure_random_string = ''.join(secrets.choice(all_characters) for _ in range(length)) | |
| return secure_random_string | |
| class GeneratedText: | |
| text: str | |
| finish_reason: Optional[str] = None | |
| class OpenAIClient: | |
| AVAILABLE_MODELS = [] | |
| EXTRA_MODELS = os.environ.get("EXTRA_MODELS", "").split(",") | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| def filter_data_by_whitelist(data, allowed_keys): | |
| """ | |
| 根据白名单过滤字典。 | |
| Args: | |
| data (dict): 原始的 Python 字典 (代表 JSON 对象)。 | |
| allowed_keys (list or set): 包含允许保留的键名的列表或集合。 | |
| 使用集合 (set) 进行查找通常更快。 | |
| Returns: | |
| dict: 只包含白名单中键的新字典。 | |
| """ | |
| # 使用集合(set)可以提高查找效率,特别是当白名单很大时 | |
| allowed_keys_set = set(allowed_keys) | |
| # 使用字典推导式创建过滤后的新字典 | |
| filtered_data = {key: value for key, value in data.items() if key in allowed_keys_set} | |
| return filtered_data | |
| # 真流式处理 | |
| async def stream_chat(self, request: ChatCompletionRequest): | |
| whitelist = ["model", "messages", "temperature", "max_tokens","stream","tools","reasoning_effort","top_k","presence_penalty"] | |
| data = self.filter_data_by_whitelist(request, whitelist) | |
| if settings.search["search_mode"] and data.model.endswith("-search"): | |
| log('INFO', "开启联网搜索模式", extra={'key': self.api_key[:8], 'model':request.model}) | |
| data.setdefault("tools", []).append({"google_search": {}}) | |
| data.model = data.model.removesuffix("-search") | |
| # 真流式请求处理逻辑 | |
| extra_log = {'key': self.api_key[:8], 'request_type': 'stream', 'model': request.model} | |
| log('INFO', "流式请求开始", extra=extra_log) | |
| url = f"https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.api_key}" | |
| } | |
| async with httpx.AsyncClient() as client: | |
| async with client.stream("POST", url, headers=headers, json=data, timeout=600) as response: | |
| buffer = b"" # 用于累积可能不完整的 JSON 数据 | |
| try: | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): # 跳过空行 (SSE 消息分隔符) | |
| continue | |
| if line.startswith("data: "): | |
| line = line[len("data: "):].strip() # 去除 "data: " 前缀 | |
| # 检查是否是结束标志,如果是,结束循环 | |
| if line == "[DONE]": | |
| break | |
| buffer += line.encode('utf-8') | |
| try: | |
| # 尝试解析整个缓冲区 | |
| data = json.loads(buffer.decode('utf-8')) | |
| # 解析成功,清空缓冲区 | |
| buffer = b"" | |
| yield data | |
| except json.JSONDecodeError: | |
| # JSON 不完整,继续累积到 buffer | |
| continue | |
| except Exception as e: | |
| log('ERROR', f"流式处理期间发生错误", | |
| extra={'key': self.api_key[:8], 'request_type': 'stream', 'model': request.model}) | |
| raise e | |
| except Exception as e: | |
| raise e | |
| finally: | |
| log('info', "流式请求结束") | |