File size: 1,243 Bytes
5dfcdef
 
 
 
 
 
 
 
 
 
 
 
6109e7b
5dfcdef
 
 
 
6109e7b
5dfcdef
6109e7b
5dfcdef
 
6109e7b
 
5dfcdef
6109e7b
5dfcdef
 
6109e7b
5dfcdef
 
 
6109e7b
5dfcdef
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from src.graph import build_graph
from dotenv import load_dotenv
import logging
import pandas as pd
import json


load_dotenv()


class ReActAgent:
    def __init__(self):
        # print("ReActAgent initialized.")
        self.graph = build_graph()
        with open("prompts/system_prompt_short.txt", "r", encoding="utf-8") as f:
            self.system_message = f.read()

        self.result_file = open('results/result7.jsonl', 'a')

    def __call__(self, task) -> str:
        initial_state = {
            'system_message': self.system_message,
            'question': task.get("question"),
            'file_name': task.get("file_name"),
        }
        final_state = self.graph.invoke(initial_state)
        final_answer = final_state.get("final_answer", None)

        row = {'task_id': task.get("task_id"), 'question': task.get("question"), 'agent_answer': final_answer}
        json.dump(row, self.result_file)
        self.result_file.write('\n')

        return final_answer

def main():
    agent = ReActAgent()

    gaia_bench_1_test = pd.read_json('../gaia_bench_1_test.jsonl', lines=True)

    for i, task in gaia_bench_1_test.iterrows():
        agent(task.question, task.file_name)

if __name__ == "__main__":
    main()