File size: 4,082 Bytes
15503f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599c9bd
15503f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599c9bd
 
 
 
 
 
15503f9
 
599c9bd
 
 
 
 
 
 
 
 
15503f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
LLM abstraction layer using LiteLLM.

Supports any model LiteLLM supports — switch with a single string:
  - OpenAI:     "gpt-4o", "gpt-5.4", "o3-pro"
  - Anthropic:  "claude-opus-4-6", "claude-sonnet-4-6"
  - Local:      "ollama/llama3", "ollama/mistral"
  - And 100+ more providers

API keys are read from environment variables (loaded from root .env):
  OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.

Usage:
    from agent.llm import LLMClient

    llm = LLMClient(model="gpt-4o")
    response = llm.chat(
        messages=[{"role": "user", "content": "Hello"}],
        tools=[...],
    )
"""

import json
import logging
from typing import Any, Dict, List, Optional

import litellm

logger = logging.getLogger(__name__)


class LLMClient:
    """
    Thin wrapper around LiteLLM for consistent tool-calling across providers.

    The same code works whether you're hitting GPT-4o, Claude, or a local
    Ollama model — LiteLLM handles the translation.
    """

    _REASONING_MODELS = {"o3-pro", "o3-mini", "o3", "o1", "o1-mini", "o1-pro", "gpt-5"}

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        max_tokens: int = 1024,
    ):
        self.model = model
        self.usage_log: list = []

        if model in self._REASONING_MODELS:
            self.temperature = 1.0
            self.max_tokens = max(max_tokens, 4096)
            if temperature != 1.0:
                logger.info(f"Model {model} requires temperature=1.0, overriding from {temperature}")
        else:
            self.temperature = temperature
            self.max_tokens = max_tokens

    def chat(
        self,
        messages: List[Dict[str, Any]],
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> Any:
        """
        Send messages to the LLM and get a response.

        Args:
            messages: Conversation history in OpenAI format
            tools: Optional list of tools in OpenAI function-calling format

        Returns:
            LiteLLM ModelResponse (same shape as OpenAI ChatCompletion).
        """
        kwargs: Dict[str, Any] = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }

        if tools:
            kwargs["tools"] = tools
            kwargs["tool_choice"] = "auto"

        logger.debug(f"LLM request: model={self.model}, messages={len(messages)}, tools={len(tools or [])}")
        response = litellm.completion(**kwargs)
        logger.debug(f"LLM response: finish_reason={response.choices[0].finish_reason}")

        if hasattr(response, "usage") and response.usage:
            self.usage_log.append({
                "prompt_tokens": getattr(response.usage, "prompt_tokens", 0) or 0,
                "completion_tokens": getattr(response.usage, "completion_tokens", 0) or 0,
            })

        return response

    @property
    def total_usage(self) -> Dict[str, int]:
        """Aggregate token usage across all calls."""
        return {
            "prompt_tokens": sum(u["prompt_tokens"] for u in self.usage_log),
            "completion_tokens": sum(u["completion_tokens"] for u in self.usage_log),
            "total_calls": len(self.usage_log),
        }

    @staticmethod
    def extract_tool_calls(response) -> List[Dict[str, Any]]:
        """Extract tool calls from an LLM response."""
        choice = response.choices[0]
        if not choice.message.tool_calls:
            return []

        calls = []
        for tc in choice.message.tool_calls:
            args = tc.function.arguments
            if isinstance(args, str):
                args = json.loads(args)
            calls.append({
                "id": tc.id,
                "name": tc.function.name,
                "arguments": args,
            })
        return calls

    @staticmethod
    def get_text_response(response) -> Optional[str]:
        """Extract plain text content from an LLM response (if any)."""
        choice = response.choices[0]
        return choice.message.content