swarm-backend / core /api_inference.py
lk080424's picture
Upload core/api_inference.py with huggingface_hub
6dfda2e verified
"""
API推理适配器 — 对接智谱GLM和NIM的云端大模型
支持:智谱GLM-4-flash(免费)/GLM-4-plus, NIM多模型
"""
import os
import json
import time
import logging
import hashlib
from typing import List, Dict, Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class APIModelConfig:
"""API模型配置"""
provider: str # zhipu / nim
model_name: str # glm-4-flash / nim-text etc
api_key: str
base_url: str
max_tokens: int = 512
temperature: float = 0.7
timeout: int = 30
priority: int = 0 # 越高越优先
free: bool = True # 是否免费
@dataclass
class APIInferenceResult:
"""API推理结果"""
answer: str
model: str
provider: str
latency_ms: float
tokens_used: int = 0
success: bool = True
error: str = ""
class ZhipuAdapter:
"""智谱GLM适配器"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
async def chat(self, model: str, messages: List[Dict],
max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult:
import urllib.request
start = time.time()
payload = json.dumps({
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}).encode()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
req = urllib.request.Request(self.base_url, data=payload, headers=headers)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode())
answer = data["choices"][0]["message"]["content"]
tokens = data.get("usage", {}).get("total_tokens", 0)
latency = (time.time() - start) * 1000
return APIInferenceResult(
answer=answer, model=model, provider="zhipu",
latency_ms=latency, tokens_used=tokens, success=True
)
except Exception as e:
latency = (time.time() - start) * 1000
return APIInferenceResult(
answer="", model=model, provider="zhipu",
latency_ms=latency, success=False, error=str(e)
)
def chat_sync(self, model: str, messages: List[Dict],
max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult:
"""同步版本"""
import urllib.request
start = time.time()
payload = json.dumps({
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}).encode()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
req = urllib.request.Request(self.base_url, data=payload, headers=headers)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode())
answer = data["choices"][0]["message"]["content"]
tokens = data.get("usage", {}).get("total_tokens", 0)
latency = (time.time() - start) * 1000
logger.info(f"[Zhipu] {model} 回答成功, {latency:.0f}ms, {tokens}tokens")
return APIInferenceResult(
answer=answer, model=model, provider="zhipu",
latency_ms=latency, tokens_used=tokens, success=True
)
except Exception as e:
latency = (time.time() - start) * 1000
logger.warning(f"[Zhipu] {model} 调用失败: {e}")
return APIInferenceResult(
answer="", model=model, provider="zhipu",
latency_ms=latency, success=False, error=str(e)
)
class NIMAdapter:
"""NIM(NVIDIA)适配器"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
def chat_sync(self, model: str, messages: List[Dict],
max_tokens: int = 512, temperature: float = 0.7) -> APIInferenceResult:
import urllib.request
start = time.time()
payload = json.dumps({
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}).encode()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
req = urllib.request.Request(self.base_url, data=payload, headers=headers)
try:
with urllib.request.urlopen(req, timeout=60) as resp:
data = json.loads(resp.read().decode())
answer = data["choices"][0]["message"]["content"]
tokens = data.get("usage", {}).get("total_tokens", 0)
latency = (time.time() - start) * 1000
logger.info(f"[NIM] {model} 回答成功, {latency:.0f}ms, {tokens}tokens")
return APIInferenceResult(
answer=answer, model=model, provider="nim",
latency_ms=latency, tokens_used=tokens, success=True
)
except Exception as e:
latency = (time.time() - start) * 1000
logger.warning(f"[NIM] {model} 调用失败: {e}")
return APIInferenceResult(
answer="", model=model, provider="nim",
latency_ms=latency, success=False, error=str(e)
)
class APIInferenceManager:
"""API推理管理器 — 统一管理多个API模型"""
def __init__(self):
self.adapters: Dict[str, object] = {} # provider -> adapter
self.models: List[APIModelConfig] = []
self.stats = {"total_calls": 0, "success": 0, "failed": 0, "total_ms": 0}
def add_provider(self, provider: str, api_key: str):
"""添加API提供商"""
if provider == "zhipu":
self.adapters["zhipu"] = ZhipuAdapter(api_key)
elif provider == "nim":
self.adapters["nim"] = NIMAdapter(api_key)
else:
raise ValueError(f"未知提供商: {provider}")
logger.info(f"[APIManager] 添加提供商: {provider}")
def register_model(self, config: APIModelConfig):
"""注册可用模型"""
self.models.append(config)
# 按优先级排序
self.models.sort(key=lambda m: m.priority, reverse=True)
logger.info(f"[APIManager] 注册模型: {config.provider}/{config.model_name}, 优先级={config.priority}")
def get_available_models(self, free_only: bool = False) -> List[APIModelConfig]:
"""获取可用模型列表"""
models = self.models
if free_only:
models = [m for m in models if m.free]
return models
def infer(self, question: str, context: str = "",
model_name: str = None, max_tokens: int = 512) -> APIInferenceResult:
"""单模型推理"""
messages = []
if context:
messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": question})
# 指定模型
if model_name:
for m in self.models:
if m.model_name == model_name:
adapter = self.adapters.get(m.provider)
if adapter:
result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature)
self._update_stats(result)
return result
# 自动选择最高优先级可用模型
for m in self.models:
adapter = self.adapters.get(m.provider)
if adapter:
result = adapter.chat_sync(m.model_name, messages, max_tokens, m.temperature)
self._update_stats(result)
if result.success:
return result
continue
return APIInferenceResult(
answer="", model="none", provider="none",
latency_ms=0, success=False, error="无可用模型"
)
def infer_multi(self, question: str, context: str = "",
max_models: int = 3, max_tokens: int = 512,
timeout_per_model: float = 15.0) -> List[APIInferenceResult]:
"""多模型并行推理(MOA前置步骤),使用线程池并行调用"""
messages = []
if context:
messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": question})
# 收集待调用的模型
tasks = []
called = 0
for m in self.models:
if called >= max_models:
break
adapter = self.adapters.get(m.provider)
if adapter:
tasks.append((adapter, m, messages, max_tokens))
called += 1
if not tasks:
return []
# 并行调用
from concurrent.futures import ThreadPoolExecutor, as_completed
results = [None] * len(tasks)
def _call(idx, adapter, model, msgs, mtok):
try:
return idx, adapter.chat_sync(model.model_name, msgs, mtok, model.temperature)
except Exception as e:
return idx, APIInferenceResult(
success=False, answer="", model=model.model_name,
provider=model.provider, latency_ms=0, error=str(e)
)
with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
futures = {
executor.submit(_call, i, a, m, msgs, mtok): i
for i, (a, m, msgs, mtok) in enumerate(tasks)
}
try:
for future in as_completed(futures, timeout=30.0):
try:
idx, result = future.result(timeout=5.0)
self._update_stats(result)
results[idx] = result
except Exception as e:
idx = futures[future]
adapter, model = tasks[idx][0], tasks[idx][1]
result = APIInferenceResult(
success=False, answer="", model=model.model_name,
provider=model.provider, latency_ms=0, error=str(e)
)
self._update_stats(result)
results[idx] = result
except TimeoutError:
# 部分future未完成,收集已完成的结果
for future, idx in futures.items():
if results[idx] is None and future.done():
try:
_, result = future.result(timeout=0)
self._update_stats(result)
results[idx] = result
except:
pass
return [r for r in results if r is not None]
def _update_stats(self, result: APIInferenceResult):
self.stats["total_calls"] += 1
if result.success:
self.stats["success"] += 1
else:
self.stats["failed"] += 1
self.stats["total_ms"] += result.latency_ms
@classmethod
def from_env(cls) -> "APIInferenceManager":
"""从环境变量创建"""
mgr = cls()
# 智谱GLM
zhipu_key = os.environ.get("GLM_API_KEY", "")
if zhipu_key:
mgr.add_provider("zhipu", zhipu_key)
mgr.register_model(APIModelConfig(
provider="zhipu", model_name="glm-4-flash",
api_key=zhipu_key, base_url="",
priority=10, free=True
))
mgr.register_model(APIModelConfig(
provider="zhipu", model_name="glm-4-plus",
api_key=zhipu_key, base_url="",
priority=8, free=False
))
# NIM
nim_key = os.environ.get("NIM_API_KEY", "")
if nim_key:
mgr.add_provider("nim", nim_key)
mgr.register_model(APIModelConfig(
provider="nim", model_name="meta/llama-3.1-8b-instruct",
api_key=nim_key, base_url="",
priority=6, free=True
))
mgr.register_model(APIModelConfig(
provider="nim", model_name="microsoft/phi-3-mini-128k-instruct",
api_key=nim_key, base_url="",
priority=5, free=True
))
logger.info(f"[APIManager] 从环境初始化: {len(mgr.adapters)}提供商, {len(mgr.models)}模型")
return mgr