bwilkie commited on
Commit
b1b9230
·
verified ·
1 Parent(s): a5cf9b0

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +238 -0
agent.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from langchain_core.output_parsers import PydanticOutputParser
4
+ from typing import Callable, Dict, List, Any
5
+ import time
6
+ import json
7
+ from groq_api import grok_get_llm_response, API_llama_get_llm_response, open_oss_get_llm_response, openai_get_llm_response, deepseekapi_get_llm_response
8
+ from local_templates import llama3_get_llm_response, mistral_get_llm_response, qwen_get_llm_response, deepseek_get_llm_response, grape_get_llm_response
9
+ import os
10
+ import re
11
+
12
+
13
+ max_steps = 15
14
+
15
+ base_dir = os.path.dirname(os.path.abspath(__file__))
16
+
17
+
18
+ def select_model(model_type: str):
19
+ """Return the correct LLM response function for a given model_type."""
20
+
21
+ mapping = {
22
+ "groq_api": grok_get_llm_response,
23
+ "llama_api": API_llama_get_llm_response,
24
+ "oss_api": open_oss_get_llm_response,
25
+ "openai_api": openai_get_llm_response,
26
+ "deepseek_api": deepseekapi_get_llm_response,
27
+ "llama3": llama3_get_llm_response,
28
+ "mistral": mistral_get_llm_response,
29
+ "qwen3": qwen_get_llm_response,
30
+ "deepseek": deepseek_get_llm_response,
31
+ "grape": grape_get_llm_response,
32
+ }
33
+
34
+ if model_type not in mapping:
35
+ raise ValueError(f"Unknown model_type: {model_type}")
36
+
37
+ return mapping[model_type]
38
+
39
+
40
+ def format_gaia_response(model_type, last_observation, question_out):
41
+
42
+ get_llm_response = select_model(model_type)
43
+
44
+ # Process Gaia
45
+ with open(base_dir+"/system_prompt_final.txt", "r") as f:
46
+ final_sys_prompt = f.read()
47
+
48
+ gaia_prompt = (
49
+ f"{final_sys_prompt}\n\n"
50
+ f"User Question:\n{question_out}\n\n"
51
+ f"Last Observation:\n{last_observation}\n\n"
52
+ "Please review user questions and the last obervation and respond with the correct answer, in the correct format. No extra text, just the answer."
53
+ )
54
+
55
+ final_answer_out = get_llm_response(final_sys_prompt, gaia_prompt, reasoning_format = 'hidden')
56
+
57
+ return final_answer_out
58
+
59
+
60
+ class ImprovedAgent:
61
+ def __init__(self, tools: Dict[str, Callable], model_type: str):
62
+ self.tools = tools
63
+ self.history = []
64
+ self.get_llm_response = select_model(model_type)
65
+
66
+
67
+ # Load system prompts from .txt files
68
+ self.system_prompt_plan = self.load_prompt(base_dir+"/system_prompt_planning.txt")
69
+ self.system_prompt_thought = self.load_prompt(base_dir+"/system_prompt_thought.txt")
70
+ self.system_prompt_action = self.load_prompt(base_dir+"/system_prompt_action.txt")
71
+ self.system_prompt_observe = self.load_prompt(base_dir+"/system_prompt_observe.txt")
72
+
73
+
74
+ def load_prompt(self, filepath: str) -> str:
75
+ with open(filepath, "r") as f:
76
+ return f.read()
77
+
78
+ def reset(self):
79
+ self.history = []
80
+ def strip_markdown_code_block(self, text: str) -> str:
81
+ """
82
+ Remove leading/trailing markdown code block markers like ```json or ```
83
+ """
84
+ # Remove leading ```json or ``` (case-insensitive, multiline-safe)
85
+ text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
86
+ # Remove trailing ```
87
+ text = re.sub(r"\s*```$", "", text)
88
+ return text.strip()
89
+ def parse_json_response(self, response_text: str) -> Dict:
90
+ """Attempt to parse LLM JSON response safely."""
91
+
92
+ try:
93
+
94
+ cleaned = self.strip_markdown_code_block(response_text.strip())
95
+
96
+ json_text = self.extract_json_string(cleaned)
97
+
98
+ json_text = json_text.replace("\\'", "'")
99
+ #json_text = json_text.replace("\n", "\\n")
100
+
101
+ return json.loads(json_text)
102
+
103
+ except json.JSONDecodeError as e:
104
+ print(f"[ERROR] JSON Parse Error: {e}")
105
+ print(f"[DEBUG] Raw response: {response_text}")
106
+ return {"error": f"Invalid JSON response: {str(e)}"}
107
+
108
+ def extract_json_string(self, text: str) -> str:
109
+ """Extract the first valid-looking JSON object from a string."""
110
+ match = re.search(r'\{.*\}', text, re.DOTALL)
111
+ return match.group(0) if match else text
112
+
113
+ def build_prompt_from_history(self, query: str) -> str:
114
+ return f"""User Query: {query}
115
+ History: {json.dumps(self.history, indent=2)}
116
+ """
117
+
118
+ def run(self, query: str):
119
+ self.reset()
120
+
121
+ # Step 1: Planning Agent
122
+ planning_input = f"User Query: {query}"
123
+ print("-----Stage Plan-----")
124
+
125
+ plan_response = self.get_llm_response(self.system_prompt_plan, planning_input)
126
+ print("-----Plan Text-----")
127
+ print(plan_response)
128
+ print("-------------------")
129
+ print("-----Plan Parsed-----")
130
+ parsed_plan = self.parse_json_response(plan_response)
131
+ print(parsed_plan)
132
+ print("---------------------")
133
+ self.history.append(parsed_plan)
134
+
135
+ current_input = self.build_prompt_from_history(query)
136
+
137
+ for _ in range(max_steps): # maximum 5 loops
138
+
139
+ print(f"-----Itterantion {_}-----")
140
+ # Step 2: Thought Agent
141
+ print("-----Stage Thought-----")
142
+
143
+ thought_response = self.get_llm_response(self.system_prompt_thought, current_input)
144
+ print(thought_response)
145
+ parsed_thought = self.parse_json_response(thought_response)
146
+ print("-----Thought Parsed-----")
147
+ print(parsed_thought)
148
+ print("-----------------")
149
+ self.history.append(parsed_thought)
150
+
151
+ # Step 3: Action Agent
152
+ if "thought" not in parsed_thought:
153
+ return "[ERROR] Thought agent did not return 'thought'. Ending.", ""
154
+ action_input = json.dumps({"thought": parsed_thought["thought"]})
155
+ print("-----Stage Action-----")
156
+
157
+ action_response_text = self.get_llm_response(self.system_prompt_action, action_input)
158
+
159
+ # With this:
160
+ try:
161
+ # Handle <think> tags
162
+ if '<think>' in action_response_text and '</think>' in action_response_text:
163
+ json_part = action_response_text.split('</think>')[1].strip()
164
+ else:
165
+ json_part = action_response_text.strip()
166
+
167
+ # Extract JSON
168
+ import re
169
+ json_match = re.search(r'\{.*\}', json_part)
170
+ if json_match:
171
+ parsed_action = json.loads(json_match.group())
172
+ else:
173
+ parsed_action = {'error': 'No JSON found in response'}
174
+
175
+ except Exception as e:
176
+ parsed_action = {'error': f'JSON parsing failed: {str(e)}'}
177
+ print(parsed_action)
178
+ print("-----------------")
179
+ self.history.append(parsed_action)
180
+
181
+ # Step 4: Tool Execution
182
+ tool_name = parsed_action.get("action")
183
+ tool_args = parsed_action.get("action_input", {})
184
+ # print("-----Tool Name-----")
185
+ # print(tool_name)
186
+ # print("-----Tool Args-----")
187
+ # print(tool_args)
188
+ # print("-----------------")
189
+ if not tool_name or tool_name not in self.tools:
190
+ observation = f"[ERROR] Invalid or missing tool: {tool_name}"
191
+ else:
192
+ try:
193
+ result = self.tools[tool_name](**tool_args)
194
+ observation = f"Tool `{tool_name}` executed successfully. Output: {result}"
195
+ print("-----Tool Observation OK-----")
196
+ print(observation)
197
+ print("-----------------")
198
+
199
+ except Exception as e:
200
+ observation = f"[ERROR] Tool `{tool_name}` execution failed: {str(e)}"
201
+ print("-----Tool Observation Fail-----")
202
+ print(observation)
203
+ print("-----------------")
204
+
205
+ # Store the tool result explicitly in history
206
+ self.history.append({
207
+ "tool_name": tool_name,
208
+ "tool_args": tool_args,
209
+ #"tool_output": result if 'result' in locals() else None
210
+ })
211
+
212
+ # Step 5: Observation Agent
213
+
214
+ observation_input = f"""User Query: {query}
215
+ Plan: {json.dumps(self.history[0], indent=2)}
216
+ History: {json.dumps(self.history, indent=2)}
217
+ Tool Output: {observation}
218
+ """
219
+ print("-----Stage Observe-----")
220
+ observation_response_text = self.get_llm_response(self.system_prompt_observe, observation_input)
221
+
222
+ print("-----Observation Parsed-----")
223
+ parsed_observation = self.parse_json_response(observation_response_text)
224
+ print(parsed_observation)
225
+ print("-----------------")
226
+ self.history.append(parsed_observation)
227
+
228
+ # Step 6: Check for final answer
229
+ if "final_answer" in parsed_observation:
230
+ print(parsed_observation["final_answer"])
231
+ #break
232
+ return self.history, observation_response_text, parsed_observation["final_answer"]
233
+
234
+ # Step 7: Update prompt for next loop
235
+ current_input = self.build_prompt_from_history(query)
236
+
237
+ print('ERROR LOOP LIMIT REACHED')
238
+ return self.history, observation_response_text + "This is our last observation. Make your best estimation given the question.", parsed_observation