swarm-backend / core /model_adapter.py
lk080424's picture
Upload folder using huggingface_hub
17fba62 verified
#!/usr/bin/env python3
"""
虫群模型适配器 — 统一不同API的调用接口
每个提供商一个适配器,暴露统一的 call(prompt, tools) 方法
"""
import os
import time
import json
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
@dataclass
class CallResult:
"""模型调用结果"""
text: str # 回复文本
tool_calls: List[Dict] # 工具调用列表
model: str # 实际使用的模型
provider: str # 提供商
latency_ms: float # 延迟(ms)
tokens_in: int = 0 # 输入token数
tokens_out: int = 0 # 输出token数
thinking: str = "" # 思维链内容
error: str = "" # 错误信息
class ModelAdapter:
"""模型调用适配器基类"""
def __init__(self, provider: str, api_key: str = "", base_url: str = ""):
self.provider = provider
self.api_key = api_key
self.base_url = base_url
async def call(self, messages: List[Dict], model: str = "",
tools: List[Dict] = None, max_tokens: int = 500,
temperature: float = 0.7) -> CallResult:
raise NotImplementedError
class ZhipuAdapter(ModelAdapter):
"""智谱AI适配器"""
def __init__(self, api_key: str = ""):
key = api_key or os.environ.get("GLM_API_KEY", "")
super().__init__("zhipu", key, "https://open.bigmodel.cn/api/paas/v4")
async def call(self, messages: List[Dict], model: str = "glm-4.7-flash",
tools: List[Dict] = None, max_tokens: int = 500,
temperature: float = 0.7) -> CallResult:
import httpx
t0 = time.time()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
body = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
if tools:
body["tools"] = tools
body["tool_choice"] = "auto"
try:
async with httpx.AsyncClient(timeout=60) as client:
r = await client.post(
f"{self.base_url}/chat/completions",
headers=headers, json=body
)
ms = round((time.time() - t0) * 1000)
if r.status_code != 200:
return CallResult(text="", tool_calls=[], model=model,
provider=self.provider, latency_ms=ms,
error=f"HTTP {r.status_code}: {r.text[:200]}")
data = r.json()
choice = data["choices"][0]
msg = choice["message"]
# 提取工具调用
tool_calls = []
if "tool_calls" in msg and msg["tool_calls"]:
for tc in msg["tool_calls"]:
tool_calls.append({
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
})
# 提取思维链
thinking = ""
if "reasoning_content" in msg:
thinking = msg["reasoning_content"]
usage = data.get("usage", {})
return CallResult(
text=msg.get("content", "") or "",
tool_calls=tool_calls,
model=model,
provider=self.provider,
latency_ms=ms,
tokens_in=usage.get("prompt_tokens", 0),
tokens_out=usage.get("completion_tokens", 0),
thinking=thinking
)
except Exception as e:
ms = round((time.time() - t0) * 1000)
return CallResult(text="", tool_calls=[], model=model,
provider=self.provider, latency_ms=ms,
error=str(e))
class NvidiaNimAdapter(ModelAdapter):
"""NVIDIA NIM适配器"""
def __init__(self, api_key: str = ""):
key = api_key or os.environ.get("NVIDIA_KEY", "")
super().__init__("nvidia_nim", key, "https://integrate.api.nvidia.com/v1")
async def call(self, messages: List[Dict], model: str = "meta/llama-4-maverick-17b-128e-instruct",
tools: List[Dict] = None, max_tokens: int = 500,
temperature: float = 0.7) -> CallResult:
import httpx
t0 = time.time()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
body = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
if tools:
body["tools"] = tools
body["tool_choice"] = "auto"
try:
async with httpx.AsyncClient(timeout=35) as client:
r = await client.post(
f"{self.base_url}/chat/completions",
headers=headers, json=body
)
ms = round((time.time() - t0) * 1000)
if r.status_code != 200:
return CallResult(text="", tool_calls=[], model=model,
provider=self.provider, latency_ms=ms,
error=f"HTTP {r.status_code}: {r.text[:200]}")
data = r.json()
choice = data["choices"][0]
msg = choice["message"]
tool_calls = []
if "tool_calls" in msg and msg["tool_calls"]:
for tc in msg["tool_calls"]:
tool_calls.append({
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
})
thinking = ""
if "reasoning_content" in msg:
thinking = msg["reasoning_content"]
usage = data.get("usage", {})
return CallResult(
text=msg.get("content", "") or "",
tool_calls=tool_calls,
model=model,
provider=self.provider,
latency_ms=ms,
tokens_in=usage.get("prompt_tokens", 0),
tokens_out=usage.get("completion_tokens", 0),
thinking=thinking
)
except Exception as e:
ms = round((time.time() - t0) * 1000)
return CallResult(text="", tool_calls=[], model=model,
provider=self.provider, latency_ms=ms,
error=str(e))
class SiliconFlowAdapter(ModelAdapter):
"""硅基流动适配器"""
def __init__(self, api_key: str = ""):
key = api_key or os.environ.get("SILICONFLOW_KEY", "")
super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1")
async def call(self, messages: List[Dict], model: str = "Qwen/Qwen2.5-7B-Instruct",
tools: List[Dict] = None, max_tokens: int = 500,
temperature: float = 0.7) -> CallResult:
import httpx
t0 = time.time()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
body = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
if tools:
body["tools"] = tools
body["tool_choice"] = "auto"
try:
async with httpx.AsyncClient(timeout=30) as client:
r = await client.post(
f"{self.base_url}/chat/completions",
headers=headers, json=body
)
ms = round((time.time() - t0) * 1000)
if r.status_code != 200:
return CallResult(text="", tool_calls=[], model=model,
provider=self.provider, latency_ms=ms,
error=f"HTTP {r.status_code}: {r.text[:200]}")
data = r.json()
choice = data["choices"][0]
msg = choice["message"]
tool_calls = []
if "tool_calls" in msg and msg["tool_calls"]:
for tc in msg["tool_calls"]:
tool_calls.append({
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
})
usage = data.get("usage", {})
return CallResult(
text=msg.get("content", "") or "",
tool_calls=tool_calls,
model=model,
provider=self.provider,
latency_ms=ms,
tokens_in=usage.get("prompt_tokens", 0),
tokens_out=usage.get("completion_tokens", 0)
)
except Exception as e:
ms = round((time.time() - t0) * 1000)
return CallResult(text="", tool_calls=[], model=model,
provider=self.provider, latency_ms=ms,
error=str(e))
class SiliconFlowEmbeddingAdapter(ModelAdapter):
"""硅基流动嵌入模型适配器"""
def __init__(self, api_key: str = ""):
key = api_key or os.environ.get("SILICONFLOW_KEY", "")
super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1")
async def embed(self, texts: List[str], model: str = "BAAI/bge-m3",
encoding_format: str = "float") -> Dict:
"""获取文本嵌入向量"""
import httpx
t0 = time.time()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
body = {
"model": model,
"input": texts,
"encoding_format": encoding_format,
}
try:
async with httpx.AsyncClient(timeout=30) as client:
r = await client.post(
f"{self.base_url}/embeddings",
headers=headers, json=body
)
ms = round((time.time() - t0) * 1000)
if r.status_code != 200:
return {"error": f"HTTP {r.status_code}: {r.text[:200]}", "latency_ms": ms}
data = r.json()
return {
"embeddings": [d["embedding"] for d in data["data"]],
"model": data.get("model", model),
"usage": data.get("usage", {}),
"latency_ms": ms
}
except Exception as e:
ms = round((time.time() - t0) * 1000)
return {"error": str(e), "latency_ms": ms}
class SiliconFlowRerankerAdapter(ModelAdapter):
"""硅基流动重排序模型适配器"""
def __init__(self, api_key: str = ""):
key = api_key or os.environ.get("SILICONFLOW_KEY", "")
super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1")
async def rerank(self, query: str, documents: List[str],
model: str = "BAAI/bge-reranker-v2-m3", top_n: int = 5) -> Dict:
"""对文档进行重排序"""
import httpx
t0 = time.time()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
body = {
"model": model,
"query": query,
"documents": documents,
"top_n": min(top_n, len(documents)),
"return_documents": True,
}
try:
async with httpx.AsyncClient(timeout=30) as client:
r = await client.post(
f"{self.base_url}/rerank",
headers=headers, json=body
)
ms = round((time.time() - t0) * 1000)
if r.status_code != 200:
return {"error": f"HTTP {r.status_code}: {r.text[:200]}", "latency_ms": ms}
data = r.json()
return {
"results": data.get("results", []),
"model": data.get("model", model),
"latency_ms": ms
}
except Exception as e:
ms = round((time.time() - t0) * 1000)
return {"error": str(e), "latency_ms": ms}
# 适配器工厂
def create_adapter(provider: str, api_key: str = "") -> ModelAdapter:
"""根据提供商创建适配器"""
adapters = {
"zhipu": ZhipuAdapter,
"nvidia_nim": NvidiaNimAdapter,
"siliconflow": SiliconFlowAdapter,
}
cls = adapters.get(provider)
if cls:
return cls(api_key)
raise ValueError(f"未知提供商: {provider}, 可选: {list(adapters.keys())}")
if __name__ == "__main__":
import asyncio
async def test():
print("=== 模型适配器测试 ===\n")
# 测试智谱
zhipu = ZhipuAdapter()
if zhipu.api_key:
r = await zhipu.call(
messages=[{"role": "user", "content": "1+1=? 只回复OK"}],
model="glm-4-flash-250414",
max_tokens=10
)
print(f"智谱 GLM-4-Flash: {r.text[:50]} ({r.latency_ms}ms)")
if r.error:
print(f" 错误: {r.error}")
else:
print("智谱: 无API Key, 跳过")
# 测试NIM
nim = NvidiaNimAdapter()
if nim.api_key:
r = await nim.call(
messages=[{"role": "user", "content": "1+1=? 只回复OK"}],
model="meta/llama-4-maverick-17b-128e-instruct",
max_tokens=10
)
print(f"NIM Llama-4: {r.text[:50]} ({r.latency_ms}ms)")
if r.error:
print(f" 错误: {r.error}")
else:
print("NIM: 无API Key, 跳过")
print("\n✅ 适配器测试完成")
asyncio.run(test())