michaelarutyunov commited on
Commit
e33834b
·
verified ·
1 Parent(s): e0c31e4

Update final_agent.py

Browse files
Files changed (1) hide show
  1. final_agent.py +231 -0
final_agent.py CHANGED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard libraries
2
+ import json
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from typing import Dict, List, Any, Optional, Annotated
6
+ from typing_extensions import TypedDict
7
+
8
+ # Langchain and langgraph
9
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, AnyMessage
10
+ from langgraph.graph import StateGraph, START, END, add_messages
11
+ from langgraph.prebuilt import ToolNode, tools_condition
12
+
13
+ # Custom modules
14
+ from prompts import MAIN_SYSTEM_PROMPT, QUESTION_DECOMPOSITION_PROMPT, TOOL_USE_INSTRUCTION, EXECUTION_INSTRUCTION
15
+ from utils import check_api_keys, setup_llm
16
+ from tools import (
17
+ calculator_tool, extract_text_from_image, transcribe_audio, execute_python_code,
18
+ read_file, web_search, wikipedia_search, arxiv_search, chess_board_image_analysis,
19
+ find_phrase_in_text, download_youtube_audio, web_content_extract, analyse_tabular_data
20
+ )
21
+
22
+ # AGENT STATE
23
+ class AgentState(TypedDict):
24
+ task_id: Optional[str]
25
+ file_name: Optional[str]
26
+ file_type: Optional[str]
27
+ file_path: Optional[str]
28
+ question_decomposition: Optional[str]
29
+ messages: Annotated[list[AnyMessage], add_messages]
30
+ tool_results: Dict
31
+ error_message: Optional[str]
32
+
33
+ # WORKFLOW CREATION
34
+ def create_workflow_for_final_agent():
35
+ """
36
+ Creates and compiles the LangGraph workflow.
37
+ """
38
+ llm_agent_management, llm_question_decomposition, _, _, _ = setup_llm()
39
+
40
+ tools = [
41
+ web_search,
42
+ web_content_extract,
43
+ wikipedia_search,
44
+ calculator_tool,
45
+ extract_text_from_image,
46
+ transcribe_audio,
47
+ execute_python_code,
48
+ read_file,
49
+ arxiv_search,
50
+ chess_board_image_analysis,
51
+ find_phrase_in_text,
52
+ download_youtube_audio,
53
+ analyse_tabular_data
54
+ ]
55
+
56
+ llm_agent_management_with_tools = llm_agent_management.bind_tools(tools)
57
+
58
+ # Define nodes
59
+ def question_decomposition_node(state: AgentState):
60
+ new_state = state.copy()
61
+ messages = new_state.get("messages", []) # Use .get for safety, ensure it's a list
62
+ question = None
63
+ for msg in messages:
64
+ if isinstance(msg, HumanMessage):
65
+ question = msg.content
66
+ break
67
+ if not question:
68
+ new_state["error_message"] = "No question found for decomposition."
69
+ # Ensure messages list exists even if we return early
70
+ if "messages" not in new_state or not isinstance(new_state["messages"], list):
71
+ new_state["messages"] = []
72
+ return new_state
73
+
74
+ question_decomposition_prompt_messages = [
75
+ SystemMessage(content=QUESTION_DECOMPOSITION_PROMPT),
76
+ HumanMessage(content=f"Decompose this question: {question}")
77
+ ]
78
+ question_decomposition_object = llm_question_decomposition.invoke(question_decomposition_prompt_messages)
79
+ question_decomposition_response = question_decomposition_object.content
80
+ new_state["question_decomposition"] = question_decomposition_response
81
+ # Ensure messages list exists
82
+ if "messages" not in new_state or not isinstance(new_state["messages"], list):
83
+ new_state["messages"] = []
84
+ return new_state
85
+
86
+ def call_model_node(state: AgentState):
87
+ new_state = state.copy()
88
+ messages = new_state.get("messages", []) # Use .get for safety
89
+ question_decomposition = new_state.get("question_decomposition", "")
90
+
91
+ llm_messages = list(messages) # Ensure it's a mutable list
92
+
93
+ add_decomposition = question_decomposition and (not llm_messages or not isinstance(llm_messages[-1], ToolMessage))
94
+ if add_decomposition:
95
+ decomposition_message = SystemMessage(content=f"Question decomposition: {question_decomposition}\\nUse this analysis to guide your actions.")
96
+ llm_messages.append(decomposition_message)
97
+
98
+ response = llm_agent_management_with_tools.invoke(llm_messages)
99
+
100
+ # Ensure new_state["messages"] exists and is a list before extending
101
+ current_messages = new_state.get("messages", [])
102
+ if not isinstance(current_messages, list):
103
+ current_messages = []
104
+ new_state["messages"] = current_messages + [response]
105
+ return new_state
106
+
107
+ workflow = StateGraph(AgentState)
108
+ workflow.add_node("decomposition", question_decomposition_node)
109
+ workflow.add_node("agent", call_model_node)
110
+ workflow.add_node("tools", ToolNode(tools))
111
+
112
+ workflow.add_edge(START, "decomposition")
113
+ workflow.add_edge("decomposition", "agent")
114
+ workflow.add_conditional_edges("agent", tools_condition)
115
+ workflow.add_edge("tools", "agent")
116
+
117
+ return workflow.compile()
118
+
119
+
120
+ class FinalAgent:
121
+ def __init__(self):
122
+ print("FinalAgent initializing...")
123
+ load_dotenv()
124
+
125
+ if not os.path.exists('.config'):
126
+ print("Warning: .config file not found. Using default values or expecting environment variables.")
127
+ self.config = {} # Default to empty config
128
+ else:
129
+ with open('.config', 'r') as f:
130
+ self.config = json.load(f)
131
+
132
+ self.base_url = self.config.get('BASE_URL', os.getenv('BASE_URL'))
133
+ self.debug_mode = self.config.get('DEBUG_MODE', str(os.getenv('DEBUG_MODE', 'False')).lower() == 'true')
134
+
135
+ if not check_api_keys():
136
+ # check_api_keys itself prints messages
137
+ raise ValueError("API keys are missing or invalid. Please set the required environment variables.")
138
+
139
+ self.workflow = create_workflow_for_final_agent()
140
+ print("FinalAgent initialized successfully.")
141
+
142
+ def __call__(self, question: str, task_id: Optional[str] = None) -> str:
143
+ print(f"FinalAgent received question for task_id '{task_id}': {question[:100]}...")
144
+
145
+ initial_messages = [
146
+ SystemMessage(content=MAIN_SYSTEM_PROMPT + "\\n\\n" + TOOL_USE_INSTRUCTION + "\\n\\n" + EXECUTION_INSTRUCTION),
147
+ HumanMessage(content=question)
148
+ ]
149
+
150
+ initial_state: AgentState = {
151
+ "messages": initial_messages,
152
+ "task_id": task_id,
153
+ "file_name": None,
154
+ "file_path": None,
155
+ "file_type": None,
156
+ "question_decomposition": None,
157
+ "tool_results": {},
158
+ "error_message": None
159
+ }
160
+
161
+ try:
162
+ result_state = self.workflow.invoke(initial_state)
163
+ except Exception as e:
164
+ print(f"Error invoking workflow for task {task_id}: {e}")
165
+ import traceback
166
+ traceback.print_exc()
167
+ return f"AGENT ERROR: Failed to process question due to an internal error: {e}"
168
+
169
+ messages = result_state.get("messages", [])
170
+ final_answer = ""
171
+ if not messages:
172
+ print(f"No messages found in the result state for task {task_id}.")
173
+ return "AGENT ERROR: No messages returned by the agent."
174
+
175
+ for msg in reversed(messages):
176
+ if hasattr(msg, "content") and msg.content:
177
+ content = msg.content
178
+ if isinstance(content, str):
179
+ if "FINAL ANSWER:" in content:
180
+ final_answer = content.split("FINAL ANSWER:", 1)[1].strip()
181
+ break
182
+ elif isinstance(msg, AIMessage):
183
+ # If it's an AIMessage and no "FINAL ANSWER:" has been found yet,
184
+ # tentatively set it. This will be overridden if a "FINAL ANSWER:" is found later.
185
+ if not final_answer:
186
+ final_answer = content
187
+
188
+ # If after checking all messages, final_answer is still from a non-"FINAL ANSWER:" AIMessage, that's our best guess.
189
+ # If final_answer is empty, it means no AIMessage with content or "FINAL ANSWER:" was found.
190
+ if not final_answer: # This means no "FINAL ANSWER:" and no AIMessage content was suitable
191
+ final_answer = "AGENT ERROR: Could not extract a final answer from the agent's messages."
192
+ print(f"Could not extract final answer for task {task_id}. Messages: {messages}")
193
+
194
+ print(f"FinalAgent returning answer for task_id '{task_id}': {final_answer[:100]}...")
195
+ return final_answer
196
+
197
+ if __name__ == '__main__':
198
+ print("Running a simple test for FinalAgent...")
199
+
200
+ if not os.path.exists('.config'):
201
+ print("Creating a dummy .config file for testing.")
202
+ with open('.config', 'w') as f:
203
+ json.dump({"DEBUG_MODE": True, "BASE_URL": "http://localhost:8000"}, f)
204
+
205
+ # Check for .env and API keys
206
+ if not load_dotenv(): # Attempts to load .env and returns True if successful
207
+ print("Warning: .env file not found or failed to load. API keys might be missing.")
208
+
209
+ if not (os.getenv("OPENAI_API_KEY") or os.getenv("DEEPSEEK_API_KEY") or os.getenv("TAVILY_API_KEY")):
210
+ print("\\nWARNING: Required API key no found in environment variables (OPENAI_API_KEY, DEEPSEEK_API_KEY, TAVILY_API_KEY).")
211
+ print("The agent will likely fail to initialize or run properly without at least one.")
212
+ print("Please set them in your .env file or environment for testing.\\n")
213
+
214
+ try:
215
+ agent = FinalAgent()
216
+ test_question = "What is the capital of France? And what is the weather like there today?"
217
+ print(f"Test Question 1: {test_question}")
218
+ answer = agent(test_question, task_id="test_001")
219
+ print(f"Test Answer 1: {answer}")
220
+
221
+ test_question_calc = "What is 123 * 4 / 2 + 6?"
222
+ print(f"\\nTest Question 2 (Calc): {test_question_calc}")
223
+ answer_calc = agent(test_question_calc, task_id="test_002")
224
+ print(f"Test Answer 2 (Calc): {answer_calc}")
225
+
226
+ except ValueError as ve:
227
+ print(f"Initialization Error: {ve}")
228
+ except Exception as e:
229
+ print(f"An error occurred during the test: {e}")
230
+ import traceback
231
+ traceback.print_exc()