|
|
""" |
|
|
Answer Formatter - Final answer formatting according to GAIA requirements |
|
|
|
|
|
The Answer Formatter is responsible for: |
|
|
1. Taking the draft answer and formatting it according to GAIA rules |
|
|
2. Extracting the final answer from comprehensive responses |
|
|
3. Ensuring exact-match compliance |
|
|
4. Handling different answer types (numbers, strings, lists) |
|
|
""" |
|
|
|
|
|
import re |
|
|
from typing import Dict, Any |
|
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage |
|
|
from langgraph.types import Command |
|
|
from langchain_groq import ChatGroq |
|
|
from observability import agent_span |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv("env.local") |
|
|
|
|
|
|
|
|
def extract_final_answer(text: str) -> str: |
|
|
""" |
|
|
Extract the final answer from text following GAIA formatting rules. |
|
|
|
|
|
GAIA Rules: |
|
|
• Single number → write the number only (no commas, units, or other symbols) |
|
|
• Single string/phrase → write the text only; omit articles and abbreviations unless explicitly required |
|
|
• List → separate elements with a single comma and a space |
|
|
• Never include surrounding text, quotes, brackets, or markdown |
|
|
""" |
|
|
|
|
|
if not text or not text.strip(): |
|
|
return "" |
|
|
|
|
|
|
|
|
text = text.strip() |
|
|
|
|
|
|
|
|
answer_patterns = [ |
|
|
r"final answer[:\s]*(.+?)(?:\n|$)", |
|
|
r"answer[:\s]*(.+?)(?:\n|$)", |
|
|
r"result[:\s]*(.+?)(?:\n|$)", |
|
|
r"conclusion[:\s]*(.+?)(?:\n|$)" |
|
|
] |
|
|
|
|
|
for pattern in answer_patterns: |
|
|
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) |
|
|
if match: |
|
|
text = match.group(1).strip() |
|
|
break |
|
|
|
|
|
|
|
|
prefixes_to_remove = [ |
|
|
"the answer is", "it is", "this is", "that is", |
|
|
"final answer:", "answer:", "result:", "conclusion:", |
|
|
"therefore", "thus", "so", "hence" |
|
|
] |
|
|
|
|
|
for prefix in prefixes_to_remove: |
|
|
if text.lower().startswith(prefix.lower()): |
|
|
text = text[len(prefix):].strip() |
|
|
|
|
|
|
|
|
text = re.sub(r'^["\'\[\(]|["\'\]\)]$', '', text) |
|
|
text = re.sub(r'^\*\*|\*\*$', '', text) |
|
|
text = re.sub(r'^`|`$', '', text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
number_match = re.match(r'^-?\d+(?:\.\d+)?$', text.strip()) |
|
|
if number_match: |
|
|
|
|
|
num = float(text.strip()) if '.' in text else int(text.strip()) |
|
|
return str(int(num)) if num == int(num) else str(num) |
|
|
|
|
|
|
|
|
if ',' in text: |
|
|
items = [item.strip() for item in text.split(',')] |
|
|
|
|
|
cleaned_items = [] |
|
|
for item in items: |
|
|
item = re.sub(r'^["\'\[\(]|["\'\]\)]$', '', item.strip()) |
|
|
if item: |
|
|
cleaned_items.append(item) |
|
|
return ', '.join(cleaned_items) |
|
|
|
|
|
|
|
|
|
|
|
words = text.split() |
|
|
if len(words) > 1 and words[0].lower() in ['the', 'a', 'an']: |
|
|
|
|
|
remaining = ' '.join(words[1:]) |
|
|
if remaining and len(remaining) > 2: |
|
|
text = remaining |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def load_formatter_prompt() -> str: |
|
|
"""Load the formatting prompt""" |
|
|
try: |
|
|
with open("archive/prompts/verification_prompt.txt", "r") as f: |
|
|
return f.read() |
|
|
except FileNotFoundError: |
|
|
return """ |
|
|
You are a final answer formatter ensuring compliance with GAIA benchmark requirements. |
|
|
|
|
|
Your task is to extract the precise final answer from a comprehensive response. |
|
|
|
|
|
CRITICAL FORMATTING RULES: |
|
|
• Single number → write the number only (no commas, units, or symbols) |
|
|
• Single string/phrase → write the text only; omit articles unless required |
|
|
• List → separate elements with comma and space |
|
|
• NEVER include surrounding text like "Final Answer:", quotes, brackets, or markdown |
|
|
• The response must contain ONLY the answer itself |
|
|
|
|
|
Examples: |
|
|
Question: "What is 25 + 17?" |
|
|
Draft: "After calculating, the answer is 42." |
|
|
Formatted: "42" |
|
|
|
|
|
Question: "What is the capital of France?" |
|
|
Draft: "The capital of France is Paris." |
|
|
Formatted: "Paris" |
|
|
|
|
|
Question: "List the first 3 prime numbers" |
|
|
Draft: "The first three prime numbers are 2, 3, and 5." |
|
|
Formatted: "2, 3, 5" |
|
|
|
|
|
Extract ONLY the final answer following these rules exactly. |
|
|
""" |
|
|
|
|
|
|
|
|
def answer_formatter(state: Dict[str, Any]) -> Command: |
|
|
""" |
|
|
Answer Formatter node that creates GAIA-compliant final answers. |
|
|
|
|
|
Takes the draft_answer and formats it according to GAIA requirements. |
|
|
Returns Command to END the workflow. |
|
|
""" |
|
|
|
|
|
print("📝 Answer Formatter: Creating final formatted answer...") |
|
|
|
|
|
try: |
|
|
|
|
|
formatter_prompt = load_formatter_prompt() |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model="llama-3.3-70b-versatile", |
|
|
temperature=0.0, |
|
|
max_tokens=512 |
|
|
) |
|
|
|
|
|
|
|
|
with agent_span( |
|
|
"formatter", |
|
|
metadata={ |
|
|
"draft_answer_length": len(state.get("draft_answer", "")), |
|
|
"user_id": state.get("user_id", "unknown"), |
|
|
"session_id": state.get("session_id", "unknown") |
|
|
} |
|
|
) as span: |
|
|
|
|
|
|
|
|
draft_answer = state.get("draft_answer", "") |
|
|
|
|
|
if not draft_answer: |
|
|
final_answer = "No answer could be generated." |
|
|
else: |
|
|
|
|
|
messages = state.get("messages", []) |
|
|
user_query = "" |
|
|
for msg in messages: |
|
|
if isinstance(msg, HumanMessage): |
|
|
user_query = msg.content |
|
|
break |
|
|
|
|
|
|
|
|
formatting_request = f""" |
|
|
Extract the final answer from this comprehensive response following GAIA formatting rules: |
|
|
|
|
|
Original Question: {user_query} |
|
|
|
|
|
Draft Response: |
|
|
{draft_answer} |
|
|
|
|
|
Instructions: |
|
|
1. Identify the core answer within the draft response |
|
|
2. Remove all explanatory text, prefixes, and formatting |
|
|
3. Apply GAIA formatting rules exactly |
|
|
4. Return ONLY the final answer |
|
|
|
|
|
What is the properly formatted final answer? |
|
|
""" |
|
|
|
|
|
|
|
|
formatting_messages = [ |
|
|
SystemMessage(content=formatter_prompt), |
|
|
HumanMessage(content=formatting_request) |
|
|
] |
|
|
|
|
|
|
|
|
response = llm.invoke(formatting_messages) |
|
|
|
|
|
|
|
|
final_answer = extract_final_answer(response.content) |
|
|
|
|
|
|
|
|
if not final_answer or len(final_answer) < 1: |
|
|
print("⚠️ LLM formatting failed, using direct extraction") |
|
|
final_answer = extract_final_answer(draft_answer) |
|
|
|
|
|
|
|
|
if not final_answer: |
|
|
final_answer = "Unable to extract a clear answer." |
|
|
|
|
|
print(f"📝 Answer Formatter: Final answer = '{final_answer}'") |
|
|
|
|
|
|
|
|
if span: |
|
|
span.update_trace(output={"final_answer": final_answer}) |
|
|
|
|
|
|
|
|
return Command( |
|
|
goto="__end__", |
|
|
update={ |
|
|
"final_answer": final_answer |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Answer Formatter Error: {e}") |
|
|
|
|
|
|
|
|
return Command( |
|
|
goto="__end__", |
|
|
update={ |
|
|
"final_answer": f"Error formatting answer: {str(e)}" |
|
|
} |
|
|
) |