michaelarutyunov commited on
Commit
bcac1d0
·
verified ·
1 Parent(s): e33834b

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -311
main.py DELETED
@@ -1,311 +0,0 @@
1
- # Standard libraries
2
- import argparse
3
- import json
4
- import logging
5
- import os
6
- from dotenv import load_dotenv
7
- from typing_extensions import TypedDict
8
- from typing import Dict, List, Any, Optional, Annotated
9
-
10
- # Langchain and langgraph
11
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, AnyMessage
12
- from langgraph.graph import StateGraph, START, END
13
- from langgraph.graph.message import add_messages
14
- from langgraph.prebuilt import ToolNode, tools_condition
15
-
16
- # Custom modules
17
- from prompts import MAIN_SYSTEM_PROMPT, QUESTION_DECOMPOSITION_PROMPT, TOOL_USE_INSTRUCTION, EXECUTION_INSTRUCTION
18
- from utils import check_api_keys, setup_llm, download_and_save_task_file, cleanup_temp_files
19
- from tools import calculator_tool, extract_text_from_image, transcribe_audio, execute_python_code, read_file, web_search, wikipedia_search, arxiv_search, chess_board_image_analysis, find_phrase_in_text, download_youtube_audio, web_content_extract, analyse_tabular_data
20
-
21
- # SETUP
22
- logging.basicConfig(
23
- level=logging.INFO, # Set default level to INFO
24
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
25
- filename='agent_run.log',
26
- filemode='w', # 'a' to append to the file on each run, 'w' to overwrite
27
- )
28
-
29
- load_dotenv()
30
-
31
- with open('.config', 'r') as f:
32
- config = json.load(f)
33
- BASE_URL = config['BASE_URL']
34
- DEBUG_MODE = config['DEBUG_MODE']
35
-
36
- # AGENT STATE
37
- class AgentState(TypedDict):
38
- task_id: str
39
- file_name: Optional[str]
40
- file_type: Optional[str]
41
- file_path: Optional[str]
42
- question_decomposition: Optional[str]
43
- messages: Annotated[list[AnyMessage], add_messages]
44
- # tool_calls: Optional[List[str]] # Note: This field doesn't seem to be explicitly used or populated in the current code.
45
- tool_results: Dict # Results from tool executions (populated implicitly by ToolNode?)
46
- error_message: Optional[str]
47
-
48
- # WORKFLOW SETUP
49
- def create_workflow():
50
-
51
- # Setup LLMs and tools
52
- llm_agent_management, llm_question_decomposition, _, _, _ = setup_llm()
53
-
54
- tools = [
55
- web_search,
56
- web_content_extract,
57
- wikipedia_search,
58
- calculator_tool,
59
- extract_text_from_image,
60
- transcribe_audio,
61
- execute_python_code,
62
- read_file,
63
- arxiv_search,
64
- chess_board_image_analysis,
65
- find_phrase_in_text,
66
- download_youtube_audio,
67
- analyse_tabular_data
68
- ]
69
-
70
- llm_agent_management_with_tools = llm_agent_management.bind_tools(tools)
71
-
72
- # Define nodes
73
- def question_decomposition(state: AgentState):
74
- """Analyze question and decompose it into a plan to answer it. The plan will be used to guide the agent's actions."""
75
- new_state = state.copy() # Create a copy of the current state
76
- messages = new_state["messages"] # Get the messages from the current state
77
-
78
- # Find the HumanMessage to analyze
79
- question = None
80
- for msg in messages:
81
- if isinstance(msg, HumanMessage):
82
- question = msg.content
83
- break
84
-
85
- if not question:
86
- new_state["error_message"] = "No question found in the state for question_decomposition node"
87
- return new_state # Return copy of original state unchanged
88
-
89
- # Create and invoke the analysis prompt
90
- question_decomposition_prompt = [
91
- SystemMessage(content=QUESTION_DECOMPOSITION_PROMPT),
92
- HumanMessage(content=f"Decompose this question : {question}")
93
- ]
94
-
95
- question_decomposition_object = llm_question_decomposition.invoke(question_decomposition_prompt)
96
- question_decomposition_response = question_decomposition_object.content
97
-
98
- # Update the state with the question analysis
99
- new_state["question_decomposition"] = question_decomposition_response
100
-
101
- # Return the complete new state with all fields preserved
102
- return new_state
103
-
104
- def call_model(state: AgentState):
105
- """Invoke the LLM with the current state."""
106
- new_state = state.copy()
107
- messages = new_state["messages"]
108
- question_decomposition = new_state.get("question_decomposition", "")
109
-
110
- # Prepare messages for the LLM call
111
- llm_messages = messages.copy()
112
- add_decomposition = question_decomposition and (not messages or not isinstance(messages[-1], ToolMessage))
113
-
114
- if add_decomposition:
115
- decomposition_message = SystemMessage(content=f"Question decomposition: {question_decomposition}\nUse this analysis to guide your actions.")
116
- llm_messages.append(decomposition_message)
117
-
118
- response = llm_agent_management_with_tools.invoke(llm_messages)
119
- new_state["messages"] = messages + [response]
120
- return new_state
121
- """
122
- def router(state: AgentState):
123
- #Determine whether to continue to tools or end the workflow.
124
- messages = state["messages"]
125
- last_message = messages[-1]
126
-
127
- if last_message.tool_calls:
128
- return "tools"
129
- return END
130
- """
131
-
132
- # Setup workflow
133
- workflow = StateGraph(AgentState)
134
-
135
- workflow.add_node("decomposition", question_decomposition)
136
- workflow.add_node("agent", call_model)
137
- workflow.add_node("tools", ToolNode(tools))
138
-
139
- workflow.add_edge(START, "decomposition")
140
- workflow.add_edge("decomposition", "agent")
141
- workflow.add_conditional_edges("agent", tools_condition) # router, {"tools": "tools", END: END}
142
- workflow.add_edge("tools", "agent")
143
-
144
- return workflow.compile()
145
-
146
- # REPORTING
147
- def save_txt_report(state: AgentState, task_id: str):
148
- """Create a txt report from the state."""
149
- messages = state["messages"]
150
- report = ""
151
-
152
- # question wording
153
- question = messages[0].content
154
- report += f"Question: {question}\n\n"
155
-
156
- # question decomposition
157
- question_decomposition = state.get("question_decomposition", "No decomposition available")
158
- report += f"Question decomposition: {question_decomposition}\n\n"
159
-
160
- # message content
161
- report += "Message Chain:\n"
162
- for msg in messages:
163
- msg_type = type(msg).__name__ # Get the class name (e.g., "HumanMessage")
164
- report += f"--- {msg_type} ---\n"
165
- report += f"{msg.content}\n"
166
- # Optionally add tool call info for AIMessages
167
- if hasattr(msg, 'tool_calls') and msg.tool_calls:
168
- report += f"Tool Calls: {msg.tool_calls}\n"
169
- # ToolMessage content often includes tool output directly,
170
- # but you could format it differently if needed.
171
- report += "---\n\n"
172
-
173
- # find the task with matching task_id
174
- validation_data = []
175
- correct_answer = ""
176
-
177
- with open("metadata_val.jsonl", "r", encoding="utf-8") as file:
178
- for line in file:
179
- validation_data.append(json.loads(line))
180
-
181
- for task_metadata in validation_data:
182
- if task_metadata.get("task_id") == task_id:
183
- correct_answer = task_metadata.get("Final answer", "Not found")
184
-
185
- report += f"Correct answer: {correct_answer}"
186
-
187
- with open(f"text_report_{task_id}.txt", "w", encoding="utf-8") as f:
188
- f.write(report)
189
-
190
- return report
191
-
192
- # WORKFLOW EXECUTION
193
- def execute_workflow(tasks_file: str, output_file: str):
194
-
195
- with open(tasks_file, 'r', encoding='utf-8') as f:
196
- tasks = json.load(f)
197
-
198
- results_json = []
199
-
200
- for task in tasks:
201
- task_id = task["task_id"]
202
- question = task["question"]
203
- temp_file_to_cleanup = None
204
-
205
- # prepare content for HumanMessage
206
- if task.get("file_name"):
207
- original_filename_from_task = task["file_name"]
208
- temp_file_path = download_and_save_task_file(task_id, original_filename_from_task)
209
-
210
- if temp_file_path:
211
- temp_file_to_cleanup = temp_file_path
212
-
213
- # Construct the question, relying on the LLM to infer file type from the path's extension
214
- question = task["question"] + f"\n\nAttached file: {temp_file_path}"
215
- print(f"File for task {task_id} processed and available at: {temp_file_path}")
216
- else:
217
- print(f"Failed to download or save file for task {task_id} using filename '{original_filename_from_task}'. Proceeding with question only.")
218
-
219
- # run agent
220
- try:
221
- print(f"Running agent for task {task_id}")
222
- print(f"Question: {question}")
223
-
224
- workflow = create_workflow()
225
-
226
- result = workflow.invoke({
227
- "messages": [
228
- SystemMessage(content=MAIN_SYSTEM_PROMPT + "\n\n" + TOOL_USE_INSTRUCTION + "\n\n" + EXECUTION_INSTRUCTION),
229
- HumanMessage(content=question)
230
- ]
231
- })
232
-
233
- # Extract final answer - result is the state itself with messages
234
- messages = result.get("messages", [])
235
- final_answer = ""
236
-
237
- # Get the content from the last message that has content
238
- for msg in reversed(messages):
239
- if hasattr(msg, "content") and msg.content:
240
- content = msg.content
241
- # Extract answer using the template format
242
- if "FINAL ANSWER:" in content:
243
- final_answer = content.split("FINAL ANSWER:")[1].strip()
244
- else:
245
- final_answer = content
246
- break
247
-
248
- if not final_answer:
249
- final_answer = "No answer generated"
250
-
251
- # Extract question decomposition
252
- question_decomposition = result.get("question_decomposition", "")
253
-
254
- # Save results to results.json
255
- if DEBUG_MODE:
256
- validation_data = []
257
- with open("metadata_val.jsonl", "r", encoding="utf-8") as file:
258
- for line in file:
259
- validation_data.append(json.loads(line))
260
-
261
- correct_answer = "Not found"
262
- for task_metadata in validation_data:
263
- if task_metadata.get("task_id") == task_id:
264
- correct_answer = task_metadata.get("Final answer", "Not found")
265
- break
266
- results_json.append({"task_id": task_id, "model_answer": final_answer, "correct_answer": correct_answer, "question_decomposition": question_decomposition})
267
- else:
268
- results_json.append({"task_id": task_id, "model_answer": final_answer, "question_decomposition": question_decomposition})
269
-
270
- # Save state to txt report
271
- if DEBUG_MODE:
272
- save_txt_report(result, task_id)
273
-
274
- except Exception as e:
275
- print(f"Error processing task {task_id}: {str(e)}")
276
- import traceback
277
- traceback.print_exc()
278
-
279
- finally:
280
- if temp_file_to_cleanup and os.path.exists(temp_file_to_cleanup):
281
- cleanup_temp_files(temp_file_to_cleanup)
282
- elif task.get("file_name") and not temp_file_to_cleanup:
283
- print(f"Note for task {task_id}: A file was expected, but no temporary file was successfully processed or tracked for cleanup.")
284
-
285
- # save batch results
286
- with open(output_file, 'w', encoding='utf-8') as f:
287
- json.dump(results_json, f, indent=2)
288
-
289
- return print(f"Results saved to {output_file}")
290
-
291
- # MAIN
292
- def main():
293
- parser = argparse.ArgumentParser(description='Process tasks from a JSON file and save results to an output file.')
294
- parser.add_argument('--tasks-file', type=str, default='tasks.json', help='Path to the JSON file containing tasks')
295
- parser.add_argument('--output-file', type=str, default='results.json', help='Path to the output JSON file')
296
- args = parser.parse_args()
297
-
298
- if not check_api_keys():
299
- print("API keys are missing. Please set the required environment variables.")
300
- return
301
-
302
- try:
303
- execute_workflow(args.tasks_file, args.output_file)
304
- except Exception as e:
305
- print(f"Critical error: {str(e)}")
306
- import traceback
307
- traceback.print_exc()
308
- return 1
309
-
310
- if __name__ == "__main__":
311
- main()