# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """LLM client abstraction for calling LLM endpoints. Provides a generic RPC abstraction: point it at an endpoint/port, tell it the protocol, and it works. OpenAI-compatible API is the first implementation, covering OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, etc. Anthropic's native API is supported via ``AnthropicClient``. Usage: client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") response = await client.complete("What is 2+2?") # Or use the factory for hosted APIs: client = create_llm_client("openai", model="gpt-4", api_key="sk-...") response = await client.complete_with_tools(messages, tools) """ from __future__ import annotations import json from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any from openai import AsyncOpenAI @dataclass class ToolCall: """A single tool/function call returned by the model.""" id: str name: str args: dict[str, Any] @dataclass class LLMResponse: """Normalized response from an LLM, with optional tool calls.""" content: str tool_calls: list[ToolCall] = field(default_factory=list) def to_message_dict(self) -> dict[str, Any]: """Convert to an OpenAI-format assistant message dict.""" msg: dict[str, Any] = {"role": "assistant", "content": self.content} if self.tool_calls: msg["tool_calls"] = [ { "id": tc.id, "type": "function", "function": { "name": tc.name, "arguments": json.dumps(tc.args), }, } for tc in self.tool_calls ] return msg class LLMClient(ABC): """Abstract base for LLM endpoint clients. Subclass and implement ``complete()`` for your protocol. Args: endpoint: The base URL of the LLM service (e.g. "http://localhost"). port: The port the service listens on. """ def __init__(self, endpoint: str, port: int): self.endpoint = endpoint self.port = port @abstractmethod async def complete(self, prompt: str, **kwargs) -> str: """Send a prompt, return the text response. Args: prompt: The user prompt to send. **kwargs: Override default parameters (temperature, max_tokens, etc.). Returns: The model's text response. """ ... async def complete_with_tools( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]], **kwargs: Any, ) -> LLMResponse: """Send messages with tool definitions, return a normalized response. Messages use OpenAI-format dicts (``{"role": "...", "content": "..."}``). Tools use MCP tool definitions; they are converted internally. Args: messages: Conversation history as OpenAI-format message dicts. tools: MCP tool definitions. **kwargs: Override default parameters (temperature, max_tokens, etc.). Returns: An ``LLMResponse`` with the model's text and any tool calls. """ raise NotImplementedError( f"{type(self).__name__} does not support tool calling" ) @property def base_url(self) -> str: """Construct base URL from endpoint and port.""" return f"{self.endpoint}:{self.port}" class OpenAIClient(LLMClient): """Client for OpenAI-compatible APIs. Works with: OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, or any endpoint that speaks the OpenAI chat completions format. Args: endpoint: The base URL (e.g. "http://localhost"). port: The port number. model: Model name to pass to the API. api_key: API key. Defaults to "not-needed" for local endpoints. system_prompt: Optional system message prepended to every request. temperature: Default sampling temperature. max_tokens: Default max tokens in the response. """ def __init__( self, endpoint: str, port: int, model: str, api_key: str | None = None, system_prompt: str | None = None, temperature: float = 0.0, max_tokens: int = 256, ): super().__init__(endpoint, port) self.model = model self.system_prompt = system_prompt self.temperature = temperature self.max_tokens = max_tokens self._client = AsyncOpenAI( base_url=f"{self.base_url}/v1", api_key=api_key if api_key is not None else "not-needed", ) async def complete(self, prompt: str, **kwargs) -> str: """Send a chat completion request. Args: prompt: The user message. **kwargs: Overrides for temperature, max_tokens. Returns: The assistant's response text. """ messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) messages.append({"role": "user", "content": prompt}) response = await self._client.chat.completions.create( model=self.model, messages=messages, temperature=kwargs.get("temperature", self.temperature), max_tokens=kwargs.get("max_tokens", self.max_tokens), ) return response.choices[0].message.content or "" async def complete_with_tools( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]], **kwargs: Any, ) -> LLMResponse: create_kwargs: dict[str, Any] = { "model": self.model, "messages": messages, "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), } openai_tools = _mcp_tools_to_openai(tools) if openai_tools: create_kwargs["tools"] = openai_tools response = await self._client.chat.completions.create(**create_kwargs) msg = response.choices[0].message tool_calls = [] if msg.tool_calls: for tc in msg.tool_calls: tool_calls.append( ToolCall( id=tc.id, name=tc.function.name, args=json.loads(tc.function.arguments), ) ) return LLMResponse(content=msg.content or "", tool_calls=tool_calls) class AnthropicClient(LLMClient): """Client for Anthropic's Messages API. Requires the ``anthropic`` package (lazy-imported at construction time). Args: endpoint: The base URL (e.g. "https://api.anthropic.com"). port: The port number. model: Model name (e.g. "claude-sonnet-4-20250514"). api_key: Anthropic API key. system_prompt: Optional system message prepended to every request. temperature: Default sampling temperature. max_tokens: Default max tokens in the response. """ def __init__( self, endpoint: str, port: int, model: str, api_key: str | None = None, system_prompt: str | None = None, temperature: float = 0.0, max_tokens: int = 256, ): super().__init__(endpoint, port) self.model = model self.system_prompt = system_prompt self.temperature = temperature self.max_tokens = max_tokens try: from anthropic import AsyncAnthropic except ImportError as exc: raise ImportError( "AnthropicClient requires the 'anthropic' package. " "Install it with: pip install anthropic" ) from exc self._client = AsyncAnthropic( base_url=self.base_url, api_key=api_key if api_key is not None else "not-needed", ) async def complete(self, prompt: str, **kwargs) -> str: create_kwargs: dict[str, Any] = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), } if self.system_prompt: create_kwargs["system"] = self.system_prompt response = await self._client.messages.create(**create_kwargs) return "".join(block.text for block in response.content if block.type == "text") async def complete_with_tools( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]], **kwargs: Any, ) -> LLMResponse: system, anthropic_msgs = _openai_msgs_to_anthropic(messages) create_kwargs: dict[str, Any] = { "model": self.model, "messages": anthropic_msgs, "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), } system_text = system or self.system_prompt if system_text: create_kwargs["system"] = system_text anthropic_tools = _mcp_tools_to_anthropic(tools) if anthropic_tools: create_kwargs["tools"] = anthropic_tools response = await self._client.messages.create(**create_kwargs) content = "" tool_calls = [] for block in response.content: if block.type == "text": content += block.text elif block.type == "tool_use": tool_calls.append( ToolCall(id=block.id, name=block.name, args=block.input) ) return LLMResponse(content=content, tool_calls=tool_calls) # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- _HOSTED_PROVIDERS: dict[str, tuple[str, int, type[LLMClient]]] = { "openai": ("https://api.openai.com", 443, OpenAIClient), "anthropic": ("https://api.anthropic.com", 443, AnthropicClient), } def create_llm_client( provider: str, model: str, api_key: str, *, system_prompt: str | None = None, temperature: float = 0.0, max_tokens: int = 4096, ) -> LLMClient: """Create an LLM client for a hosted provider. Args: provider: Provider name ("openai" or "anthropic"). model: Model identifier. api_key: API key for the provider. system_prompt: Optional system message prepended to every request. temperature: Sampling temperature. max_tokens: Maximum tokens in the response. Returns: A configured ``LLMClient`` instance. """ key = provider.lower() if key not in _HOSTED_PROVIDERS: raise ValueError( f"Unsupported provider: {provider!r}. " f"Supported: {sorted(_HOSTED_PROVIDERS)}" ) endpoint, port, cls = _HOSTED_PROVIDERS[key] return cls( endpoint, port, model, api_key=api_key, system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens, ) # --------------------------------------------------------------------------- # MCP tool-schema helpers # --------------------------------------------------------------------------- def _clean_mcp_schema(schema: dict[str, Any]) -> dict[str, Any]: """Normalize an MCP tool ``inputSchema`` for LLM function-calling APIs.""" if not isinstance(schema, dict): return {"type": "object", "properties": {}, "required": []} # Shallow copy to avoid mutating the caller's schema dict. schema = dict(schema) if "oneOf" in schema: for option in schema["oneOf"]: if isinstance(option, dict) and option.get("type") == "object": schema = option break else: return {"type": "object", "properties": {}, "required": []} if "allOf" in schema: merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} for sub in schema["allOf"]: if isinstance(sub, dict): if "properties" in sub: merged["properties"].update(sub["properties"]) if "required" in sub: merged["required"].extend(sub["required"]) schema = merged if "anyOf" in schema: for option in schema["anyOf"]: if isinstance(option, dict) and option.get("type") == "object": schema = option break else: return {"type": "object", "properties": {}, "required": []} schema.setdefault("type", "object") if schema.get("type") == "object" and "properties" not in schema: schema["properties"] = {} return schema def _mcp_tools_to_openai( mcp_tools: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Convert MCP tool definitions to OpenAI function-calling format.""" result = [] for tool in mcp_tools: input_schema = tool.get( "inputSchema", {"type": "object", "properties": {}, "required": []} ) result.append( { "type": "function", "function": { "name": tool["name"], "description": tool.get("description", ""), "parameters": _clean_mcp_schema(input_schema), }, } ) return result def _mcp_tools_to_anthropic( mcp_tools: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Convert MCP tool definitions to Anthropic tool format.""" result = [] for tool in mcp_tools: input_schema = tool.get( "inputSchema", {"type": "object", "properties": {}, "required": []} ) result.append( { "name": tool["name"], "description": tool.get("description", ""), "input_schema": _clean_mcp_schema(input_schema), } ) return result def _openai_msgs_to_anthropic( messages: list[dict[str, Any]], ) -> tuple[str, list[dict[str, Any]]]: """Convert OpenAI-format messages to Anthropic format. Returns ``(system_text, anthropic_messages)``. System-role messages are extracted and concatenated; tool-result messages are converted to Anthropic's ``tool_result`` content blocks inside user turns. """ system_parts: list[str] = [] anthropic_msgs: list[dict[str, Any]] = [] for msg in messages: role = msg["role"] if role == "system": system_parts.append(msg["content"]) elif role == "user": anthropic_msgs.append({"role": "user", "content": msg["content"]}) elif role == "assistant": if msg.get("tool_calls"): content: list[dict[str, Any]] = [] if msg.get("content"): content.append({"type": "text", "text": msg["content"]}) for tc in msg["tool_calls"]: args = tc["function"]["arguments"] if isinstance(args, str): args = json.loads(args) content.append( { "type": "tool_use", "id": tc["id"], "name": tc["function"]["name"], "input": args, } ) anthropic_msgs.append({"role": "assistant", "content": content}) else: anthropic_msgs.append( {"role": "assistant", "content": msg.get("content", "")} ) elif role == "tool": tool_result = { "type": "tool_result", "tool_use_id": msg["tool_call_id"], "content": msg["content"], } # Anthropic requires tool results in user turns; merge if possible. if ( anthropic_msgs and anthropic_msgs[-1]["role"] == "user" and isinstance(anthropic_msgs[-1]["content"], list) ): anthropic_msgs[-1]["content"].append(tool_result) else: anthropic_msgs.append({"role": "user", "content": [tool_result]}) system = "\n\n".join(system_parts) return system, anthropic_msgs