|
|
|
|
|
""" |
|
|
Final Answer Tool for GAIA Agent System |
|
|
Extracts precise, EXACT MATCH compliant answers from agent results |
|
|
""" |
|
|
|
|
|
import re |
|
|
import logging |
|
|
from typing import Dict, Any, Optional |
|
|
|
|
|
from models.qwen_client import QwenClient, ModelTier |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class FinalAnswerTool: |
|
|
""" |
|
|
Tool for extracting precise, GAIA-compliant final answers |
|
|
Ensures EXACT MATCH compatibility for Unit 4 API submission |
|
|
""" |
|
|
|
|
|
def __init__(self, llm_client: QwenClient): |
|
|
self.llm_client = llm_client |
|
|
|
|
|
def extract_final_answer(self, question: str, agent_results: str, question_type: str = "") -> Dict[str, Any]: |
|
|
""" |
|
|
Extract the precise final answer in GAIA-compliant format |
|
|
|
|
|
Args: |
|
|
question: The original GAIA question |
|
|
agent_results: Combined results from multiple agents |
|
|
question_type: Type of question (for specialized extraction) |
|
|
|
|
|
Returns: |
|
|
Dict with extracted answer, confidence, and reasoning |
|
|
""" |
|
|
try: |
|
|
logger.info("🎯 Extracting GAIA-compliant final answer") |
|
|
|
|
|
|
|
|
extraction_prompt = self._create_extraction_prompt(question, agent_results, question_type) |
|
|
|
|
|
|
|
|
result = self.llm_client.generate( |
|
|
extraction_prompt, |
|
|
tier=ModelTier.COMPLEX, |
|
|
max_tokens=50 |
|
|
) |
|
|
|
|
|
if not result.success: |
|
|
logger.error("Final answer extraction failed") |
|
|
return { |
|
|
"answer": "Processing failed", |
|
|
"confidence": 0.0, |
|
|
"reasoning": f"Extraction failed: {result.response}" |
|
|
} |
|
|
|
|
|
|
|
|
extracted_answer = self._clean_answer(result.response, question, question_type) |
|
|
|
|
|
|
|
|
validation_result = self._validate_answer(extracted_answer, question_type) |
|
|
|
|
|
logger.info(f"✅ Final answer extracted: '{extracted_answer}'") |
|
|
|
|
|
return { |
|
|
"answer": extracted_answer, |
|
|
"confidence": validation_result["confidence"], |
|
|
"reasoning": f"Extracted using 72B model. Validation: {validation_result['status']}" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Final answer extraction error: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
return { |
|
|
"answer": "Extraction error", |
|
|
"confidence": 0.0, |
|
|
"reasoning": error_msg |
|
|
} |
|
|
|
|
|
def _create_extraction_prompt(self, question: str, agent_results: str, question_type: str) -> str: |
|
|
"""Create specialized extraction prompt based on question type""" |
|
|
|
|
|
base_prompt = f""" |
|
|
CRITICAL: This is for GAIA benchmark evaluation using EXACT MATCH comparison. |
|
|
Your response must be ONLY the precise answer - no explanations, no "FINAL ANSWER:", no extra text. |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Agent Analysis Results: |
|
|
{agent_results} |
|
|
|
|
|
EXTRACTION RULES: |
|
|
""" |
|
|
|
|
|
|
|
|
if "mathematical" in question_type.lower() or any(word in question.lower() for word in ["how many", "count", "number", "albums"]): |
|
|
base_prompt += """ |
|
|
- If asking for a count/number: respond with ONLY the number (e.g., "5", "23", "0") |
|
|
- If asking for calculation: respond with ONLY the result (e.g., "42", "3.14", "100") |
|
|
- No units unless specifically requested in the question |
|
|
""" |
|
|
elif "text_manipulation" in question_type.lower() or "reverse" in question.lower(): |
|
|
base_prompt += """ |
|
|
- If text is reversed: provide the corrected text |
|
|
- If asking for opposite: provide ONLY the opposite word (e.g., "right" for opposite of "left") |
|
|
- If asking to decode: provide ONLY the decoded answer |
|
|
""" |
|
|
elif "yes" in question.lower() or "true" in question.lower() or "false" in question.lower(): |
|
|
base_prompt += """ |
|
|
- If yes/no question: respond with ONLY "yes" or "no" (lowercase) |
|
|
- If true/false question: respond with ONLY "true" or "false" (lowercase) |
|
|
""" |
|
|
elif any(word in question.lower() for word in ["name", "who", "which person"]): |
|
|
base_prompt += """ |
|
|
- If asking for a name: provide ONLY the name (e.g., "John Smith", "Einstein") |
|
|
- If asking for first name only: provide ONLY first name (e.g., "John") |
|
|
- If asking for last name only: provide ONLY last name (e.g., "Smith") |
|
|
""" |
|
|
elif any(word in question.lower() for word in ["where", "location", "city", "country"]): |
|
|
base_prompt += """ |
|
|
- If asking for location: provide ONLY the location name (e.g., "Paris", "USA", "New York") |
|
|
- No additional descriptors unless specifically requested |
|
|
""" |
|
|
else: |
|
|
base_prompt += """ |
|
|
- Provide ONLY the direct answer to the question |
|
|
- No explanations, context, or additional information |
|
|
- Be as concise as possible while being accurate |
|
|
""" |
|
|
|
|
|
base_prompt += """ |
|
|
|
|
|
EXAMPLES OF CORRECT FORMAT: |
|
|
- Question: "How many albums?" → Answer: "5" |
|
|
- Question: "What is the opposite of left?" → Answer: "right" |
|
|
- Question: "True or false?" → Answer: "true" |
|
|
- Question: "Who discovered X?" → Answer: "Einstein" |
|
|
- Question: "Which city?" → Answer: "London" |
|
|
|
|
|
Extract the precise answer NOW:""" |
|
|
|
|
|
return base_prompt |
|
|
|
|
|
def _clean_answer(self, raw_answer: str, question: str, question_type: str) -> str: |
|
|
"""Clean and format the extracted answer""" |
|
|
|
|
|
|
|
|
answer = raw_answer.strip() |
|
|
|
|
|
|
|
|
prefixes_to_remove = [ |
|
|
"the answer is", |
|
|
"answer:", |
|
|
"final answer:", |
|
|
"result:", |
|
|
"response:", |
|
|
"conclusion:", |
|
|
"based on", |
|
|
"according to", |
|
|
"from the", |
|
|
] |
|
|
|
|
|
for prefix in prefixes_to_remove: |
|
|
if answer.lower().startswith(prefix): |
|
|
answer = answer[len(prefix):].strip() |
|
|
|
|
|
|
|
|
if answer.startswith('"') and answer.endswith('"'): |
|
|
answer = answer[1:-1] |
|
|
if answer.startswith("'") and answer.endswith("'"): |
|
|
answer = answer[1:-1] |
|
|
|
|
|
|
|
|
|
|
|
if len(answer) > 50: |
|
|
|
|
|
if "mathematical" in question_type.lower() or any(word in question.lower() for word in ["how many", "count", "number", "albums"]): |
|
|
|
|
|
number_match = re.search(r'-?\d+(?:\.\d+)?', answer) |
|
|
if number_match: |
|
|
answer = number_match.group() |
|
|
elif "name" in question_type.lower() or any(word in question.lower() for word in ["who", "name"]): |
|
|
|
|
|
words = answer.split() |
|
|
if len(words) > 3: |
|
|
answer = ' '.join(words[:3]) |
|
|
elif "location" in question_type.lower() or any(word in question.lower() for word in ["where", "city", "country"]): |
|
|
|
|
|
words = answer.split() |
|
|
if len(words) > 2: |
|
|
answer = ' '.join(words[:2]) |
|
|
elif "yes_no" in question_type.lower() or any(word in answer.lower() for word in ["yes", "no", "true", "false"]): |
|
|
|
|
|
if any(word in answer.lower() for word in ["yes", "no", "true", "false"]): |
|
|
for word in answer.lower().split(): |
|
|
if word in ["yes", "no", "true", "false"]: |
|
|
answer = word |
|
|
break |
|
|
else: |
|
|
|
|
|
sentences = re.split(r'[.!?]', answer) |
|
|
if sentences: |
|
|
answer = sentences[0].strip() |
|
|
|
|
|
if len(answer) > 30: |
|
|
clauses = re.split(r'[,;:]', answer) |
|
|
if clauses: |
|
|
answer = clauses[0].strip() |
|
|
|
|
|
|
|
|
if "text_manipulation" in question_type.lower(): |
|
|
|
|
|
if len(answer.split()) == 1: |
|
|
answer = answer.lower() |
|
|
|
|
|
|
|
|
if len(answer) > 40: |
|
|
|
|
|
words = answer.split() |
|
|
truncated_words = [] |
|
|
current_length = 0 |
|
|
for word in words: |
|
|
if current_length + len(word) + 1 <= 40: |
|
|
truncated_words.append(word) |
|
|
current_length += len(word) + 1 |
|
|
else: |
|
|
break |
|
|
if truncated_words: |
|
|
answer = ' '.join(truncated_words) |
|
|
else: |
|
|
|
|
|
answer = answer[:40].strip() |
|
|
|
|
|
|
|
|
answer = answer.rstrip('.,!?;:') |
|
|
|
|
|
return answer.strip() |
|
|
|
|
|
def _validate_answer(self, answer: str, question_type: str) -> Dict[str, Any]: |
|
|
"""Validate the extracted answer format""" |
|
|
|
|
|
if not answer: |
|
|
return {"status": "empty_answer", "confidence": 0.0} |
|
|
|
|
|
|
|
|
if len(answer) > 100: |
|
|
return {"status": "too_long", "confidence": 0.3} |
|
|
|
|
|
|
|
|
if "mathematical" in question_type.lower(): |
|
|
if re.match(r'^-?\d+(?:\.\d+)?$', answer): |
|
|
return {"status": "valid_number", "confidence": 0.9} |
|
|
else: |
|
|
return {"status": "invalid_number_format", "confidence": 0.5} |
|
|
|
|
|
elif "yes_no" in question_type.lower(): |
|
|
if answer.lower() in ["yes", "no", "true", "false"]: |
|
|
return {"status": "valid_boolean", "confidence": 0.9} |
|
|
else: |
|
|
return {"status": "invalid_boolean_format", "confidence": 0.4} |
|
|
|
|
|
|
|
|
if len(answer) <= 20: |
|
|
return {"status": "concise_answer", "confidence": 0.8} |
|
|
elif len(answer) <= 50: |
|
|
return {"status": "moderate_length", "confidence": 0.6} |
|
|
else: |
|
|
return {"status": "long_answer", "confidence": 0.4} |