Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """LLM client abstraction for calling LLM endpoints. | |
| Provides a generic RPC abstraction: point it at an endpoint/port, tell it the | |
| protocol, and it works. OpenAI-compatible API is the first implementation, | |
| covering OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, etc. | |
| Anthropic's native API is supported via ``AnthropicClient``. | |
| Usage: | |
| client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") | |
| response = await client.complete("What is 2+2?") | |
| # Or use the factory for hosted APIs: | |
| client = create_llm_client("openai", model="gpt-4", api_key="sk-...") | |
| response = await client.complete_with_tools(messages, tools) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| from openai import AsyncOpenAI | |
| class ToolCall: | |
| """A single tool/function call returned by the model.""" | |
| id: str | |
| name: str | |
| args: dict[str, Any] | |
| class LLMResponse: | |
| """Normalized response from an LLM, with optional tool calls.""" | |
| content: str | |
| tool_calls: list[ToolCall] = field(default_factory=list) | |
| def to_message_dict(self) -> dict[str, Any]: | |
| """Convert to an OpenAI-format assistant message dict.""" | |
| msg: dict[str, Any] = {"role": "assistant", "content": self.content} | |
| if self.tool_calls: | |
| msg["tool_calls"] = [ | |
| { | |
| "id": tc.id, | |
| "type": "function", | |
| "function": { | |
| "name": tc.name, | |
| "arguments": json.dumps(tc.args), | |
| }, | |
| } | |
| for tc in self.tool_calls | |
| ] | |
| return msg | |
| class LLMClient(ABC): | |
| """Abstract base for LLM endpoint clients. | |
| Subclass and implement ``complete()`` for your protocol. | |
| Args: | |
| endpoint: The base URL of the LLM service (e.g. "http://localhost"). | |
| port: The port the service listens on. | |
| """ | |
| def __init__(self, endpoint: str, port: int): | |
| self.endpoint = endpoint | |
| self.port = port | |
| async def complete(self, prompt: str, **kwargs) -> str: | |
| """Send a prompt, return the text response. | |
| Args: | |
| prompt: The user prompt to send. | |
| **kwargs: Override default parameters (temperature, max_tokens, etc.). | |
| Returns: | |
| The model's text response. | |
| """ | |
| ... | |
| async def complete_with_tools( | |
| self, | |
| messages: list[dict[str, Any]], | |
| tools: list[dict[str, Any]], | |
| **kwargs: Any, | |
| ) -> LLMResponse: | |
| """Send messages with tool definitions, return a normalized response. | |
| Messages use OpenAI-format dicts (``{"role": "...", "content": "..."}``). | |
| Tools use MCP tool definitions; they are converted internally. | |
| Args: | |
| messages: Conversation history as OpenAI-format message dicts. | |
| tools: MCP tool definitions. | |
| **kwargs: Override default parameters (temperature, max_tokens, etc.). | |
| Returns: | |
| An ``LLMResponse`` with the model's text and any tool calls. | |
| """ | |
| raise NotImplementedError( | |
| f"{type(self).__name__} does not support tool calling" | |
| ) | |
| def base_url(self) -> str: | |
| """Construct base URL from endpoint and port.""" | |
| return f"{self.endpoint}:{self.port}" | |
| class OpenAIClient(LLMClient): | |
| """Client for OpenAI-compatible APIs. | |
| Works with: OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, | |
| or any endpoint that speaks the OpenAI chat completions format. | |
| Args: | |
| endpoint: The base URL (e.g. "http://localhost"). | |
| port: The port number. | |
| model: Model name to pass to the API. | |
| api_key: API key. Defaults to "not-needed" for local endpoints. | |
| system_prompt: Optional system message prepended to every request. | |
| temperature: Default sampling temperature. | |
| max_tokens: Default max tokens in the response. | |
| """ | |
| def __init__( | |
| self, | |
| endpoint: str, | |
| port: int, | |
| model: str, | |
| api_key: str | None = None, | |
| system_prompt: str | None = None, | |
| temperature: float = 0.0, | |
| max_tokens: int = 256, | |
| ): | |
| super().__init__(endpoint, port) | |
| self.model = model | |
| self.system_prompt = system_prompt | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self._client = AsyncOpenAI( | |
| base_url=f"{self.base_url}/v1", | |
| api_key=api_key if api_key is not None else "not-needed", | |
| ) | |
| async def complete(self, prompt: str, **kwargs) -> str: | |
| """Send a chat completion request. | |
| Args: | |
| prompt: The user message. | |
| **kwargs: Overrides for temperature, max_tokens. | |
| Returns: | |
| The assistant's response text. | |
| """ | |
| messages = [] | |
| if self.system_prompt: | |
| messages.append({"role": "system", "content": self.system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| response = await self._client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=kwargs.get("temperature", self.temperature), | |
| max_tokens=kwargs.get("max_tokens", self.max_tokens), | |
| ) | |
| return response.choices[0].message.content or "" | |
| async def complete_with_tools( | |
| self, | |
| messages: list[dict[str, Any]], | |
| tools: list[dict[str, Any]], | |
| **kwargs: Any, | |
| ) -> LLMResponse: | |
| create_kwargs: dict[str, Any] = { | |
| "model": self.model, | |
| "messages": messages, | |
| "temperature": kwargs.get("temperature", self.temperature), | |
| "max_tokens": kwargs.get("max_tokens", self.max_tokens), | |
| } | |
| openai_tools = _mcp_tools_to_openai(tools) | |
| if openai_tools: | |
| create_kwargs["tools"] = openai_tools | |
| response = await self._client.chat.completions.create(**create_kwargs) | |
| msg = response.choices[0].message | |
| tool_calls = [] | |
| if msg.tool_calls: | |
| for tc in msg.tool_calls: | |
| tool_calls.append( | |
| ToolCall( | |
| id=tc.id, | |
| name=tc.function.name, | |
| args=json.loads(tc.function.arguments), | |
| ) | |
| ) | |
| return LLMResponse(content=msg.content or "", tool_calls=tool_calls) | |
| class AnthropicClient(LLMClient): | |
| """Client for Anthropic's Messages API. | |
| Requires the ``anthropic`` package (lazy-imported at construction time). | |
| Args: | |
| endpoint: The base URL (e.g. "https://api.anthropic.com"). | |
| port: The port number. | |
| model: Model name (e.g. "claude-sonnet-4-20250514"). | |
| api_key: Anthropic API key. | |
| system_prompt: Optional system message prepended to every request. | |
| temperature: Default sampling temperature. | |
| max_tokens: Default max tokens in the response. | |
| """ | |
| def __init__( | |
| self, | |
| endpoint: str, | |
| port: int, | |
| model: str, | |
| api_key: str | None = None, | |
| system_prompt: str | None = None, | |
| temperature: float = 0.0, | |
| max_tokens: int = 256, | |
| ): | |
| super().__init__(endpoint, port) | |
| self.model = model | |
| self.system_prompt = system_prompt | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| try: | |
| from anthropic import AsyncAnthropic | |
| except ImportError as exc: | |
| raise ImportError( | |
| "AnthropicClient requires the 'anthropic' package. " | |
| "Install it with: pip install anthropic" | |
| ) from exc | |
| self._client = AsyncAnthropic( | |
| base_url=self.base_url, | |
| api_key=api_key if api_key is not None else "not-needed", | |
| ) | |
| async def complete(self, prompt: str, **kwargs) -> str: | |
| create_kwargs: dict[str, Any] = { | |
| "model": self.model, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "temperature": kwargs.get("temperature", self.temperature), | |
| "max_tokens": kwargs.get("max_tokens", self.max_tokens), | |
| } | |
| if self.system_prompt: | |
| create_kwargs["system"] = self.system_prompt | |
| response = await self._client.messages.create(**create_kwargs) | |
| return "".join(block.text for block in response.content if block.type == "text") | |
| async def complete_with_tools( | |
| self, | |
| messages: list[dict[str, Any]], | |
| tools: list[dict[str, Any]], | |
| **kwargs: Any, | |
| ) -> LLMResponse: | |
| system, anthropic_msgs = _openai_msgs_to_anthropic(messages) | |
| create_kwargs: dict[str, Any] = { | |
| "model": self.model, | |
| "messages": anthropic_msgs, | |
| "temperature": kwargs.get("temperature", self.temperature), | |
| "max_tokens": kwargs.get("max_tokens", self.max_tokens), | |
| } | |
| system_text = system or self.system_prompt | |
| if system_text: | |
| create_kwargs["system"] = system_text | |
| anthropic_tools = _mcp_tools_to_anthropic(tools) | |
| if anthropic_tools: | |
| create_kwargs["tools"] = anthropic_tools | |
| response = await self._client.messages.create(**create_kwargs) | |
| content = "" | |
| tool_calls = [] | |
| for block in response.content: | |
| if block.type == "text": | |
| content += block.text | |
| elif block.type == "tool_use": | |
| tool_calls.append( | |
| ToolCall(id=block.id, name=block.name, args=block.input) | |
| ) | |
| return LLMResponse(content=content, tool_calls=tool_calls) | |
| # --------------------------------------------------------------------------- | |
| # Factory | |
| # --------------------------------------------------------------------------- | |
| _HOSTED_PROVIDERS: dict[str, tuple[str, int, type[LLMClient]]] = { | |
| "openai": ("https://api.openai.com", 443, OpenAIClient), | |
| "anthropic": ("https://api.anthropic.com", 443, AnthropicClient), | |
| } | |
| def create_llm_client( | |
| provider: str, | |
| model: str, | |
| api_key: str, | |
| *, | |
| system_prompt: str | None = None, | |
| temperature: float = 0.0, | |
| max_tokens: int = 4096, | |
| ) -> LLMClient: | |
| """Create an LLM client for a hosted provider. | |
| Args: | |
| provider: Provider name ("openai" or "anthropic"). | |
| model: Model identifier. | |
| api_key: API key for the provider. | |
| system_prompt: Optional system message prepended to every request. | |
| temperature: Sampling temperature. | |
| max_tokens: Maximum tokens in the response. | |
| Returns: | |
| A configured ``LLMClient`` instance. | |
| """ | |
| key = provider.lower() | |
| if key not in _HOSTED_PROVIDERS: | |
| raise ValueError( | |
| f"Unsupported provider: {provider!r}. " | |
| f"Supported: {sorted(_HOSTED_PROVIDERS)}" | |
| ) | |
| endpoint, port, cls = _HOSTED_PROVIDERS[key] | |
| return cls( | |
| endpoint, | |
| port, | |
| model, | |
| api_key=api_key, | |
| system_prompt=system_prompt, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # MCP tool-schema helpers | |
| # --------------------------------------------------------------------------- | |
| def _clean_mcp_schema(schema: dict[str, Any]) -> dict[str, Any]: | |
| """Normalize an MCP tool ``inputSchema`` for LLM function-calling APIs.""" | |
| if not isinstance(schema, dict): | |
| return {"type": "object", "properties": {}, "required": []} | |
| # Shallow copy to avoid mutating the caller's schema dict. | |
| schema = dict(schema) | |
| if "oneOf" in schema: | |
| for option in schema["oneOf"]: | |
| if isinstance(option, dict) and option.get("type") == "object": | |
| schema = option | |
| break | |
| else: | |
| return {"type": "object", "properties": {}, "required": []} | |
| if "allOf" in schema: | |
| merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} | |
| for sub in schema["allOf"]: | |
| if isinstance(sub, dict): | |
| if "properties" in sub: | |
| merged["properties"].update(sub["properties"]) | |
| if "required" in sub: | |
| merged["required"].extend(sub["required"]) | |
| schema = merged | |
| if "anyOf" in schema: | |
| for option in schema["anyOf"]: | |
| if isinstance(option, dict) and option.get("type") == "object": | |
| schema = option | |
| break | |
| else: | |
| return {"type": "object", "properties": {}, "required": []} | |
| schema.setdefault("type", "object") | |
| if schema.get("type") == "object" and "properties" not in schema: | |
| schema["properties"] = {} | |
| return schema | |
| def _mcp_tools_to_openai( | |
| mcp_tools: list[dict[str, Any]], | |
| ) -> list[dict[str, Any]]: | |
| """Convert MCP tool definitions to OpenAI function-calling format.""" | |
| result = [] | |
| for tool in mcp_tools: | |
| input_schema = tool.get( | |
| "inputSchema", {"type": "object", "properties": {}, "required": []} | |
| ) | |
| result.append( | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": tool["name"], | |
| "description": tool.get("description", ""), | |
| "parameters": _clean_mcp_schema(input_schema), | |
| }, | |
| } | |
| ) | |
| return result | |
| def _mcp_tools_to_anthropic( | |
| mcp_tools: list[dict[str, Any]], | |
| ) -> list[dict[str, Any]]: | |
| """Convert MCP tool definitions to Anthropic tool format.""" | |
| result = [] | |
| for tool in mcp_tools: | |
| input_schema = tool.get( | |
| "inputSchema", {"type": "object", "properties": {}, "required": []} | |
| ) | |
| result.append( | |
| { | |
| "name": tool["name"], | |
| "description": tool.get("description", ""), | |
| "input_schema": _clean_mcp_schema(input_schema), | |
| } | |
| ) | |
| return result | |
| def _openai_msgs_to_anthropic( | |
| messages: list[dict[str, Any]], | |
| ) -> tuple[str, list[dict[str, Any]]]: | |
| """Convert OpenAI-format messages to Anthropic format. | |
| Returns ``(system_text, anthropic_messages)``. System-role messages are | |
| extracted and concatenated; tool-result messages are converted to | |
| Anthropic's ``tool_result`` content blocks inside user turns. | |
| """ | |
| system_parts: list[str] = [] | |
| anthropic_msgs: list[dict[str, Any]] = [] | |
| for msg in messages: | |
| role = msg["role"] | |
| if role == "system": | |
| system_parts.append(msg["content"]) | |
| elif role == "user": | |
| anthropic_msgs.append({"role": "user", "content": msg["content"]}) | |
| elif role == "assistant": | |
| if msg.get("tool_calls"): | |
| content: list[dict[str, Any]] = [] | |
| if msg.get("content"): | |
| content.append({"type": "text", "text": msg["content"]}) | |
| for tc in msg["tool_calls"]: | |
| args = tc["function"]["arguments"] | |
| if isinstance(args, str): | |
| args = json.loads(args) | |
| content.append( | |
| { | |
| "type": "tool_use", | |
| "id": tc["id"], | |
| "name": tc["function"]["name"], | |
| "input": args, | |
| } | |
| ) | |
| anthropic_msgs.append({"role": "assistant", "content": content}) | |
| else: | |
| anthropic_msgs.append( | |
| {"role": "assistant", "content": msg.get("content", "")} | |
| ) | |
| elif role == "tool": | |
| tool_result = { | |
| "type": "tool_result", | |
| "tool_use_id": msg["tool_call_id"], | |
| "content": msg["content"], | |
| } | |
| # Anthropic requires tool results in user turns; merge if possible. | |
| if ( | |
| anthropic_msgs | |
| and anthropic_msgs[-1]["role"] == "user" | |
| and isinstance(anthropic_msgs[-1]["content"], list) | |
| ): | |
| anthropic_msgs[-1]["content"].append(tool_result) | |
| else: | |
| anthropic_msgs.append({"role": "user", "content": [tool_result]}) | |
| system = "\n\n".join(system_parts) | |
| return system, anthropic_msgs | |