| |
| |
|
|
| """ |
| 基础提供商抽象层 |
| 定义统一的提供商接口规范 |
| """ |
|
|
| import json |
| import time |
| import uuid |
| from abc import ABC, abstractmethod |
| from typing import Dict, List, Any, Optional, AsyncGenerator, Union |
| from dataclasses import dataclass |
|
|
| from app.models.schemas import OpenAIRequest, Message |
| from app.utils.logger import get_logger |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class ProviderConfig: |
| """提供商配置""" |
| name: str |
| api_endpoint: str |
| timeout: int = 30 |
| headers: Optional[Dict[str, str]] = None |
| extra_config: Optional[Dict[str, Any]] = None |
|
|
|
|
| @dataclass |
| class ProviderResponse: |
| """提供商响应""" |
| success: bool |
| content: str = "" |
| error: Optional[str] = None |
| usage: Optional[Dict[str, int]] = None |
| extra_data: Optional[Dict[str, Any]] = None |
|
|
|
|
| class BaseProvider(ABC): |
| """基础提供商抽象类""" |
| |
| def __init__(self, config: ProviderConfig): |
| """初始化提供商""" |
| self.config = config |
| self.name = config.name |
| self.logger = get_logger() |
| |
| @abstractmethod |
| async def chat_completion( |
| self, |
| request: OpenAIRequest, |
| **kwargs |
| ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: |
| """ |
| 聊天完成接口 |
| |
| Args: |
| request: OpenAI格式的请求 |
| **kwargs: 额外参数 |
| |
| Returns: |
| 非流式: Dict[str, Any] - OpenAI格式的响应 |
| 流式: AsyncGenerator[str, None] - SSE格式的流式响应 |
| """ |
| pass |
| |
| @abstractmethod |
| async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]: |
| """ |
| 转换OpenAI请求为提供商特定格式 |
| |
| Args: |
| request: OpenAI格式的请求 |
| |
| Returns: |
| Dict[str, Any]: 提供商特定格式的请求 |
| """ |
| pass |
| |
| @abstractmethod |
| async def transform_response( |
| self, |
| response: Any, |
| request: OpenAIRequest |
| ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]: |
| """ |
| 转换提供商响应为OpenAI格式 |
| |
| Args: |
| response: 提供商的原始响应 |
| request: 原始请求(用于构造响应) |
| |
| Returns: |
| Union[Dict[str, Any], AsyncGenerator[str, None]]: OpenAI格式的响应 |
| """ |
| pass |
| |
| def get_supported_models(self) -> List[str]: |
| """获取支持的模型列表""" |
| return [] |
| |
| def create_chat_id(self) -> str: |
| """生成聊天ID""" |
| return f"chatcmpl-{uuid.uuid4().hex}" |
| |
| def create_openai_chunk( |
| self, |
| chat_id: str, |
| model: str, |
| delta: Dict[str, Any], |
| finish_reason: Optional[str] = None |
| ) -> Dict[str, Any]: |
| """创建OpenAI格式的流式响应块""" |
| return { |
| "id": chat_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "delta": delta, |
| "finish_reason": finish_reason, |
| "logprobs": None, |
| }], |
| "system_fingerprint": f"fp_{self.name}_001", |
| } |
| |
| def create_openai_response( |
| self, |
| chat_id: str, |
| model: str, |
| content: str, |
| usage: Optional[Dict[str, int]] = None |
| ) -> Dict[str, Any]: |
| """创建OpenAI格式的非流式响应""" |
| return { |
| "id": chat_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": content |
| }, |
| "finish_reason": "stop", |
| "logprobs": None, |
| }], |
| "usage": usage or { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0 |
| }, |
| "system_fingerprint": f"fp_{self.name}_001", |
| } |
|
|
| def create_openai_response_with_reasoning( |
| self, |
| chat_id: str, |
| model: str, |
| content: str, |
| reasoning_content: str = None, |
| usage: Optional[Dict[str, int]] = None |
| ) -> Dict[str, Any]: |
| """创建包含推理内容的OpenAI格式非流式响应""" |
| message = { |
| "role": "assistant", |
| "content": content |
| } |
|
|
| |
| if reasoning_content and reasoning_content.strip(): |
| message["reasoning_content"] = reasoning_content |
|
|
| return { |
| "id": chat_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "message": message, |
| "finish_reason": "stop", |
| "logprobs": None, |
| }], |
| "usage": usage or { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0 |
| }, |
| "system_fingerprint": f"fp_{self.name}_001", |
| } |
|
|
| async def format_sse_chunk(self, chunk: Dict[str, Any]) -> str: |
| """格式化SSE响应块""" |
| return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
| |
| async def format_sse_done(self) -> str: |
| """格式化SSE结束标记""" |
| return "data: [DONE]\n\n" |
| |
| def log_request(self, request: OpenAIRequest): |
| """记录请求日志""" |
| self.logger.info(f"🔄 {self.name} 处理请求: {request.model}") |
| self.logger.debug(f" 消息数量: {len(request.messages)}") |
| self.logger.debug(f" 流式模式: {request.stream}") |
| |
| def log_response(self, success: bool, error: Optional[str] = None): |
| """记录响应日志""" |
| if success: |
| self.logger.info(f"✅ {self.name} 响应成功") |
| else: |
| self.logger.error(f"❌ {self.name} 响应失败: {error}") |
| |
| def handle_error(self, error: Exception, context: str = "") -> Dict[str, Any]: |
| """统一错误处理""" |
| error_msg = f"{self.name} {context} 错误: {str(error)}" |
| self.logger.error(error_msg) |
| |
| return { |
| "error": { |
| "message": error_msg, |
| "type": "provider_error", |
| "code": "internal_error" |
| } |
| } |
|
|
|
|
| class ProviderRegistry: |
| """提供商注册表""" |
| |
| def __init__(self): |
| self._providers: Dict[str, BaseProvider] = {} |
| self._model_mapping: Dict[str, str] = {} |
| |
| def register(self, provider: BaseProvider, models: List[str]): |
| """注册提供商""" |
| self._providers[provider.name] = provider |
| for model in models: |
| self._model_mapping[model] = provider.name |
| logger.info(f"📝 注册提供商: {provider.name}, 模型: {models}") |
| |
| def get_provider(self, model: str) -> Optional[BaseProvider]: |
| """根据模型获取提供商""" |
| provider_name = self._model_mapping.get(model) |
| if provider_name: |
| return self._providers.get(provider_name) |
| return None |
| |
| def get_provider_by_name(self, name: str) -> Optional[BaseProvider]: |
| """根据名称获取提供商""" |
| return self._providers.get(name) |
| |
| def list_models(self) -> List[str]: |
| """列出所有支持的模型""" |
| return list(self._model_mapping.keys()) |
| |
| def list_providers(self) -> List[str]: |
| """列出所有提供商""" |
| return list(self._providers.keys()) |
|
|
|
|
| |
| provider_registry = ProviderRegistry() |
|
|