#!/usr/bin/env python3 """ 虫群模型适配器 — 统一不同API的调用接口 每个提供商一个适配器,暴露统一的 call(prompt, tools) 方法 """ import os import time import json from typing import Dict, List, Optional, Any from dataclasses import dataclass @dataclass class CallResult: """模型调用结果""" text: str # 回复文本 tool_calls: List[Dict] # 工具调用列表 model: str # 实际使用的模型 provider: str # 提供商 latency_ms: float # 延迟(ms) tokens_in: int = 0 # 输入token数 tokens_out: int = 0 # 输出token数 thinking: str = "" # 思维链内容 error: str = "" # 错误信息 class ModelAdapter: """模型调用适配器基类""" def __init__(self, provider: str, api_key: str = "", base_url: str = ""): self.provider = provider self.api_key = api_key self.base_url = base_url async def call(self, messages: List[Dict], model: str = "", tools: List[Dict] = None, max_tokens: int = 500, temperature: float = 0.7) -> CallResult: raise NotImplementedError class ZhipuAdapter(ModelAdapter): """智谱AI适配器""" def __init__(self, api_key: str = ""): key = api_key or os.environ.get("GLM_API_KEY", "") super().__init__("zhipu", key, "https://open.bigmodel.cn/api/paas/v4") async def call(self, messages: List[Dict], model: str = "glm-4.7-flash", tools: List[Dict] = None, max_tokens: int = 500, temperature: float = 0.7) -> CallResult: import httpx t0 = time.time() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } body = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if tools: body["tools"] = tools body["tool_choice"] = "auto" try: async with httpx.AsyncClient(timeout=60) as client: r = await client.post( f"{self.base_url}/chat/completions", headers=headers, json=body ) ms = round((time.time() - t0) * 1000) if r.status_code != 200: return CallResult(text="", tool_calls=[], model=model, provider=self.provider, latency_ms=ms, error=f"HTTP {r.status_code}: {r.text[:200]}") data = r.json() choice = data["choices"][0] msg = choice["message"] # 提取工具调用 tool_calls = [] if "tool_calls" in msg and msg["tool_calls"]: for tc in msg["tool_calls"]: tool_calls.append({ "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] }) # 提取思维链 thinking = "" if "reasoning_content" in msg: thinking = msg["reasoning_content"] usage = data.get("usage", {}) return CallResult( text=msg.get("content", "") or "", tool_calls=tool_calls, model=model, provider=self.provider, latency_ms=ms, tokens_in=usage.get("prompt_tokens", 0), tokens_out=usage.get("completion_tokens", 0), thinking=thinking ) except Exception as e: ms = round((time.time() - t0) * 1000) return CallResult(text="", tool_calls=[], model=model, provider=self.provider, latency_ms=ms, error=str(e)) class NvidiaNimAdapter(ModelAdapter): """NVIDIA NIM适配器""" def __init__(self, api_key: str = ""): key = api_key or os.environ.get("NVIDIA_KEY", "") super().__init__("nvidia_nim", key, "https://integrate.api.nvidia.com/v1") async def call(self, messages: List[Dict], model: str = "meta/llama-4-maverick-17b-128e-instruct", tools: List[Dict] = None, max_tokens: int = 500, temperature: float = 0.7) -> CallResult: import httpx t0 = time.time() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } body = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if tools: body["tools"] = tools body["tool_choice"] = "auto" try: async with httpx.AsyncClient(timeout=35) as client: r = await client.post( f"{self.base_url}/chat/completions", headers=headers, json=body ) ms = round((time.time() - t0) * 1000) if r.status_code != 200: return CallResult(text="", tool_calls=[], model=model, provider=self.provider, latency_ms=ms, error=f"HTTP {r.status_code}: {r.text[:200]}") data = r.json() choice = data["choices"][0] msg = choice["message"] tool_calls = [] if "tool_calls" in msg and msg["tool_calls"]: for tc in msg["tool_calls"]: tool_calls.append({ "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] }) thinking = "" if "reasoning_content" in msg: thinking = msg["reasoning_content"] usage = data.get("usage", {}) return CallResult( text=msg.get("content", "") or "", tool_calls=tool_calls, model=model, provider=self.provider, latency_ms=ms, tokens_in=usage.get("prompt_tokens", 0), tokens_out=usage.get("completion_tokens", 0), thinking=thinking ) except Exception as e: ms = round((time.time() - t0) * 1000) return CallResult(text="", tool_calls=[], model=model, provider=self.provider, latency_ms=ms, error=str(e)) class SiliconFlowAdapter(ModelAdapter): """硅基流动适配器""" def __init__(self, api_key: str = ""): key = api_key or os.environ.get("SILICONFLOW_KEY", "") super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1") async def call(self, messages: List[Dict], model: str = "Qwen/Qwen2.5-7B-Instruct", tools: List[Dict] = None, max_tokens: int = 500, temperature: float = 0.7) -> CallResult: import httpx t0 = time.time() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } body = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } if tools: body["tools"] = tools body["tool_choice"] = "auto" try: async with httpx.AsyncClient(timeout=30) as client: r = await client.post( f"{self.base_url}/chat/completions", headers=headers, json=body ) ms = round((time.time() - t0) * 1000) if r.status_code != 200: return CallResult(text="", tool_calls=[], model=model, provider=self.provider, latency_ms=ms, error=f"HTTP {r.status_code}: {r.text[:200]}") data = r.json() choice = data["choices"][0] msg = choice["message"] tool_calls = [] if "tool_calls" in msg and msg["tool_calls"]: for tc in msg["tool_calls"]: tool_calls.append({ "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] }) usage = data.get("usage", {}) return CallResult( text=msg.get("content", "") or "", tool_calls=tool_calls, model=model, provider=self.provider, latency_ms=ms, tokens_in=usage.get("prompt_tokens", 0), tokens_out=usage.get("completion_tokens", 0) ) except Exception as e: ms = round((time.time() - t0) * 1000) return CallResult(text="", tool_calls=[], model=model, provider=self.provider, latency_ms=ms, error=str(e)) class SiliconFlowEmbeddingAdapter(ModelAdapter): """硅基流动嵌入模型适配器""" def __init__(self, api_key: str = ""): key = api_key or os.environ.get("SILICONFLOW_KEY", "") super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1") async def embed(self, texts: List[str], model: str = "BAAI/bge-m3", encoding_format: str = "float") -> Dict: """获取文本嵌入向量""" import httpx t0 = time.time() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } body = { "model": model, "input": texts, "encoding_format": encoding_format, } try: async with httpx.AsyncClient(timeout=30) as client: r = await client.post( f"{self.base_url}/embeddings", headers=headers, json=body ) ms = round((time.time() - t0) * 1000) if r.status_code != 200: return {"error": f"HTTP {r.status_code}: {r.text[:200]}", "latency_ms": ms} data = r.json() return { "embeddings": [d["embedding"] for d in data["data"]], "model": data.get("model", model), "usage": data.get("usage", {}), "latency_ms": ms } except Exception as e: ms = round((time.time() - t0) * 1000) return {"error": str(e), "latency_ms": ms} class SiliconFlowRerankerAdapter(ModelAdapter): """硅基流动重排序模型适配器""" def __init__(self, api_key: str = ""): key = api_key or os.environ.get("SILICONFLOW_KEY", "") super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1") async def rerank(self, query: str, documents: List[str], model: str = "BAAI/bge-reranker-v2-m3", top_n: int = 5) -> Dict: """对文档进行重排序""" import httpx t0 = time.time() headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } body = { "model": model, "query": query, "documents": documents, "top_n": min(top_n, len(documents)), "return_documents": True, } try: async with httpx.AsyncClient(timeout=30) as client: r = await client.post( f"{self.base_url}/rerank", headers=headers, json=body ) ms = round((time.time() - t0) * 1000) if r.status_code != 200: return {"error": f"HTTP {r.status_code}: {r.text[:200]}", "latency_ms": ms} data = r.json() return { "results": data.get("results", []), "model": data.get("model", model), "latency_ms": ms } except Exception as e: ms = round((time.time() - t0) * 1000) return {"error": str(e), "latency_ms": ms} # 适配器工厂 def create_adapter(provider: str, api_key: str = "") -> ModelAdapter: """根据提供商创建适配器""" adapters = { "zhipu": ZhipuAdapter, "nvidia_nim": NvidiaNimAdapter, "siliconflow": SiliconFlowAdapter, } cls = adapters.get(provider) if cls: return cls(api_key) raise ValueError(f"未知提供商: {provider}, 可选: {list(adapters.keys())}") if __name__ == "__main__": import asyncio async def test(): print("=== 模型适配器测试 ===\n") # 测试智谱 zhipu = ZhipuAdapter() if zhipu.api_key: r = await zhipu.call( messages=[{"role": "user", "content": "1+1=? 只回复OK"}], model="glm-4-flash-250414", max_tokens=10 ) print(f"智谱 GLM-4-Flash: {r.text[:50]} ({r.latency_ms}ms)") if r.error: print(f" 错误: {r.error}") else: print("智谱: 无API Key, 跳过") # 测试NIM nim = NvidiaNimAdapter() if nim.api_key: r = await nim.call( messages=[{"role": "user", "content": "1+1=? 只回复OK"}], model="meta/llama-4-maverick-17b-128e-instruct", max_tokens=10 ) print(f"NIM Llama-4: {r.text[:50]} ({r.latency_ms}ms)") if r.error: print(f" 错误: {r.error}") else: print("NIM: 无API Key, 跳过") print("\n✅ 适配器测试完成") asyncio.run(test())