chat_env / src /core /llm_client.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
25bcc11 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""LLM client abstraction for calling LLM endpoints.
Provides a generic RPC abstraction: point it at an endpoint/port, tell it the
protocol, and it works. OpenAI-compatible API is the first implementation,
covering OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, etc.
Anthropic's native API is supported via ``AnthropicClient``.
Usage:
client = OpenAIClient("http://localhost", 8000, model="meta-llama/...")
response = await client.complete("What is 2+2?")
# Or use the factory for hosted APIs:
client = create_llm_client("openai", model="gpt-4", api_key="sk-...")
response = await client.complete_with_tools(messages, tools)
"""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from openai import AsyncOpenAI
@dataclass
class ToolCall:
"""A single tool/function call returned by the model."""
id: str
name: str
args: dict[str, Any]
@dataclass
class LLMResponse:
"""Normalized response from an LLM, with optional tool calls."""
content: str
tool_calls: list[ToolCall] = field(default_factory=list)
def to_message_dict(self) -> dict[str, Any]:
"""Convert to an OpenAI-format assistant message dict."""
msg: dict[str, Any] = {"role": "assistant", "content": self.content}
if self.tool_calls:
msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.args),
},
}
for tc in self.tool_calls
]
return msg
class LLMClient(ABC):
"""Abstract base for LLM endpoint clients.
Subclass and implement ``complete()`` for your protocol.
Args:
endpoint: The base URL of the LLM service (e.g. "http://localhost").
port: The port the service listens on.
"""
def __init__(self, endpoint: str, port: int):
self.endpoint = endpoint
self.port = port
@abstractmethod
async def complete(self, prompt: str, **kwargs) -> str:
"""Send a prompt, return the text response.
Args:
prompt: The user prompt to send.
**kwargs: Override default parameters (temperature, max_tokens, etc.).
Returns:
The model's text response.
"""
...
async def complete_with_tools(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]],
**kwargs: Any,
) -> LLMResponse:
"""Send messages with tool definitions, return a normalized response.
Messages use OpenAI-format dicts (``{"role": "...", "content": "..."}``).
Tools use MCP tool definitions; they are converted internally.
Args:
messages: Conversation history as OpenAI-format message dicts.
tools: MCP tool definitions.
**kwargs: Override default parameters (temperature, max_tokens, etc.).
Returns:
An ``LLMResponse`` with the model's text and any tool calls.
"""
raise NotImplementedError(
f"{type(self).__name__} does not support tool calling"
)
@property
def base_url(self) -> str:
"""Construct base URL from endpoint and port."""
return f"{self.endpoint}:{self.port}"
class OpenAIClient(LLMClient):
"""Client for OpenAI-compatible APIs.
Works with: OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API,
or any endpoint that speaks the OpenAI chat completions format.
Args:
endpoint: The base URL (e.g. "http://localhost").
port: The port number.
model: Model name to pass to the API.
api_key: API key. Defaults to "not-needed" for local endpoints.
system_prompt: Optional system message prepended to every request.
temperature: Default sampling temperature.
max_tokens: Default max tokens in the response.
"""
def __init__(
self,
endpoint: str,
port: int,
model: str,
api_key: str | None = None,
system_prompt: str | None = None,
temperature: float = 0.0,
max_tokens: int = 256,
):
super().__init__(endpoint, port)
self.model = model
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
self._client = AsyncOpenAI(
base_url=f"{self.base_url}/v1",
api_key=api_key if api_key is not None else "not-needed",
)
async def complete(self, prompt: str, **kwargs) -> str:
"""Send a chat completion request.
Args:
prompt: The user message.
**kwargs: Overrides for temperature, max_tokens.
Returns:
The assistant's response text.
"""
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": prompt})
response = await self._client.chat.completions.create(
model=self.model,
messages=messages,
temperature=kwargs.get("temperature", self.temperature),
max_tokens=kwargs.get("max_tokens", self.max_tokens),
)
return response.choices[0].message.content or ""
async def complete_with_tools(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]],
**kwargs: Any,
) -> LLMResponse:
create_kwargs: dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
}
openai_tools = _mcp_tools_to_openai(tools)
if openai_tools:
create_kwargs["tools"] = openai_tools
response = await self._client.chat.completions.create(**create_kwargs)
msg = response.choices[0].message
tool_calls = []
if msg.tool_calls:
for tc in msg.tool_calls:
tool_calls.append(
ToolCall(
id=tc.id,
name=tc.function.name,
args=json.loads(tc.function.arguments),
)
)
return LLMResponse(content=msg.content or "", tool_calls=tool_calls)
class AnthropicClient(LLMClient):
"""Client for Anthropic's Messages API.
Requires the ``anthropic`` package (lazy-imported at construction time).
Args:
endpoint: The base URL (e.g. "https://api.anthropic.com").
port: The port number.
model: Model name (e.g. "claude-sonnet-4-20250514").
api_key: Anthropic API key.
system_prompt: Optional system message prepended to every request.
temperature: Default sampling temperature.
max_tokens: Default max tokens in the response.
"""
def __init__(
self,
endpoint: str,
port: int,
model: str,
api_key: str | None = None,
system_prompt: str | None = None,
temperature: float = 0.0,
max_tokens: int = 256,
):
super().__init__(endpoint, port)
self.model = model
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
try:
from anthropic import AsyncAnthropic
except ImportError as exc:
raise ImportError(
"AnthropicClient requires the 'anthropic' package. "
"Install it with: pip install anthropic"
) from exc
self._client = AsyncAnthropic(
base_url=self.base_url,
api_key=api_key if api_key is not None else "not-needed",
)
async def complete(self, prompt: str, **kwargs) -> str:
create_kwargs: dict[str, Any] = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
}
if self.system_prompt:
create_kwargs["system"] = self.system_prompt
response = await self._client.messages.create(**create_kwargs)
return "".join(block.text for block in response.content if block.type == "text")
async def complete_with_tools(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]],
**kwargs: Any,
) -> LLMResponse:
system, anthropic_msgs = _openai_msgs_to_anthropic(messages)
create_kwargs: dict[str, Any] = {
"model": self.model,
"messages": anthropic_msgs,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
}
system_text = system or self.system_prompt
if system_text:
create_kwargs["system"] = system_text
anthropic_tools = _mcp_tools_to_anthropic(tools)
if anthropic_tools:
create_kwargs["tools"] = anthropic_tools
response = await self._client.messages.create(**create_kwargs)
content = ""
tool_calls = []
for block in response.content:
if block.type == "text":
content += block.text
elif block.type == "tool_use":
tool_calls.append(
ToolCall(id=block.id, name=block.name, args=block.input)
)
return LLMResponse(content=content, tool_calls=tool_calls)
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
_HOSTED_PROVIDERS: dict[str, tuple[str, int, type[LLMClient]]] = {
"openai": ("https://api.openai.com", 443, OpenAIClient),
"anthropic": ("https://api.anthropic.com", 443, AnthropicClient),
}
def create_llm_client(
provider: str,
model: str,
api_key: str,
*,
system_prompt: str | None = None,
temperature: float = 0.0,
max_tokens: int = 4096,
) -> LLMClient:
"""Create an LLM client for a hosted provider.
Args:
provider: Provider name ("openai" or "anthropic").
model: Model identifier.
api_key: API key for the provider.
system_prompt: Optional system message prepended to every request.
temperature: Sampling temperature.
max_tokens: Maximum tokens in the response.
Returns:
A configured ``LLMClient`` instance.
"""
key = provider.lower()
if key not in _HOSTED_PROVIDERS:
raise ValueError(
f"Unsupported provider: {provider!r}. "
f"Supported: {sorted(_HOSTED_PROVIDERS)}"
)
endpoint, port, cls = _HOSTED_PROVIDERS[key]
return cls(
endpoint,
port,
model,
api_key=api_key,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
)
# ---------------------------------------------------------------------------
# MCP tool-schema helpers
# ---------------------------------------------------------------------------
def _clean_mcp_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Normalize an MCP tool ``inputSchema`` for LLM function-calling APIs."""
if not isinstance(schema, dict):
return {"type": "object", "properties": {}, "required": []}
# Shallow copy to avoid mutating the caller's schema dict.
schema = dict(schema)
if "oneOf" in schema:
for option in schema["oneOf"]:
if isinstance(option, dict) and option.get("type") == "object":
schema = option
break
else:
return {"type": "object", "properties": {}, "required": []}
if "allOf" in schema:
merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for sub in schema["allOf"]:
if isinstance(sub, dict):
if "properties" in sub:
merged["properties"].update(sub["properties"])
if "required" in sub:
merged["required"].extend(sub["required"])
schema = merged
if "anyOf" in schema:
for option in schema["anyOf"]:
if isinstance(option, dict) and option.get("type") == "object":
schema = option
break
else:
return {"type": "object", "properties": {}, "required": []}
schema.setdefault("type", "object")
if schema.get("type") == "object" and "properties" not in schema:
schema["properties"] = {}
return schema
def _mcp_tools_to_openai(
mcp_tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Convert MCP tool definitions to OpenAI function-calling format."""
result = []
for tool in mcp_tools:
input_schema = tool.get(
"inputSchema", {"type": "object", "properties": {}, "required": []}
)
result.append(
{
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": _clean_mcp_schema(input_schema),
},
}
)
return result
def _mcp_tools_to_anthropic(
mcp_tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Convert MCP tool definitions to Anthropic tool format."""
result = []
for tool in mcp_tools:
input_schema = tool.get(
"inputSchema", {"type": "object", "properties": {}, "required": []}
)
result.append(
{
"name": tool["name"],
"description": tool.get("description", ""),
"input_schema": _clean_mcp_schema(input_schema),
}
)
return result
def _openai_msgs_to_anthropic(
messages: list[dict[str, Any]],
) -> tuple[str, list[dict[str, Any]]]:
"""Convert OpenAI-format messages to Anthropic format.
Returns ``(system_text, anthropic_messages)``. System-role messages are
extracted and concatenated; tool-result messages are converted to
Anthropic's ``tool_result`` content blocks inside user turns.
"""
system_parts: list[str] = []
anthropic_msgs: list[dict[str, Any]] = []
for msg in messages:
role = msg["role"]
if role == "system":
system_parts.append(msg["content"])
elif role == "user":
anthropic_msgs.append({"role": "user", "content": msg["content"]})
elif role == "assistant":
if msg.get("tool_calls"):
content: list[dict[str, Any]] = []
if msg.get("content"):
content.append({"type": "text", "text": msg["content"]})
for tc in msg["tool_calls"]:
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json.loads(args)
content.append(
{
"type": "tool_use",
"id": tc["id"],
"name": tc["function"]["name"],
"input": args,
}
)
anthropic_msgs.append({"role": "assistant", "content": content})
else:
anthropic_msgs.append(
{"role": "assistant", "content": msg.get("content", "")}
)
elif role == "tool":
tool_result = {
"type": "tool_result",
"tool_use_id": msg["tool_call_id"],
"content": msg["content"],
}
# Anthropic requires tool results in user turns; merge if possible.
if (
anthropic_msgs
and anthropic_msgs[-1]["role"] == "user"
and isinstance(anthropic_msgs[-1]["content"], list)
):
anthropic_msgs[-1]["content"].append(tool_result)
else:
anthropic_msgs.append({"role": "user", "content": [tool_result]})
system = "\n\n".join(system_parts)
return system, anthropic_msgs