YI Zhongyue commited on
Commit ·
989f731
1
Parent(s): 482e90e
update agent and prompt
Browse files
agent.py
CHANGED
|
@@ -10,7 +10,11 @@ from smolagents import (
|
|
| 10 |
WikipediaSearchTool,
|
| 11 |
)
|
| 12 |
|
| 13 |
-
from prompt import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
if pathlib.Path(".env").exists():
|
| 16 |
dotenv.load_dotenv(".env")
|
|
@@ -50,6 +54,41 @@ calc_agent = CodeAgent(
|
|
| 50 |
tools=[]
|
| 51 |
)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
class Agent:
|
| 54 |
def __init__(self):
|
| 55 |
self.agent = CodeAgent(
|
|
@@ -61,7 +100,8 @@ class Agent:
|
|
| 61 |
|
| 62 |
|
| 63 |
def __call__(self, question: str) -> str:
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
if __name__ == "__main__":
|
|
|
|
| 10 |
WikipediaSearchTool,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
from prompt import (
|
| 14 |
+
calc_agent_prompt,
|
| 15 |
+
gen_GAIA_answer_formatter_prompt,
|
| 16 |
+
web_search_agent_prompt,
|
| 17 |
+
)
|
| 18 |
|
| 19 |
if pathlib.Path(".env").exists():
|
| 20 |
dotenv.load_dotenv(".env")
|
|
|
|
| 54 |
tools=[]
|
| 55 |
)
|
| 56 |
|
| 57 |
+
class ExtractFailedException(Exception):
|
| 58 |
+
"""Custom exception for failed extraction of formatted answer."""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
class GAIAAnswerFormatter:
|
| 62 |
+
|
| 63 |
+
model = OpenAIServerModel(
|
| 64 |
+
model_id="gpt-4.1-mini",
|
| 65 |
+
api_key=OPENAI_API_KEY,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def extract_formatted_answer(self, llm_response: str) -> str:
|
| 69 |
+
import re
|
| 70 |
+
|
| 71 |
+
match = re.search(r'<formated_answer>(.*?)</formated_answer>', llm_response, re.DOTALL)
|
| 72 |
+
if match:
|
| 73 |
+
return match.group(1).strip()
|
| 74 |
+
raise ExtractFailedException(
|
| 75 |
+
"Failed to extract formatted answer from the LLM response."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def __call__(self, question: str, answer: str) -> str:
|
| 79 |
+
message = [{
|
| 80 |
+
"role": "user",
|
| 81 |
+
"content": gen_GAIA_answer_formatter_prompt(question, answer),
|
| 82 |
+
}]
|
| 83 |
+
response = self.model.generate(messages=message)
|
| 84 |
+
try:
|
| 85 |
+
return self.extract_formatted_answer(response.content)
|
| 86 |
+
except ExtractFailedException:
|
| 87 |
+
return answer
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
answer_formatter = GAIAAnswerFormatter()
|
| 91 |
+
|
| 92 |
class Agent:
|
| 93 |
def __init__(self):
|
| 94 |
self.agent = CodeAgent(
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
def __call__(self, question: str) -> str:
|
| 103 |
+
answer = self.agent.run(question)
|
| 104 |
+
return answer_formatter(question, answer)
|
| 105 |
|
| 106 |
|
| 107 |
if __name__ == "__main__":
|
prompt.py
CHANGED
|
@@ -288,4 +288,35 @@ Now Begin!""",
|
|
| 288 |
pre_messages="",
|
| 289 |
post_messages="",
|
| 290 |
),
|
| 291 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
pre_messages="",
|
| 289 |
post_messages="",
|
| 290 |
),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
def gen_GAIA_answer_formatter_prompt(question: str, answer: str) -> str:
|
| 294 |
+
return f"""You are a GAIA answer format validator and formatter. Your task is to check if an agent's answer meets GAIA benchmark requirements and reformat it if necessary.
|
| 295 |
+
|
| 296 |
+
GAIA FORMAT REQUIREMENTS:
|
| 297 |
+
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings
|
| 298 |
+
- If asked for a number: don't use commas to write numbers, no units like $ or % unless specified
|
| 299 |
+
- If asked for a string: no articles (a, an, the), no abbreviations (e.g. for cities), write digits in plain text unless specified
|
| 300 |
+
- If asked for a comma separated list: apply above rules for each element
|
| 301 |
+
|
| 302 |
+
ORIGINAL QUESTION:
|
| 303 |
+
{question}
|
| 304 |
+
|
| 305 |
+
AGENT'S ANSWER:
|
| 306 |
+
{answer}
|
| 307 |
+
|
| 308 |
+
INSTRUCTIONS:
|
| 309 |
+
1. First, analyze what type of answer the question is asking for (number, string, or list)
|
| 310 |
+
2. Check if the agent's answer meets GAIA format requirements
|
| 311 |
+
3. If the answer is already correctly formatted, return it as is
|
| 312 |
+
4. If the answer needs reformatting, extract the core information and reformat according to GAIA rules
|
| 313 |
+
5. Provide your final answer wrapped in XML tags: <formated_answer>YOUR FORMATTED ANSWER</formated_answer>
|
| 314 |
+
|
| 315 |
+
Remember:
|
| 316 |
+
- Keep only essential information
|
| 317 |
+
- Remove unnecessary words, articles, and explanations
|
| 318 |
+
- Follow the specific formatting rules for numbers, strings, or lists
|
| 319 |
+
- If the agent's answer contains multiple pieces of information, extract only what the question specifically asks for
|
| 320 |
+
- The content inside <formated_answer> tags should be the clean, formatted answer without any additional text
|
| 321 |
+
|
| 322 |
+
Now analyze and format the answer:"""
|