Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| 基础提供商抽象层 | |
| 定义统一的提供商接口规范 | |
| """ | |
| import json | |
| import time | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, Any, Optional, AsyncGenerator, Union | |
| from dataclasses import dataclass | |
| from app.models.schemas import OpenAIRequest, Message | |
| from app.utils.logger import get_logger | |
| logger = get_logger() | |
| class ProviderConfig: | |
| """提供商配置""" | |
| name: str | |
| api_endpoint: str | |
| timeout: int = 30 | |
| headers: Optional[Dict[str, str]] = None | |
| extra_config: Optional[Dict[str, Any]] = None | |
| class ProviderResponse: | |
| """提供商响应""" | |
| success: bool | |
| content: str = "" | |
| error: Optional[str] = None | |
| usage: Optional[Dict[str, int]] = None | |
| extra_data: Optional[Dict[str, Any]] = None | |
| class BaseProvider(ABC): | |
| """基础提供商抽象类""" | |
| def __init__(self, config: ProviderConfig): | |
| """初始化提供商""" | |
| self.config = config | |
| self.name = config.name | |
| self.logger = get_logger() | |
| async def chat_completion( | |
| self, | |
| request: OpenAIRequest, | |
| **kwargs | |
| ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: | |
| """ | |
| 聊天完成接口 | |
| Args: | |
| request: OpenAI格式的请求 | |
| **kwargs: 额外参数 | |
| Returns: | |
| 非流式: Dict[str, Any] - OpenAI格式的响应 | |
| 流式: AsyncGenerator[str, None] - SSE格式的流式响应 | |
| """ | |
| pass | |
| async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]: | |
| """ | |
| 转换OpenAI请求为提供商特定格式 | |
| Args: | |
| request: OpenAI格式的请求 | |
| Returns: | |
| Dict[str, Any]: 提供商特定格式的请求 | |
| """ | |
| pass | |
| async def transform_response( | |
| self, | |
| response: Any, | |
| request: OpenAIRequest | |
| ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: | |
| """ | |
| 转换提供商响应为OpenAI格式 | |
| Args: | |
| response: 提供商的原始响应 | |
| request: 原始请求(用于构造响应) | |
| Returns: | |
| Union[Dict[str, Any], AsyncGenerator[str, None]]: OpenAI格式的响应 | |
| """ | |
| pass | |
| def get_supported_models(self) -> List[str]: | |
| """获取支持的模型列表""" | |
| return [] | |
| def create_chat_id(self) -> str: | |
| """生成聊天ID""" | |
| return f"chatcmpl-{uuid.uuid4().hex}" | |
| def create_openai_chunk( | |
| self, | |
| chat_id: str, | |
| model: str, | |
| delta: Dict[str, Any], | |
| finish_reason: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """创建OpenAI格式的流式响应块""" | |
| return { | |
| "id": chat_id, | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": delta, | |
| "finish_reason": finish_reason, | |
| "logprobs": None, | |
| }], | |
| "system_fingerprint": f"fp_{self.name}_001", | |
| } | |
| def create_openai_response( | |
| self, | |
| chat_id: str, | |
| model: str, | |
| content: str, | |
| usage: Optional[Dict[str, int]] = None | |
| ) -> Dict[str, Any]: | |
| """创建OpenAI格式的非流式响应""" | |
| return { | |
| "id": chat_id, | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": content | |
| }, | |
| "finish_reason": "stop", | |
| "logprobs": None, | |
| }], | |
| "usage": usage or { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| }, | |
| "system_fingerprint": f"fp_{self.name}_001", | |
| } | |
| def create_openai_response_with_reasoning( | |
| self, | |
| chat_id: str, | |
| model: str, | |
| content: str, | |
| reasoning_content: str = None, | |
| usage: Optional[Dict[str, int]] = None | |
| ) -> Dict[str, Any]: | |
| """创建包含推理内容的OpenAI格式非流式响应""" | |
| message = { | |
| "role": "assistant", | |
| "content": content | |
| } | |
| # 只有当推理内容存在且不为空时才添加 | |
| if reasoning_content and reasoning_content.strip(): | |
| message["reasoning_content"] = reasoning_content | |
| return { | |
| "id": chat_id, | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{ | |
| "index": 0, | |
| "message": message, | |
| "finish_reason": "stop", | |
| "logprobs": None, | |
| }], | |
| "usage": usage or { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| }, | |
| "system_fingerprint": f"fp_{self.name}_001", | |
| } | |
| async def format_sse_chunk(self, chunk: Dict[str, Any]) -> str: | |
| """格式化SSE响应块""" | |
| return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" | |
| async def format_sse_done(self) -> str: | |
| """格式化SSE结束标记""" | |
| return "data: [DONE]\n\n" | |
| def log_request(self, request: OpenAIRequest): | |
| """记录请求日志""" | |
| self.logger.info(f"🔄 {self.name} 处理请求: {request.model}") | |
| self.logger.debug(f" 消息数量: {len(request.messages)}") | |
| self.logger.debug(f" 流式模式: {request.stream}") | |
| def log_response(self, success: bool, error: Optional[str] = None): | |
| """记录响应日志""" | |
| if success: | |
| self.logger.info(f"✅ {self.name} 响应成功") | |
| else: | |
| self.logger.error(f"❌ {self.name} 响应失败: {error}") | |
| def handle_error(self, error: Exception, context: str = "") -> Dict[str, Any]: | |
| """统一错误处理""" | |
| error_msg = f"{self.name} {context} 错误: {str(error)}" | |
| self.logger.error(error_msg) | |
| return { | |
| "error": { | |
| "message": error_msg, | |
| "type": "provider_error", | |
| "code": "internal_error" | |
| } | |
| } | |
| class ProviderRegistry: | |
| """提供商注册表""" | |
| def __init__(self): | |
| self._providers: Dict[str, BaseProvider] = {} | |
| self._model_mapping: Dict[str, str] = {} | |
| def register(self, provider: BaseProvider, models: List[str]): | |
| """注册提供商""" | |
| self._providers[provider.name] = provider | |
| for model in models: | |
| self._model_mapping[model] = provider.name | |
| logger.info(f"📝 注册提供商: {provider.name}, 模型: {models}") | |
| def get_provider(self, model: str) -> Optional[BaseProvider]: | |
| """根据模型获取提供商""" | |
| provider_name = self._model_mapping.get(model) | |
| if provider_name: | |
| return self._providers.get(provider_name) | |
| return None | |
| def get_provider_by_name(self, name: str) -> Optional[BaseProvider]: | |
| """根据名称获取提供商""" | |
| return self._providers.get(name) | |
| def list_models(self) -> List[str]: | |
| """列出所有支持的模型""" | |
| return list(self._model_mapping.keys()) | |
| def list_providers(self) -> List[str]: | |
| """列出所有提供商""" | |
| return list(self._providers.keys()) | |
| # 全局提供商注册表 | |
| provider_registry = ProviderRegistry() | |