""" API推理适配器 — 对接智谱GLM和NIM的云端大模型 支持:智谱GLM-4-flash(免费)/GLM-4-plus, NIM多模型 """ import os import json import time import logging import hashlib from typing import List, Dict, Optional from dataclasses import dataclass, field logger = logging.getLogger(__name__) @dataclass class APIModelConfig: """API模型配置""" provider: str # zhipu / nim model_name: str # glm-4-flash / nim-text etc api_key: str base_url: str max_tokens: int = 512 temperature: float = 0.7 timeout: int = 30 priority: int = 0 # 越高越优先 free: bool = True # 是否免费 @dataclass class APIInferenceResult: """API推理结果""" answer: str model: str provider: str latency_ms: float tokens_used: int = 0 success: bool = True error: str = "" class ZhipuAdapter: """智谱GLM适配器""" def __init__(self, api_key: str): self.api_key = api_key self.base_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" async def chat(self, model: str, messages: List[Dict], max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult: import urllib.request start = time.time() payload = json.dumps({ "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature }).encode() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } req = urllib.request.Request(self.base_url, data=payload, headers=headers) try: with urllib.request.urlopen(req, timeout=30) as resp: data = json.loads(resp.read().decode()) answer = data["choices"][0]["message"]["content"] tokens = data.get("usage", {}).get("total_tokens", 0) latency = (time.time() - start) * 1000 return APIInferenceResult( answer=answer, model=model, provider="zhipu", latency_ms=latency, tokens_used=tokens, success=True ) except Exception as e: latency = (time.time() - start) * 1000 return APIInferenceResult( answer="", model=model, provider="zhipu", latency_ms=latency, success=False, error=str(e) ) def chat_sync(self, model: str, messages: List[Dict], max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult: """同步版本""" import urllib.request start = time.time() payload = json.dumps({ "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature }).encode() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } req = urllib.request.Request(self.base_url, data=payload, headers=headers) try: with urllib.request.urlopen(req, timeout=30) as resp: data = json.loads(resp.read().decode()) answer = data["choices"][0]["message"]["content"] tokens = data.get("usage", {}).get("total_tokens", 0) latency = (time.time() - start) * 1000 logger.info(f"[Zhipu] {model} 回答成功, {latency:.0f}ms, {tokens}tokens") return APIInferenceResult( answer=answer, model=model, provider="zhipu", latency_ms=latency, tokens_used=tokens, success=True ) except Exception as e: latency = (time.time() - start) * 1000 logger.warning(f"[Zhipu] {model} 调用失败: {e}") return APIInferenceResult( answer="", model=model, provider="zhipu", latency_ms=latency, success=False, error=str(e) ) class NIMAdapter: """NIM(NVIDIA)适配器""" def __init__(self, api_key: str): self.api_key = api_key self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions" def chat_sync(self, model: str, messages: List[Dict], max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult: import urllib.request start = time.time() payload = json.dumps({ "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature }).encode() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } req = urllib.request.Request(self.base_url, data=payload, headers=headers) try: with urllib.request.urlopen(req, timeout=60) as resp: data = json.loads(resp.read().decode()) answer = data["choices"][0]["message"]["content"] tokens = data.get("usage", {}).get("total_tokens", 0) latency = (time.time() - start) * 1000 logger.info(f"[NIM] {model} 回答成功, {latency:.0f}ms, {tokens}tokens") return APIInferenceResult( answer=answer, model=model, provider="nim", latency_ms=latency, tokens_used=tokens, success=True ) except Exception as e: latency = (time.time() - start) * 1000 logger.warning(f"[NIM] {model} 调用失败: {e}") return APIInferenceResult( answer="", model=model, provider="nim", latency_ms=latency, success=False, error=str(e) ) class APIInferenceManager: """API推理管理器 — 统一管理多个API模型""" def __init__(self): self.adapters: Dict[str, object] = {} # provider -> adapter self.models: List[APIModelConfig] = [] self.stats = {"total_calls": 0, "success": 0, "failed": 0, "total_ms": 0} def add_provider(self, provider: str, api_key: str): """添加API提供商""" if provider == "zhipu": self.adapters["zhipu"] = ZhipuAdapter(api_key) elif provider == "nim": self.adapters["nim"] = NIMAdapter(api_key) else: raise ValueError(f"未知提供商: {provider}") logger.info(f"[APIManager] 添加提供商: {provider}") def register_model(self, config: APIModelConfig): """注册可用模型""" self.models.append(config) # 按优先级排序 self.models.sort(key=lambda m: m.priority, reverse=True) logger.info(f"[APIManager] 注册模型: {config.provider}/{config.model_name}, 优先级={config.priority}") def get_available_models(self, free_only: bool = False) -> List[APIModelConfig]: """获取可用模型列表""" models = self.models if free_only: models = [m for m in models if m.free] return models def infer(self, question: str, context: str = "", model_name: str = None, max_tokens: int = 512) -> APIInferenceResult: """单模型推理""" messages = [] if context: messages.append({"role": "system", "content": context}) messages.append({"role": "user", "content": question}) # 指定模型 if model_name: for m in self.models: if m.model_name == model_name: adapter = self.adapters.get(m.provider) if adapter: result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature) self._update_stats(result) return result # 自动选择最高优先级可用模型 for m in self.models: adapter = self.adapters.get(m.provider) if adapter: result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature) self._update_stats(result) if result.success: return result continue return APIInferenceResult( answer="", model="none", provider="none", latency_ms=0, success=False, error="无可用模型" ) def infer_multi(self, question: str, context: str = "", max_models: int = 3, max_tokens: int = 512, timeout_per_model: float = 15.0) -> List[APIInferenceResult]: """多模型并行推理(MOA前置步骤),使用线程池并行调用""" messages = [] if context: messages.append({"role": "system", "content": context}) messages.append({"role": "user", "content": question}) # 收集待调用的模型 tasks = [] called = 0 for m in self.models: if called >= max_models: break adapter = self.adapters.get(m.provider) if adapter: tasks.append((adapter, m, messages, max_tokens)) called += 1 if not tasks: return [] # 并行调用 from concurrent.futures import ThreadPoolExecutor, as_completed results = [None] * len(tasks) def _call(idx, adapter, model, msgs, mtok): try: return idx, adapter.chat_sync(model.model_name, msgs, mtok, model.temperature) except Exception as e: return idx, APIInferenceResult( success=False, answer="", model=model.model_name, provider=model.provider, latency_ms=0, error=str(e) ) with ThreadPoolExecutor(max_workers=len(tasks)) as executor: futures = { executor.submit(_call, i, a, m, msgs, mtok): i for i, (a, m, msgs, mtok) in enumerate(tasks) } try: for future in as_completed(futures, timeout=30.0): try: idx, result = future.result(timeout=5.0) self._update_stats(result) results[idx] = result except Exception as e: idx = futures[future] adapter, model = tasks[idx][0], tasks[idx][1] result = APIInferenceResult( success=False, answer="", model=model.model_name, provider=model.provider, latency_ms=0, error=str(e) ) self._update_stats(result) results[idx] = result except TimeoutError: # 部分future未完成,收集已完成的结果 for future, idx in futures.items(): if results[idx] is None and future.done(): try: _, result = future.result(timeout=0) self._update_stats(result) results[idx] = result except: pass return [r for r in results if r is not None] def _update_stats(self, result: APIInferenceResult): self.stats["total_calls"] += 1 if result.success: self.stats["success"] += 1 else: self.stats["failed"] += 1 self.stats["total_ms"] += result.latency_ms @classmethod def from_env(cls) -> "APIInferenceManager": """从环境变量创建""" mgr = cls() # 智谱GLM zhipu_key = os.environ.get("GLM_API_KEY", "") if zhipu_key: mgr.add_provider("zhipu", zhipu_key) mgr.register_model(APIModelConfig( provider="zhipu", model_name="glm-4-flash", api_key=zhipu_key, base_url="", priority=10, free=True )) mgr.register_model(APIModelConfig( provider="zhipu", model_name="glm-4-plus", api_key=zhipu_key, base_url="", priority=8, free=False )) # NIM nim_key = os.environ.get("NIM_API_KEY", "") if nim_key: mgr.add_provider("nim", nim_key) mgr.register_model(APIModelConfig( provider="nim", model_name="meta/llama-3.1-8b-instruct", api_key=nim_key, base_url="", priority=6, free=True )) mgr.register_model(APIModelConfig( provider="nim", model_name="microsoft/phi-3-mini-128k-instruct", api_key=nim_key, base_url="", priority=5, free=True )) logger.info(f"[APIManager] 从环境初始化: {len(mgr.adapters)}提供商, {len(mgr.models)}模型") return mgr