dduseja commited on
Commit
b8daa00
·
verified ·
1 Parent(s): 54c320b

Upload arcgis_test_agent/agent.py

Browse files
Files changed (1) hide show
  1. arcgis_test_agent/agent.py +179 -0
arcgis_test_agent/agent.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core agent loop for the ArcGIS Test Script Generator.
3
+
4
+ This module implements the main agent that orchestrates:
5
+ 1. Multi-turn conversations with Azure OpenAI
6
+ 2. Tool dispatching
7
+ 3. Output extraction and file writing
8
+ """
9
+
10
+ import json
11
+ import logging
12
+ import time
13
+ from dataclasses import dataclass, field
14
+ from typing import Optional
15
+
16
+ from openai import AzureOpenAI
17
+
18
+ from .config import AzureConfig, AgentConfig, load_config
19
+ from .tools import TOOL_SCHEMAS, ToolHandlers
20
+ from .prompts import SYSTEM_PROMPT
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class AgentResult:
27
+ """Result from an agent run."""
28
+ success: bool
29
+ script: Optional[str] = None # The generated test script
30
+ tool_name: str = "" # Tool we generated tests for
31
+ iterations: int = 0 # How many loop iterations
32
+ tool_calls_made: list = field(default_factory=list) # Log of all tool calls
33
+ error: Optional[str] = None # Error message if failed
34
+ duration_seconds: float = 0.0
35
+
36
+
37
+ class TestGeneratorAgent:
38
+ """
39
+ The main agent that generates ArcGIS test scripts.
40
+
41
+ Usage:
42
+ agent = TestGeneratorAgent()
43
+ result = agent.generate("Buffer")
44
+ if result.success:
45
+ print(result.script)
46
+ """
47
+
48
+ def __init__(self, azure_config: Optional[AzureConfig] = None,
49
+ agent_config: Optional[AgentConfig] = None):
50
+ if azure_config is None or agent_config is None:
51
+ _azure, _agent = load_config()
52
+ azure_config = azure_config or _azure
53
+ agent_config = agent_config or _agent
54
+
55
+ self.azure_config = azure_config
56
+ self.agent_config = agent_config
57
+
58
+ self.client = AzureOpenAI(
59
+ azure_endpoint=azure_config.endpoint,
60
+ api_key=azure_config.api_key,
61
+ api_version=azure_config.api_version,
62
+ )
63
+
64
+ self.tools = ToolHandlers(agent_config)
65
+
66
+ logger.info(
67
+ f"Agent initialized. Model: {azure_config.deployment_name}, "
68
+ f"Max iterations: {agent_config.max_iterations}"
69
+ )
70
+
71
+ def generate(self, tool_name: str, additional_context: str = "") -> AgentResult:
72
+ """Generate a test script for the given ArcGIS tool."""
73
+ start_time = time.time()
74
+ tool_calls_log = []
75
+
76
+ user_message = f"Generate a complete test script for the ArcGIS tool: {tool_name}"
77
+ if additional_context:
78
+ user_message += f"\n\nAdditional context: {additional_context}"
79
+
80
+ messages = [
81
+ {"role": "system", "content": SYSTEM_PROMPT},
82
+ {"role": "user", "content": user_message},
83
+ ]
84
+
85
+ logger.info(f"Starting generation for tool: {tool_name}")
86
+
87
+ for iteration in range(1, self.agent_config.max_iterations + 1):
88
+ logger.debug(f"Iteration {iteration}/{self.agent_config.max_iterations}")
89
+
90
+ try:
91
+ response = self.client.chat.completions.create(
92
+ model=self.azure_config.deployment_name,
93
+ messages=messages,
94
+ tools=TOOL_SCHEMAS,
95
+ tool_choice="auto",
96
+ temperature=self.agent_config.temperature,
97
+ )
98
+ except Exception as e:
99
+ logger.error(f"API call failed at iteration {iteration}: {e}")
100
+ return AgentResult(
101
+ success=False, tool_name=tool_name, iterations=iteration,
102
+ tool_calls_made=tool_calls_log, error=f"API error: {str(e)}",
103
+ duration_seconds=time.time() - start_time,
104
+ )
105
+
106
+ msg = response.choices[0].message
107
+ messages.append(msg)
108
+
109
+ if msg.tool_calls:
110
+ for tc in msg.tool_calls:
111
+ func_name = tc.function.name
112
+ func_args = json.loads(tc.function.arguments)
113
+
114
+ logger.info(f" Tool call: {func_name}({json.dumps(func_args)[:100]}...)")
115
+ tool_calls_log.append({
116
+ "iteration": iteration, "tool": func_name, "arguments": func_args,
117
+ })
118
+
119
+ result = self.tools.dispatch(func_name, func_args)
120
+ messages.append({
121
+ "role": "tool", "tool_call_id": tc.id,
122
+ "content": json.dumps(result, default=str),
123
+ })
124
+ else:
125
+ final_output = msg.content
126
+ if final_output:
127
+ script = self._extract_script(final_output)
128
+ duration = time.time() - start_time
129
+ logger.info(
130
+ f"Generation complete. {iteration} iterations, "
131
+ f"{len(tool_calls_log)} tool calls, {duration:.1f}s"
132
+ )
133
+ return AgentResult(
134
+ success=True, script=script, tool_name=tool_name,
135
+ iterations=iteration, tool_calls_made=tool_calls_log,
136
+ duration_seconds=duration,
137
+ )
138
+ else:
139
+ logger.warning("Empty response from model, retrying...")
140
+ messages.append({
141
+ "role": "user",
142
+ "content": "Your response was empty. Please output the complete test script."
143
+ })
144
+
145
+ duration = time.time() - start_time
146
+ logger.error(f"Max iterations ({self.agent_config.max_iterations}) reached without completion.")
147
+ return AgentResult(
148
+ success=False, tool_name=tool_name,
149
+ iterations=self.agent_config.max_iterations,
150
+ tool_calls_made=tool_calls_log,
151
+ error=f"Exceeded max iterations ({self.agent_config.max_iterations})",
152
+ duration_seconds=duration,
153
+ )
154
+
155
+ def _extract_script(self, raw_output: str) -> str:
156
+ """Extract clean Python code from the model's output."""
157
+ text = raw_output.strip()
158
+ if text.startswith("```python"):
159
+ text = text[len("```python"):].strip()
160
+ elif text.startswith("```"):
161
+ text = text[len("```"):].strip()
162
+ if text.endswith("```"):
163
+ text = text[:-3].strip()
164
+ return text
165
+
166
+ def generate_and_save(self, tool_name: str, output_path: Optional[str] = None,
167
+ additional_context: str = "") -> AgentResult:
168
+ """Generate a test script and save it to disk."""
169
+ import os
170
+ result = self.generate(tool_name, additional_context)
171
+ if result.success and result.script:
172
+ if output_path is None:
173
+ os.makedirs(self.agent_config.output_dir, exist_ok=True)
174
+ safe_name = tool_name.lower().replace(" ", "_")
175
+ output_path = os.path.join(self.agent_config.output_dir, f"test_{safe_name}.py")
176
+ with open(output_path, "w") as f:
177
+ f.write(result.script)
178
+ logger.info(f"Test script saved to: {output_path}")
179
+ return result