File size: 9,605 Bytes
676582c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
Agent Runner

Core orchestrator for AI agent execution with tool calling support.
Manages the full request cycle: LLM generation → tool execution → final response.
"""

import logging
from typing import List, Dict, Any, Optional
import asyncio

from .agent_config import AgentConfiguration
from .providers.base import LLMProvider
from .providers.gemini import GeminiProvider
from .providers.openrouter import OpenRouterProvider
from .providers.cohere import CohereProvider
from ..mcp.tool_registry import MCPToolRegistry, ToolExecutionResult

logger = logging.getLogger(__name__)


class AgentRunner:
    """
    Agent execution orchestrator with tool calling support.

    This class manages the full agent request cycle:
    1. Generate LLM response with tool definitions
    2. If tool calls requested, execute tools with user context injection
    3. Generate final response with tool results
    4. Handle rate limiting with fallback providers
    """

    def __init__(self, config: AgentConfiguration, tool_registry: MCPToolRegistry):
        """
        Initialize the agent runner.

        Args:
            config: Agent configuration
            tool_registry: MCP tool registry
        """
        self.config = config
        self.tool_registry = tool_registry
        self.primary_provider = self._create_provider(config.provider)
        self.fallback_provider = None

        if config.fallback_provider:
            self.fallback_provider = self._create_provider(config.fallback_provider)

        logger.info(f"Initialized AgentRunner with provider: {config.provider}")

    def _create_provider(self, provider_name: str) -> LLMProvider:
        """
        Create an LLM provider instance.

        Args:
            provider_name: Provider name (gemini, openrouter, cohere)

        Returns:
            LLMProvider instance

        Raises:
            ValueError: If provider is not supported or API key is missing
        """
        api_key = self.config.get_provider_api_key(provider_name)
        if not api_key:
            raise ValueError(f"API key not configured for provider: {provider_name}")

        model = self.config.get_provider_model(provider_name)

        if provider_name == "gemini":
            return GeminiProvider(
                api_key=api_key,
                model=model,
                temperature=self.config.temperature,
                max_tokens=self.config.max_tokens
            )
        elif provider_name == "openrouter":
            return OpenRouterProvider(
                api_key=api_key,
                model=model,
                temperature=self.config.temperature,
                max_tokens=self.config.max_tokens
            )
        elif provider_name == "cohere":
            return CohereProvider(
                api_key=api_key,
                model=model,
                temperature=self.config.temperature,
                max_tokens=self.config.max_tokens
            )
        else:
            raise ValueError(f"Unsupported provider: {provider_name}")

    async def execute(
        self,
        messages: List[Dict[str, str]],
        user_id: int,
        system_prompt: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Execute agent request with tool calling support.

        SECURITY: user_id is injected by backend, never from LLM output.

        Args:
            messages: Conversation history [{"role": "user", "content": "..."}]
            user_id: User ID (injected by backend for security)
            system_prompt: Optional system prompt (uses config default if not provided)

        Returns:
            Dict with response content and metadata
        """
        prompt = system_prompt or self.config.system_prompt
        provider = self.primary_provider

        try:
            # Get tool definitions
            tool_definitions = self.tool_registry.get_tool_definitions()

            logger.info(f"Executing agent for user {user_id} with {len(tool_definitions)} tools")

            # Generate initial response with tool definitions
            response = await provider.generate_response_with_tools(
                messages=messages,
                system_prompt=prompt,
                tools=tool_definitions
            )

            # Check if tool calls were requested
            if response.tool_calls:
                logger.info(f"Agent requested {len(response.tool_calls)} tool calls")

                # Execute all tool calls
                tool_results = []
                for tool_call in response.tool_calls:
                    result = await self.tool_registry.execute_tool(
                        tool_name=tool_call["name"],
                        arguments=tool_call["arguments"],
                        user_id=user_id  # Inject user context for security
                    )
                    tool_results.append(result)

                # Generate final response with tool results
                final_response = await provider.generate_response_with_tool_results(
                    messages=messages,
                    tool_calls=response.tool_calls,
                    tool_results=tool_results
                )

                return {
                    "content": final_response.content,
                    "tool_calls": response.tool_calls,
                    "tool_results": tool_results,
                    "provider": provider.get_provider_name()
                }

            # No tool calls, return direct response
            logger.info("Agent generated direct response (no tool calls)")
            return {
                "content": response.content,
                "tool_calls": None,
                "tool_results": None,
                "provider": provider.get_provider_name()
            }

        except Exception as e:
            logger.error(f"Agent execution failed with primary provider: {str(e)}")

            # Try fallback provider if configured
            if self.fallback_provider:
                logger.info("Attempting fallback provider")
                try:
                    return await self._execute_with_provider(
                        provider=self.fallback_provider,
                        messages=messages,
                        user_id=user_id,
                        system_prompt=prompt
                    )
                except Exception as fallback_error:
                    logger.error(f"Fallback provider also failed: {str(fallback_error)}")
                    raise

            raise

    async def _execute_with_provider(
        self,
        provider: LLMProvider,
        messages: List[Dict[str, str]],
        user_id: int,
        system_prompt: str
    ) -> Dict[str, Any]:
        """
        Execute agent request with a specific provider.

        Args:
            provider: LLM provider to use
            messages: Conversation history
            user_id: User ID
            system_prompt: System prompt

        Returns:
            Dict with response content and metadata
        """
        tool_definitions = self.tool_registry.get_tool_definitions()

        # Generate initial response
        response = await provider.generate_response_with_tools(
            messages=messages,
            system_prompt=system_prompt,
            tools=tool_definitions
        )

        # Handle tool calls
        if response.tool_calls:
            tool_results = []
            for tool_call in response.tool_calls:
                result = await self.tool_registry.execute_tool(
                    tool_name=tool_call["name"],
                    arguments=tool_call["arguments"],
                    user_id=user_id
                )
                tool_results.append(result)

            final_response = await provider.generate_response_with_tool_results(
                messages=messages,
                tool_calls=response.tool_calls,
                tool_results=tool_results
            )

            return {
                "content": final_response.content,
                "tool_calls": response.tool_calls,
                "tool_results": tool_results,
                "provider": provider.get_provider_name()
            }

        return {
            "content": response.content,
            "tool_calls": None,
            "tool_results": None,
            "provider": provider.get_provider_name()
        }

    async def execute_simple(
        self,
        messages: List[Dict[str, str]],
        system_prompt: Optional[str] = None
    ) -> str:
        """
        Execute a simple agent request without tool calling.

        Args:
            messages: Conversation history
            system_prompt: Optional system prompt

        Returns:
            Response content as string
        """
        prompt = system_prompt or self.config.system_prompt
        provider = self.primary_provider

        try:
            response = await provider.generate_simple_response(
                messages=messages,
                system_prompt=prompt
            )
            return response.content or ""

        except Exception as e:
            logger.error(f"Simple execution failed: {str(e)}")

            # Try fallback provider
            if self.fallback_provider:
                logger.info("Attempting fallback provider for simple execution")
                response = await self.fallback_provider.generate_simple_response(
                    messages=messages,
                    system_prompt=prompt
                )
                return response.content or ""

            raise