Spaces:
Runtime error
Runtime error
Omachoko
commited on
Commit
·
2d0e062
1
Parent(s):
50f18bd
GAIA agent: strict output normalization, reasoning planner, RAG, modular tool chaining, robust error handling
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import requests
|
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
from typing import Any
|
|
|
|
| 7 |
|
| 8 |
# (Keep Constants as is)
|
| 9 |
# --- Constants ---
|
|
@@ -281,13 +282,81 @@ Question:
|
|
| 281 |
Answer:
|
| 282 |
"""
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
# --- Refactored ModularGAIAAgent ---
|
| 285 |
class ModularGAIAAgent:
|
| 286 |
-
|
|
|
|
| 287 |
self.api_url = api_url
|
| 288 |
-
self.tools = tool_registry or TOOL_REGISTRY
|
| 289 |
self.reasoning_trace = []
|
| 290 |
self.file_cache = set(os.listdir('.'))
|
|
|
|
| 291 |
|
| 292 |
def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
|
| 293 |
"""Fetch questions from API or local file."""
|
|
@@ -357,15 +426,15 @@ class ModularGAIAAgent:
|
|
| 357 |
"""Analyze file and return context for the question."""
|
| 358 |
try:
|
| 359 |
if file_type == 'audio':
|
| 360 |
-
transcript = self.tools
|
| 361 |
self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
|
| 362 |
return transcript
|
| 363 |
elif file_type == 'image':
|
| 364 |
-
caption = self.tools
|
| 365 |
self.reasoning_trace.append(f"Image caption: {caption}")
|
| 366 |
return caption
|
| 367 |
elif file_type == 'code':
|
| 368 |
-
result = self.tools
|
| 369 |
self.reasoning_trace.append(f"Code analysis result: {result}")
|
| 370 |
return result
|
| 371 |
elif file_type == 'excel':
|
|
@@ -400,41 +469,7 @@ class ModularGAIAAgent:
|
|
| 400 |
self.reasoning_trace.append(f"Analyze file error: {e}")
|
| 401 |
return None
|
| 402 |
|
| 403 |
-
def smart_tool_select(self, question, file_type=None):
|
| 404 |
-
"""Select the best tool(s) for the question, optionally using GPT-4.1 for planning."""
|
| 405 |
-
api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 406 |
-
try:
|
| 407 |
-
if api_key:
|
| 408 |
-
plan_prompt = f"""
|
| 409 |
-
You are an expert AI agent. Given the following question and file type, suggest the best tool(s) to use from this list: {list(self.tools.keys())}.
|
| 410 |
-
Question: {question}
|
| 411 |
-
File type: {file_type}
|
| 412 |
-
Respond with a comma-separated list of tool names only, in order of use. If unsure, start with web_search_duckduckgo.
|
| 413 |
-
"""
|
| 414 |
-
plan = gpt4_chat(plan_prompt, api_key=api_key)
|
| 415 |
-
tool_names = [t.strip() for t in plan.split(',') if t.strip() in self.tools]
|
| 416 |
-
if tool_names:
|
| 417 |
-
return tool_names
|
| 418 |
-
except Exception as e:
|
| 419 |
-
logger.error(f"smart_tool_select planning error: {e}")
|
| 420 |
-
# Fallback: heuristic
|
| 421 |
-
if file_type == 'audio':
|
| 422 |
-
return ['asr_transcribe']
|
| 423 |
-
elif file_type == 'image':
|
| 424 |
-
return ['image_caption']
|
| 425 |
-
elif file_type == 'code':
|
| 426 |
-
return ['code_analysis']
|
| 427 |
-
elif file_type in ['excel', 'csv']:
|
| 428 |
-
return ['table_qa']
|
| 429 |
-
elif 'youtube.com' in question or 'youtu.be' in question:
|
| 430 |
-
return ['youtube_video_qa']
|
| 431 |
-
elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
|
| 432 |
-
return ['web_search_duckduckgo']
|
| 433 |
-
else:
|
| 434 |
-
return ['llama3_chat']
|
| 435 |
-
|
| 436 |
def answer_question(self, question_obj):
|
| 437 |
-
"""Answer a question using the best tool(s) and context."""
|
| 438 |
self.reasoning_trace = []
|
| 439 |
q = question_obj["question"]
|
| 440 |
file_name = question_obj.get("file_name", "")
|
|
@@ -446,19 +481,23 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
|
|
| 446 |
if local_file:
|
| 447 |
file_type = self.detect_file_type(local_file)
|
| 448 |
file_content = self.analyze_file(local_file, file_type)
|
| 449 |
-
#
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
answer = None
|
| 452 |
-
context = file_content
|
| 453 |
for tool_name in tool_names:
|
| 454 |
-
tool = self.tools
|
| 455 |
try:
|
| 456 |
logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
|
| 457 |
if tool_name == 'web_search_duckduckgo':
|
| 458 |
context = tool(q)
|
| 459 |
answer = llama3_chat(build_prompt(context, q))
|
| 460 |
-
elif tool_name == 'gpt4_chat':
|
| 461 |
-
answer = tool(build_prompt(context, q))
|
| 462 |
elif tool_name == 'table_qa' and file_content:
|
| 463 |
answer = tool(q, file_content)
|
| 464 |
elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
|
|
@@ -466,7 +505,6 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
|
|
| 466 |
elif tool_name == 'youtube_video_qa':
|
| 467 |
answer = tool(q, q)
|
| 468 |
else:
|
| 469 |
-
# Always pass context if available
|
| 470 |
if context:
|
| 471 |
answer = llama3_chat(build_prompt(context, q))
|
| 472 |
else:
|
|
@@ -479,13 +517,7 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
|
|
| 479 |
continue
|
| 480 |
self.reasoning_trace.append(f"Tools used: {tool_names}")
|
| 481 |
self.reasoning_trace.append(f"Final answer: {answer}")
|
| 482 |
-
return
|
| 483 |
-
|
| 484 |
-
def format_answer(self, answer):
|
| 485 |
-
"""Strict GAIA: only the answer, no extra text, no prefix."""
|
| 486 |
-
if isinstance(answer, str):
|
| 487 |
-
return answer.strip().split('\n')[0]
|
| 488 |
-
return str(answer)
|
| 489 |
|
| 490 |
# --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
|
| 491 |
class BasicAgent:
|
|
|
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
from typing import Any
|
| 7 |
+
import re
|
| 8 |
|
| 9 |
# (Keep Constants as is)
|
| 10 |
# --- Constants ---
|
|
|
|
| 282 |
Answer:
|
| 283 |
"""
|
| 284 |
|
| 285 |
+
# --- Centralized Output Formatting & Normalization ---
|
| 286 |
+
def gaia_normalize_answer(answer):
|
| 287 |
+
"""Normalize answer for GAIA: remove units, articles, extra text, and ensure concise, factual output."""
|
| 288 |
+
if not isinstance(answer, str):
|
| 289 |
+
answer = str(answer)
|
| 290 |
+
# Remove common articles and units unless required
|
| 291 |
+
answer = answer.strip()
|
| 292 |
+
answer = re.sub(r"\b(the|a|an)\b", "", answer, flags=re.IGNORECASE)
|
| 293 |
+
answer = re.sub(r"\s+", " ", answer)
|
| 294 |
+
# Remove currency, percent, or units unless specified (GAIA rules)
|
| 295 |
+
answer = re.sub(r"\$|%|USD|dollars|euros|eur|\bpercent\b", "", answer, flags=re.IGNORECASE)
|
| 296 |
+
# Remove leading/trailing punctuation
|
| 297 |
+
answer = answer.strip(' .,:;\n\t')
|
| 298 |
+
return answer
|
| 299 |
+
|
| 300 |
+
# --- Reasoning Planner for Tool Chaining ---
|
| 301 |
+
def reasoning_planner(question, file_type, tools):
|
| 302 |
+
"""Plan the sequence of tools to use for a question. Uses LLM or heuristic."""
|
| 303 |
+
# Heuristic: if file_type is known, use the corresponding tool; else, use web search + LLM
|
| 304 |
+
if file_type == 'audio':
|
| 305 |
+
return ['asr_transcribe', 'llama3_chat']
|
| 306 |
+
elif file_type == 'image':
|
| 307 |
+
return ['image_caption', 'llama3_chat']
|
| 308 |
+
elif file_type == 'code':
|
| 309 |
+
return ['code_analysis', 'llama3_chat']
|
| 310 |
+
elif file_type in ['excel', 'csv']:
|
| 311 |
+
return ['table_qa']
|
| 312 |
+
elif 'youtube.com' in question or 'youtu.be' in question:
|
| 313 |
+
return ['youtube_video_qa']
|
| 314 |
+
elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
|
| 315 |
+
return ['web_search_duckduckgo', 'llama3_chat']
|
| 316 |
+
else:
|
| 317 |
+
return ['llama3_chat']
|
| 318 |
+
|
| 319 |
+
# --- Improved RAG: Context Retrieval & Chunking ---
|
| 320 |
+
def retrieve_context(question, context_files, max_chunks=3):
|
| 321 |
+
"""Retrieve relevant context chunks from large files for RAG."""
|
| 322 |
+
# Simple keyword search for now; can be replaced with semantic search
|
| 323 |
+
relevant_chunks = []
|
| 324 |
+
for file_path in context_files:
|
| 325 |
+
try:
|
| 326 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
| 327 |
+
text = f.read()
|
| 328 |
+
# Split into chunks (e.g., 500 words)
|
| 329 |
+
chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
|
| 330 |
+
for chunk in chunks:
|
| 331 |
+
if any(word.lower() in chunk.lower() for word in question.split()):
|
| 332 |
+
relevant_chunks.append(chunk)
|
| 333 |
+
if len(relevant_chunks) >= max_chunks:
|
| 334 |
+
break
|
| 335 |
+
except Exception as e:
|
| 336 |
+
logger.error(f"retrieve_context error: {e}")
|
| 337 |
+
return '\n'.join(relevant_chunks)
|
| 338 |
+
|
| 339 |
+
# --- Modular Tool Registry & Chaining ---
|
| 340 |
+
class ToolRegistry:
|
| 341 |
+
"""Central registry for tools. Allows easy addition and chaining."""
|
| 342 |
+
def __init__(self, tools):
|
| 343 |
+
self.tools = tools
|
| 344 |
+
def get(self, name):
|
| 345 |
+
return self.tools.get(name)
|
| 346 |
+
def add(self, name, func):
|
| 347 |
+
self.tools[name] = func
|
| 348 |
+
def list(self):
|
| 349 |
+
return list(self.tools.keys())
|
| 350 |
+
|
| 351 |
# --- Refactored ModularGAIAAgent ---
|
| 352 |
class ModularGAIAAgent:
|
| 353 |
+
"""GAIA-compliant agent with robust reasoning, tool chaining, RAG, and output normalization."""
|
| 354 |
+
def __init__(self, api_url=DEFAULT_API_URL, tool_registry=None, context_files=None):
|
| 355 |
self.api_url = api_url
|
| 356 |
+
self.tools = ToolRegistry(tool_registry or TOOL_REGISTRY)
|
| 357 |
self.reasoning_trace = []
|
| 358 |
self.file_cache = set(os.listdir('.'))
|
| 359 |
+
self.context_files = context_files or []
|
| 360 |
|
| 361 |
def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
|
| 362 |
"""Fetch questions from API or local file."""
|
|
|
|
| 426 |
"""Analyze file and return context for the question."""
|
| 427 |
try:
|
| 428 |
if file_type == 'audio':
|
| 429 |
+
transcript = self.tools.get('asr_transcribe')(file_name)
|
| 430 |
self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
|
| 431 |
return transcript
|
| 432 |
elif file_type == 'image':
|
| 433 |
+
caption = self.tools.get('image_caption')(file_name)
|
| 434 |
self.reasoning_trace.append(f"Image caption: {caption}")
|
| 435 |
return caption
|
| 436 |
elif file_type == 'code':
|
| 437 |
+
result = self.tools.get('code_analysis')(file_name)
|
| 438 |
self.reasoning_trace.append(f"Code analysis result: {result}")
|
| 439 |
return result
|
| 440 |
elif file_type == 'excel':
|
|
|
|
| 469 |
self.reasoning_trace.append(f"Analyze file error: {e}")
|
| 470 |
return None
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
def answer_question(self, question_obj):
|
|
|
|
| 473 |
self.reasoning_trace = []
|
| 474 |
q = question_obj["question"]
|
| 475 |
file_name = question_obj.get("file_name", "")
|
|
|
|
| 481 |
if local_file:
|
| 482 |
file_type = self.detect_file_type(local_file)
|
| 483 |
file_content = self.analyze_file(local_file, file_type)
|
| 484 |
+
# RAG: retrieve context if needed
|
| 485 |
+
rag_context = ''
|
| 486 |
+
if not file_content and self.context_files:
|
| 487 |
+
rag_context = retrieve_context(q, self.context_files)
|
| 488 |
+
if rag_context:
|
| 489 |
+
self.reasoning_trace.append(f"RAG context used: {rag_context[:200]}...")
|
| 490 |
+
# Reasoning planner: decide tool chain
|
| 491 |
+
tool_names = reasoning_planner(q, file_type, self.tools.list())
|
| 492 |
answer = None
|
| 493 |
+
context = file_content or rag_context
|
| 494 |
for tool_name in tool_names:
|
| 495 |
+
tool = self.tools.get(tool_name)
|
| 496 |
try:
|
| 497 |
logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
|
| 498 |
if tool_name == 'web_search_duckduckgo':
|
| 499 |
context = tool(q)
|
| 500 |
answer = llama3_chat(build_prompt(context, q))
|
|
|
|
|
|
|
| 501 |
elif tool_name == 'table_qa' and file_content:
|
| 502 |
answer = tool(q, file_content)
|
| 503 |
elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
|
|
|
|
| 505 |
elif tool_name == 'youtube_video_qa':
|
| 506 |
answer = tool(q, q)
|
| 507 |
else:
|
|
|
|
| 508 |
if context:
|
| 509 |
answer = llama3_chat(build_prompt(context, q))
|
| 510 |
else:
|
|
|
|
| 517 |
continue
|
| 518 |
self.reasoning_trace.append(f"Tools used: {tool_names}")
|
| 519 |
self.reasoning_trace.append(f"Final answer: {answer}")
|
| 520 |
+
return gaia_normalize_answer(answer), self.reasoning_trace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
# --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
|
| 523 |
class BasicAgent:
|