Spaces:
Sleeping
Sleeping
| """ | |
| LLM API 客户端 | |
| 支持多个 AI 模型 API 提供商 | |
| """ | |
| import os | |
| import json | |
| import time | |
| from typing import List, Dict, Optional, Union | |
| import requests | |
| from config import LLM_API_CONFIG | |
| class LLMAPIClient: | |
| """统一的 LLM API 客户端""" | |
| def __init__(self, config: Optional[Dict] = None): | |
| """ | |
| 初始化 API 客户端 | |
| Args: | |
| config: LLM API 配置,默认使用 LLM_API_CONFIG | |
| """ | |
| self.config = config or LLM_API_CONFIG | |
| self.provider = self.config.get("provider", "openai") | |
| self.api_key = self.config.get("api_key", "") | |
| self.base_url = self.config.get("base_url", "") | |
| self.model = self.config.get("model", "gpt-4o-mini") | |
| self.timeout = self.config.get("timeout", 30) | |
| # 验证配置 | |
| if self.config.get("enabled") and not self.api_key: | |
| print("警告: LLM API 已启用但未配置 API_KEY") | |
| def _get_endpoint(self) -> str: | |
| """获取 API 端点""" | |
| if self.base_url: | |
| # 自定义端点 | |
| return f"{self.base_url.rstrip('/')}/chat/completions" | |
| # 根据提供商返回默认端点 | |
| endpoints = { | |
| "openai": "https://api.openai.com/v1/chat/completions", | |
| "anthropic": "https://api.anthropic.com/v1/messages", | |
| "deepseek": "https://api.deepseek.com/v1/chat/completions", | |
| "moonshot": "https://api.moonshot.cn/v1/chat/completions", | |
| "zhipu": "https://open.bigmodel.cn/api/paas/v4/chat/completions", | |
| "dashscope": "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", | |
| "ollama": "http://localhost:11434/v1/chat/completions", | |
| } | |
| return endpoints.get(self.provider, "https://api.openai.com/v1/chat/completions") | |
| def _get_headers(self) -> Dict[str, str]: | |
| """获取请求头""" | |
| headers = {"Content-Type": "application/json"} | |
| if self.provider == "anthropic": | |
| headers["x-api-key"] = self.api_key | |
| headers["anthropic-version"] = "2023-06-01" | |
| else: | |
| headers["Authorization"] = f"Bearer {self.api_key}" | |
| return headers | |
| def _format_messages( | |
| self, | |
| system_prompt: str, | |
| user_message: str, | |
| conversation_history: Optional[List[Dict]] = None | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| 格式化消息 | |
| Args: | |
| system_prompt: 系统提示词 | |
| user_message: 用户消息 | |
| conversation_history: 对话历史 | |
| Returns: | |
| 格式化后的消息列表 | |
| """ | |
| messages = [] | |
| # 添加系统提示 | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # 添加对话历史 | |
| if conversation_history: | |
| messages.extend(conversation_history) | |
| # 添加当前用户消息 | |
| messages.append({"role": "user", "content": user_message}) | |
| return messages | |
| def _call_openai_compatible_api( | |
| self, | |
| messages: List[Dict[str, str]], | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| top_p: Optional[float] = None | |
| ) -> str: | |
| """ | |
| 调用 OpenAI 兼容的 API | |
| 支持的提供商: OpenAI, DeepSeek, Moonshot, 智谱, DashScope, Ollama 等 | |
| """ | |
| payload = { | |
| "model": self.model, | |
| "messages": messages, | |
| "temperature": temperature or self.config.get("temperature", 0.7), | |
| "max_tokens": max_tokens or self.config.get("max_tokens", 512), | |
| "top_p": top_p or self.config.get("top_p", 0.9), | |
| } | |
| try: | |
| response = requests.post( | |
| self._get_endpoint(), | |
| headers=self._get_headers(), | |
| json=payload, | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| # OpenAI 兼容格式 | |
| if "choices" in data and len(data["choices"]) > 0: | |
| return data["choices"][0]["message"]["content"] | |
| # 检查是否有错误 | |
| if "error" in data: | |
| raise Exception(f"API 错误: {data['error']}") | |
| raise Exception(f"未知的响应格式: {data}") | |
| except requests.exceptions.Timeout: | |
| raise Exception(f"API 请求超时 (>{self.timeout}秒)") | |
| except requests.exceptions.RequestException as e: | |
| raise Exception(f"API 请求失败: {str(e)}") | |
| except json.JSONDecodeError as e: | |
| raise Exception(f"API 响应解析失败: {str(e)}") | |
| def _call_anthropic_api( | |
| self, | |
| messages: List[Dict[str, str]], | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None | |
| ) -> str: | |
| """调用 Anthropic Claude API""" | |
| # 分离系统提示和对话消息 | |
| system_prompt = "" | |
| chat_messages = [] | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| system_prompt = msg["content"] | |
| else: | |
| chat_messages.append(msg) | |
| payload = { | |
| "model": self.model, | |
| "messages": chat_messages, | |
| "max_tokens": max_tokens or self.config.get("max_tokens", 512), | |
| "temperature": temperature or self.config.get("temperature", 0.7), | |
| } | |
| if system_prompt: | |
| payload["system"] = system_prompt | |
| try: | |
| response = requests.post( | |
| self._get_endpoint(), | |
| headers=self._get_headers(), | |
| json=payload, | |
| timeout=self.timeout | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| if "content" in data and len(data["content"]) > 0: | |
| return data["content"][0]["text"] | |
| if "error" in data: | |
| raise Exception(f"API 错误: {data['error']}") | |
| raise Exception(f"未知的响应格式: {data}") | |
| except Exception as e: | |
| raise Exception(f"Anthropic API 调用失败: {str(e)}") | |
| def generate( | |
| self, | |
| system_prompt: str, | |
| user_message: str, | |
| conversation_history: Optional[List[Dict]] = None, | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| top_p: Optional[float] = None | |
| ) -> str: | |
| """ | |
| 生成回复 | |
| Args: | |
| system_prompt: 系统提示词 | |
| user_message: 用户消息 | |
| conversation_history: 对话历史 | |
| temperature: 温度参数 | |
| max_tokens: 最大生成长度 | |
| top_p: top_p 参数 | |
| Returns: | |
| 模型生成的回复 | |
| """ | |
| # 检查是否启用 | |
| if not self.config.get("enabled"): | |
| raise Exception("LLM API 未启用,请在配置中设置 LLM_API_ENABLED=true") | |
| # 检查 API Key | |
| if not self.api_key: | |
| raise Exception("LLM_API_KEY 未配置") | |
| # 格式化消息 | |
| messages = self._format_messages(system_prompt, user_message, conversation_history) | |
| # 根据提供商调用相应的 API | |
| if self.provider == "anthropic": | |
| return self._call_anthropic_api(messages, temperature, max_tokens) | |
| else: | |
| # OpenAI 兼容格式 | |
| return self._call_openai_compatible_api(messages, temperature, max_tokens, top_p) | |
| def generate_with_retry( | |
| self, | |
| system_prompt: str, | |
| user_message: str, | |
| conversation_history: Optional[List[Dict]] = None, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0 | |
| ) -> str: | |
| """ | |
| 带重试的生成方法 | |
| Args: | |
| system_prompt: 系统提示词 | |
| user_message: 用户消息 | |
| conversation_history: 对话历史 | |
| max_retries: 最大重试次数 | |
| retry_delay: 重试延迟(秒) | |
| Returns: | |
| 模型生成的回复 | |
| """ | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| return self.generate(system_prompt, user_message, conversation_history) | |
| except Exception as e: | |
| last_error = e | |
| if attempt < max_retries - 1: | |
| print(f"API 调用失败,正在重试 ({attempt + 1}/{max_retries}): {str(e)}") | |
| time.sleep(retry_delay * (2 ** attempt)) # 指数退避 | |
| else: | |
| raise Exception(f"API 调用失败,已重试 {max_retries} 次: {str(last_error)}") | |
| # 全局单例 | |
| _llm_api_client = None | |
| def get_llm_api_client() -> LLMAPIClient: | |
| """获取 LLM API 客户端单例""" | |
| global _llm_api_client | |
| if _llm_api_client is None: | |
| _llm_api_client = LLMAPIClient() | |
| return _llm_api_client | |