olyandrevn commited on
Commit
6109e7b
·
1 Parent(s): 267b8cc
Files changed (4) hide show
  1. agent.py +8 -11
  2. app.py +1 -1
  3. src/nodes.py +3 -3
  4. src/tools.py +1 -1
agent.py CHANGED
@@ -10,30 +10,27 @@ load_dotenv()
10
 
11
  class ReActAgent:
12
  def __init__(self):
13
- print("ReActAgent initialized.")
14
  self.graph = build_graph()
15
  with open("prompts/system_prompt_short.txt", "r", encoding="utf-8") as f:
16
  self.system_message = f.read()
17
 
18
- self.result_file = open('results/result6.jsonl', 'a')
19
-
20
- def __call__(self, question: str, file_name: str) -> str:
21
- print(f"Agent received question (first 50 chars): {question[:50]}...")
22
 
 
23
  initial_state = {
24
  'system_message': self.system_message,
25
- 'question': question,
26
- 'file_name': file_name,
27
  }
28
- final_state = graph.invoke(initial_state)
29
  final_answer = final_state.get("final_answer", None)
30
 
31
- row = {'task_id': task.task_id, 'question': task.question, 'gt': task['Final answer'], 'agent_answer': final_answer}
32
  json.dump(row, self.result_file)
33
  self.result_file.write('\n')
34
 
35
- print(f"Agent returning fixed answer: {fixed_answer}")
36
- return fixed_answer
37
 
38
  def main():
39
  agent = ReActAgent()
 
10
 
11
  class ReActAgent:
12
  def __init__(self):
13
+ # print("ReActAgent initialized.")
14
  self.graph = build_graph()
15
  with open("prompts/system_prompt_short.txt", "r", encoding="utf-8") as f:
16
  self.system_message = f.read()
17
 
18
+ self.result_file = open('results/result7.jsonl', 'a')
 
 
 
19
 
20
+ def __call__(self, task) -> str:
21
  initial_state = {
22
  'system_message': self.system_message,
23
+ 'question': task.get("question"),
24
+ 'file_name': task.get("file_name"),
25
  }
26
+ final_state = self.graph.invoke(initial_state)
27
  final_answer = final_state.get("final_answer", None)
28
 
29
+ row = {'task_id': task.get("task_id"), 'question': task.get("question"), 'agent_answer': final_answer}
30
  json.dump(row, self.result_file)
31
  self.result_file.write('\n')
32
 
33
+ return final_answer
 
34
 
35
  def main():
36
  agent = ReActAgent()
app.py CHANGED
@@ -95,7 +95,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
95
  print(f"Skipping item with missing task_id or question: {item}")
96
  continue
97
  try:
98
- submitted_answer = agent(question_text, file_name)
99
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
100
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
101
  except Exception as e:
 
95
  print(f"Skipping item with missing task_id or question: {item}")
96
  continue
97
  try:
98
+ submitted_answer = agent(item)
99
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
100
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
101
  except Exception as e:
src/nodes.py CHANGED
@@ -26,9 +26,9 @@ class AnswerTemplate(BaseModel):
26
  tools = [
27
  calculator,
28
  wiki_search,
29
- web_search,
30
  reverse_string,
31
- tool_download_image,
32
  tool_read_files,
33
  ]
34
 
@@ -62,7 +62,7 @@ def assistant(state: AgentState):
62
  temperature=0,
63
  timeout=None,
64
  max_retries=2,
65
- top_p=0.7,
66
  # truncation='auto',
67
  )
68
 
 
26
  tools = [
27
  calculator,
28
  wiki_search,
29
+ # web_search,
30
  reverse_string,
31
+ # tool_download_image,
32
  tool_read_files,
33
  ]
34
 
 
62
  temperature=0,
63
  timeout=None,
64
  max_retries=2,
65
+ top_p=0.8,
66
  # truncation='auto',
67
  )
68
 
src/tools.py CHANGED
@@ -61,7 +61,7 @@ def wiki_search(query: str) -> dict:
61
 
62
  Args:
63
  query: The search query."""
64
- search_docs = WikipediaRetriever(load_max_docs=1).invoke(query)
65
  wiki_results = []
66
 
67
  for doc in search_docs:
 
61
 
62
  Args:
63
  query: The search query."""
64
+ search_docs = WikipediaRetriever(load_max_docs=1, top_k_results=2).invoke(query)
65
  wiki_results = []
66
 
67
  for doc in search_docs: