File size: 7,948 Bytes
fd21f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
基础提供商抽象层
定义统一的提供商接口规范
"""

import json
import time
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, AsyncGenerator, Union
from dataclasses import dataclass

from app.models.schemas import OpenAIRequest, Message
from app.utils.logger import get_logger

logger = get_logger()


@dataclass
class ProviderConfig:
    """提供商配置"""
    name: str
    api_endpoint: str
    timeout: int = 30
    headers: Optional[Dict[str, str]] = None
    extra_config: Optional[Dict[str, Any]] = None


@dataclass
class ProviderResponse:
    """提供商响应"""
    success: bool
    content: str = ""
    error: Optional[str] = None
    usage: Optional[Dict[str, int]] = None
    extra_data: Optional[Dict[str, Any]] = None


class BaseProvider(ABC):
    """基础提供商抽象类"""
    
    def __init__(self, config: ProviderConfig):
        """初始化提供商"""
        self.config = config
        self.name = config.name
        self.logger = get_logger()
        
    @abstractmethod
    async def chat_completion(
        self, 
        request: OpenAIRequest,
        **kwargs
    ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
        """
        聊天完成接口
        
        Args:
            request: OpenAI格式的请求
            **kwargs: 额外参数
            
        Returns:
            非流式: Dict[str, Any] - OpenAI格式的响应
            流式: AsyncGenerator[str, None] - SSE格式的流式响应
        """
        pass
    
    @abstractmethod
    async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
        """
        转换OpenAI请求为提供商特定格式
        
        Args:
            request: OpenAI格式的请求
            
        Returns:
            Dict[str, Any]: 提供商特定格式的请求
        """
        pass
    
    @abstractmethod
    async def transform_response(
        self, 
        response: Any, 
        request: OpenAIRequest
    ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
        """
        转换提供商响应为OpenAI格式
        
        Args:
            response: 提供商的原始响应
            request: 原始请求(用于构造响应)
            
        Returns:
            Union[Dict[str, Any], AsyncGenerator[str, None]]: OpenAI格式的响应
        """
        pass
    
    def get_supported_models(self) -> List[str]:
        """获取支持的模型列表"""
        return []
    
    def create_chat_id(self) -> str:
        """生成聊天ID"""
        return f"chatcmpl-{uuid.uuid4().hex}"
    
    def create_openai_chunk(
        self, 
        chat_id: str, 
        model: str, 
        delta: Dict[str, Any], 
        finish_reason: Optional[str] = None
    ) -> Dict[str, Any]:
        """创建OpenAI格式的流式响应块"""
        return {
            "id": chat_id,
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": model,
            "choices": [{
                "index": 0,
                "delta": delta,
                "finish_reason": finish_reason,
                "logprobs": None,
            }],
            "system_fingerprint": f"fp_{self.name}_001",
        }
    
    def create_openai_response(
        self, 
        chat_id: str, 
        model: str, 
        content: str, 
        usage: Optional[Dict[str, int]] = None
    ) -> Dict[str, Any]:
        """创建OpenAI格式的非流式响应"""
        return {
            "id": chat_id,
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": content
                },
                "finish_reason": "stop",
                "logprobs": None,
            }],
            "usage": usage or {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0
            },
            "system_fingerprint": f"fp_{self.name}_001",
        }

    def create_openai_response_with_reasoning(
        self,
        chat_id: str,
        model: str,
        content: str,
        reasoning_content: str = None,
        usage: Optional[Dict[str, int]] = None
    ) -> Dict[str, Any]:
        """创建包含推理内容的OpenAI格式非流式响应"""
        message = {
            "role": "assistant",
            "content": content
        }

        # 只有当推理内容存在且不为空时才添加
        if reasoning_content and reasoning_content.strip():
            message["reasoning_content"] = reasoning_content

        return {
            "id": chat_id,
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [{
                "index": 0,
                "message": message,
                "finish_reason": "stop",
                "logprobs": None,
            }],
            "usage": usage or {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0
            },
            "system_fingerprint": f"fp_{self.name}_001",
        }

    async def format_sse_chunk(self, chunk: Dict[str, Any]) -> str:
        """格式化SSE响应块"""
        return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
    
    async def format_sse_done(self) -> str:
        """格式化SSE结束标记"""
        return "data: [DONE]\n\n"
    
    def log_request(self, request: OpenAIRequest):
        """记录请求日志"""
        self.logger.info(f"🔄 {self.name} 处理请求: {request.model}")
        self.logger.debug(f"  消息数量: {len(request.messages)}")
        self.logger.debug(f"  流式模式: {request.stream}")
        
    def log_response(self, success: bool, error: Optional[str] = None):
        """记录响应日志"""
        if success:
            self.logger.info(f"✅ {self.name} 响应成功")
        else:
            self.logger.error(f"❌ {self.name} 响应失败: {error}")
    
    def handle_error(self, error: Exception, context: str = "") -> Dict[str, Any]:
        """统一错误处理"""
        error_msg = f"{self.name} {context} 错误: {str(error)}"
        self.logger.error(error_msg)
        
        return {
            "error": {
                "message": error_msg,
                "type": "provider_error",
                "code": "internal_error"
            }
        }


class ProviderRegistry:
    """提供商注册表"""
    
    def __init__(self):
        self._providers: Dict[str, BaseProvider] = {}
        self._model_mapping: Dict[str, str] = {}
    
    def register(self, provider: BaseProvider, models: List[str]):
        """注册提供商"""
        self._providers[provider.name] = provider
        for model in models:
            self._model_mapping[model] = provider.name
        logger.info(f"📝 注册提供商: {provider.name}, 模型: {models}")
    
    def get_provider(self, model: str) -> Optional[BaseProvider]:
        """根据模型获取提供商"""
        provider_name = self._model_mapping.get(model)
        if provider_name:
            return self._providers.get(provider_name)
        return None
    
    def get_provider_by_name(self, name: str) -> Optional[BaseProvider]:
        """根据名称获取提供商"""
        return self._providers.get(name)
    
    def list_models(self) -> List[str]:
        """列出所有支持的模型"""
        return list(self._model_mapping.keys())
    
    def list_providers(self) -> List[str]:
        """列出所有提供商"""
        return list(self._providers.keys())


# 全局提供商注册表
provider_registry = ProviderRegistry()