Spaces:
Running
Running
| """ | |
| API推理适配器 — 对接智谱GLM和NIM的云端大模型 | |
| 支持:智谱GLM-4-flash(免费)/GLM-4-plus, NIM多模型 | |
| """ | |
| import os | |
| import json | |
| import time | |
| import logging | |
| import hashlib | |
| from typing import List, Dict, Optional | |
| from dataclasses import dataclass, field | |
| logger = logging.getLogger(__name__) | |
| class APIModelConfig: | |
| """API模型配置""" | |
| provider: str # zhipu / nim | |
| model_name: str # glm-4-flash / nim-text etc | |
| api_key: str | |
| base_url: str | |
| max_tokens: int = 512 | |
| temperature: float = 0.7 | |
| timeout: int = 30 | |
| priority: int = 0 # 越高越优先 | |
| free: bool = True # 是否免费 | |
| class APIInferenceResult: | |
| """API推理结果""" | |
| answer: str | |
| model: str | |
| provider: str | |
| latency_ms: float | |
| tokens_used: int = 0 | |
| success: bool = True | |
| error: str = "" | |
| class ZhipuAdapter: | |
| """智谱GLM适配器""" | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| self.base_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" | |
| async def chat(self, model: str, messages: List[Dict], | |
| max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult: | |
| import urllib.request | |
| start = time.time() | |
| payload = json.dumps({ | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| }).encode() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| req = urllib.request.Request(self.base_url, data=payload, headers=headers) | |
| try: | |
| with urllib.request.urlopen(req, timeout=30) as resp: | |
| data = json.loads(resp.read().decode()) | |
| answer = data["choices"][0]["message"]["content"] | |
| tokens = data.get("usage", {}).get("total_tokens", 0) | |
| latency = (time.time() - start) * 1000 | |
| return APIInferenceResult( | |
| answer=answer, model=model, provider="zhipu", | |
| latency_ms=latency, tokens_used=tokens, success=True | |
| ) | |
| except Exception as e: | |
| latency = (time.time() - start) * 1000 | |
| return APIInferenceResult( | |
| answer="", model=model, provider="zhipu", | |
| latency_ms=latency, success=False, error=str(e) | |
| ) | |
| def chat_sync(self, model: str, messages: List[Dict], | |
| max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult: | |
| """同步版本""" | |
| import urllib.request | |
| start = time.time() | |
| payload = json.dumps({ | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| }).encode() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| req = urllib.request.Request(self.base_url, data=payload, headers=headers) | |
| try: | |
| with urllib.request.urlopen(req, timeout=30) as resp: | |
| data = json.loads(resp.read().decode()) | |
| answer = data["choices"][0]["message"]["content"] | |
| tokens = data.get("usage", {}).get("total_tokens", 0) | |
| latency = (time.time() - start) * 1000 | |
| logger.info(f"[Zhipu] {model} 回答成功, {latency:.0f}ms, {tokens}tokens") | |
| return APIInferenceResult( | |
| answer=answer, model=model, provider="zhipu", | |
| latency_ms=latency, tokens_used=tokens, success=True | |
| ) | |
| except Exception as e: | |
| latency = (time.time() - start) * 1000 | |
| logger.warning(f"[Zhipu] {model} 调用失败: {e}") | |
| return APIInferenceResult( | |
| answer="", model=model, provider="zhipu", | |
| latency_ms=latency, success=False, error=str(e) | |
| ) | |
| class NIMAdapter: | |
| """NIM(NVIDIA)适配器""" | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| def chat_sync(self, model: str, messages: List[Dict], | |
| max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult: | |
| import urllib.request | |
| start = time.time() | |
| payload = json.dumps({ | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| }).encode() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| req = urllib.request.Request(self.base_url, data=payload, headers=headers) | |
| try: | |
| with urllib.request.urlopen(req, timeout=60) as resp: | |
| data = json.loads(resp.read().decode()) | |
| answer = data["choices"][0]["message"]["content"] | |
| tokens = data.get("usage", {}).get("total_tokens", 0) | |
| latency = (time.time() - start) * 1000 | |
| logger.info(f"[NIM] {model} 回答成功, {latency:.0f}ms, {tokens}tokens") | |
| return APIInferenceResult( | |
| answer=answer, model=model, provider="nim", | |
| latency_ms=latency, tokens_used=tokens, success=True | |
| ) | |
| except Exception as e: | |
| latency = (time.time() - start) * 1000 | |
| logger.warning(f"[NIM] {model} 调用失败: {e}") | |
| return APIInferenceResult( | |
| answer="", model=model, provider="nim", | |
| latency_ms=latency, success=False, error=str(e) | |
| ) | |
| class APIInferenceManager: | |
| """API推理管理器 — 统一管理多个API模型""" | |
| def __init__(self): | |
| self.adapters: Dict[str, object] = {} # provider -> adapter | |
| self.models: List[APIModelConfig] = [] | |
| self.stats = {"total_calls": 0, "success": 0, "failed": 0, "total_ms": 0} | |
| def add_provider(self, provider: str, api_key: str): | |
| """添加API提供商""" | |
| if provider == "zhipu": | |
| self.adapters["zhipu"] = ZhipuAdapter(api_key) | |
| elif provider == "nim": | |
| self.adapters["nim"] = NIMAdapter(api_key) | |
| else: | |
| raise ValueError(f"未知提供商: {provider}") | |
| logger.info(f"[APIManager] 添加提供商: {provider}") | |
| def register_model(self, config: APIModelConfig): | |
| """注册可用模型""" | |
| self.models.append(config) | |
| # 按优先级排序 | |
| self.models.sort(key=lambda m: m.priority, reverse=True) | |
| logger.info(f"[APIManager] 注册模型: {config.provider}/{config.model_name}, 优先级={config.priority}") | |
| def get_available_models(self, free_only: bool = False) -> List[APIModelConfig]: | |
| """获取可用模型列表""" | |
| models = self.models | |
| if free_only: | |
| models = [m for m in models if m.free] | |
| return models | |
| def infer(self, question: str, context: str = "", | |
| model_name: str = None, max_tokens: int = 512) -> APIInferenceResult: | |
| """单模型推理""" | |
| messages = [] | |
| if context: | |
| messages.append({"role": "system", "content": context}) | |
| messages.append({"role": "user", "content": question}) | |
| # 指定模型 | |
| if model_name: | |
| for m in self.models: | |
| if m.model_name == model_name: | |
| adapter = self.adapters.get(m.provider) | |
| if adapter: | |
| result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature) | |
| self._update_stats(result) | |
| return result | |
| # 自动选择最高优先级可用模型 | |
| for m in self.models: | |
| adapter = self.adapters.get(m.provider) | |
| if adapter: | |
| result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature) | |
| self._update_stats(result) | |
| if result.success: | |
| return result | |
| continue | |
| return APIInferenceResult( | |
| answer="", model="none", provider="none", | |
| latency_ms=0, success=False, error="无可用模型" | |
| ) | |
| def infer_multi(self, question: str, context: str = "", | |
| max_models: int = 3, max_tokens: int = 512, | |
| timeout_per_model: float = 15.0) -> List[APIInferenceResult]: | |
| """多模型并行推理(MOA前置步骤),使用线程池并行调用""" | |
| messages = [] | |
| if context: | |
| messages.append({"role": "system", "content": context}) | |
| messages.append({"role": "user", "content": question}) | |
| # 收集待调用的模型 | |
| tasks = [] | |
| called = 0 | |
| for m in self.models: | |
| if called >= max_models: | |
| break | |
| adapter = self.adapters.get(m.provider) | |
| if adapter: | |
| tasks.append((adapter, m, messages, max_tokens)) | |
| called += 1 | |
| if not tasks: | |
| return [] | |
| # 并行调用 | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| results = [None] * len(tasks) | |
| def _call(idx, adapter, model, msgs, mtok): | |
| try: | |
| return idx, adapter.chat_sync(model.model_name, msgs, mtok, model.temperature) | |
| except Exception as e: | |
| return idx, APIInferenceResult( | |
| success=False, answer="", model=model.model_name, | |
| provider=model.provider, latency_ms=0, error=str(e) | |
| ) | |
| with ThreadPoolExecutor(max_workers=len(tasks)) as executor: | |
| futures = { | |
| executor.submit(_call, i, a, m, msgs, mtok): i | |
| for i, (a, m, msgs, mtok) in enumerate(tasks) | |
| } | |
| try: | |
| for future in as_completed(futures, timeout=30.0): | |
| try: | |
| idx, result = future.result(timeout=5.0) | |
| self._update_stats(result) | |
| results[idx] = result | |
| except Exception as e: | |
| idx = futures[future] | |
| adapter, model = tasks[idx][0], tasks[idx][1] | |
| result = APIInferenceResult( | |
| success=False, answer="", model=model.model_name, | |
| provider=model.provider, latency_ms=0, error=str(e) | |
| ) | |
| self._update_stats(result) | |
| results[idx] = result | |
| except TimeoutError: | |
| # 部分future未完成,收集已完成的结果 | |
| for future, idx in futures.items(): | |
| if results[idx] is None and future.done(): | |
| try: | |
| _, result = future.result(timeout=0) | |
| self._update_stats(result) | |
| results[idx] = result | |
| except: | |
| pass | |
| return [r for r in results if r is not None] | |
| def _update_stats(self, result: APIInferenceResult): | |
| self.stats["total_calls"] += 1 | |
| if result.success: | |
| self.stats["success"] += 1 | |
| else: | |
| self.stats["failed"] += 1 | |
| self.stats["total_ms"] += result.latency_ms | |
| def from_env(cls) -> "APIInferenceManager": | |
| """从环境变量创建""" | |
| mgr = cls() | |
| # 智谱GLM | |
| zhipu_key = os.environ.get("GLM_API_KEY", "") | |
| if zhipu_key: | |
| mgr.add_provider("zhipu", zhipu_key) | |
| mgr.register_model(APIModelConfig( | |
| provider="zhipu", model_name="glm-4-flash", | |
| api_key=zhipu_key, base_url="", | |
| priority=10, free=True | |
| )) | |
| mgr.register_model(APIModelConfig( | |
| provider="zhipu", model_name="glm-4-plus", | |
| api_key=zhipu_key, base_url="", | |
| priority=8, free=False | |
| )) | |
| # NIM | |
| nim_key = os.environ.get("NIM_API_KEY", "") | |
| if nim_key: | |
| mgr.add_provider("nim", nim_key) | |
| mgr.register_model(APIModelConfig( | |
| provider="nim", model_name="meta/llama-3.1-8b-instruct", | |
| api_key=nim_key, base_url="", | |
| priority=6, free=True | |
| )) | |
| mgr.register_model(APIModelConfig( | |
| provider="nim", model_name="microsoft/phi-3-mini-128k-instruct", | |
| api_key=nim_key, base_url="", | |
| priority=5, free=True | |
| )) | |
| logger.info(f"[APIManager] 从环境初始化: {len(mgr.adapters)}提供商, {len(mgr.models)}模型") | |
| return mgr | |