File size: 13,124 Bytes
c8b347c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9567a7f
 
 
c8b347c
 
 
 
 
9567a7f
 
c8b347c
 
 
 
 
 
9567a7f
c8b347c
 
9567a7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfda2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9567a7f
 
c8b347c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
API推理适配器 — 对接智谱GLM和NIM的云端大模型
支持:智谱GLM-4-flash(免费)/GLM-4-plus, NIM多模型
"""
import os
import json
import time
import logging
import hashlib
from typing import List, Dict, Optional
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)

@dataclass
class APIModelConfig:
    """API模型配置"""
    provider: str        # zhipu / nim
    model_name: str      # glm-4-flash / nim-text etc
    api_key: str
    base_url: str
    max_tokens: int = 512
    temperature: float = 0.7
    timeout: int = 30
    priority: int = 0    # 越高越优先
    free: bool = True    # 是否免费

@dataclass
class APIInferenceResult:
    """API推理结果"""
    answer: str
    model: str
    provider: str
    latency_ms: float
    tokens_used: int = 0
    success: bool = True
    error: str = ""

class ZhipuAdapter:
    """智谱GLM适配器"""
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
    
    async def chat(self, model: str, messages: List[Dict], 
                   max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult:
        import urllib.request
        start = time.time()
        payload = json.dumps({
            "model": model,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature
        }).encode()
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        req = urllib.request.Request(self.base_url, data=payload, headers=headers)
        try:
            with urllib.request.urlopen(req, timeout=30) as resp:
                data = json.loads(resp.read().decode())
                answer = data["choices"][0]["message"]["content"]
                tokens = data.get("usage", {}).get("total_tokens", 0)
                latency = (time.time() - start) * 1000
                return APIInferenceResult(
                    answer=answer, model=model, provider="zhipu",
                    latency_ms=latency, tokens_used=tokens, success=True
                )
        except Exception as e:
            latency = (time.time() - start) * 1000
            return APIInferenceResult(
                answer="", model=model, provider="zhipu",
                latency_ms=latency, success=False, error=str(e)
            )
    
    def chat_sync(self, model: str, messages: List[Dict],
                  max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult:
        """同步版本"""
        import urllib.request
        start = time.time()
        payload = json.dumps({
            "model": model,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature
        }).encode()
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        req = urllib.request.Request(self.base_url, data=payload, headers=headers)
        try:
            with urllib.request.urlopen(req, timeout=30) as resp:
                data = json.loads(resp.read().decode())
                answer = data["choices"][0]["message"]["content"]
                tokens = data.get("usage", {}).get("total_tokens", 0)
                latency = (time.time() - start) * 1000
                logger.info(f"[Zhipu] {model} 回答成功, {latency:.0f}ms, {tokens}tokens")
                return APIInferenceResult(
                    answer=answer, model=model, provider="zhipu",
                    latency_ms=latency, tokens_used=tokens, success=True
                )
        except Exception as e:
            latency = (time.time() - start) * 1000
            logger.warning(f"[Zhipu] {model} 调用失败: {e}")
            return APIInferenceResult(
                answer="", model=model, provider="zhipu",
                latency_ms=latency, success=False, error=str(e)
            )

class NIMAdapter:
    """NIM(NVIDIA)适配器"""
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
    
    def chat_sync(self, model: str, messages: List[Dict],
                  max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult:
        import urllib.request
        start = time.time()
        payload = json.dumps({
            "model": model,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature
        }).encode()
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        req = urllib.request.Request(self.base_url, data=payload, headers=headers)
        try:
            with urllib.request.urlopen(req, timeout=60) as resp:
                data = json.loads(resp.read().decode())
                answer = data["choices"][0]["message"]["content"]
                tokens = data.get("usage", {}).get("total_tokens", 0)
                latency = (time.time() - start) * 1000
                logger.info(f"[NIM] {model} 回答成功, {latency:.0f}ms, {tokens}tokens")
                return APIInferenceResult(
                    answer=answer, model=model, provider="nim",
                    latency_ms=latency, tokens_used=tokens, success=True
                )
        except Exception as e:
            latency = (time.time() - start) * 1000
            logger.warning(f"[NIM] {model} 调用失败: {e}")
            return APIInferenceResult(
                answer="", model=model, provider="nim",
                latency_ms=latency, success=False, error=str(e)
            )

class APIInferenceManager:
    """API推理管理器 — 统一管理多个API模型"""
    
    def __init__(self):
        self.adapters: Dict[str, object] = {}  # provider -> adapter
        self.models: List[APIModelConfig] = []
        self.stats = {"total_calls": 0, "success": 0, "failed": 0, "total_ms": 0}
    
    def add_provider(self, provider: str, api_key: str):
        """添加API提供商"""
        if provider == "zhipu":
            self.adapters["zhipu"] = ZhipuAdapter(api_key)
        elif provider == "nim":
            self.adapters["nim"] = NIMAdapter(api_key)
        else:
            raise ValueError(f"未知提供商: {provider}")
        logger.info(f"[APIManager] 添加提供商: {provider}")
    
    def register_model(self, config: APIModelConfig):
        """注册可用模型"""
        self.models.append(config)
        # 按优先级排序
        self.models.sort(key=lambda m: m.priority, reverse=True)
        logger.info(f"[APIManager] 注册模型: {config.provider}/{config.model_name}, 优先级={config.priority}")
    
    def get_available_models(self, free_only: bool = False) -> List[APIModelConfig]:
        """获取可用模型列表"""
        models = self.models
        if free_only:
            models = [m for m in models if m.free]
        return models
    
    def infer(self, question: str, context: str = "", 
              model_name: str = None, max_tokens: int = 512) -> APIInferenceResult:
        """单模型推理"""
        messages = []
        if context:
            messages.append({"role": "system", "content": context})
        messages.append({"role": "user", "content": question})
        
        # 指定模型
        if model_name:
            for m in self.models:
                if m.model_name == model_name:
                    adapter = self.adapters.get(m.provider)
                    if adapter:
                        result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature)
                        self._update_stats(result)
                        return result
        
        # 自动选择最高优先级可用模型
        for m in self.models:
            adapter = self.adapters.get(m.provider)
            if adapter:
                result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature)
                self._update_stats(result)
                if result.success:
                    return result
                continue
        
        return APIInferenceResult(
            answer="", model="none", provider="none",
            latency_ms=0, success=False, error="无可用模型"
        )
    
    def infer_multi(self, question: str, context: str = "",
                    max_models: int = 3, max_tokens: int = 512,
                    timeout_per_model: float = 15.0) -> List[APIInferenceResult]:
        """多模型并行推理(MOA前置步骤),使用线程池并行调用"""
        messages = []
        if context:
            messages.append({"role": "system", "content": context})
        messages.append({"role": "user", "content": question})
        
        # 收集待调用的模型
        tasks = []
        called = 0
        for m in self.models:
            if called >= max_models:
                break
            adapter = self.adapters.get(m.provider)
            if adapter:
                tasks.append((adapter, m, messages, max_tokens))
                called += 1
        
        if not tasks:
            return []
        
        # 并行调用
        from concurrent.futures import ThreadPoolExecutor, as_completed
        results = [None] * len(tasks)
        
        def _call(idx, adapter, model, msgs, mtok):
            try:
                return idx, adapter.chat_sync(model.model_name, msgs, mtok, model.temperature)
            except Exception as e:
                return idx, APIInferenceResult(
                    success=False, answer="", model=model.model_name,
                    provider=model.provider, latency_ms=0, error=str(e)
                )
        
        with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
            futures = {
                executor.submit(_call, i, a, m, msgs, mtok): i
                for i, (a, m, msgs, mtok) in enumerate(tasks)
            }
            try:
                for future in as_completed(futures, timeout=30.0):
                    try:
                        idx, result = future.result(timeout=5.0)
                        self._update_stats(result)
                        results[idx] = result
                    except Exception as e:
                        idx = futures[future]
                        adapter, model = tasks[idx][0], tasks[idx][1]
                        result = APIInferenceResult(
                            success=False, answer="", model=model.model_name,
                            provider=model.provider, latency_ms=0, error=str(e)
                        )
                        self._update_stats(result)
                        results[idx] = result
            except TimeoutError:
                # 部分future未完成,收集已完成的结果
                for future, idx in futures.items():
                    if results[idx] is None and future.done():
                        try:
                            _, result = future.result(timeout=0)
                            self._update_stats(result)
                            results[idx] = result
                        except:
                            pass
        
        return [r for r in results if r is not None]
    
    def _update_stats(self, result: APIInferenceResult):
        self.stats["total_calls"] += 1
        if result.success:
            self.stats["success"] += 1
        else:
            self.stats["failed"] += 1
        self.stats["total_ms"] += result.latency_ms
    
    @classmethod
    def from_env(cls) -> "APIInferenceManager":
        """从环境变量创建"""
        mgr = cls()
        
        # 智谱GLM
        zhipu_key = os.environ.get("GLM_API_KEY", "")
        if zhipu_key:
            mgr.add_provider("zhipu", zhipu_key)
            mgr.register_model(APIModelConfig(
                provider="zhipu", model_name="glm-4-flash",
                api_key=zhipu_key, base_url="",
                priority=10, free=True
            ))
            mgr.register_model(APIModelConfig(
                provider="zhipu", model_name="glm-4-plus",
                api_key=zhipu_key, base_url="",
                priority=8, free=False
            ))
        
        # NIM
        nim_key = os.environ.get("NIM_API_KEY", "")
        if nim_key:
            mgr.add_provider("nim", nim_key)
            mgr.register_model(APIModelConfig(
                provider="nim", model_name="meta/llama-3.1-8b-instruct",
                api_key=nim_key, base_url="",
                priority=6, free=True
            ))
            mgr.register_model(APIModelConfig(
                provider="nim", model_name="microsoft/phi-3-mini-128k-instruct",
                api_key=nim_key, base_url="",
                priority=5, free=True
            ))
        
        logger.info(f"[APIManager] 从环境初始化: {len(mgr.adapters)}提供商, {len(mgr.models)}模型")
        return mgr