Delete main.py
Browse files
main.py
DELETED
|
@@ -1,311 +0,0 @@
|
|
| 1 |
-
# Standard libraries
|
| 2 |
-
import argparse
|
| 3 |
-
import json
|
| 4 |
-
import logging
|
| 5 |
-
import os
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
-
from typing_extensions import TypedDict
|
| 8 |
-
from typing import Dict, List, Any, Optional, Annotated
|
| 9 |
-
|
| 10 |
-
# Langchain and langgraph
|
| 11 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, AnyMessage
|
| 12 |
-
from langgraph.graph import StateGraph, START, END
|
| 13 |
-
from langgraph.graph.message import add_messages
|
| 14 |
-
from langgraph.prebuilt import ToolNode, tools_condition
|
| 15 |
-
|
| 16 |
-
# Custom modules
|
| 17 |
-
from prompts import MAIN_SYSTEM_PROMPT, QUESTION_DECOMPOSITION_PROMPT, TOOL_USE_INSTRUCTION, EXECUTION_INSTRUCTION
|
| 18 |
-
from utils import check_api_keys, setup_llm, download_and_save_task_file, cleanup_temp_files
|
| 19 |
-
from tools import calculator_tool, extract_text_from_image, transcribe_audio, execute_python_code, read_file, web_search, wikipedia_search, arxiv_search, chess_board_image_analysis, find_phrase_in_text, download_youtube_audio, web_content_extract, analyse_tabular_data
|
| 20 |
-
|
| 21 |
-
# SETUP
|
| 22 |
-
logging.basicConfig(
|
| 23 |
-
level=logging.INFO, # Set default level to INFO
|
| 24 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 25 |
-
filename='agent_run.log',
|
| 26 |
-
filemode='w', # 'a' to append to the file on each run, 'w' to overwrite
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
load_dotenv()
|
| 30 |
-
|
| 31 |
-
with open('.config', 'r') as f:
|
| 32 |
-
config = json.load(f)
|
| 33 |
-
BASE_URL = config['BASE_URL']
|
| 34 |
-
DEBUG_MODE = config['DEBUG_MODE']
|
| 35 |
-
|
| 36 |
-
# AGENT STATE
|
| 37 |
-
class AgentState(TypedDict):
|
| 38 |
-
task_id: str
|
| 39 |
-
file_name: Optional[str]
|
| 40 |
-
file_type: Optional[str]
|
| 41 |
-
file_path: Optional[str]
|
| 42 |
-
question_decomposition: Optional[str]
|
| 43 |
-
messages: Annotated[list[AnyMessage], add_messages]
|
| 44 |
-
# tool_calls: Optional[List[str]] # Note: This field doesn't seem to be explicitly used or populated in the current code.
|
| 45 |
-
tool_results: Dict # Results from tool executions (populated implicitly by ToolNode?)
|
| 46 |
-
error_message: Optional[str]
|
| 47 |
-
|
| 48 |
-
# WORKFLOW SETUP
|
| 49 |
-
def create_workflow():
|
| 50 |
-
|
| 51 |
-
# Setup LLMs and tools
|
| 52 |
-
llm_agent_management, llm_question_decomposition, _, _, _ = setup_llm()
|
| 53 |
-
|
| 54 |
-
tools = [
|
| 55 |
-
web_search,
|
| 56 |
-
web_content_extract,
|
| 57 |
-
wikipedia_search,
|
| 58 |
-
calculator_tool,
|
| 59 |
-
extract_text_from_image,
|
| 60 |
-
transcribe_audio,
|
| 61 |
-
execute_python_code,
|
| 62 |
-
read_file,
|
| 63 |
-
arxiv_search,
|
| 64 |
-
chess_board_image_analysis,
|
| 65 |
-
find_phrase_in_text,
|
| 66 |
-
download_youtube_audio,
|
| 67 |
-
analyse_tabular_data
|
| 68 |
-
]
|
| 69 |
-
|
| 70 |
-
llm_agent_management_with_tools = llm_agent_management.bind_tools(tools)
|
| 71 |
-
|
| 72 |
-
# Define nodes
|
| 73 |
-
def question_decomposition(state: AgentState):
|
| 74 |
-
"""Analyze question and decompose it into a plan to answer it. The plan will be used to guide the agent's actions."""
|
| 75 |
-
new_state = state.copy() # Create a copy of the current state
|
| 76 |
-
messages = new_state["messages"] # Get the messages from the current state
|
| 77 |
-
|
| 78 |
-
# Find the HumanMessage to analyze
|
| 79 |
-
question = None
|
| 80 |
-
for msg in messages:
|
| 81 |
-
if isinstance(msg, HumanMessage):
|
| 82 |
-
question = msg.content
|
| 83 |
-
break
|
| 84 |
-
|
| 85 |
-
if not question:
|
| 86 |
-
new_state["error_message"] = "No question found in the state for question_decomposition node"
|
| 87 |
-
return new_state # Return copy of original state unchanged
|
| 88 |
-
|
| 89 |
-
# Create and invoke the analysis prompt
|
| 90 |
-
question_decomposition_prompt = [
|
| 91 |
-
SystemMessage(content=QUESTION_DECOMPOSITION_PROMPT),
|
| 92 |
-
HumanMessage(content=f"Decompose this question : {question}")
|
| 93 |
-
]
|
| 94 |
-
|
| 95 |
-
question_decomposition_object = llm_question_decomposition.invoke(question_decomposition_prompt)
|
| 96 |
-
question_decomposition_response = question_decomposition_object.content
|
| 97 |
-
|
| 98 |
-
# Update the state with the question analysis
|
| 99 |
-
new_state["question_decomposition"] = question_decomposition_response
|
| 100 |
-
|
| 101 |
-
# Return the complete new state with all fields preserved
|
| 102 |
-
return new_state
|
| 103 |
-
|
| 104 |
-
def call_model(state: AgentState):
|
| 105 |
-
"""Invoke the LLM with the current state."""
|
| 106 |
-
new_state = state.copy()
|
| 107 |
-
messages = new_state["messages"]
|
| 108 |
-
question_decomposition = new_state.get("question_decomposition", "")
|
| 109 |
-
|
| 110 |
-
# Prepare messages for the LLM call
|
| 111 |
-
llm_messages = messages.copy()
|
| 112 |
-
add_decomposition = question_decomposition and (not messages or not isinstance(messages[-1], ToolMessage))
|
| 113 |
-
|
| 114 |
-
if add_decomposition:
|
| 115 |
-
decomposition_message = SystemMessage(content=f"Question decomposition: {question_decomposition}\nUse this analysis to guide your actions.")
|
| 116 |
-
llm_messages.append(decomposition_message)
|
| 117 |
-
|
| 118 |
-
response = llm_agent_management_with_tools.invoke(llm_messages)
|
| 119 |
-
new_state["messages"] = messages + [response]
|
| 120 |
-
return new_state
|
| 121 |
-
"""
|
| 122 |
-
def router(state: AgentState):
|
| 123 |
-
#Determine whether to continue to tools or end the workflow.
|
| 124 |
-
messages = state["messages"]
|
| 125 |
-
last_message = messages[-1]
|
| 126 |
-
|
| 127 |
-
if last_message.tool_calls:
|
| 128 |
-
return "tools"
|
| 129 |
-
return END
|
| 130 |
-
"""
|
| 131 |
-
|
| 132 |
-
# Setup workflow
|
| 133 |
-
workflow = StateGraph(AgentState)
|
| 134 |
-
|
| 135 |
-
workflow.add_node("decomposition", question_decomposition)
|
| 136 |
-
workflow.add_node("agent", call_model)
|
| 137 |
-
workflow.add_node("tools", ToolNode(tools))
|
| 138 |
-
|
| 139 |
-
workflow.add_edge(START, "decomposition")
|
| 140 |
-
workflow.add_edge("decomposition", "agent")
|
| 141 |
-
workflow.add_conditional_edges("agent", tools_condition) # router, {"tools": "tools", END: END}
|
| 142 |
-
workflow.add_edge("tools", "agent")
|
| 143 |
-
|
| 144 |
-
return workflow.compile()
|
| 145 |
-
|
| 146 |
-
# REPORTING
|
| 147 |
-
def save_txt_report(state: AgentState, task_id: str):
|
| 148 |
-
"""Create a txt report from the state."""
|
| 149 |
-
messages = state["messages"]
|
| 150 |
-
report = ""
|
| 151 |
-
|
| 152 |
-
# question wording
|
| 153 |
-
question = messages[0].content
|
| 154 |
-
report += f"Question: {question}\n\n"
|
| 155 |
-
|
| 156 |
-
# question decomposition
|
| 157 |
-
question_decomposition = state.get("question_decomposition", "No decomposition available")
|
| 158 |
-
report += f"Question decomposition: {question_decomposition}\n\n"
|
| 159 |
-
|
| 160 |
-
# message content
|
| 161 |
-
report += "Message Chain:\n"
|
| 162 |
-
for msg in messages:
|
| 163 |
-
msg_type = type(msg).__name__ # Get the class name (e.g., "HumanMessage")
|
| 164 |
-
report += f"--- {msg_type} ---\n"
|
| 165 |
-
report += f"{msg.content}\n"
|
| 166 |
-
# Optionally add tool call info for AIMessages
|
| 167 |
-
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
| 168 |
-
report += f"Tool Calls: {msg.tool_calls}\n"
|
| 169 |
-
# ToolMessage content often includes tool output directly,
|
| 170 |
-
# but you could format it differently if needed.
|
| 171 |
-
report += "---\n\n"
|
| 172 |
-
|
| 173 |
-
# find the task with matching task_id
|
| 174 |
-
validation_data = []
|
| 175 |
-
correct_answer = ""
|
| 176 |
-
|
| 177 |
-
with open("metadata_val.jsonl", "r", encoding="utf-8") as file:
|
| 178 |
-
for line in file:
|
| 179 |
-
validation_data.append(json.loads(line))
|
| 180 |
-
|
| 181 |
-
for task_metadata in validation_data:
|
| 182 |
-
if task_metadata.get("task_id") == task_id:
|
| 183 |
-
correct_answer = task_metadata.get("Final answer", "Not found")
|
| 184 |
-
|
| 185 |
-
report += f"Correct answer: {correct_answer}"
|
| 186 |
-
|
| 187 |
-
with open(f"text_report_{task_id}.txt", "w", encoding="utf-8") as f:
|
| 188 |
-
f.write(report)
|
| 189 |
-
|
| 190 |
-
return report
|
| 191 |
-
|
| 192 |
-
# WORKFLOW EXECUTION
|
| 193 |
-
def execute_workflow(tasks_file: str, output_file: str):
|
| 194 |
-
|
| 195 |
-
with open(tasks_file, 'r', encoding='utf-8') as f:
|
| 196 |
-
tasks = json.load(f)
|
| 197 |
-
|
| 198 |
-
results_json = []
|
| 199 |
-
|
| 200 |
-
for task in tasks:
|
| 201 |
-
task_id = task["task_id"]
|
| 202 |
-
question = task["question"]
|
| 203 |
-
temp_file_to_cleanup = None
|
| 204 |
-
|
| 205 |
-
# prepare content for HumanMessage
|
| 206 |
-
if task.get("file_name"):
|
| 207 |
-
original_filename_from_task = task["file_name"]
|
| 208 |
-
temp_file_path = download_and_save_task_file(task_id, original_filename_from_task)
|
| 209 |
-
|
| 210 |
-
if temp_file_path:
|
| 211 |
-
temp_file_to_cleanup = temp_file_path
|
| 212 |
-
|
| 213 |
-
# Construct the question, relying on the LLM to infer file type from the path's extension
|
| 214 |
-
question = task["question"] + f"\n\nAttached file: {temp_file_path}"
|
| 215 |
-
print(f"File for task {task_id} processed and available at: {temp_file_path}")
|
| 216 |
-
else:
|
| 217 |
-
print(f"Failed to download or save file for task {task_id} using filename '{original_filename_from_task}'. Proceeding with question only.")
|
| 218 |
-
|
| 219 |
-
# run agent
|
| 220 |
-
try:
|
| 221 |
-
print(f"Running agent for task {task_id}")
|
| 222 |
-
print(f"Question: {question}")
|
| 223 |
-
|
| 224 |
-
workflow = create_workflow()
|
| 225 |
-
|
| 226 |
-
result = workflow.invoke({
|
| 227 |
-
"messages": [
|
| 228 |
-
SystemMessage(content=MAIN_SYSTEM_PROMPT + "\n\n" + TOOL_USE_INSTRUCTION + "\n\n" + EXECUTION_INSTRUCTION),
|
| 229 |
-
HumanMessage(content=question)
|
| 230 |
-
]
|
| 231 |
-
})
|
| 232 |
-
|
| 233 |
-
# Extract final answer - result is the state itself with messages
|
| 234 |
-
messages = result.get("messages", [])
|
| 235 |
-
final_answer = ""
|
| 236 |
-
|
| 237 |
-
# Get the content from the last message that has content
|
| 238 |
-
for msg in reversed(messages):
|
| 239 |
-
if hasattr(msg, "content") and msg.content:
|
| 240 |
-
content = msg.content
|
| 241 |
-
# Extract answer using the template format
|
| 242 |
-
if "FINAL ANSWER:" in content:
|
| 243 |
-
final_answer = content.split("FINAL ANSWER:")[1].strip()
|
| 244 |
-
else:
|
| 245 |
-
final_answer = content
|
| 246 |
-
break
|
| 247 |
-
|
| 248 |
-
if not final_answer:
|
| 249 |
-
final_answer = "No answer generated"
|
| 250 |
-
|
| 251 |
-
# Extract question decomposition
|
| 252 |
-
question_decomposition = result.get("question_decomposition", "")
|
| 253 |
-
|
| 254 |
-
# Save results to results.json
|
| 255 |
-
if DEBUG_MODE:
|
| 256 |
-
validation_data = []
|
| 257 |
-
with open("metadata_val.jsonl", "r", encoding="utf-8") as file:
|
| 258 |
-
for line in file:
|
| 259 |
-
validation_data.append(json.loads(line))
|
| 260 |
-
|
| 261 |
-
correct_answer = "Not found"
|
| 262 |
-
for task_metadata in validation_data:
|
| 263 |
-
if task_metadata.get("task_id") == task_id:
|
| 264 |
-
correct_answer = task_metadata.get("Final answer", "Not found")
|
| 265 |
-
break
|
| 266 |
-
results_json.append({"task_id": task_id, "model_answer": final_answer, "correct_answer": correct_answer, "question_decomposition": question_decomposition})
|
| 267 |
-
else:
|
| 268 |
-
results_json.append({"task_id": task_id, "model_answer": final_answer, "question_decomposition": question_decomposition})
|
| 269 |
-
|
| 270 |
-
# Save state to txt report
|
| 271 |
-
if DEBUG_MODE:
|
| 272 |
-
save_txt_report(result, task_id)
|
| 273 |
-
|
| 274 |
-
except Exception as e:
|
| 275 |
-
print(f"Error processing task {task_id}: {str(e)}")
|
| 276 |
-
import traceback
|
| 277 |
-
traceback.print_exc()
|
| 278 |
-
|
| 279 |
-
finally:
|
| 280 |
-
if temp_file_to_cleanup and os.path.exists(temp_file_to_cleanup):
|
| 281 |
-
cleanup_temp_files(temp_file_to_cleanup)
|
| 282 |
-
elif task.get("file_name") and not temp_file_to_cleanup:
|
| 283 |
-
print(f"Note for task {task_id}: A file was expected, but no temporary file was successfully processed or tracked for cleanup.")
|
| 284 |
-
|
| 285 |
-
# save batch results
|
| 286 |
-
with open(output_file, 'w', encoding='utf-8') as f:
|
| 287 |
-
json.dump(results_json, f, indent=2)
|
| 288 |
-
|
| 289 |
-
return print(f"Results saved to {output_file}")
|
| 290 |
-
|
| 291 |
-
# MAIN
|
| 292 |
-
def main():
|
| 293 |
-
parser = argparse.ArgumentParser(description='Process tasks from a JSON file and save results to an output file.')
|
| 294 |
-
parser.add_argument('--tasks-file', type=str, default='tasks.json', help='Path to the JSON file containing tasks')
|
| 295 |
-
parser.add_argument('--output-file', type=str, default='results.json', help='Path to the output JSON file')
|
| 296 |
-
args = parser.parse_args()
|
| 297 |
-
|
| 298 |
-
if not check_api_keys():
|
| 299 |
-
print("API keys are missing. Please set the required environment variables.")
|
| 300 |
-
return
|
| 301 |
-
|
| 302 |
-
try:
|
| 303 |
-
execute_workflow(args.tasks_file, args.output_file)
|
| 304 |
-
except Exception as e:
|
| 305 |
-
print(f"Critical error: {str(e)}")
|
| 306 |
-
import traceback
|
| 307 |
-
traceback.print_exc()
|
| 308 |
-
return 1
|
| 309 |
-
|
| 310 |
-
if __name__ == "__main__":
|
| 311 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|