Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Final Answer Tool for GAIA Agent System | |
| Extracts precise, EXACT MATCH compliant answers from agent results | |
| """ | |
| import re | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from models.qwen_client import QwenClient, ModelTier | |
| logger = logging.getLogger(__name__) | |
| class FinalAnswerTool: | |
| """ | |
| Tool for extracting precise, GAIA-compliant final answers | |
| Ensures EXACT MATCH compatibility for Unit 4 API submission | |
| """ | |
| def __init__(self, llm_client: QwenClient): | |
| self.llm_client = llm_client | |
| def extract_final_answer(self, question: str, agent_results: str, question_type: str = "") -> Dict[str, Any]: | |
| """ | |
| Extract the precise final answer in GAIA-compliant format | |
| Args: | |
| question: The original GAIA question | |
| agent_results: Combined results from multiple agents | |
| question_type: Type of question (for specialized extraction) | |
| Returns: | |
| Dict with extracted answer, confidence, and reasoning | |
| """ | |
| try: | |
| logger.info("🎯 Extracting GAIA-compliant final answer") | |
| # Create specialized extraction prompt | |
| extraction_prompt = self._create_extraction_prompt(question, agent_results, question_type) | |
| # Use 72B model for precise extraction | |
| result = self.llm_client.generate( | |
| extraction_prompt, | |
| tier=ModelTier.COMPLEX, # 72B model | |
| max_tokens=50 # Force concise answers | |
| ) | |
| if not result.success: | |
| logger.error("Final answer extraction failed") | |
| return { | |
| "answer": "Processing failed", | |
| "confidence": 0.0, | |
| "reasoning": f"Extraction failed: {result.response}" | |
| } | |
| # Parse and clean the extracted answer | |
| extracted_answer = self._clean_answer(result.response, question, question_type) | |
| # Validate answer format | |
| validation_result = self._validate_answer(extracted_answer, question_type) | |
| logger.info(f"✅ Final answer extracted: '{extracted_answer}'") | |
| return { | |
| "answer": extracted_answer, | |
| "confidence": validation_result["confidence"], | |
| "reasoning": f"Extracted using 72B model. Validation: {validation_result['status']}" | |
| } | |
| except Exception as e: | |
| error_msg = f"Final answer extraction error: {str(e)}" | |
| logger.error(error_msg) | |
| return { | |
| "answer": "Extraction error", | |
| "confidence": 0.0, | |
| "reasoning": error_msg | |
| } | |
| def _create_extraction_prompt(self, question: str, agent_results: str, question_type: str) -> str: | |
| """Create specialized extraction prompt based on question type""" | |
| base_prompt = f""" | |
| CRITICAL: This is for GAIA benchmark evaluation using EXACT MATCH comparison. | |
| Your response must be ONLY the precise answer - no explanations, no "FINAL ANSWER:", no extra text. | |
| Question: {question} | |
| Agent Analysis Results: | |
| {agent_results} | |
| EXTRACTION RULES: | |
| """ | |
| # Add type-specific rules | |
| if "mathematical" in question_type.lower() or any(word in question.lower() for word in ["how many", "count", "number", "albums"]): | |
| base_prompt += """ | |
| - If asking for a count/number: respond with ONLY the number (e.g., "5", "23", "0") | |
| - If asking for calculation: respond with ONLY the result (e.g., "42", "3.14", "100") | |
| - No units unless specifically requested in the question | |
| """ | |
| elif "text_manipulation" in question_type.lower() or "reverse" in question.lower(): | |
| base_prompt += """ | |
| - If text is reversed: provide the corrected text | |
| - If asking for opposite: provide ONLY the opposite word (e.g., "right" for opposite of "left") | |
| - If asking to decode: provide ONLY the decoded answer | |
| """ | |
| elif "yes" in question.lower() or "true" in question.lower() or "false" in question.lower(): | |
| base_prompt += """ | |
| - If yes/no question: respond with ONLY "yes" or "no" (lowercase) | |
| - If true/false question: respond with ONLY "true" or "false" (lowercase) | |
| """ | |
| elif any(word in question.lower() for word in ["name", "who", "which person"]): | |
| base_prompt += """ | |
| - If asking for a name: provide ONLY the name (e.g., "John Smith", "Einstein") | |
| - If asking for first name only: provide ONLY first name (e.g., "John") | |
| - If asking for last name only: provide ONLY last name (e.g., "Smith") | |
| """ | |
| elif any(word in question.lower() for word in ["where", "location", "city", "country"]): | |
| base_prompt += """ | |
| - If asking for location: provide ONLY the location name (e.g., "Paris", "USA", "New York") | |
| - No additional descriptors unless specifically requested | |
| """ | |
| else: | |
| base_prompt += """ | |
| - Provide ONLY the direct answer to the question | |
| - No explanations, context, or additional information | |
| - Be as concise as possible while being accurate | |
| """ | |
| base_prompt += """ | |
| EXAMPLES OF CORRECT FORMAT: | |
| - Question: "How many albums?" → Answer: "5" | |
| - Question: "What is the opposite of left?" → Answer: "right" | |
| - Question: "True or false?" → Answer: "true" | |
| - Question: "Who discovered X?" → Answer: "Einstein" | |
| - Question: "Which city?" → Answer: "London" | |
| Extract the precise answer NOW:""" | |
| return base_prompt | |
| def _clean_answer(self, raw_answer: str, question: str, question_type: str) -> str: | |
| """Clean and format the extracted answer""" | |
| # Remove common unwanted prefixes/suffixes | |
| answer = raw_answer.strip() | |
| # Remove common prefixes | |
| prefixes_to_remove = [ | |
| "the answer is", | |
| "answer:", | |
| "final answer:", | |
| "result:", | |
| "response:", | |
| "conclusion:", | |
| "based on", | |
| "according to", | |
| "from the", | |
| ] | |
| for prefix in prefixes_to_remove: | |
| if answer.lower().startswith(prefix): | |
| answer = answer[len(prefix):].strip() | |
| # Remove quotes if they wrap the entire answer | |
| if answer.startswith('"') and answer.endswith('"'): | |
| answer = answer[1:-1] | |
| if answer.startswith("'") and answer.endswith("'"): | |
| answer = answer[1:-1] | |
| # AGGRESSIVE LENGTH ENFORCEMENT FOR GAIA | |
| # If answer is too long, extract the core information | |
| if len(answer) > 50: | |
| # For different question types, extract differently | |
| if "mathematical" in question_type.lower() or any(word in question.lower() for word in ["how many", "count", "number", "albums"]): | |
| # Extract just the number for mathematical questions | |
| number_match = re.search(r'-?\d+(?:\.\d+)?', answer) | |
| if number_match: | |
| answer = number_match.group() | |
| elif "name" in question_type.lower() or any(word in question.lower() for word in ["who", "name"]): | |
| # Extract just the name (first few words) | |
| words = answer.split() | |
| if len(words) > 3: | |
| answer = ' '.join(words[:3]) # Keep only first 3 words for names | |
| elif "location" in question_type.lower() or any(word in question.lower() for word in ["where", "city", "country"]): | |
| # Extract just the location name | |
| words = answer.split() | |
| if len(words) > 2: | |
| answer = ' '.join(words[:2]) # Keep only first 2 words for locations | |
| elif "yes_no" in question_type.lower() or any(word in answer.lower() for word in ["yes", "no", "true", "false"]): | |
| # Extract yes/no/true/false | |
| if any(word in answer.lower() for word in ["yes", "no", "true", "false"]): | |
| for word in answer.lower().split(): | |
| if word in ["yes", "no", "true", "false"]: | |
| answer = word | |
| break | |
| else: | |
| # For other types, take first sentence or clause | |
| sentences = re.split(r'[.!?]', answer) | |
| if sentences: | |
| answer = sentences[0].strip() | |
| # If still too long, take first clause | |
| if len(answer) > 30: | |
| clauses = re.split(r'[,;:]', answer) | |
| if clauses: | |
| answer = clauses[0].strip() | |
| # Handle specific formatting based on question type | |
| if "text_manipulation" in question_type.lower(): | |
| # For reversed text questions, ensure clean output | |
| if len(answer.split()) == 1: # Single word answer | |
| answer = answer.lower() | |
| # Final aggressive truncation if still too long | |
| if len(answer) > 40: | |
| # Split into words and take as many as fit | |
| words = answer.split() | |
| truncated_words = [] | |
| current_length = 0 | |
| for word in words: | |
| if current_length + len(word) + 1 <= 40: | |
| truncated_words.append(word) | |
| current_length += len(word) + 1 | |
| else: | |
| break | |
| if truncated_words: | |
| answer = ' '.join(truncated_words) | |
| else: | |
| # Last resort - take first 40 characters | |
| answer = answer[:40].strip() | |
| # Remove any trailing punctuation that's not part of the answer | |
| answer = answer.rstrip('.,!?;:') | |
| return answer.strip() | |
| def _validate_answer(self, answer: str, question_type: str) -> Dict[str, Any]: | |
| """Validate the extracted answer format""" | |
| if not answer: | |
| return {"status": "empty_answer", "confidence": 0.0} | |
| # Check length - GAIA answers should be concise | |
| if len(answer) > 100: | |
| return {"status": "too_long", "confidence": 0.3} | |
| # Type-specific validation | |
| if "mathematical" in question_type.lower(): | |
| if re.match(r'^-?\d+(?:\.\d+)?$', answer): | |
| return {"status": "valid_number", "confidence": 0.9} | |
| else: | |
| return {"status": "invalid_number_format", "confidence": 0.5} | |
| elif "yes_no" in question_type.lower(): | |
| if answer.lower() in ["yes", "no", "true", "false"]: | |
| return {"status": "valid_boolean", "confidence": 0.9} | |
| else: | |
| return {"status": "invalid_boolean_format", "confidence": 0.4} | |
| # General validation - prefer short, direct answers | |
| if len(answer) <= 20: | |
| return {"status": "concise_answer", "confidence": 0.8} | |
| elif len(answer) <= 50: | |
| return {"status": "moderate_length", "confidence": 0.6} | |
| else: | |
| return {"status": "long_answer", "confidence": 0.4} |