Spaces:
Sleeping
Sleeping
File size: 4,082 Bytes
15503f9 599c9bd 15503f9 599c9bd 15503f9 599c9bd 15503f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
LLM abstraction layer using LiteLLM.
Supports any model LiteLLM supports — switch with a single string:
- OpenAI: "gpt-4o", "gpt-5.4", "o3-pro"
- Anthropic: "claude-opus-4-6", "claude-sonnet-4-6"
- Local: "ollama/llama3", "ollama/mistral"
- And 100+ more providers
API keys are read from environment variables (loaded from root .env):
OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.
Usage:
from agent.llm import LLMClient
llm = LLMClient(model="gpt-4o")
response = llm.chat(
messages=[{"role": "user", "content": "Hello"}],
tools=[...],
)
"""
import json
import logging
from typing import Any, Dict, List, Optional
import litellm
logger = logging.getLogger(__name__)
class LLMClient:
"""
Thin wrapper around LiteLLM for consistent tool-calling across providers.
The same code works whether you're hitting GPT-4o, Claude, or a local
Ollama model — LiteLLM handles the translation.
"""
_REASONING_MODELS = {"o3-pro", "o3-mini", "o3", "o1", "o1-mini", "o1-pro", "gpt-5"}
def __init__(
self,
model: str,
temperature: float = 0.0,
max_tokens: int = 1024,
):
self.model = model
self.usage_log: list = []
if model in self._REASONING_MODELS:
self.temperature = 1.0
self.max_tokens = max(max_tokens, 4096)
if temperature != 1.0:
logger.info(f"Model {model} requires temperature=1.0, overriding from {temperature}")
else:
self.temperature = temperature
self.max_tokens = max_tokens
def chat(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> Any:
"""
Send messages to the LLM and get a response.
Args:
messages: Conversation history in OpenAI format
tools: Optional list of tools in OpenAI function-calling format
Returns:
LiteLLM ModelResponse (same shape as OpenAI ChatCompletion).
"""
kwargs: Dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
logger.debug(f"LLM request: model={self.model}, messages={len(messages)}, tools={len(tools or [])}")
response = litellm.completion(**kwargs)
logger.debug(f"LLM response: finish_reason={response.choices[0].finish_reason}")
if hasattr(response, "usage") and response.usage:
self.usage_log.append({
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(response.usage, "completion_tokens", 0) or 0,
})
return response
@property
def total_usage(self) -> Dict[str, int]:
"""Aggregate token usage across all calls."""
return {
"prompt_tokens": sum(u["prompt_tokens"] for u in self.usage_log),
"completion_tokens": sum(u["completion_tokens"] for u in self.usage_log),
"total_calls": len(self.usage_log),
}
@staticmethod
def extract_tool_calls(response) -> List[Dict[str, Any]]:
"""Extract tool calls from an LLM response."""
choice = response.choices[0]
if not choice.message.tool_calls:
return []
calls = []
for tc in choice.message.tool_calls:
args = tc.function.arguments
if isinstance(args, str):
args = json.loads(args)
calls.append({
"id": tc.id,
"name": tc.function.name,
"arguments": args,
})
return calls
@staticmethod
def get_text_response(response) -> Optional[str]:
"""Extract plain text content from an LLM response (if any)."""
choice = response.choices[0]
return choice.message.content
|