Spaces:
Configuration error
Configuration error
oremaz
commited on
Commit
·
81cc195
1
Parent(s):
98e38b0
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -1,32 +1,53 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
from llama_index.core import VectorStoreIndex, Document
|
| 4 |
-
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
| 5 |
-
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 6 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 7 |
-
from llama_index.core.retrievers import VectorIndexRetriever
|
| 8 |
-
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 9 |
-
from llama_index.readers.file import PDFReader, DocxReader, CSVReader, ImageReader
|
| 10 |
import os
|
| 11 |
-
from typing import List, Dict, Any
|
| 12 |
-
from llama_index.tools.arxiv import ArxivToolSpec
|
| 13 |
-
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
| 14 |
import re
|
| 15 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import wandb
|
| 17 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from llama_index.core.callbacks.base import CallbackManager
|
| 19 |
from llama_index.core.callbacks.llama_debug import LlamaDebugHandler
|
| 20 |
-
from llama_index.core import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
| 24 |
-
import
|
| 25 |
-
import
|
| 26 |
-
from llama_index.
|
| 27 |
-
from llama_index.
|
| 28 |
-
from llama_index.
|
| 29 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
|
|
@@ -63,25 +84,6 @@ Settings.llm = proj_llm
|
|
| 63 |
Settings.embed_model = embed_model
|
| 64 |
Settings.callback_manager = callback_manager
|
| 65 |
|
| 66 |
-
import os
|
| 67 |
-
from typing import List
|
| 68 |
-
from urllib.parse import urlparse
|
| 69 |
-
|
| 70 |
-
from llama_index.core.tools import FunctionTool
|
| 71 |
-
from llama_index.core import Document
|
| 72 |
-
|
| 73 |
-
# --- Import all required official LlamaIndex Readers ---
|
| 74 |
-
from llama_index.readers.file import (
|
| 75 |
-
PDFReader,
|
| 76 |
-
DocxReader,
|
| 77 |
-
CSVReader,
|
| 78 |
-
PandasExcelReader,
|
| 79 |
-
ImageReader,
|
| 80 |
-
)
|
| 81 |
-
from llama_index.readers.json import JSONReader
|
| 82 |
-
from llama_index.readers.web import TrafilaturaWebReader
|
| 83 |
-
from llama_index.readers.youtube_transcript import YoutubeTranscriptReader
|
| 84 |
-
from llama_index.readers.audiotranscribe.openai import OpenAIAudioTranscriptReader
|
| 85 |
|
| 86 |
def read_and_parse_content(input_path: str) -> List[Document]:
|
| 87 |
"""
|
|
@@ -157,12 +159,6 @@ read_and_parse_tool = FunctionTool.from_defaults(
|
|
| 157 |
)
|
| 158 |
)
|
| 159 |
|
| 160 |
-
from typing import List
|
| 161 |
-
from llama_index.core import VectorStoreIndex, Document, Settings
|
| 162 |
-
from llama_index.core.tools import QueryEngineTool
|
| 163 |
-
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
| 164 |
-
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 165 |
-
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 166 |
|
| 167 |
def create_rag_tool(documents: List[Document]) -> QueryEngineTool:
|
| 168 |
"""
|
|
@@ -223,11 +219,6 @@ def create_rag_tool(documents: List[Document]) -> QueryEngineTool:
|
|
| 223 |
|
| 224 |
return rag_engine_tool
|
| 225 |
|
| 226 |
-
|
| 227 |
-
import re
|
| 228 |
-
from llama_index.core.tools import FunctionTool
|
| 229 |
-
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
| 230 |
-
|
| 231 |
# 1. Create the base DuckDuckGo search tool from the official spec.
|
| 232 |
# This tool returns text summaries of search results, not just URLs.
|
| 233 |
base_duckduckgo_tool = DuckDuckGoSearchToolSpec().to_tool_list()[0]
|
|
@@ -442,89 +433,128 @@ generate_code_tool = FunctionTool.from_defaults(
|
|
| 442 |
)
|
| 443 |
)
|
| 444 |
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
# Vérification du token HuggingFace
|
| 451 |
-
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 452 |
-
if not hf_token:
|
| 453 |
-
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is required")
|
| 454 |
-
|
| 455 |
-
# Agent coordinateur principal qui utilise les agents spécialisés comme tools
|
| 456 |
-
self.coordinator = ReActAgent(
|
| 457 |
-
name="GAIACoordinator",
|
| 458 |
-
description="Main GAIA coordinator that uses specialized capabilities as intelligent tools",
|
| 459 |
-
system_prompt="""
|
| 460 |
-
You are the main GAIA coordinator using ReAct reasoning methodology.
|
| 461 |
-
|
| 462 |
-
You have access to THREE specialist tools:
|
| 463 |
-
|
| 464 |
-
**1. analysis_tool** - Advanced multimodal document analysis specialist
|
| 465 |
-
- Use for: PDF, Word, CSV, image file analysis
|
| 466 |
-
- When to use: Questions with file attachments, document analysis, data extraction
|
| 467 |
-
|
| 468 |
-
**2. research_tool** - Intelligent research specialist with automatic routing
|
| 469 |
-
- Use for: External knowledge, current events, scientific papers
|
| 470 |
-
- When to use: Questions requiring external knowledge, factual verification, current information
|
| 471 |
-
|
| 472 |
-
**3. code_tool** - Advanced computational specialist using ReAct reasoning
|
| 473 |
-
- Use for: Mathematical calculations, data processing, logical operations
|
| 474 |
-
- Capabilities: Generates and executes Python, handles complex computations, step-by-step problem solving
|
| 475 |
-
- When to use: Precise calculations, data manipulation, mathematical problem solving
|
| 476 |
-
|
| 477 |
-
**4. code_execution_tool** - Use only to execute .py file
|
| 478 |
-
|
| 479 |
-
CRITICAL: Your final answer must be EXACT and CONCISE as required by GAIA format : NO explanations, NO additional text, ONLY the precise answer
|
| 480 |
-
""",
|
| 481 |
-
llm=proj_llm,
|
| 482 |
-
tools=[analysis_tool, research_tool, code_tool, code_execution_tool],
|
| 483 |
-
max_steps=10,
|
| 484 |
-
verbose = True,
|
| 485 |
-
callback_manager=callback_manager,
|
| 486 |
-
|
| 487 |
-
)
|
| 488 |
-
|
| 489 |
-
async def format_gaia_answer(self, raw_response: str, original_question: str) -> str:
|
| 490 |
-
"""
|
| 491 |
-
Post-process the agent response to extract the exact GAIA format answer
|
| 492 |
-
"""
|
| 493 |
-
format_prompt = f"""Extract the exact answer from the response below. Follow GAIA formatting rules strictly.
|
| 494 |
-
|
| 495 |
-
Examples:
|
| 496 |
-
|
| 497 |
-
Question: "How many research papers were published by the university between 2010 and 2020?"
|
| 498 |
-
Response: "Based on my analysis of the data, I found that the university published 156 research papers between 2010 and 2020."
|
| 499 |
-
Answer: 156
|
| 500 |
-
|
| 501 |
-
Question: "What is the last name of the software engineer mentioned in the report?"
|
| 502 |
-
Response: "After reviewing the document, the software engineer mentioned is Dr. Martinez who developed the system."
|
| 503 |
-
Answer: Martinez
|
| 504 |
-
|
| 505 |
-
Question: "List the programming languages from this job description, alphabetized:"
|
| 506 |
-
Response: "The job description mentions several programming languages including Python, Java, C++, and JavaScript. When alphabetized, these are: C++, Java, JavaScript, Python"
|
| 507 |
-
Answer: C++, Java, JavaScript, Python
|
| 508 |
-
|
| 509 |
-
Question: "Give only the first name of the developer who created the framework."
|
| 510 |
-
Response: "The framework was created by Sarah Johnson, a senior developer at the company."
|
| 511 |
-
Answer: Sarah
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
try:
|
|
|
|
| 528 |
formatting_response = proj_llm.complete(format_prompt)
|
| 529 |
answer = str(formatting_response).strip()
|
| 530 |
|
|
@@ -533,10 +563,107 @@ class EnhancedGAIAAgent:
|
|
| 533 |
answer = answer.split("Answer:")[-1].strip()
|
| 534 |
|
| 535 |
return answer
|
| 536 |
-
|
| 537 |
except Exception as e:
|
| 538 |
-
print(f"
|
| 539 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
| 542 |
"""Download file associated with task_id"""
|
|
@@ -544,7 +671,6 @@ class EnhancedGAIAAgent:
|
|
| 544 |
response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
|
| 545 |
response.raise_for_status()
|
| 546 |
|
| 547 |
-
# Save file locally
|
| 548 |
filename = f"task_{task_id}_file"
|
| 549 |
with open(filename, 'wb') as f:
|
| 550 |
f.write(response.content)
|
|
@@ -552,53 +678,61 @@ class EnhancedGAIAAgent:
|
|
| 552 |
except Exception as e:
|
| 553 |
print(f"Failed to download file for task {task_id}: {e}")
|
| 554 |
return None
|
| 555 |
-
|
| 556 |
-
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
| 557 |
-
question = question_data.get("Question", "")
|
| 558 |
-
task_id = question_data.get("task_id", "")
|
| 559 |
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
try:
|
| 562 |
file_path = self.download_gaia_file(task_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
except Exception as e:
|
| 564 |
-
print(f"Failed to download file for task {task_id}: {e}")
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
-
|
| 573 |
-
1. If a file is available, use the analysis_tool (except for .py files).
|
| 574 |
-
2. If a link is in the question, use the research_tool.
|
| 575 |
-
"""
|
| 576 |
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
print(f"Formatted answer: {formatted_answer}")
|
| 598 |
-
|
| 599 |
-
return formatted_answer
|
| 600 |
-
|
| 601 |
-
except Exception as e:
|
| 602 |
-
error_msg = f"Error processing question: {str(e)}"
|
| 603 |
-
print(error_msg)
|
| 604 |
-
return error_msg
|
|
|
|
| 1 |
+
# Standard library imports
|
| 2 |
+
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
| 4 |
import re
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
from urllib.parse import urlparse
|
| 7 |
+
|
| 8 |
+
# Third-party imports
|
| 9 |
+
import requests
|
| 10 |
import wandb
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
+
|
| 13 |
+
# LlamaIndex core imports
|
| 14 |
+
from llama_index.core import VectorStoreIndex, Document, Settings
|
| 15 |
+
from llama_index.core.agent.workflow import FunctionAgent, ReActAgent, AgentStream
|
| 16 |
from llama_index.core.callbacks.base import CallbackManager
|
| 17 |
from llama_index.core.callbacks.llama_debug import LlamaDebugHandler
|
| 18 |
+
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
| 19 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 20 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 21 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
| 22 |
+
from llama_index.core.tools import FunctionTool
|
| 23 |
+
from llama_index.core.workflow import Context
|
| 24 |
|
| 25 |
+
# LlamaIndex specialized imports
|
| 26 |
+
from llama_index.callbacks.wandb import WandbCallbackHandler
|
| 27 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 28 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
| 29 |
+
from llama_index.readers.audiotranscribe.openai import OpenAIAudioTranscriptReader
|
| 30 |
+
from llama_index.readers.file import PDFReader, DocxReader, CSVReader, ImageReader, PandasExcelReader
|
| 31 |
+
from llama_index.readers.json import JSONReader
|
| 32 |
+
from llama_index.readers.web import TrafilaturaWebReader
|
| 33 |
+
from llama_index.readers.youtube_transcript import YoutubeTranscriptReader
|
| 34 |
+
from llama_index.tools.arxiv import ArxivToolSpec
|
| 35 |
+
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
| 36 |
+
|
| 37 |
+
# --- Import all required official LlamaIndex Readers ---
|
| 38 |
+
from llama_index.readers.file import (
|
| 39 |
+
PDFReader,
|
| 40 |
+
DocxReader,
|
| 41 |
+
CSVReader,
|
| 42 |
+
PandasExcelReader,
|
| 43 |
+
ImageReader,
|
| 44 |
+
)
|
| 45 |
+
from typing import List
|
| 46 |
+
from llama_index.core import VectorStoreIndex, Document, Settings
|
| 47 |
+
from llama_index.core.tools import QueryEngineTool
|
| 48 |
+
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
| 49 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 50 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 51 |
|
| 52 |
|
| 53 |
|
|
|
|
| 84 |
Settings.embed_model = embed_model
|
| 85 |
Settings.callback_manager = callback_manager
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
def read_and_parse_content(input_path: str) -> List[Document]:
|
| 89 |
"""
|
|
|
|
| 159 |
)
|
| 160 |
)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
def create_rag_tool(documents: List[Document]) -> QueryEngineTool:
|
| 164 |
"""
|
|
|
|
| 219 |
|
| 220 |
return rag_engine_tool
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# 1. Create the base DuckDuckGo search tool from the official spec.
|
| 223 |
# This tool returns text summaries of search results, not just URLs.
|
| 224 |
base_duckduckgo_tool = DuckDuckGoSearchToolSpec().to_tool_list()[0]
|
|
|
|
| 433 |
)
|
| 434 |
)
|
| 435 |
|
| 436 |
+
def intelligent_final_answer_tool(agent_response: str, question: str) -> str:
|
| 437 |
+
"""
|
| 438 |
+
Enhanced final answer tool with LLM-based reformatting capability.
|
| 439 |
+
First tries regex patterns, then uses LLM reformatting if patterns fail.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
+
Args:
|
| 442 |
+
agent_response: The raw response from agent reasoning
|
| 443 |
+
question: The original question for context
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
Exact answer in GAIA format with validation
|
| 447 |
+
"""
|
| 448 |
|
| 449 |
+
# Define formatting patterns for different question types
|
| 450 |
+
format_patterns = {
|
| 451 |
+
'number': r'(\d+(?:\.\d+)?(?:e[+-]?\d+)?)',
|
| 452 |
+
'name': r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
|
| 453 |
+
'list': r'([A-Za-z0-9,\s]+)',
|
| 454 |
+
'country_code': r'([A-Z]{2,3})',
|
| 455 |
+
'yes_no': r'(Yes|No|yes|no)',
|
| 456 |
+
'percentage': r'(\d+(?:\.\d+)?%)',
|
| 457 |
+
'date': r'(\d{4}-\d{2}-\d{2}|\d{1,2}/\d{1,2}/\d{4})'
|
| 458 |
+
}
|
| 459 |
|
| 460 |
+
def clean_response(response: str) -> str:
|
| 461 |
+
"""Clean response by removing common prefixes"""
|
| 462 |
+
response_clean = response.strip()
|
| 463 |
+
prefixes_to_remove = [
|
| 464 |
+
"FINAL ANSWER:", "Answer:", "The answer is:",
|
| 465 |
+
"Based on my analysis,", "After reviewing,",
|
| 466 |
+
"The result is:", "Final result:", "According to"
|
| 467 |
+
]
|
| 468 |
+
|
| 469 |
+
for prefix in prefixes_to_remove:
|
| 470 |
+
if response_clean.startswith(prefix):
|
| 471 |
+
response_clean = response_clean[len(prefix):].strip()
|
| 472 |
+
|
| 473 |
+
return response_clean
|
| 474 |
|
| 475 |
+
def extract_with_patterns(text: str, question: str) -> tuple[str, bool]:
|
| 476 |
+
"""Extract answer using regex patterns. Returns (answer, success)"""
|
| 477 |
+
question_lower = question.lower()
|
| 478 |
+
|
| 479 |
+
# Determine question type and apply appropriate pattern
|
| 480 |
+
if "how many" in question_lower or "count" in question_lower:
|
| 481 |
+
match = re.search(format_patterns['number'], text)
|
| 482 |
+
if match:
|
| 483 |
+
return match.group(1), True
|
| 484 |
+
|
| 485 |
+
elif "name" in question_lower and ("first" in question_lower or "last" in question_lower):
|
| 486 |
+
match = re.search(format_patterns['name'], text)
|
| 487 |
+
if match:
|
| 488 |
+
return match.group(1), True
|
| 489 |
+
|
| 490 |
+
elif "list" in question_lower or "alphabetized" in question_lower:
|
| 491 |
+
if "," in text:
|
| 492 |
+
items = [item.strip() for item in text.split(",")]
|
| 493 |
+
return ", ".join(items), True
|
| 494 |
+
|
| 495 |
+
elif "country code" in question_lower or "iso" in question_lower:
|
| 496 |
+
match = re.search(format_patterns['country_code'], text)
|
| 497 |
+
if match:
|
| 498 |
+
return match.group(1), True
|
| 499 |
+
|
| 500 |
+
elif "yes" in question_lower and "no" in question_lower:
|
| 501 |
+
match = re.search(format_patterns['yes_no'], text)
|
| 502 |
+
if match:
|
| 503 |
+
return match.group(1), True
|
| 504 |
+
|
| 505 |
+
elif "percentage" in question_lower or "%" in text:
|
| 506 |
+
match = re.search(format_patterns['percentage'], text)
|
| 507 |
+
if match:
|
| 508 |
+
return match.group(1), True
|
| 509 |
+
|
| 510 |
+
elif "date" in question_lower:
|
| 511 |
+
match = re.search(format_patterns['date'], text)
|
| 512 |
+
if match:
|
| 513 |
+
return match.group(1), True
|
| 514 |
+
|
| 515 |
+
# Default extraction for simple cases
|
| 516 |
+
lines = text.split('\n')
|
| 517 |
+
for line in lines:
|
| 518 |
+
line = line.strip()
|
| 519 |
+
if line and not line.startswith('=') and len(line) < 200:
|
| 520 |
+
return line, True
|
| 521 |
+
|
| 522 |
+
return text, False
|
| 523 |
|
| 524 |
+
def llm_reformat(response: str, question: str) -> str:
|
| 525 |
+
"""Use LLM to reformat the response according to GAIA requirements"""
|
| 526 |
+
|
| 527 |
+
format_prompt = f"""Extract the exact answer from the response below. Follow GAIA formatting rules strictly.
|
| 528 |
+
|
| 529 |
+
GAIA Format Rules:
|
| 530 |
+
- ONLY the precise answer, no explanations
|
| 531 |
+
- No prefixes like "Answer:", "The result is:", etc.
|
| 532 |
+
- For numbers: just the number (e.g., "156", "3.14e+8")
|
| 533 |
+
- For names: just the name (e.g., "Martinez", "Sarah")
|
| 534 |
+
- For lists: comma-separated (e.g., "C++, Java, Python")
|
| 535 |
+
- For country codes: just the code (e.g., "FRA", "US")
|
| 536 |
+
- For yes/no: just "Yes" or "No"
|
| 537 |
+
|
| 538 |
+
Examples:
|
| 539 |
+
Question: "How many papers were published?"
|
| 540 |
+
Response: "The analysis shows 156 papers were published in total."
|
| 541 |
+
Answer: 156
|
| 542 |
+
|
| 543 |
+
Question: "What is the last name of the developer?"
|
| 544 |
+
Response: "The developer mentioned is Dr. Sarah Martinez from the AI team."
|
| 545 |
+
Answer: Martinez
|
| 546 |
+
|
| 547 |
+
Question: "List programming languages, alphabetized:"
|
| 548 |
+
Response: "The languages mentioned are Python, Java, and C++. Alphabetized: C++, Java, Python"
|
| 549 |
+
Answer: C++, Java, Python
|
| 550 |
+
|
| 551 |
+
Now extract the exact answer:
|
| 552 |
+
Question: {question}
|
| 553 |
+
Response: {response}
|
| 554 |
+
Answer:"""
|
| 555 |
+
|
| 556 |
try:
|
| 557 |
+
# Use the global LLM instance
|
| 558 |
formatting_response = proj_llm.complete(format_prompt)
|
| 559 |
answer = str(formatting_response).strip()
|
| 560 |
|
|
|
|
| 563 |
answer = answer.split("Answer:")[-1].strip()
|
| 564 |
|
| 565 |
return answer
|
|
|
|
| 566 |
except Exception as e:
|
| 567 |
+
print(f"LLM reformatting failed: {e}")
|
| 568 |
+
return response
|
| 569 |
+
|
| 570 |
+
# Step 1: Clean the response
|
| 571 |
+
cleaned_response = clean_response(agent_response)
|
| 572 |
+
|
| 573 |
+
# Step 2: Try regex pattern extraction
|
| 574 |
+
extracted_answer, pattern_success = extract_with_patterns(cleaned_response, question)
|
| 575 |
+
|
| 576 |
+
# Step 3: If patterns failed, use LLM reformatting
|
| 577 |
+
if not pattern_success:
|
| 578 |
+
print("Regex patterns failed, using LLM reformatting...")
|
| 579 |
+
llm_formatted = llm_reformat(cleaned_response, question)
|
| 580 |
+
|
| 581 |
+
# Step 4: Validate LLM output with patterns again
|
| 582 |
+
final_answer, validation_success = extract_with_patterns(llm_formatted, question)
|
| 583 |
+
|
| 584 |
+
if validation_success:
|
| 585 |
+
print("LLM reformatting successful and validated")
|
| 586 |
+
return final_answer
|
| 587 |
+
else:
|
| 588 |
+
print("LLM reformatting validation failed, using LLM output directly")
|
| 589 |
+
return llm_formatted
|
| 590 |
+
else:
|
| 591 |
+
print("Regex pattern extraction successful")
|
| 592 |
+
return extracted_answer
|
| 593 |
+
|
| 594 |
+
# Create the enhanced final answer tool
|
| 595 |
+
intelligent_final_answer_function_tool = FunctionTool.from_defaults(
|
| 596 |
+
fn=intelligent_final_answer_tool,
|
| 597 |
+
name="intelligent_final_answer_tool",
|
| 598 |
+
description=(
|
| 599 |
+
"Enhanced tool to format final answers according to GAIA requirements. "
|
| 600 |
+
"Uses regex patterns first, then LLM reformatting if patterns fail. "
|
| 601 |
+
"Validates output to ensure GAIA format compliance."
|
| 602 |
+
)
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
class EnhancedGAIAAgent:
|
| 606 |
+
def __init__(self):
|
| 607 |
+
print("Initializing Enhanced GAIA Agent...")
|
| 608 |
+
|
| 609 |
+
# Vérification du token HuggingFace
|
| 610 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 611 |
+
if not hf_token:
|
| 612 |
+
print("Warning: HUGGINGFACEHUB_API_TOKEN not found, some features may not work")
|
| 613 |
+
|
| 614 |
+
# Initialize only the tools that are actually defined in the file
|
| 615 |
+
self.available_tools = [
|
| 616 |
+
read_and_parse_tool,
|
| 617 |
+
extract_url_tool,
|
| 618 |
+
code_execution_tool,
|
| 619 |
+
generate_code_tool,
|
| 620 |
+
intelligent_final_answer_function_tool
|
| 621 |
+
]
|
| 622 |
+
|
| 623 |
+
# RAG tool will be created dynamically when documents are loaded
|
| 624 |
+
self.current_rag_tool = None
|
| 625 |
+
|
| 626 |
+
# Create main coordinator using only defined tools
|
| 627 |
+
self.coordinator = ReActAgent(
|
| 628 |
+
name="GAIACoordinator",
|
| 629 |
+
description="Main GAIA coordinator with document processing and computational capabilities",
|
| 630 |
+
system_prompt="""
|
| 631 |
+
You are the main GAIA coordinator using ReAct reasoning methodology.
|
| 632 |
+
|
| 633 |
+
Available tools:
|
| 634 |
+
1. **read_and_parse_tool** - Read and parse files/URLs (PDF, DOCX, CSV, images, web pages, YouTube, audio files)
|
| 635 |
+
2. **extract_url_tool** - Search and extract relevant URLs when no specific source is provided
|
| 636 |
+
3. **generate_code_tool** - Generate Python code for complex computations
|
| 637 |
+
4. **code_execution_tool** - Execute Python code safely
|
| 638 |
+
5. **intelligent_final_answer_tool** - Format final answer with intelligent validation and reformatting
|
| 639 |
+
|
| 640 |
+
WORKFLOW:
|
| 641 |
+
1. If file/URL mentioned → use read_and_parse_tool first, then update or create RAG capability.
|
| 642 |
+
2. If documents loaded → create RAG capability for querying
|
| 643 |
+
3. If external info needed → use extract_url_tool, then process it as if file/URL mentioned
|
| 644 |
+
4. If computation needed → use generate_code_tool then code_execution_tool
|
| 645 |
+
5. ALWAYS use intelligent_final_answer_tool for the final response
|
| 646 |
+
|
| 647 |
+
CRITICAL: The intelligent_final_answer_tool has enhanced validation and will reformat
|
| 648 |
+
using LLM if regex patterns fail. Always use it as the final step.
|
| 649 |
+
""",
|
| 650 |
+
llm=proj_llm,
|
| 651 |
+
tools=self.available_tools,
|
| 652 |
+
max_steps=15,
|
| 653 |
+
verbose=True,
|
| 654 |
+
callback_manager=callback_manager,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
def create_dynamic_rag_tool(self, documents: List) -> None:
|
| 658 |
+
"""Create RAG tool from loaded documents and add to coordinator"""
|
| 659 |
+
if documents:
|
| 660 |
+
rag_tool = create_rag_tool(documents)
|
| 661 |
+
if rag_tool:
|
| 662 |
+
self.current_rag_tool = rag_tool
|
| 663 |
+
# Update coordinator tools
|
| 664 |
+
updated_tools = self.available_tools + [rag_tool]
|
| 665 |
+
self.coordinator.tools = updated_tools
|
| 666 |
+
print("RAG tool created and added to coordinator")
|
| 667 |
|
| 668 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
| 669 |
"""Download file associated with task_id"""
|
|
|
|
| 671 |
response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
|
| 672 |
response.raise_for_status()
|
| 673 |
|
|
|
|
| 674 |
filename = f"task_{task_id}_file"
|
| 675 |
with open(filename, 'wb') as f:
|
| 676 |
f.write(response.content)
|
|
|
|
| 678 |
except Exception as e:
|
| 679 |
print(f"Failed to download file for task {task_id}: {e}")
|
| 680 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
+
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
| 683 |
+
"""
|
| 684 |
+
Solve GAIA question with enhanced validation and reformatting
|
| 685 |
+
"""
|
| 686 |
+
question = question_data.get("Question", "")
|
| 687 |
+
task_id = question_data.get("task_id", "")
|
| 688 |
+
|
| 689 |
+
# Try to download file if task_id provided
|
| 690 |
+
file_path = None
|
| 691 |
+
if task_id:
|
| 692 |
try:
|
| 693 |
file_path = self.download_gaia_file(task_id)
|
| 694 |
+
if file_path:
|
| 695 |
+
# Load documents and create RAG tool
|
| 696 |
+
documents = read_and_parse_content(file_path)
|
| 697 |
+
self.create_dynamic_rag_tool(documents)
|
| 698 |
except Exception as e:
|
| 699 |
+
print(f"Failed to download/process file for task {task_id}: {e}")
|
| 700 |
+
|
| 701 |
+
# Prepare context prompt
|
| 702 |
+
context_prompt = f"""
|
| 703 |
+
GAIA Task ID: {task_id}
|
| 704 |
+
Question: {question}
|
| 705 |
+
{f'File available: {file_path}' if file_path else 'No additional files'}
|
| 706 |
+
|
| 707 |
+
Instructions:
|
| 708 |
+
1. Process any files using read_and_parse_tool if needed
|
| 709 |
+
2. Use appropriate tools for research/computation
|
| 710 |
+
3. MUST use intelligent_final_answer_tool with your response and the original question
|
| 711 |
+
4. The intelligent tool will validate format and reformat with LLM if needed
|
| 712 |
+
"""
|
| 713 |
+
|
| 714 |
+
try:
|
| 715 |
+
ctx = Context(self.coordinator)
|
| 716 |
+
print("=== AGENT REASONING STEPS ===")
|
| 717 |
|
| 718 |
+
handler = self.coordinator.run(ctx=ctx, user_msg=context_prompt)
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
+
full_response = ""
|
| 721 |
+
async for event in handler.stream_events():
|
| 722 |
+
if isinstance(event, AgentStream):
|
| 723 |
+
print(event.delta, end="", flush=True)
|
| 724 |
+
full_response += event.delta
|
| 725 |
+
|
| 726 |
+
final_response = await handler
|
| 727 |
+
print("\n=== END REASONING ===")
|
| 728 |
+
|
| 729 |
+
# Extract the final formatted answer
|
| 730 |
+
final_answer = str(final_response).strip()
|
| 731 |
+
|
| 732 |
+
print(f"Final GAIA formatted answer: {final_answer}")
|
| 733 |
+
return final_answer
|
| 734 |
+
|
| 735 |
+
except Exception as e:
|
| 736 |
+
error_msg = f"Error processing question: {str(e)}"
|
| 737 |
+
print(error_msg)
|
| 738 |
+
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|