YI Zhongyue commited on
Commit
989f731
·
1 Parent(s): 482e90e

update agent and prompt

Browse files
Files changed (2) hide show
  1. agent.py +42 -2
  2. prompt.py +32 -1
agent.py CHANGED
@@ -10,7 +10,11 @@ from smolagents import (
10
  WikipediaSearchTool,
11
  )
12
 
13
- from prompt import calc_agent_prompt, web_search_agent_prompt
 
 
 
 
14
 
15
  if pathlib.Path(".env").exists():
16
  dotenv.load_dotenv(".env")
@@ -50,6 +54,41 @@ calc_agent = CodeAgent(
50
  tools=[]
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class Agent:
54
  def __init__(self):
55
  self.agent = CodeAgent(
@@ -61,7 +100,8 @@ class Agent:
61
 
62
 
63
  def __call__(self, question: str) -> str:
64
- return self.agent.run(question)
 
65
 
66
 
67
  if __name__ == "__main__":
 
10
  WikipediaSearchTool,
11
  )
12
 
13
+ from prompt import (
14
+ calc_agent_prompt,
15
+ gen_GAIA_answer_formatter_prompt,
16
+ web_search_agent_prompt,
17
+ )
18
 
19
  if pathlib.Path(".env").exists():
20
  dotenv.load_dotenv(".env")
 
54
  tools=[]
55
  )
56
 
57
+ class ExtractFailedException(Exception):
58
+ """Custom exception for failed extraction of formatted answer."""
59
+ pass
60
+
61
+ class GAIAAnswerFormatter:
62
+
63
+ model = OpenAIServerModel(
64
+ model_id="gpt-4.1-mini",
65
+ api_key=OPENAI_API_KEY,
66
+ )
67
+
68
+ def extract_formatted_answer(self, llm_response: str) -> str:
69
+ import re
70
+
71
+ match = re.search(r'<formated_answer>(.*?)</formated_answer>', llm_response, re.DOTALL)
72
+ if match:
73
+ return match.group(1).strip()
74
+ raise ExtractFailedException(
75
+ "Failed to extract formatted answer from the LLM response."
76
+ )
77
+
78
+ def __call__(self, question: str, answer: str) -> str:
79
+ message = [{
80
+ "role": "user",
81
+ "content": gen_GAIA_answer_formatter_prompt(question, answer),
82
+ }]
83
+ response = self.model.generate(messages=message)
84
+ try:
85
+ return self.extract_formatted_answer(response.content)
86
+ except ExtractFailedException:
87
+ return answer
88
+
89
+
90
+ answer_formatter = GAIAAnswerFormatter()
91
+
92
  class Agent:
93
  def __init__(self):
94
  self.agent = CodeAgent(
 
100
 
101
 
102
  def __call__(self, question: str) -> str:
103
+ answer = self.agent.run(question)
104
+ return answer_formatter(question, answer)
105
 
106
 
107
  if __name__ == "__main__":
prompt.py CHANGED
@@ -288,4 +288,35 @@ Now Begin!""",
288
  pre_messages="",
289
  post_messages="",
290
  ),
291
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  pre_messages="",
289
  post_messages="",
290
  ),
291
+ )
292
+
293
+ def gen_GAIA_answer_formatter_prompt(question: str, answer: str) -> str:
294
+ return f"""You are a GAIA answer format validator and formatter. Your task is to check if an agent's answer meets GAIA benchmark requirements and reformat it if necessary.
295
+
296
+ GAIA FORMAT REQUIREMENTS:
297
+ - YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings
298
+ - If asked for a number: don't use commas to write numbers, no units like $ or % unless specified
299
+ - If asked for a string: no articles (a, an, the), no abbreviations (e.g. for cities), write digits in plain text unless specified
300
+ - If asked for a comma separated list: apply above rules for each element
301
+
302
+ ORIGINAL QUESTION:
303
+ {question}
304
+
305
+ AGENT'S ANSWER:
306
+ {answer}
307
+
308
+ INSTRUCTIONS:
309
+ 1. First, analyze what type of answer the question is asking for (number, string, or list)
310
+ 2. Check if the agent's answer meets GAIA format requirements
311
+ 3. If the answer is already correctly formatted, return it as is
312
+ 4. If the answer needs reformatting, extract the core information and reformat according to GAIA rules
313
+ 5. Provide your final answer wrapped in XML tags: <formated_answer>YOUR FORMATTED ANSWER</formated_answer>
314
+
315
+ Remember:
316
+ - Keep only essential information
317
+ - Remove unnecessary words, articles, and explanations
318
+ - Follow the specific formatting rules for numbers, strings, or lists
319
+ - If the agent's answer contains multiple pieces of information, extract only what the question specifically asks for
320
+ - The content inside <formated_answer> tags should be the clean, formatted answer without any additional text
321
+
322
+ Now analyze and format the answer:"""