Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| 虫群模型适配器 — 统一不同API的调用接口 | |
| 每个提供商一个适配器,暴露统一的 call(prompt, tools) 方法 | |
| """ | |
| import os | |
| import time | |
| import json | |
| from typing import Dict, List, Optional, Any | |
| from dataclasses import dataclass | |
| class CallResult: | |
| """模型调用结果""" | |
| text: str # 回复文本 | |
| tool_calls: List[Dict] # 工具调用列表 | |
| model: str # 实际使用的模型 | |
| provider: str # 提供商 | |
| latency_ms: float # 延迟(ms) | |
| tokens_in: int = 0 # 输入token数 | |
| tokens_out: int = 0 # 输出token数 | |
| thinking: str = "" # 思维链内容 | |
| error: str = "" # 错误信息 | |
| class ModelAdapter: | |
| """模型调用适配器基类""" | |
| def __init__(self, provider: str, api_key: str = "", base_url: str = ""): | |
| self.provider = provider | |
| self.api_key = api_key | |
| self.base_url = base_url | |
| async def call(self, messages: List[Dict], model: str = "", | |
| tools: List[Dict] = None, max_tokens: int = 500, | |
| temperature: float = 0.7) -> CallResult: | |
| raise NotImplementedError | |
| class ZhipuAdapter(ModelAdapter): | |
| """智谱AI适配器""" | |
| def __init__(self, api_key: str = ""): | |
| key = api_key or os.environ.get("GLM_API_KEY", "") | |
| super().__init__("zhipu", key, "https://open.bigmodel.cn/api/paas/v4") | |
| async def call(self, messages: List[Dict], model: str = "glm-4.7-flash", | |
| tools: List[Dict] = None, max_tokens: int = 500, | |
| temperature: float = 0.7) -> CallResult: | |
| import httpx | |
| t0 = time.time() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| body = { | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| } | |
| if tools: | |
| body["tools"] = tools | |
| body["tool_choice"] = "auto" | |
| try: | |
| async with httpx.AsyncClient(timeout=60) as client: | |
| r = await client.post( | |
| f"{self.base_url}/chat/completions", | |
| headers=headers, json=body | |
| ) | |
| ms = round((time.time() - t0) * 1000) | |
| if r.status_code != 200: | |
| return CallResult(text="", tool_calls=[], model=model, | |
| provider=self.provider, latency_ms=ms, | |
| error=f"HTTP {r.status_code}: {r.text[:200]}") | |
| data = r.json() | |
| choice = data["choices"][0] | |
| msg = choice["message"] | |
| # 提取工具调用 | |
| tool_calls = [] | |
| if "tool_calls" in msg and msg["tool_calls"]: | |
| for tc in msg["tool_calls"]: | |
| tool_calls.append({ | |
| "name": tc["function"]["name"], | |
| "arguments": tc["function"]["arguments"] | |
| }) | |
| # 提取思维链 | |
| thinking = "" | |
| if "reasoning_content" in msg: | |
| thinking = msg["reasoning_content"] | |
| usage = data.get("usage", {}) | |
| return CallResult( | |
| text=msg.get("content", "") or "", | |
| tool_calls=tool_calls, | |
| model=model, | |
| provider=self.provider, | |
| latency_ms=ms, | |
| tokens_in=usage.get("prompt_tokens", 0), | |
| tokens_out=usage.get("completion_tokens", 0), | |
| thinking=thinking | |
| ) | |
| except Exception as e: | |
| ms = round((time.time() - t0) * 1000) | |
| return CallResult(text="", tool_calls=[], model=model, | |
| provider=self.provider, latency_ms=ms, | |
| error=str(e)) | |
| class NvidiaNimAdapter(ModelAdapter): | |
| """NVIDIA NIM适配器""" | |
| def __init__(self, api_key: str = ""): | |
| key = api_key or os.environ.get("NVIDIA_KEY", "") | |
| super().__init__("nvidia_nim", key, "https://integrate.api.nvidia.com/v1") | |
| async def call(self, messages: List[Dict], model: str = "meta/llama-4-maverick-17b-128e-instruct", | |
| tools: List[Dict] = None, max_tokens: int = 500, | |
| temperature: float = 0.7) -> CallResult: | |
| import httpx | |
| t0 = time.time() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| body = { | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| } | |
| if tools: | |
| body["tools"] = tools | |
| body["tool_choice"] = "auto" | |
| try: | |
| async with httpx.AsyncClient(timeout=35) as client: | |
| r = await client.post( | |
| f"{self.base_url}/chat/completions", | |
| headers=headers, json=body | |
| ) | |
| ms = round((time.time() - t0) * 1000) | |
| if r.status_code != 200: | |
| return CallResult(text="", tool_calls=[], model=model, | |
| provider=self.provider, latency_ms=ms, | |
| error=f"HTTP {r.status_code}: {r.text[:200]}") | |
| data = r.json() | |
| choice = data["choices"][0] | |
| msg = choice["message"] | |
| tool_calls = [] | |
| if "tool_calls" in msg and msg["tool_calls"]: | |
| for tc in msg["tool_calls"]: | |
| tool_calls.append({ | |
| "name": tc["function"]["name"], | |
| "arguments": tc["function"]["arguments"] | |
| }) | |
| thinking = "" | |
| if "reasoning_content" in msg: | |
| thinking = msg["reasoning_content"] | |
| usage = data.get("usage", {}) | |
| return CallResult( | |
| text=msg.get("content", "") or "", | |
| tool_calls=tool_calls, | |
| model=model, | |
| provider=self.provider, | |
| latency_ms=ms, | |
| tokens_in=usage.get("prompt_tokens", 0), | |
| tokens_out=usage.get("completion_tokens", 0), | |
| thinking=thinking | |
| ) | |
| except Exception as e: | |
| ms = round((time.time() - t0) * 1000) | |
| return CallResult(text="", tool_calls=[], model=model, | |
| provider=self.provider, latency_ms=ms, | |
| error=str(e)) | |
| class SiliconFlowAdapter(ModelAdapter): | |
| """硅基流动适配器""" | |
| def __init__(self, api_key: str = ""): | |
| key = api_key or os.environ.get("SILICONFLOW_KEY", "") | |
| super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1") | |
| async def call(self, messages: List[Dict], model: str = "Qwen/Qwen2.5-7B-Instruct", | |
| tools: List[Dict] = None, max_tokens: int = 500, | |
| temperature: float = 0.7) -> CallResult: | |
| import httpx | |
| t0 = time.time() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| body = { | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| } | |
| if tools: | |
| body["tools"] = tools | |
| body["tool_choice"] = "auto" | |
| try: | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| r = await client.post( | |
| f"{self.base_url}/chat/completions", | |
| headers=headers, json=body | |
| ) | |
| ms = round((time.time() - t0) * 1000) | |
| if r.status_code != 200: | |
| return CallResult(text="", tool_calls=[], model=model, | |
| provider=self.provider, latency_ms=ms, | |
| error=f"HTTP {r.status_code}: {r.text[:200]}") | |
| data = r.json() | |
| choice = data["choices"][0] | |
| msg = choice["message"] | |
| tool_calls = [] | |
| if "tool_calls" in msg and msg["tool_calls"]: | |
| for tc in msg["tool_calls"]: | |
| tool_calls.append({ | |
| "name": tc["function"]["name"], | |
| "arguments": tc["function"]["arguments"] | |
| }) | |
| usage = data.get("usage", {}) | |
| return CallResult( | |
| text=msg.get("content", "") or "", | |
| tool_calls=tool_calls, | |
| model=model, | |
| provider=self.provider, | |
| latency_ms=ms, | |
| tokens_in=usage.get("prompt_tokens", 0), | |
| tokens_out=usage.get("completion_tokens", 0) | |
| ) | |
| except Exception as e: | |
| ms = round((time.time() - t0) * 1000) | |
| return CallResult(text="", tool_calls=[], model=model, | |
| provider=self.provider, latency_ms=ms, | |
| error=str(e)) | |
| class SiliconFlowEmbeddingAdapter(ModelAdapter): | |
| """硅基流动嵌入模型适配器""" | |
| def __init__(self, api_key: str = ""): | |
| key = api_key or os.environ.get("SILICONFLOW_KEY", "") | |
| super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1") | |
| async def embed(self, texts: List[str], model: str = "BAAI/bge-m3", | |
| encoding_format: str = "float") -> Dict: | |
| """获取文本嵌入向量""" | |
| import httpx | |
| t0 = time.time() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| body = { | |
| "model": model, | |
| "input": texts, | |
| "encoding_format": encoding_format, | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| r = await client.post( | |
| f"{self.base_url}/embeddings", | |
| headers=headers, json=body | |
| ) | |
| ms = round((time.time() - t0) * 1000) | |
| if r.status_code != 200: | |
| return {"error": f"HTTP {r.status_code}: {r.text[:200]}", "latency_ms": ms} | |
| data = r.json() | |
| return { | |
| "embeddings": [d["embedding"] for d in data["data"]], | |
| "model": data.get("model", model), | |
| "usage": data.get("usage", {}), | |
| "latency_ms": ms | |
| } | |
| except Exception as e: | |
| ms = round((time.time() - t0) * 1000) | |
| return {"error": str(e), "latency_ms": ms} | |
| class SiliconFlowRerankerAdapter(ModelAdapter): | |
| """硅基流动重排序模型适配器""" | |
| def __init__(self, api_key: str = ""): | |
| key = api_key or os.environ.get("SILICONFLOW_KEY", "") | |
| super().__init__("siliconflow", key, "https://api.siliconflow.cn/v1") | |
| async def rerank(self, query: str, documents: List[str], | |
| model: str = "BAAI/bge-reranker-v2-m3", top_n: int = 5) -> Dict: | |
| """对文档进行重排序""" | |
| import httpx | |
| t0 = time.time() | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| body = { | |
| "model": model, | |
| "query": query, | |
| "documents": documents, | |
| "top_n": min(top_n, len(documents)), | |
| "return_documents": True, | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| r = await client.post( | |
| f"{self.base_url}/rerank", | |
| headers=headers, json=body | |
| ) | |
| ms = round((time.time() - t0) * 1000) | |
| if r.status_code != 200: | |
| return {"error": f"HTTP {r.status_code}: {r.text[:200]}", "latency_ms": ms} | |
| data = r.json() | |
| return { | |
| "results": data.get("results", []), | |
| "model": data.get("model", model), | |
| "latency_ms": ms | |
| } | |
| except Exception as e: | |
| ms = round((time.time() - t0) * 1000) | |
| return {"error": str(e), "latency_ms": ms} | |
| # 适配器工厂 | |
| def create_adapter(provider: str, api_key: str = "") -> ModelAdapter: | |
| """根据提供商创建适配器""" | |
| adapters = { | |
| "zhipu": ZhipuAdapter, | |
| "nvidia_nim": NvidiaNimAdapter, | |
| "siliconflow": SiliconFlowAdapter, | |
| } | |
| cls = adapters.get(provider) | |
| if cls: | |
| return cls(api_key) | |
| raise ValueError(f"未知提供商: {provider}, 可选: {list(adapters.keys())}") | |
| if __name__ == "__main__": | |
| import asyncio | |
| async def test(): | |
| print("=== 模型适配器测试 ===\n") | |
| # 测试智谱 | |
| zhipu = ZhipuAdapter() | |
| if zhipu.api_key: | |
| r = await zhipu.call( | |
| messages=[{"role": "user", "content": "1+1=? 只回复OK"}], | |
| model="glm-4-flash-250414", | |
| max_tokens=10 | |
| ) | |
| print(f"智谱 GLM-4-Flash: {r.text[:50]} ({r.latency_ms}ms)") | |
| if r.error: | |
| print(f" 错误: {r.error}") | |
| else: | |
| print("智谱: 无API Key, 跳过") | |
| # 测试NIM | |
| nim = NvidiaNimAdapter() | |
| if nim.api_key: | |
| r = await nim.call( | |
| messages=[{"role": "user", "content": "1+1=? 只回复OK"}], | |
| model="meta/llama-4-maverick-17b-128e-instruct", | |
| max_tokens=10 | |
| ) | |
| print(f"NIM Llama-4: {r.text[:50]} ({r.latency_ms}ms)") | |
| if r.error: | |
| print(f" 错误: {r.error}") | |
| else: | |
| print("NIM: 无API Key, 跳过") | |
| print("\n✅ 适配器测试完成") | |
| asyncio.run(test()) | |