visual_memory / agent /llm.py
kdemon1011's picture
Upload folder using huggingface_hub
599c9bd verified
"""
LLM abstraction layer using LiteLLM.
Supports any model LiteLLM supports — switch with a single string:
- OpenAI: "gpt-4o", "gpt-5.4", "o3-pro"
- Anthropic: "claude-opus-4-6", "claude-sonnet-4-6"
- Local: "ollama/llama3", "ollama/mistral"
- And 100+ more providers
API keys are read from environment variables (loaded from root .env):
OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.
Usage:
from agent.llm import LLMClient
llm = LLMClient(model="gpt-4o")
response = llm.chat(
messages=[{"role": "user", "content": "Hello"}],
tools=[...],
)
"""
import json
import logging
from typing import Any, Dict, List, Optional
import litellm
logger = logging.getLogger(__name__)
class LLMClient:
"""
Thin wrapper around LiteLLM for consistent tool-calling across providers.
The same code works whether you're hitting GPT-4o, Claude, or a local
Ollama model — LiteLLM handles the translation.
"""
_REASONING_MODELS = {"o3-pro", "o3-mini", "o3", "o1", "o1-mini", "o1-pro", "gpt-5"}
def __init__(
self,
model: str,
temperature: float = 0.0,
max_tokens: int = 1024,
):
self.model = model
self.usage_log: list = []
if model in self._REASONING_MODELS:
self.temperature = 1.0
self.max_tokens = max(max_tokens, 4096)
if temperature != 1.0:
logger.info(f"Model {model} requires temperature=1.0, overriding from {temperature}")
else:
self.temperature = temperature
self.max_tokens = max_tokens
def chat(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> Any:
"""
Send messages to the LLM and get a response.
Args:
messages: Conversation history in OpenAI format
tools: Optional list of tools in OpenAI function-calling format
Returns:
LiteLLM ModelResponse (same shape as OpenAI ChatCompletion).
"""
kwargs: Dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
logger.debug(f"LLM request: model={self.model}, messages={len(messages)}, tools={len(tools or [])}")
response = litellm.completion(**kwargs)
logger.debug(f"LLM response: finish_reason={response.choices[0].finish_reason}")
if hasattr(response, "usage") and response.usage:
self.usage_log.append({
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(response.usage, "completion_tokens", 0) or 0,
})
return response
@property
def total_usage(self) -> Dict[str, int]:
"""Aggregate token usage across all calls."""
return {
"prompt_tokens": sum(u["prompt_tokens"] for u in self.usage_log),
"completion_tokens": sum(u["completion_tokens"] for u in self.usage_log),
"total_calls": len(self.usage_log),
}
@staticmethod
def extract_tool_calls(response) -> List[Dict[str, Any]]:
"""Extract tool calls from an LLM response."""
choice = response.choices[0]
if not choice.message.tool_calls:
return []
calls = []
for tc in choice.message.tool_calls:
args = tc.function.arguments
if isinstance(args, str):
args = json.loads(args)
calls.append({
"id": tc.id,
"name": tc.function.name,
"arguments": args,
})
return calls
@staticmethod
def get_text_response(response) -> Optional[str]:
"""Extract plain text content from an LLM response (if any)."""
choice = response.choices[0]
return choice.message.content