Spaces:
Configuration error
Configuration error
feat(tools): add more tool to extend the functionaily of jarvis
Browse files- .gitignore +1 -0
- README.md +4 -0
- app.py +326 -344
- project_structure.txt +22 -0
- requirements.txt +5 -2
- result.txt +0 -0
- retriever.py +165 -15
- state.py +83 -6
- test.py +231 -8
- tools/__init__.py +3 -1
- tools/answer_generator.py +129 -0
- tools/calculator.py +28 -8
- tools/document_retriever.py +39 -22
- tools/duckduckgo_search.py +95 -2
- tools/file_fetcher.py +42 -0
- tools/file_parser.py +93 -17
- tools/guest_info.py +40 -13
- tools/hub_stats.py +43 -6
- tools/image_parser.py +34 -16
- tools/search.py +82 -85
- tools/weather_info.py +33 -6
.gitignore
CHANGED
|
@@ -41,6 +41,7 @@ coverage.xml
|
|
| 41 |
*.py,cover
|
| 42 |
.tox/
|
| 43 |
.pytest_cache/
|
|
|
|
| 44 |
|
| 45 |
# Logs and temporary files
|
| 46 |
*.log
|
|
|
|
| 41 |
*.py,cover
|
| 42 |
.tox/
|
| 43 |
.pytest_cache/
|
| 44 |
+
cache/
|
| 45 |
|
| 46 |
# Logs and temporary files
|
| 47 |
*.log
|
README.md
CHANGED
|
@@ -74,6 +74,10 @@ jarvis_gaia_agent/
|
|
| 74 |
- `SERPAPI_API_KEY`: SERPAPI key for web searches.
|
| 75 |
- `OPENWEATHERMAP_API_KEY`: OpenWeatherMap key for weather queries.
|
| 76 |
- `SPACE_ID`: `onisj/jarvis_gaia_agent`.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
## Setup and Local Testing
|
| 79 |
|
|
|
|
| 74 |
- `SERPAPI_API_KEY`: SERPAPI key for web searches.
|
| 75 |
- `OPENWEATHERMAP_API_KEY`: OpenWeatherMap key for weather queries.
|
| 76 |
- `SPACE_ID`: `onisj/jarvis_gaia_agent`.
|
| 77 |
+
- Install dependencies:
|
| 78 |
+
```bash
|
| 79 |
+
pip install -r requirements.txt
|
| 80 |
+
```
|
| 81 |
|
| 82 |
## Setup and Local Testing
|
| 83 |
|
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
| 3 |
import logging
|
| 4 |
import asyncio
|
| 5 |
import aiohttp
|
|
|
|
| 6 |
import nest_asyncio
|
| 7 |
import requests
|
| 8 |
import pandas as pd
|
|
@@ -10,18 +11,25 @@ from typing import Dict, Any, List
|
|
| 10 |
from langchain_core.prompts import ChatPromptTemplate
|
| 11 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 12 |
from langgraph.graph import StateGraph, END
|
|
|
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
import gradio as gr
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from huggingface_hub import InferenceClient
|
| 17 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 18 |
import together
|
| 19 |
-
from state import JARVISState
|
| 20 |
-
from tools import
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Setup logging
|
| 27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -33,10 +41,10 @@ nest_asyncio.apply()
|
|
| 33 |
# Load environment variables
|
| 34 |
load_dotenv()
|
| 35 |
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
|
| 36 |
-
GAIA_API_URL = "https://agents-course-unit4-
|
| 37 |
-
GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
|
| 38 |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
|
| 39 |
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
|
|
|
| 40 |
|
| 41 |
# Verify environment variables
|
| 42 |
if not SPACE_ID:
|
|
@@ -45,6 +53,8 @@ if not HF_API_TOKEN:
|
|
| 45 |
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
|
| 46 |
if not TOGETHER_API_KEY:
|
| 47 |
raise ValueError("TOGETHER_API_KEY not set")
|
|
|
|
|
|
|
| 48 |
logger.info(f"SPACE_ID: {SPACE_ID}")
|
| 49 |
|
| 50 |
# Model configuration
|
|
@@ -56,23 +66,20 @@ HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
|
| 56 |
|
| 57 |
# Initialize LLM clients
|
| 58 |
def initialize_llm():
|
| 59 |
-
# Try Together AI models
|
| 60 |
for model in TOGETHER_MODELS:
|
| 61 |
try:
|
| 62 |
together.api_key = TOGETHER_API_KEY
|
| 63 |
client = together.Together()
|
| 64 |
-
# Test the model
|
| 65 |
response = client.chat.completions.create(
|
| 66 |
model=model,
|
| 67 |
messages=[{"role": "user", "content": "Test"}],
|
| 68 |
max_tokens=10
|
| 69 |
)
|
| 70 |
logger.info(f"Initialized Together AI model: {model}")
|
| 71 |
-
return client, "together"
|
| 72 |
except Exception as e:
|
| 73 |
logger.warning(f"Failed to initialize Together AI model {model}: {e}")
|
| 74 |
|
| 75 |
-
# Fallback to Hugging Face Inference API
|
| 76 |
try:
|
| 77 |
client = InferenceClient(
|
| 78 |
model=HF_MODEL,
|
|
@@ -80,381 +87,355 @@ def initialize_llm():
|
|
| 80 |
timeout=30
|
| 81 |
)
|
| 82 |
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
|
| 83 |
-
return client, "hf_api"
|
| 84 |
except Exception as e:
|
| 85 |
logger.warning(f"Failed to initialize HF Inference API: {e}")
|
| 86 |
|
| 87 |
-
# Fallback to local Hugging Face model
|
| 88 |
try:
|
| 89 |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
|
| 90 |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
|
| 91 |
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
|
| 92 |
-
return (model, tokenizer), "hf_local"
|
| 93 |
except Exception as e:
|
| 94 |
logger.error(f"Failed to initialize local HF model: {e}")
|
| 95 |
raise Exception("No LLM could be initialized")
|
| 96 |
|
| 97 |
-
llm_client, llm_type = initialize_llm()
|
| 98 |
|
| 99 |
# Initialize embedder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
-
embedder =
|
| 102 |
-
logger.info("Sentence transformer initialized")
|
| 103 |
except Exception as e:
|
| 104 |
logger.error(f"Failed to initialize embedder: {e}")
|
| 105 |
embedder = None
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
# Parse question to select tools
|
| 129 |
async def parse_question(state: JARVISState) -> JARVISState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
try:
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
tools_needed = ["search_tool"]
|
| 134 |
-
|
|
|
|
| 135 |
if llm_client:
|
| 136 |
prompt = ChatPromptTemplate.from_messages([
|
| 137 |
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
Rules:
|
| 140 |
-
-
|
| 141 |
-
-
|
| 142 |
-
-
|
| 143 |
-
-
|
| 144 |
-
-
|
| 145 |
-
-
|
| 146 |
-
-
|
| 147 |
-
-
|
| 148 |
-
-
|
| 149 |
-
-
|
|
|
|
| 150 |
- Output ONLY valid JSON."""),
|
| 151 |
HumanMessage(content=f"Query: {question}")
|
| 152 |
])
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
if any(
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
if
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
tools_needed.append("file_parser_tool")
|
| 223 |
-
|
| 224 |
tools_needed.append("image_parser_tool")
|
| 225 |
-
|
| 226 |
tools_needed.append("document_retriever_tool")
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
state["tools_needed"] = list(set(tools_needed))
|
| 231 |
-
logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
|
| 232 |
return state
|
| 233 |
except Exception as e:
|
| 234 |
-
logger.error(f"
|
| 235 |
state["error"] = f"Parse question failed: {str(e)}"
|
| 236 |
state["tools_needed"] = ["search_tool"]
|
| 237 |
return state
|
| 238 |
|
| 239 |
# Tool dispatcher
|
| 240 |
async def tool_dispatcher(state: JARVISState) -> JARVISState:
|
|
|
|
| 241 |
try:
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
file_type = "xlsx"
|
| 248 |
-
|
| 249 |
-
for tool in updated_state["tools_needed"]:
|
| 250 |
-
try:
|
| 251 |
-
if tool == "search_tool":
|
| 252 |
-
result = search_tool(updated_state["question"])
|
| 253 |
-
updated_state["web_results"].extend([str(r) for r in result])
|
| 254 |
-
elif tool == "multi_hop_search_tool":
|
| 255 |
-
result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3, "llm_client": llm_client, "llm_type": llm_type})
|
| 256 |
-
updated_state["multi_hop_results"].extend([r["content"] for r in result])
|
| 257 |
-
await asyncio.sleep(2)
|
| 258 |
-
elif tool == "file_parser_tool":
|
| 259 |
-
for ext in ["txt", "csv", "xlsx"]:
|
| 260 |
-
file_path = await download_file(updated_state["task_id"], ext)
|
| 261 |
-
if file_path:
|
| 262 |
-
result = file_parser_tool(file_path)
|
| 263 |
-
updated_state["file_results"] = str(result)
|
| 264 |
-
break
|
| 265 |
-
elif tool == "image_parser_tool":
|
| 266 |
-
file_path = await download_file(updated_state["task_id"], "jpg")
|
| 267 |
-
if file_path:
|
| 268 |
-
result = image_parser_tool(file_path)
|
| 269 |
-
updated_state["image_results"] = str(result)
|
| 270 |
-
elif tool == "calculator_tool":
|
| 271 |
-
result = calculator_tool(updated_state["question"])
|
| 272 |
-
updated_state["calculation_results"] = str(result)
|
| 273 |
-
elif tool == "document_retriever_tool":
|
| 274 |
-
file_path = await download_file(updated_state["task_id"], "pdf")
|
| 275 |
-
if file_path:
|
| 276 |
-
result = document_retriever_tool({"task_id": updated_state["task_id"], "query": updated_state["question"], "file_type": "pdf"})
|
| 277 |
-
updated_state["document_results"] = str(result)
|
| 278 |
-
elif tool == "duckduckgo_search_tool":
|
| 279 |
-
result = duckduckgo_search_tool(updated_state["question"])
|
| 280 |
-
updated_state["web_results"].append(str(result))
|
| 281 |
-
elif tool == "weather_info_tool":
|
| 282 |
-
location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
|
| 283 |
-
result = weather_info_tool({"location": location})
|
| 284 |
-
updated_state["web_results"].append(str(result))
|
| 285 |
-
elif tool == "hub_stats_tool":
|
| 286 |
-
author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
|
| 287 |
-
result = hub_stats_tool({"author": author})
|
| 288 |
-
updated_state["web_results"].append(str(result))
|
| 289 |
-
elif tool == "guest_info_retriever_tool":
|
| 290 |
-
query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
|
| 291 |
-
result = guest_info_retriever_tool({"query": query})
|
| 292 |
-
updated_state["web_results"].append(str(result))
|
| 293 |
-
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_executed": True}
|
| 294 |
-
except Exception as e:
|
| 295 |
-
logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
|
| 296 |
-
updated_state["error"] = f"Tool {tool} failed: {str(e)}"
|
| 297 |
-
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_error": str(e)}
|
| 298 |
-
|
| 299 |
-
logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
|
| 300 |
-
return updated_state
|
| 301 |
-
except Exception as e:
|
| 302 |
-
logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
|
| 303 |
-
updated_state["error"] = f"Tool dispatch failed: {str(e)}"
|
| 304 |
-
return updated_state
|
| 305 |
-
|
| 306 |
-
# Reasoning
|
| 307 |
-
async def reasoning(state: JARVISState) -> Dict[str, Any]:
|
| 308 |
-
try:
|
| 309 |
-
prompt = ChatPromptTemplate.from_messages([
|
| 310 |
-
SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
|
| 311 |
-
HumanMessage(content="""Task: {task_id}
|
| 312 |
-
Question: {question}
|
| 313 |
-
Web results: {web_results}
|
| 314 |
-
Multi-hop results: {multi_hop_results}
|
| 315 |
-
File results: {file_results}
|
| 316 |
-
Image results: {image_results}
|
| 317 |
-
Calculation results: {calculation_results}
|
| 318 |
-
Document results: {document_results}""")
|
| 319 |
-
])
|
| 320 |
-
messages = [
|
| 321 |
-
{"role": "system", "content": prompt[0].content},
|
| 322 |
-
{"role": "user", "content": prompt[1].content.format(
|
| 323 |
-
task_id=state["task_id"],
|
| 324 |
-
question=state["question"],
|
| 325 |
-
web_results="\n".join(state["web_results"]),
|
| 326 |
-
multi_hop_results="\n".join(state["multi_hop_results"]),
|
| 327 |
-
file_results=state["file_results"],
|
| 328 |
-
image_results=state["image_results"],
|
| 329 |
-
calculation_results=state["calculation_results"],
|
| 330 |
-
document_results=state["document_results"]
|
| 331 |
-
)}
|
| 332 |
-
]
|
| 333 |
-
for attempt in range(3):
|
| 334 |
try:
|
| 335 |
-
if
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
except Exception as e:
|
| 371 |
-
logger.warning(f"
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
except Exception as e:
|
| 376 |
-
logger.error(f"
|
| 377 |
-
state["error"] = f"
|
| 378 |
-
return
|
| 379 |
-
|
| 380 |
-
# Router
|
| 381 |
-
def router(state: JARVISState) -> str:
|
| 382 |
-
if state["tools_needed"]:
|
| 383 |
-
return "tool_dispatcher"
|
| 384 |
-
return "reasoning"
|
| 385 |
|
| 386 |
# Define StateGraph
|
| 387 |
workflow = StateGraph(JARVISState)
|
| 388 |
-
workflow.add_node("
|
| 389 |
workflow.add_node("tool_dispatcher", tool_dispatcher)
|
| 390 |
-
workflow.
|
| 391 |
-
workflow.
|
| 392 |
-
workflow.
|
| 393 |
-
"parse",
|
| 394 |
-
router,
|
| 395 |
-
{
|
| 396 |
-
"tool_dispatcher": "tool_dispatcher",
|
| 397 |
-
"reasoning": "reasoning"
|
| 398 |
-
}
|
| 399 |
-
)
|
| 400 |
-
workflow.add_edge("tool_dispatcher", "reasoning")
|
| 401 |
-
workflow.add_edge("reasoning", END)
|
| 402 |
graph = workflow.compile()
|
| 403 |
|
| 404 |
# Agent class
|
| 405 |
class JARVISAgent:
|
| 406 |
def __init__(self):
|
| 407 |
-
self.state =
|
| 408 |
-
|
| 409 |
-
question="",
|
| 410 |
-
tools_needed=[],
|
| 411 |
-
web_results=[],
|
| 412 |
-
file_results="",
|
| 413 |
-
image_results="",
|
| 414 |
-
calculation_results="",
|
| 415 |
-
document_results="",
|
| 416 |
-
multi_hop_results=[],
|
| 417 |
-
messages=[],
|
| 418 |
-
answer="",
|
| 419 |
-
results_table=[],
|
| 420 |
-
status_output="",
|
| 421 |
-
error=None,
|
| 422 |
-
metadata={}
|
| 423 |
-
)
|
| 424 |
logger.info("JARVISAgent initialized.")
|
| 425 |
|
| 426 |
async def process_question(self, task_id: str, question: str) -> str:
|
| 427 |
-
state =
|
| 428 |
-
task_id=task_id,
|
| 429 |
-
question=question,
|
| 430 |
-
tools_needed=["search_tool"],
|
| 431 |
-
web_results=[],
|
| 432 |
-
file_results="",
|
| 433 |
-
image_results="",
|
| 434 |
-
calculation_results="",
|
| 435 |
-
document_results="",
|
| 436 |
-
multi_hop_results=[],
|
| 437 |
-
messages=[HumanMessage(content=question)],
|
| 438 |
-
answer="",
|
| 439 |
-
results_table=[],
|
| 440 |
-
status_output="",
|
| 441 |
-
error=None,
|
| 442 |
-
metadata={}
|
| 443 |
-
)
|
| 444 |
try:
|
| 445 |
result = await graph.ainvoke(state)
|
| 446 |
-
answer = result
|
| 447 |
-
logger.info(f"Task {task_id}
|
| 448 |
-
self.state
|
| 449 |
-
self.state
|
| 450 |
return answer
|
| 451 |
except Exception as e:
|
| 452 |
logger.error(f"Error processing task {task_id}: {e}")
|
| 453 |
-
self.state
|
| 454 |
-
self.state
|
| 455 |
return f"Error: {str(e)}"
|
| 456 |
finally:
|
| 457 |
-
for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
|
| 458 |
file_path = f"temp/{task_id}.{ext}"
|
| 459 |
if os.path.exists(file_path):
|
| 460 |
try:
|
|
@@ -466,25 +447,26 @@ class JARVISAgent:
|
|
| 466 |
async def process_all_questions(self, profile: gr.OAuthProfile | None):
|
| 467 |
if not profile:
|
| 468 |
logger.error("User not logged in.")
|
| 469 |
-
self.state
|
| 470 |
-
return pd.DataFrame(self.state
|
| 471 |
|
| 472 |
-
username =
|
| 473 |
logger.info(f"User logged in: {username}")
|
| 474 |
questions_url = f"{GAIA_API_URL}/questions"
|
| 475 |
submit_url = f"{GAIA_API_URL}/submit"
|
| 476 |
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
|
| 477 |
|
| 478 |
try:
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
| 482 |
logger.info(f"Fetched {len(questions)} questions.")
|
| 483 |
except Exception as e:
|
| 484 |
logger.error(f"Error fetching questions: {e}")
|
| 485 |
-
self.state
|
| 486 |
-
self.state
|
| 487 |
-
return pd.DataFrame(self.state
|
| 488 |
|
| 489 |
answers_payload = []
|
| 490 |
for item in questions:
|
|
@@ -498,33 +480,34 @@ class JARVISAgent:
|
|
| 498 |
|
| 499 |
if not answers_payload:
|
| 500 |
logger.error("No answers generated.")
|
| 501 |
-
self.state
|
| 502 |
-
self.state
|
| 503 |
-
return pd.DataFrame(self.state
|
| 504 |
|
| 505 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 506 |
try:
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
|
|
|
| 511 |
f"Submission Successful!\n"
|
| 512 |
f"User: {result_data.get('username')}\n"
|
| 513 |
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 514 |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 515 |
f"Message: {result_data.get('message', 'No message received.')}"
|
| 516 |
)
|
| 517 |
-
self.state
|
| 518 |
except Exception as e:
|
| 519 |
logger.error(f"Submission failed: {e}")
|
| 520 |
-
self.state
|
| 521 |
-
self.state
|
| 522 |
|
| 523 |
-
return pd.DataFrame(self.state.results_table), self.state
|
| 524 |
|
| 525 |
# Gradio interface
|
| 526 |
with gr.Blocks() as demo:
|
| 527 |
-
gr.Markdown("#
|
| 528 |
gr.Markdown(
|
| 529 |
"""
|
| 530 |
**Instructions:**
|
|
@@ -539,7 +522,6 @@ with gr.Blocks() as demo:
|
|
| 539 |
)
|
| 540 |
with gr.Row():
|
| 541 |
gr.LoginButton(value="Login to Hugging Face")
|
| 542 |
-
# Removed gr.LogoutButton due to deprecation
|
| 543 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 544 |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
| 545 |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
|
|
|
|
| 3 |
import logging
|
| 4 |
import asyncio
|
| 5 |
import aiohttp
|
| 6 |
+
import ssl
|
| 7 |
import nest_asyncio
|
| 8 |
import requests
|
| 9 |
import pandas as pd
|
|
|
|
| 11 |
from langchain_core.prompts import ChatPromptTemplate
|
| 12 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 13 |
from langgraph.graph import StateGraph, END
|
| 14 |
+
import torch
|
| 15 |
from sentence_transformers import SentenceTransformer
|
| 16 |
import gradio as gr
|
| 17 |
from dotenv import load_dotenv
|
| 18 |
from huggingface_hub import InferenceClient
|
| 19 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 20 |
import together
|
| 21 |
+
from state import JARVISState, validate_state, reset_state
|
| 22 |
+
from tools.answer_generator import generate_answer, preprocess_question
|
| 23 |
+
from tools.file_fetcher import fetch_task_file
|
| 24 |
+
from tools.search import search_tool, multi_hop_search_tool
|
| 25 |
+
from tools.file_parser import file_parser_tool
|
| 26 |
+
from tools.image_parser import image_parser_tool
|
| 27 |
+
from tools.calculator import calculator_tool
|
| 28 |
+
from tools.document_retriever import document_retriever_tool
|
| 29 |
+
from tools.duckduckgo_search import duckduckgo_search_tool
|
| 30 |
+
from tools.weather_info import weather_info_tool
|
| 31 |
+
from tools.hub_stats import hub_stats_tool
|
| 32 |
+
from tools.guest_info import guest_info_retriever_tool
|
| 33 |
|
| 34 |
# Setup logging
|
| 35 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 41 |
# Load environment variables
|
| 42 |
load_dotenv()
|
| 43 |
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
|
| 44 |
+
GAIA_API_URL = "https://agents-course-unit4-api-1.hf.space/api"
|
|
|
|
| 45 |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
|
| 46 |
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 47 |
+
OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY")
|
| 48 |
|
| 49 |
# Verify environment variables
|
| 50 |
if not SPACE_ID:
|
|
|
|
| 53 |
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
|
| 54 |
if not TOGETHER_API_KEY:
|
| 55 |
raise ValueError("TOGETHER_API_KEY not set")
|
| 56 |
+
if not OPENWEATHERMAP_API_KEY:
|
| 57 |
+
logger.warning("OPENWEATHERMAP_API_KEY not set; weather_info_tool may fail")
|
| 58 |
logger.info(f"SPACE_ID: {SPACE_ID}")
|
| 59 |
|
| 60 |
# Model configuration
|
|
|
|
| 66 |
|
| 67 |
# Initialize LLM clients
|
| 68 |
def initialize_llm():
|
|
|
|
| 69 |
for model in TOGETHER_MODELS:
|
| 70 |
try:
|
| 71 |
together.api_key = TOGETHER_API_KEY
|
| 72 |
client = together.Together()
|
|
|
|
| 73 |
response = client.chat.completions.create(
|
| 74 |
model=model,
|
| 75 |
messages=[{"role": "user", "content": "Test"}],
|
| 76 |
max_tokens=10
|
| 77 |
)
|
| 78 |
logger.info(f"Initialized Together AI model: {model}")
|
| 79 |
+
return client, "together", model
|
| 80 |
except Exception as e:
|
| 81 |
logger.warning(f"Failed to initialize Together AI model {model}: {e}")
|
| 82 |
|
|
|
|
| 83 |
try:
|
| 84 |
client = InferenceClient(
|
| 85 |
model=HF_MODEL,
|
|
|
|
| 87 |
timeout=30
|
| 88 |
)
|
| 89 |
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
|
| 90 |
+
return client, "hf_api", HF_MODEL
|
| 91 |
except Exception as e:
|
| 92 |
logger.warning(f"Failed to initialize HF Inference API: {e}")
|
| 93 |
|
|
|
|
| 94 |
try:
|
| 95 |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
|
| 96 |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
|
| 97 |
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
|
| 98 |
+
return (model, tokenizer), "hf_local", HF_MODEL
|
| 99 |
except Exception as e:
|
| 100 |
logger.error(f"Failed to initialize local HF model: {e}")
|
| 101 |
raise Exception("No LLM could be initialized")
|
| 102 |
|
| 103 |
+
llm_client, llm_type, llm_model = initialize_llm()
|
| 104 |
|
| 105 |
# Initialize embedder
|
| 106 |
+
_embedder = None
|
| 107 |
+
|
| 108 |
+
def get_embedder():
|
| 109 |
+
global _embedder
|
| 110 |
+
if _embedder is None:
|
| 111 |
+
try:
|
| 112 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 113 |
+
_embedder = SentenceTransformer(
|
| 114 |
+
"all-MiniLM-L6-v2",
|
| 115 |
+
device=device,
|
| 116 |
+
cache_folder="./cache"
|
| 117 |
+
)
|
| 118 |
+
logger.info(f"SentenceTransformer initialized on {device.upper()}")
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Failed to initialize SentenceTransformer: {e}")
|
| 121 |
+
raise RuntimeError(f"Embedder initialization failed: {e}")
|
| 122 |
+
return _embedder
|
| 123 |
+
|
| 124 |
try:
|
| 125 |
+
embedder = get_embedder()
|
|
|
|
| 126 |
except Exception as e:
|
| 127 |
logger.error(f"Failed to initialize embedder: {e}")
|
| 128 |
embedder = None
|
| 129 |
|
| 130 |
+
# Log device
|
| 131 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 132 |
+
logger.info(f"Using device: {device}")
|
| 133 |
+
|
| 134 |
+
# HTTP session with SSL handling
|
| 135 |
+
async def create_http_session():
|
| 136 |
+
ssl_context = ssl.create_default_context()
|
| 137 |
+
ssl_context.check_hostname = False
|
| 138 |
+
ssl_context.verify_mode = ssl.CERT_NONE
|
| 139 |
+
return aiohttp.ClientSession(
|
| 140 |
+
connector=aiohttp.TCPConnector(ssl=ssl_context),
|
| 141 |
+
timeout=aiohttp.ClientTimeout(total=30)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Tool registration
|
| 145 |
+
tools = {
|
| 146 |
+
"search_tool": search_tool,
|
| 147 |
+
"multi_hop_search_tool": multi_hop_search_tool,
|
| 148 |
+
"file_parser_tool": file_parser_tool,
|
| 149 |
+
"image_parser_tool": image_parser_tool,
|
| 150 |
+
"calculator_tool": calculator_tool,
|
| 151 |
+
"document_retriever_tool": document_retriever_tool,
|
| 152 |
+
"duckduckgo_search_tool": duckduckgo_search_tool,
|
| 153 |
+
"weather_info_tool": weather_info_tool,
|
| 154 |
+
"hub_stats_tool": hub_stats_tool,
|
| 155 |
+
"guest_info_retriever_tool": guest_info_retriever_tool,
|
| 156 |
+
}
|
| 157 |
|
| 158 |
# Parse question to select tools
|
| 159 |
async def parse_question(state: JARVISState) -> JARVISState:
|
| 160 |
+
"""
|
| 161 |
+
Parse the question to select appropriate tools using LLM with retries, preprocess the question, and integrate file-based tools.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
state (JARVISState): The input state containing task_id, question.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
JARVISState: Updated state with selected tools_needed and metadata.
|
| 168 |
+
"""
|
| 169 |
+
state = validate_state(state)
|
| 170 |
+
task_id = state["task_id"]
|
| 171 |
+
question = state["question"]
|
| 172 |
+
|
| 173 |
+
logger.info(f"Task {task_id} Parsing question: {question}")
|
| 174 |
try:
|
| 175 |
+
# Preprocess question
|
| 176 |
+
processed_question = await preprocess_question(question)
|
| 177 |
+
if processed_question != question:
|
| 178 |
+
logger.info(f"Task {task_id} Preprocessed question: {processed_question}")
|
| 179 |
+
state["question"] = processed_question
|
| 180 |
+
question = processed_question
|
| 181 |
+
|
| 182 |
+
# Default to search_tool
|
| 183 |
tools_needed = ["search_tool"]
|
| 184 |
+
|
| 185 |
+
# LLM-based tool selection
|
| 186 |
if llm_client:
|
| 187 |
prompt = ChatPromptTemplate.from_messages([
|
| 188 |
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
|
| 189 |
+
|
| 190 |
+
Return a JSON list of all relevant tools, e.g., ["search_tool", "duckduckgo_search_tool"].
|
| 191 |
+
|
| 192 |
Rules:
|
| 193 |
+
- Include "search_tool" for web-based questions unless purely computational or file-based.
|
| 194 |
+
- Include "multi_hop_search_tool" for questions with >20 words or requiring multiple steps.
|
| 195 |
+
- Include "file_parser_tool" for 'data', 'table', 'excel', 'csv', 'txt', 'mp3', or file extensions.
|
| 196 |
+
- Include "image_parser_tool" for 'image', 'video', 'picture', or 'painting'.
|
| 197 |
+
- Include "calculator_tool" for 'calculate', 'math', 'sum', 'average', 'total', or numerical operations.
|
| 198 |
+
- Include "document_retriever_tool" for 'document', 'pdf', 'report', or 'paper'.
|
| 199 |
+
- Include "duckduckgo_search_tool" for 'search', 'wikipedia', 'online', or general knowledge.
|
| 200 |
+
- Include "weather_info_tool" for 'weather', 'temperature', or 'forecast'.
|
| 201 |
+
- Include "hub_stats_tool" for 'model', 'huggingface', or 'dataset'.
|
| 202 |
+
- Include "guest_info_retriever_tool" for 'guest', 'name', 'relation', or 'person'.
|
| 203 |
+
- Select multiple tools if the question spans multiple domains (e.g., web and file).
|
| 204 |
- Output ONLY valid JSON."""),
|
| 205 |
HumanMessage(content=f"Query: {question}")
|
| 206 |
])
|
| 207 |
+
messages = prompt.format_messages()
|
| 208 |
+
|
| 209 |
+
for attempt in range(3): # Retry up to 3 times
|
| 210 |
+
try:
|
| 211 |
+
formatted_messages = [
|
| 212 |
+
{"role": "system" if isinstance(m, SystemMessage) else "user", "content": m.content}
|
| 213 |
+
for m in messages
|
| 214 |
+
]
|
| 215 |
+
if llm_type == "hf_local":
|
| 216 |
+
model, tokenizer = llm_client
|
| 217 |
+
inputs = tokenizer.apply_chat_template(
|
| 218 |
+
formatted_messages,
|
| 219 |
+
return_tensors="pt"
|
| 220 |
+
).to(model.device)
|
| 221 |
+
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.5)
|
| 222 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 223 |
+
elif llm_type == "together":
|
| 224 |
+
response = llm_client.chat.completions.create(
|
| 225 |
+
model=llm_model,
|
| 226 |
+
messages=formatted_messages,
|
| 227 |
+
max_tokens=100,
|
| 228 |
+
temperature=0.5
|
| 229 |
+
)
|
| 230 |
+
response = response.choices[0].message.content.strip()
|
| 231 |
+
else: # hf_api
|
| 232 |
+
response = llm_client.chat.completions.create(
|
| 233 |
+
messages=formatted_messages,
|
| 234 |
+
max_tokens=100,
|
| 235 |
+
temperature=0.5
|
| 236 |
+
)
|
| 237 |
+
response = response.choices[0].message.content.strip()
|
| 238 |
+
|
| 239 |
+
logger.info(f"Task {task_id} LLM tool selection response: {response}")
|
| 240 |
+
try:
|
| 241 |
+
tools_needed = json.loads(response)
|
| 242 |
+
if isinstance(tools_needed, list) and all(isinstance(t, str) and t in tools for t in tools_needed):
|
| 243 |
+
break # Valid response, exit retry loop
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError("Invalid tool list format")
|
| 246 |
+
except json.JSONDecodeError as e:
|
| 247 |
+
logger.warning(f"Task {task_id}: Invalid JSON (attempt {attempt + 1}): {e}")
|
| 248 |
+
if attempt == 2:
|
| 249 |
+
tools_needed = ["search_tool"] # Fallback after retries
|
| 250 |
+
except Exception as e:
|
| 251 |
+
logger.warning(f"Task {task_id} Tool selection failed (attempt {attempt + 1}): {e}")
|
| 252 |
+
if attempt == 2:
|
| 253 |
+
tools_needed = ["search_tool"] # Fallback after retries
|
| 254 |
+
|
| 255 |
+
# Fallback to keyword-based selection if LLM fails
|
| 256 |
+
if tools_needed == ["search_tool"] and not any(kw in question.lower() for kw in ["calculate", "math", "image", "document", "file", "weather", "guest", "model"]):
|
| 257 |
+
question_lower = question.lower()
|
| 258 |
+
if any(kw in question_lower for kw in ["excel", "csv", "mp3", "data", "table", "xlsx"]):
|
| 259 |
+
tools_needed.append("file_parser_tool")
|
| 260 |
+
if any(kw in question_lower for kw in ["image", "video", "picture", "painting"]):
|
| 261 |
+
tools_needed.append("image_parser_tool")
|
| 262 |
+
if any(kw in question_lower for kw in ["calculate", "math", "sum", "average", "total"]):
|
| 263 |
+
tools_needed.append("calculator_tool")
|
| 264 |
+
if any(kw in question_lower for kw in ["document", "pdf", "report", "paper"]):
|
| 265 |
+
tools_needed.append("document_retriever_tool")
|
| 266 |
+
if any(kw in question_lower for kw in ["search", "wikipedia", "online"]):
|
| 267 |
+
tools_needed.append("duckduckgo_search_tool")
|
| 268 |
+
if any(kw in question_lower for kw in ["weather", "temperature", "forecast"]):
|
| 269 |
+
tools_needed.append("weather_info_tool")
|
| 270 |
+
if any(kw in question_lower for kw in ["model", "huggingface", "dataset"]):
|
| 271 |
+
tools_needed.append("hub_stats_tool")
|
| 272 |
+
if any(kw in question_lower for kw in ["guest", "name", "relation", "person"]):
|
| 273 |
+
tools_needed.append("guest_info_retriever_tool")
|
| 274 |
+
if len(question.split()) > 20 or "multiple" in question_lower:
|
| 275 |
+
tools_needed.append("multi_hop_search_tool")
|
| 276 |
+
|
| 277 |
+
# Integrate file-based tools
|
| 278 |
+
file_results = await fetch_task_file(task_id, question)
|
| 279 |
+
for ext, content in file_results.items():
|
| 280 |
+
if content:
|
| 281 |
+
os.makedirs("temp", exist_ok=True)
|
| 282 |
+
file_path = f"temp/{task_id}.{ext}"
|
| 283 |
+
with open(file_path, "wb") as f:
|
| 284 |
+
f.write(content)
|
| 285 |
+
state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
|
| 286 |
+
if ext in ["txt", "csv", "xlsx", "mp3"] and "file_parser_tool" not in tools_needed:
|
| 287 |
tools_needed.append("file_parser_tool")
|
| 288 |
+
elif ext in ["jpg", "png"] and "image_parser_tool" not in tools_needed:
|
| 289 |
tools_needed.append("image_parser_tool")
|
| 290 |
+
elif ext == "pdf" and "document_retriever_tool" not in tools_needed:
|
| 291 |
tools_needed.append("document_retriever_tool")
|
| 292 |
+
|
| 293 |
+
state["tools_needed"] = list(set(tools_needed)) # Remove duplicates
|
| 294 |
+
logger.info(f"Task {task_id} Selected tools: {state['tools_needed']}")
|
|
|
|
|
|
|
| 295 |
return state
|
| 296 |
except Exception as e:
|
| 297 |
+
logger.error(f"Task {task_id} Tool selection failed: {e}")
|
| 298 |
state["error"] = f"Parse question failed: {str(e)}"
|
| 299 |
state["tools_needed"] = ["search_tool"]
|
| 300 |
return state
|
| 301 |
|
| 302 |
# Tool dispatcher
|
| 303 |
async def tool_dispatcher(state: JARVISState) -> JARVISState:
|
| 304 |
+
state = validate_state(state)
|
| 305 |
try:
|
| 306 |
+
task_id = state["task_id"]
|
| 307 |
+
question = state["question"]
|
| 308 |
+
tools_needed = state["tools_needed"]
|
| 309 |
+
|
| 310 |
+
for tool_name in tools_needed:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
try:
|
| 312 |
+
if tool_name == "search_tool":
|
| 313 |
+
result = await tools["search_tool"].ainvoke({"query": question})
|
| 314 |
+
state["web_results"].extend([str(r) for r in result] if result else ["No results from search_tool"])
|
| 315 |
+
elif tool_name == "multi_hop_search_tool":
|
| 316 |
+
result = await tools["multi_hop_search_tool"].ainvoke({
|
| 317 |
+
"query": question,
|
| 318 |
+
"steps": 3,
|
| 319 |
+
"llm_client": llm_client,
|
| 320 |
+
"llm_type": llm_type,
|
| 321 |
+
"llm_model": llm_model
|
| 322 |
+
})
|
| 323 |
+
state["multi_hop_results"].extend([r["content"] if isinstance(r, dict) else str(r) for r in result] if result else ["No results from multi_hop_search_tool"])
|
| 324 |
+
elif tool_name == "file_parser_tool":
|
| 325 |
+
file_path = state["metadata"].get("file_path")
|
| 326 |
+
file_ext = state["metadata"].get("file_ext")
|
| 327 |
+
if file_path and os.path.exists(file_path) and file_ext:
|
| 328 |
+
result = await tools["file_parser_tool"].ainvoke({
|
| 329 |
+
"task_id": task_id,
|
| 330 |
+
"file_type": file_ext,
|
| 331 |
+
"file_path": file_path,
|
| 332 |
+
"query": question
|
| 333 |
+
})
|
| 334 |
+
state["file_results"] = str(result) if result else "No file results"
|
| 335 |
+
else:
|
| 336 |
+
state["file_results"] = "No file available"
|
| 337 |
+
elif tool_name == "image_parser_tool":
|
| 338 |
+
file_path = state["metadata"].get("file_path")
|
| 339 |
+
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] in ["jpg", "png"]:
|
| 340 |
+
result = await tools["image_parser_tool"].ainvoke({"task_id": task_id, "file_path": file_path})
|
| 341 |
+
state["image_results"] = str(result) if result else "No image results"
|
| 342 |
+
else:
|
| 343 |
+
state["image_results"] = "No image available"
|
| 344 |
+
elif tool_name == "calculator_tool":
|
| 345 |
+
result = await tools["calculator_tool"].ainvoke({"expression": question})
|
| 346 |
+
state["calculation_results"] = str(result) if result else "No calculation results"
|
| 347 |
+
elif tool_name == "document_retriever_tool":
|
| 348 |
+
file_path = state["metadata"].get("file_path")
|
| 349 |
+
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] == "pdf":
|
| 350 |
+
result = await tools["document_retriever_tool"].ainvoke({
|
| 351 |
+
"task_id": task_id,
|
| 352 |
+
"query": question,
|
| 353 |
+
"file_path": file_path
|
| 354 |
+
})
|
| 355 |
+
state["document_results"] = str(result) if result else "No document results"
|
| 356 |
+
else:
|
| 357 |
+
state["document_results"] = "No document available"
|
| 358 |
+
elif tool_name == "duckduckgo_search_tool":
|
| 359 |
+
result = await tools["duckduckgo_search_tool"].ainvoke({
|
| 360 |
+
"query": question,
|
| 361 |
+
"original_query": question,
|
| 362 |
+
"embedder": embedder
|
| 363 |
+
})
|
| 364 |
+
state["web_results"].extend(result if isinstance(result, list) else [str(result)] if result else ["No results from duckduckgo_search_tool"])
|
| 365 |
+
elif tool_name == "weather_info_tool":
|
| 366 |
+
location = question.split()[-1] if "weather" in question.lower() else "Unknown"
|
| 367 |
+
result = await tools["weather_info_tool"].ainvoke({"location": location})
|
| 368 |
+
state["web_results"].append(str(result) if result else "No weather results")
|
| 369 |
+
elif tool_name == "hub_stats_tool":
|
| 370 |
+
author = question.split("by ")[1].split()[0] if "by" in question.lower() else "Unknown"
|
| 371 |
+
result = await tools["hub_stats_tool"].ainvoke({"author": author})
|
| 372 |
+
state["web_results"].append(str(result) if result else "No hub stats results")
|
| 373 |
+
elif tool_name == "guest_info_retriever_tool":
|
| 374 |
+
result = await tools["guest_info_retriever_tool"].ainvoke({"query": question})
|
| 375 |
+
state["web_results"].append(str(result) if result else "No guest info results")
|
| 376 |
|
| 377 |
+
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_executed": True}
|
| 378 |
+
logger.info(f"Task {task_id}: Executed {tool_name}")
|
| 379 |
except Exception as e:
|
| 380 |
+
logger.warning(f"Tool {tool_name} failed for task {task_id}: {e}")
|
| 381 |
+
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_error": str(e)}
|
| 382 |
+
|
| 383 |
+
# Ensure results are populated
|
| 384 |
+
state["web_results"] = state.get("web_results", ["No web results found"])
|
| 385 |
+
state["file_results"] = state.get("file_results", "No file results found")
|
| 386 |
+
state["image_results"] = state.get("image_results", "No image results found")
|
| 387 |
+
state["document_results"] = state.get("document_results", "No document results found")
|
| 388 |
+
state["calculation_results"] = state.get("calculation_results", "No calculation results found")
|
| 389 |
+
|
| 390 |
+
state["answer"] = await generate_answer(
|
| 391 |
+
task_id=task_id,
|
| 392 |
+
question=question,
|
| 393 |
+
search_results=state.get("web_results", []) + [
|
| 394 |
+
r["content"] if isinstance(r, dict) else str(r) for r in state.get("multi_hop_results", [])
|
| 395 |
+
],
|
| 396 |
+
file_results=state.get("file_results", "") + state.get("document_results", "") + state.get("image_results", "") + state.get("calculation_results", ""),
|
| 397 |
+
llm_client=llm_client
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
logger.info(f"Task {task_id}: Generated answer: {state['answer']}")
|
| 401 |
+
return state
|
| 402 |
except Exception as e:
|
| 403 |
+
logger.error(f"Tool dispatch failed: {e}")
|
| 404 |
+
state["error"] = f"Tool dispatch failed: {e}"
|
| 405 |
+
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
# Define StateGraph
|
| 408 |
workflow = StateGraph(JARVISState)
|
| 409 |
+
workflow.add_node("parse_question", parse_question)
|
| 410 |
workflow.add_node("tool_dispatcher", tool_dispatcher)
|
| 411 |
+
workflow.set_entry_point("parse_question")
|
| 412 |
+
workflow.add_edge("parse_question", "tool_dispatcher")
|
| 413 |
+
workflow.add_edge("tool_dispatcher", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
graph = workflow.compile()
|
| 415 |
|
| 416 |
# Agent class
|
| 417 |
class JARVISAgent:
|
| 418 |
def __init__(self):
|
| 419 |
+
self.state = reset_state(task_id="init", question="Agent initialized")
|
| 420 |
+
self.state["results_table"] = [] # Initialize as empty list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
logger.info("JARVISAgent initialized.")
|
| 422 |
|
| 423 |
async def process_question(self, task_id: str, question: str) -> str:
|
| 424 |
+
state = reset_state(task_id=task_id, question=question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
try:
|
| 426 |
result = await graph.ainvoke(state)
|
| 427 |
+
answer = result.get("answer", "Unknown")
|
| 428 |
+
logger.info(f"Task {task_id} Final answer: {answer}")
|
| 429 |
+
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": answer})
|
| 430 |
+
self.state["metadata"] = {"last_task_id": task_id, "answer": answer}
|
| 431 |
return answer
|
| 432 |
except Exception as e:
|
| 433 |
logger.error(f"Error processing task {task_id}: {e}")
|
| 434 |
+
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
|
| 435 |
+
self.state["error"] = f"Task {task_id} failed: {str(e)}"
|
| 436 |
return f"Error: {str(e)}"
|
| 437 |
finally:
|
| 438 |
+
for ext in ["txt", "csv", "xlsx", "mp3", "jpg", "png", "pdf"]:
|
| 439 |
file_path = f"temp/{task_id}.{ext}"
|
| 440 |
if os.path.exists(file_path):
|
| 441 |
try:
|
|
|
|
| 447 |
async def process_all_questions(self, profile: gr.OAuthProfile | None):
|
| 448 |
if not profile:
|
| 449 |
logger.error("User not logged in.")
|
| 450 |
+
self.state["status_output"] = "Please Login to Hugging Face."
|
| 451 |
+
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
|
| 452 |
|
| 453 |
+
username = profile.username
|
| 454 |
logger.info(f"User logged in: {username}")
|
| 455 |
questions_url = f"{GAIA_API_URL}/questions"
|
| 456 |
submit_url = f"{GAIA_API_URL}/submit"
|
| 457 |
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
|
| 458 |
|
| 459 |
try:
|
| 460 |
+
async with await create_http_session() as session:
|
| 461 |
+
async with session.get(questions_url) as response:
|
| 462 |
+
response.raise_for_status()
|
| 463 |
+
questions = await response.json()
|
| 464 |
logger.info(f"Fetched {len(questions)} questions.")
|
| 465 |
except Exception as e:
|
| 466 |
logger.error(f"Error fetching questions: {e}")
|
| 467 |
+
self.state["status_output"] = f"Error fetching questions: {e}"
|
| 468 |
+
self.state["error"] = f"Fetch questions failed: {str(e)}"
|
| 469 |
+
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
|
| 470 |
|
| 471 |
answers_payload = []
|
| 472 |
for item in questions:
|
|
|
|
| 480 |
|
| 481 |
if not answers_payload:
|
| 482 |
logger.error("No answers generated.")
|
| 483 |
+
self.state["status_output"] = "No answers to submit."
|
| 484 |
+
self.state["error"] = "No answers generated"
|
| 485 |
+
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
|
| 486 |
|
| 487 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 488 |
try:
|
| 489 |
+
async with await create_http_session() as session:
|
| 490 |
+
async with session.post(submit_url, json=submission_data) as response:
|
| 491 |
+
response.raise_for_status()
|
| 492 |
+
result_data = await response.json()
|
| 493 |
+
self.state["status_output"] = (
|
| 494 |
f"Submission Successful!\n"
|
| 495 |
f"User: {result_data.get('username')}\n"
|
| 496 |
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 497 |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 498 |
f"Message: {result_data.get('message', 'No message received.')}"
|
| 499 |
)
|
| 500 |
+
self.state["metadata"] = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
|
| 501 |
except Exception as e:
|
| 502 |
logger.error(f"Submission failed: {e}")
|
| 503 |
+
self.state["status_output"] = f"Submission Failed: {e}"
|
| 504 |
+
self.state["error"] = f"Submission failed: {str(e)}"
|
| 505 |
|
| 506 |
+
return pd.DataFrame(self.state["results_table"] if self.state["results_table"] else [], columns=["Task ID", "Question", "Answer"]), self.state["status_output"]
|
| 507 |
|
| 508 |
# Gradio interface
|
| 509 |
with gr.Blocks() as demo:
|
| 510 |
+
gr.Markdown("# JARVIS GAIA Agent")
|
| 511 |
gr.Markdown(
|
| 512 |
"""
|
| 513 |
**Instructions:**
|
|
|
|
| 522 |
)
|
| 523 |
with gr.Row():
|
| 524 |
gr.LoginButton(value="Login to Hugging Face")
|
|
|
|
| 525 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 526 |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
| 527 |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
|
project_structure.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.
|
| 2 |
+
├── app.py
|
| 3 |
+
├── dockerfile
|
| 4 |
+
├── README.md
|
| 5 |
+
├── requirements.txt
|
| 6 |
+
├── retriever.py
|
| 7 |
+
├── state.py
|
| 8 |
+
└── tools
|
| 9 |
+
├── __init__.py
|
| 10 |
+
├── answer_generator.py
|
| 11 |
+
├── calculator.py
|
| 12 |
+
├── document_retriever.py
|
| 13 |
+
├── duckduckgo_search.py
|
| 14 |
+
├── file_fetcher.py
|
| 15 |
+
├── file_parser.py
|
| 16 |
+
├── guest_info.py
|
| 17 |
+
├── hub_stats.py
|
| 18 |
+
├── image_parser.py
|
| 19 |
+
├── search.py
|
| 20 |
+
└── weather_info.py
|
| 21 |
+
|
| 22 |
+
3 directories, 18 files
|
requirements.txt
CHANGED
|
@@ -20,8 +20,11 @@ transformers
|
|
| 20 |
asyncio
|
| 21 |
serpapi
|
| 22 |
duckduckgo-search
|
| 23 |
-
torch
|
| 24 |
together
|
| 25 |
google-search-results
|
| 26 |
beautifulsoup4
|
| 27 |
-
gradio[oauth]
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
asyncio
|
| 21 |
serpapi
|
| 22 |
duckduckgo-search
|
| 23 |
+
torch==2.2.2
|
| 24 |
together
|
| 25 |
google-search-results
|
| 26 |
beautifulsoup4
|
| 27 |
+
gradio[oauth]
|
| 28 |
+
nlkt
|
| 29 |
+
speechrecognition
|
| 30 |
+
rank_bm25
|
result.txt
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
retriever.py
CHANGED
|
@@ -1,25 +1,109 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
try:
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
docs = [
|
| 10 |
Document(
|
| 11 |
page_content="\n".join([
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
]),
|
| 17 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
-
for guest in guest_dataset
|
| 20 |
]
|
|
|
|
|
|
|
|
|
|
| 21 |
except Exception as e:
|
| 22 |
-
|
|
|
|
| 23 |
docs = [
|
| 24 |
Document(
|
| 25 |
page_content="\n".join([
|
|
@@ -28,7 +112,73 @@ def load_guest_dataset():
|
|
| 28 |
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
| 29 |
"Email: nikola.tesla@gmail.com"
|
| 30 |
]),
|
| 31 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
]
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
from typing import List
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
try:
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
except ImportError:
|
| 11 |
+
load_dataset = None
|
| 12 |
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def get_device():
|
| 16 |
+
"""
|
| 17 |
+
Determine the appropriate device for PyTorch.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Device name ('cuda', 'mps', or 'cpu').
|
| 21 |
+
"""
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
return "cuda"
|
| 24 |
+
elif torch.backends.mps.is_available():
|
| 25 |
+
return "mps"
|
| 26 |
+
return "cpu"
|
| 27 |
+
|
| 28 |
+
def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]:
|
| 29 |
+
"""
|
| 30 |
+
Load guest dataset from a local JSON file or Hugging Face dataset.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
dataset_path (str): Path to local JSON file or Hugging Face dataset name.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
List[Document]: List of Document objects with guest information.
|
| 37 |
+
"""
|
| 38 |
try:
|
| 39 |
+
# Try loading from Hugging Face dataset if datasets library is available
|
| 40 |
+
if load_dataset and not os.path.exists(dataset_path):
|
| 41 |
+
logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}")
|
| 42 |
+
guest_dataset = load_dataset(dataset_path, split="train")
|
| 43 |
+
docs = [
|
| 44 |
+
Document(
|
| 45 |
+
page_content="\n".join([
|
| 46 |
+
f"Name: {guest['name']}",
|
| 47 |
+
f"Relation: {guest['relation']}",
|
| 48 |
+
f"Description: {guest['description']}",
|
| 49 |
+
f"Email: {guest['email']}"
|
| 50 |
+
]),
|
| 51 |
+
metadata={
|
| 52 |
+
"name": guest["name"],
|
| 53 |
+
"relation": guest["relation"],
|
| 54 |
+
"description": guest["description"],
|
| 55 |
+
"email": guest["email"]
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
+
for guest in guest_dataset
|
| 59 |
+
]
|
| 60 |
+
logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset")
|
| 61 |
+
return docs
|
| 62 |
+
|
| 63 |
+
# Try loading from local JSON file
|
| 64 |
+
if os.path.exists(dataset_path):
|
| 65 |
+
logger.info(f"Loading guest dataset from local path: {dataset_path}")
|
| 66 |
+
with open(dataset_path, 'r') as f:
|
| 67 |
+
guests = json.load(f)
|
| 68 |
+
docs = [
|
| 69 |
+
Document(
|
| 70 |
+
page_content=guest.get('description', ''),
|
| 71 |
+
metadata={
|
| 72 |
+
'name': guest.get('name', ''),
|
| 73 |
+
'relation': guest.get('relation', ''),
|
| 74 |
+
'description': guest.get('description', ''),
|
| 75 |
+
'email': guest.get('email', '') # Optional email field
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
for guest in guests
|
| 79 |
+
]
|
| 80 |
+
logger.info(f"Loaded {len(docs)} guests from local JSON")
|
| 81 |
+
return docs
|
| 82 |
+
|
| 83 |
+
# Fallback to mock dataset if both fail
|
| 84 |
+
logger.warning(f"Dataset not found at {dataset_path}, using mock dataset")
|
| 85 |
docs = [
|
| 86 |
Document(
|
| 87 |
page_content="\n".join([
|
| 88 |
+
"Name: Dr. Nikola Tesla",
|
| 89 |
+
"Relation: old friend from university days",
|
| 90 |
+
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
| 91 |
+
"Email: nikola.tesla@gmail.com"
|
| 92 |
]),
|
| 93 |
+
metadata={
|
| 94 |
+
"name": "Dr. Nikola Tesla",
|
| 95 |
+
"relation": "old friend from university days",
|
| 96 |
+
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
| 97 |
+
"email": "nikola.tesla@gmail.com"
|
| 98 |
+
}
|
| 99 |
)
|
|
|
|
| 100 |
]
|
| 101 |
+
logger.info("Loaded mock dataset with 1 guest")
|
| 102 |
+
return docs
|
| 103 |
+
|
| 104 |
except Exception as e:
|
| 105 |
+
logger.error(f"Failed to load guest dataset: {e}")
|
| 106 |
+
# Return mock dataset as final fallback
|
| 107 |
docs = [
|
| 108 |
Document(
|
| 109 |
page_content="\n".join([
|
|
|
|
| 112 |
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
| 113 |
"Email: nikola.tesla@gmail.com"
|
| 114 |
]),
|
| 115 |
+
metadata={
|
| 116 |
+
"name": "Dr. Nikola Tesla",
|
| 117 |
+
"relation": "old friend from university days",
|
| 118 |
+
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
|
| 119 |
+
"email": "nikola.tesla@gmail.com"
|
| 120 |
+
}
|
| 121 |
)
|
| 122 |
]
|
| 123 |
+
logger.info("Loaded mock dataset with 1 guest due to error")
|
| 124 |
+
return docs
|
| 125 |
+
|
| 126 |
+
class BM25Retriever:
|
| 127 |
+
"""
|
| 128 |
+
A retriever class using SentenceTransformer for embedding-based search.
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self, dataset_path: str):
|
| 131 |
+
"""
|
| 132 |
+
Initialize the retriever with a SentenceTransformer model.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
dataset_path (str): Path to the dataset for retrieval.
|
| 136 |
+
|
| 137 |
+
Raises:
|
| 138 |
+
Exception: If embedder initialization fails.
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device())
|
| 142 |
+
self.dataset_path = dataset_path
|
| 143 |
+
logger.info("Initialized SentenceTransformer")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Failed to initialize embedder: {e}")
|
| 146 |
+
raise
|
| 147 |
+
|
| 148 |
+
def search(self, query: str) -> List[dict]:
|
| 149 |
+
"""
|
| 150 |
+
Search the dataset for relevant guest information.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
query (str): Search query (e.g., guest name or relation).
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
List[dict]: List of matching guest metadata dictionaries.
|
| 157 |
+
"""
|
| 158 |
+
try:
|
| 159 |
+
# Load dataset
|
| 160 |
+
docs = load_guest_dataset(self.dataset_path)
|
| 161 |
+
if not docs:
|
| 162 |
+
logger.warning("No documents available for search")
|
| 163 |
+
return []
|
| 164 |
+
|
| 165 |
+
# Convert documents to text for BM25 (using metadata for consistency)
|
| 166 |
+
texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs]
|
| 167 |
+
from langchain_community.retrievers import BM25Retriever
|
| 168 |
+
retriever = BM25Retriever.from_texts(texts)
|
| 169 |
+
retriever.k = 3 # Limit to top 3 results
|
| 170 |
+
|
| 171 |
+
# Perform search
|
| 172 |
+
results = retriever.invoke(query)
|
| 173 |
+
# Map results back to original metadata
|
| 174 |
+
matches = [
|
| 175 |
+
docs[i].metadata
|
| 176 |
+
for i in range(len(docs))
|
| 177 |
+
if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results)
|
| 178 |
+
]
|
| 179 |
+
logger.info(f"Found {len(matches)} matches for query: {query}")
|
| 180 |
+
return matches[:3] # Return top 3 matches
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logger.error(f"Search failed for query '{query}': {e}")
|
| 184 |
+
return []
|
state.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
-
from typing import TypedDict, List, Dict, Optional, Any
|
| 2 |
from langchain_core.messages import BaseMessage
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class JARVISState(TypedDict):
|
| 5 |
"""
|
|
@@ -10,11 +13,11 @@ class JARVISState(TypedDict):
|
|
| 10 |
question: The question text to be answered.
|
| 11 |
tools_needed: List of tool names to be used for the task.
|
| 12 |
web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
|
| 13 |
-
file_results: Parsed content from text, CSV, or
|
| 14 |
image_results: OCR or description results from image files.
|
| 15 |
calculation_results: Results from mathematical calculations.
|
| 16 |
-
document_results: Extracted content from PDF documents.
|
| 17 |
-
multi_hop_results: Results from iterative multi-hop searches.
|
| 18 |
messages: List of messages for LLM context (e.g., user prompts, system instructions).
|
| 19 |
answer: Final answer for the task, formatted for GAIA submission.
|
| 20 |
results_table: List of task results for Gradio display (Task ID, Question, Answer).
|
|
@@ -30,10 +33,84 @@ class JARVISState(TypedDict):
|
|
| 30 |
image_results: str
|
| 31 |
calculation_results: str
|
| 32 |
document_results: str
|
| 33 |
-
multi_hop_results: List[str]
|
| 34 |
messages: List[BaseMessage]
|
| 35 |
answer: str
|
| 36 |
results_table: List[Dict[str, str]]
|
| 37 |
status_output: str
|
| 38 |
error: Optional[str]
|
| 39 |
-
metadata: Optional[Dict[str, Any]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, List, Dict, Optional, Any, Union
|
| 2 |
from langchain_core.messages import BaseMessage
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
class JARVISState(TypedDict):
|
| 8 |
"""
|
|
|
|
| 13 |
question: The question text to be answered.
|
| 14 |
tools_needed: List of tool names to be used for the task.
|
| 15 |
web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
|
| 16 |
+
file_results: Parsed content from text, CSV, Excel, or audio files.
|
| 17 |
image_results: OCR or description results from image files.
|
| 18 |
calculation_results: Results from mathematical calculations.
|
| 19 |
+
document_results: Extracted content from PDF or text documents.
|
| 20 |
+
multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts).
|
| 21 |
messages: List of messages for LLM context (e.g., user prompts, system instructions).
|
| 22 |
answer: Final answer for the task, formatted for GAIA submission.
|
| 23 |
results_table: List of task results for Gradio display (Task ID, Question, Answer).
|
|
|
|
| 33 |
image_results: str
|
| 34 |
calculation_results: str
|
| 35 |
document_results: str
|
| 36 |
+
multi_hop_results: List[Union[str, Dict[str, Any]]]
|
| 37 |
messages: List[BaseMessage]
|
| 38 |
answer: str
|
| 39 |
results_table: List[Dict[str, str]]
|
| 40 |
status_output: str
|
| 41 |
error: Optional[str]
|
| 42 |
+
metadata: Optional[Dict[str, Any]]
|
| 43 |
+
|
| 44 |
+
def validate_state(state: JARVISState) -> JARVISState:
|
| 45 |
+
"""
|
| 46 |
+
Validate and initialize JARVISState fields.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
state: Input state dictionary.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Validated and initialized state.
|
| 53 |
+
"""
|
| 54 |
+
try:
|
| 55 |
+
if not state.get("task_id"):
|
| 56 |
+
logger.error("task_id is required")
|
| 57 |
+
raise ValueError("task_id is required")
|
| 58 |
+
if not state.get("question"):
|
| 59 |
+
logger.error("question is required")
|
| 60 |
+
raise ValueError("question is required")
|
| 61 |
+
|
| 62 |
+
# Initialize default values if missing
|
| 63 |
+
defaults = {
|
| 64 |
+
"tools_needed": ["search_tool"],
|
| 65 |
+
"web_results": [],
|
| 66 |
+
"file_results": "",
|
| 67 |
+
"image_results": "",
|
| 68 |
+
"calculation_results": "",
|
| 69 |
+
"document_results": "",
|
| 70 |
+
"multi_hop_results": [],
|
| 71 |
+
"messages": [],
|
| 72 |
+
"answer": "",
|
| 73 |
+
"results_table": [],
|
| 74 |
+
"status_output": "",
|
| 75 |
+
"error": None,
|
| 76 |
+
"metadata": {}
|
| 77 |
+
}
|
| 78 |
+
for key, default in defaults.items():
|
| 79 |
+
if key not in state or state[key] is None:
|
| 80 |
+
state[key] = default
|
| 81 |
+
|
| 82 |
+
logger.debug(f"Validated state for task {state['task_id']}")
|
| 83 |
+
return state
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"State validation failed: {e}")
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
def reset_state(task_id: str, question: str) -> JARVISState:
|
| 89 |
+
"""
|
| 90 |
+
Create a fresh JARVISState for a new task.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
task_id: Task identifier.
|
| 94 |
+
question: Question text.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Initialized JARVISState.
|
| 98 |
+
"""
|
| 99 |
+
state = JARVISState(
|
| 100 |
+
task_id=task_id,
|
| 101 |
+
question=question,
|
| 102 |
+
tools_needed=["search_tool"],
|
| 103 |
+
web_results=[],
|
| 104 |
+
file_results="",
|
| 105 |
+
image_results="",
|
| 106 |
+
calculation_results="",
|
| 107 |
+
document_results="",
|
| 108 |
+
multi_hop_results=[],
|
| 109 |
+
messages=[],
|
| 110 |
+
answer="",
|
| 111 |
+
results_table=[],
|
| 112 |
+
status_output="",
|
| 113 |
+
error=None,
|
| 114 |
+
metadata={}
|
| 115 |
+
)
|
| 116 |
+
return validate_state(state)
|
test.py
CHANGED
|
@@ -1,10 +1,233 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
}
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from app import JARVISAgent, llm_client, llm_type, llm_model, embedder
|
| 7 |
+
from tools.search import search_tool, multi_hop_search_tool
|
| 8 |
+
from tools.file_parser import file_parser_tool
|
| 9 |
+
from tools.image_parser import image_parser_tool
|
| 10 |
+
from tools.calculator import calculator_tool
|
| 11 |
+
from tools.document_retriever import document_retriever_tool
|
| 12 |
+
from tools.duckduckgo_search import duckduckgo_search_tool
|
| 13 |
+
from tools.weather_info import weather_info_tool
|
| 14 |
+
from tools.hub_stats import hub_stats_tool
|
| 15 |
+
from tools.guest_info import guest_info_retriever_tool
|
| 16 |
+
from tools.file_fetcher import fetch_task_file
|
| 17 |
+
from tools.answer_generator import preprocess_question, filter_results
|
| 18 |
+
from state import validate_state, reset_state, JARVISState
|
| 19 |
|
| 20 |
+
# Setup logging
|
| 21 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 22 |
+
logger = logging.getLogger(__name__)
|
|
|
|
| 23 |
|
| 24 |
+
async def test_tools():
|
| 25 |
+
"""Test all tools."""
|
| 26 |
+
logger.info("Testing Search Tool (SerpAPI)...")
|
| 27 |
+
try:
|
| 28 |
+
if not os.getenv("SERPAPI_API_KEY"):
|
| 29 |
+
logger.warning("Search Warning: SERPAPI_API_KEY not set")
|
| 30 |
+
else:
|
| 31 |
+
result = await search_tool.ainvoke({"query": "What is the capital of France?"})
|
| 32 |
+
logger.info(f"Search Result: {result}")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error(f"Search Error: {e}")
|
| 35 |
+
|
| 36 |
+
logger.info("Testing Multi-Hop Search Tool...")
|
| 37 |
+
try:
|
| 38 |
+
result = await multi_hop_search_tool.ainvoke({
|
| 39 |
+
"query": "What is the population of France's capital?",
|
| 40 |
+
"steps": 2,
|
| 41 |
+
"llm_client": llm_client,
|
| 42 |
+
"llm_type": llm_type,
|
| 43 |
+
"llm_model": llm_model
|
| 44 |
+
})
|
| 45 |
+
logger.info(f"Multi-Hop Search Result: {result}")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"Multi-Hop Search Error: {e}")
|
| 48 |
+
|
| 49 |
+
logger.info("Testing DuckDuckGo Search Tool...")
|
| 50 |
+
try:
|
| 51 |
+
result = await duckduckgo_search_tool.ainvoke({
|
| 52 |
+
"query": "What is the capital of France?",
|
| 53 |
+
"original_query": "What is the capital of France?",
|
| 54 |
+
"embedder": embedder
|
| 55 |
+
})
|
| 56 |
+
logger.info(f"DuckDuckGo Result: {result}")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"DuckDuckGo Error: {e}")
|
| 59 |
+
|
| 60 |
+
logger.info("Testing Weather Info Tool...")
|
| 61 |
+
try:
|
| 62 |
+
if not os.getenv("OPENWEATHERMAP_API_KEY"):
|
| 63 |
+
logger.warning("Weather Warning: OPENWEATHERMAP_API_KEY not set")
|
| 64 |
+
else:
|
| 65 |
+
result = await weather_info_tool.ainvoke({"location": "London"})
|
| 66 |
+
logger.info(f"Weather Result: {result}")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Weather Error: {e}")
|
| 69 |
+
|
| 70 |
+
logger.info("Testing Document Retriever Tool...")
|
| 71 |
+
try:
|
| 72 |
+
from PyPDF2 import PdfWriter
|
| 73 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
| 74 |
+
writer = PdfWriter()
|
| 75 |
+
from PyPDF2.generic import NameObject, create_string_object
|
| 76 |
+
page = writer.add_blank_page(width=72, height=72)
|
| 77 |
+
page[NameObject("/Contents")] = create_string_object("Sample document content for testing.")
|
| 78 |
+
writer.write(tmp)
|
| 79 |
+
tmp_path = tmp.name
|
| 80 |
+
result = await document_retriever_tool.ainvoke({
|
| 81 |
+
"task_id": "test_task",
|
| 82 |
+
"query": "Sample question",
|
| 83 |
+
"file_path": tmp_path
|
| 84 |
+
})
|
| 85 |
+
logger.info(f"Document Retriever Result: {result}")
|
| 86 |
+
os.unlink(tmp_path)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Document Retriever Error: {e}")
|
| 89 |
+
|
| 90 |
+
logger.info("Testing Image Parser Tool...")
|
| 91 |
+
try:
|
| 92 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 93 |
+
# Create a minimal PNG (1x1 pixel)
|
| 94 |
+
from PIL import Image
|
| 95 |
+
img = Image.new('RGB', (1, 1), color='white')
|
| 96 |
+
img.save(tmp.name, 'PNG')
|
| 97 |
+
tmp_path = tmp.name
|
| 98 |
+
result = await image_parser_tool.ainvoke({"task_id": "test_task", "file_path": tmp_path})
|
| 99 |
+
logger.info(f"Image Parser Result: {result}")
|
| 100 |
+
os.unlink(tmp_path)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"Image Parser Error: {e}")
|
| 103 |
+
|
| 104 |
+
logger.info("Testing File Parser Tool...")
|
| 105 |
+
try:
|
| 106 |
+
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
|
| 107 |
+
tmp.write(b"Sample text file content")
|
| 108 |
+
tmp_path = tmp.name
|
| 109 |
+
result = await file_parser_tool.ainvoke({
|
| 110 |
+
"task_id": "test_task",
|
| 111 |
+
"file_type": "txt",
|
| 112 |
+
"file_path": tmp_path,
|
| 113 |
+
"query": "What is in the file?"
|
| 114 |
+
})
|
| 115 |
+
logger.info(f"File Parser Result: {result}")
|
| 116 |
+
os.unlink(tmp_path)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"File Parser Error: {e}")
|
| 119 |
+
|
| 120 |
+
logger.info("Testing Calculator Tool...")
|
| 121 |
+
try:
|
| 122 |
+
result = await calculator_tool.ainvoke({"expression": "2 + 2"})
|
| 123 |
+
logger.info(f"Calculator Result: {result}")
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Calculator Error: {e}")
|
| 126 |
+
|
| 127 |
+
logger.info("Testing Hub Stats Tool...")
|
| 128 |
+
try:
|
| 129 |
+
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
|
| 130 |
+
logger.warning("Hub Stats Warning: HUGGINGFACEHUB_API_TOKEN not set")
|
| 131 |
+
else:
|
| 132 |
+
result = await hub_stats_tool.ainvoke({"author": "meta-llama"})
|
| 133 |
+
logger.info(f"Hub Stats Result: {result}")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Hub Stats Error: {e}")
|
| 136 |
+
|
| 137 |
+
logger.info("Testing Guest Info Retriever Tool...")
|
| 138 |
+
try:
|
| 139 |
+
result = await guest_info_retriever_tool.ainvoke({"query": "Who is the guest named John?"})
|
| 140 |
+
logger.info(f"Guest Info Result: {result}")
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Guest Info Error: {e}")
|
| 143 |
+
|
| 144 |
+
async def test_file_fetcher():
|
| 145 |
+
"""Test file fetcher."""
|
| 146 |
+
logger.info("Testing File Fetcher...")
|
| 147 |
+
try:
|
| 148 |
+
result = await fetch_task_file("8e867cd7-cff9-4e6c-867a-ff5ddc2550be", "Sample question with data")
|
| 149 |
+
logger.info(f"File Fetcher Result: {result}")
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"File Fetcher Error: {e}")
|
| 152 |
+
|
| 153 |
+
async def test_answer_generator():
|
| 154 |
+
"""Test answer generator functions."""
|
| 155 |
+
logger.info("Testing Preprocess Question...")
|
| 156 |
+
try:
|
| 157 |
+
result = await preprocess_question("What's the weather in Paris?")
|
| 158 |
+
logger.info(f"Preprocess Question Result: {result}")
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error(f"Preprocess Question Error: {e}")
|
| 161 |
+
|
| 162 |
+
logger.info("Testing Filter Results...")
|
| 163 |
+
try:
|
| 164 |
+
results = ["Paris is the capital of France.", "Florida is a state.", "Paris is in Texas."]
|
| 165 |
+
filtered = filter_results(results, "What is the capital of France?")
|
| 166 |
+
logger.info(f"Filter Results: {filtered}")
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Filter Results Error: {e}")
|
| 169 |
+
|
| 170 |
+
async def test_state_management():
|
| 171 |
+
"""Test state management functions."""
|
| 172 |
+
logger.info("Testing Reset State...")
|
| 173 |
+
try:
|
| 174 |
+
state = reset_state("test_task", "What is the capital of France?")
|
| 175 |
+
logger.info(f"Reset State Result: {state}")
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"Reset State Error: {e}")
|
| 178 |
+
|
| 179 |
+
logger.info("Testing Validate State...")
|
| 180 |
+
try:
|
| 181 |
+
invalid_state = {"task_id": "", "question": ""}
|
| 182 |
+
validate_state(invalid_state)
|
| 183 |
+
logger.error("Validate State should have failed")
|
| 184 |
+
except ValueError as e:
|
| 185 |
+
logger.info(f"Validate State Error (expected): {e}")
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
valid_state = reset_state("test_task", "Sample question")
|
| 189 |
+
validated = validate_state(valid_state)
|
| 190 |
+
logger.info(f"Validate State Result: {validated}")
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Validate State Error: {e}")
|
| 193 |
+
|
| 194 |
+
async def test_agent():
|
| 195 |
+
"""Test JARVISAgent with various cases."""
|
| 196 |
+
logger.info("Testing JARVISAgent (Simple Question)...")
|
| 197 |
+
try:
|
| 198 |
+
agent = JARVISAgent()
|
| 199 |
+
answer = await agent.process_question("test_task", "What is the capital of France?")
|
| 200 |
+
logger.info(f"JARVISAgent Answer: {answer}")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"JARVISAgent Error: {e}")
|
| 203 |
+
|
| 204 |
+
logger.info("Testing JARVISAgent (Edge Case: Empty Question)...")
|
| 205 |
+
try:
|
| 206 |
+
agent = JARVISAgent()
|
| 207 |
+
answer = await agent.process_question("test_task", "")
|
| 208 |
+
logger.info(f"JARVISAgent Empty Question Answer: {answer}")
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.info(f"JARVISAgent Empty Question Error (expected): {e}")
|
| 211 |
+
|
| 212 |
+
async def main():
|
| 213 |
+
required_envs = [
|
| 214 |
+
"HUGGINGFACEHUB_API_TOKEN",
|
| 215 |
+
"TOGETHER_API_KEY",
|
| 216 |
+
"OPENWEATHERMAP_API_KEY",
|
| 217 |
+
"SERPAPI_API_KEY"
|
| 218 |
+
]
|
| 219 |
+
for env in required_envs:
|
| 220 |
+
if not os.getenv(env):
|
| 221 |
+
logger.warning(f"{env} not set, some tools may fail")
|
| 222 |
+
|
| 223 |
+
await test_tools()
|
| 224 |
+
await test_file_fetcher()
|
| 225 |
+
await test_answer_generator()
|
| 226 |
+
await test_state_management()
|
| 227 |
+
await test_agent()
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
try:
|
| 231 |
+
asyncio.run(main())
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.error(f"Test script failed: {e}")
|
tools/__init__.py
CHANGED
|
@@ -6,4 +6,6 @@ from .document_retriever import document_retriever_tool
|
|
| 6 |
from .duckduckgo_search import duckduckgo_search_tool
|
| 7 |
from .weather_info import weather_info_tool
|
| 8 |
from .hub_stats import hub_stats_tool
|
| 9 |
-
from .guest_info import guest_info_retriever_tool
|
|
|
|
|
|
|
|
|
| 6 |
from .duckduckgo_search import duckduckgo_search_tool
|
| 7 |
from .weather_info import weather_info_tool
|
| 8 |
from .hub_stats import hub_stats_tool
|
| 9 |
+
from .guest_info import guest_info_retriever_tool
|
| 10 |
+
from .file_fetcher import fetch_task_file
|
| 11 |
+
from .answer_generator import generate_answer, preprocess_question#, filter_results, get_embedder
|
tools/answer_generator.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Any
|
| 5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
|
| 9 |
+
# Setup logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Download NLTK data
|
| 14 |
+
try:
|
| 15 |
+
nltk.download('punkt', quiet=True)
|
| 16 |
+
nltk.download('stopwords', quiet=True)
|
| 17 |
+
except Exception as e:
|
| 18 |
+
logger.warning(f"NLTK data download failed: {e}")
|
| 19 |
+
|
| 20 |
+
# Global embedder
|
| 21 |
+
_embedder = None
|
| 22 |
+
|
| 23 |
+
def get_embedder():
|
| 24 |
+
global _embedder
|
| 25 |
+
if _embedder is None:
|
| 26 |
+
try:
|
| 27 |
+
_embedder = SentenceTransformer(
|
| 28 |
+
"all-MiniLM-L6-v2",
|
| 29 |
+
device="cpu",
|
| 30 |
+
cache_folder="./cache"
|
| 31 |
+
)
|
| 32 |
+
logger.info("SentenceTransformer initialized")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error(f"Failed to initialize SentenceTransformer: {e}")
|
| 35 |
+
raise RuntimeError(f"Embedder initialization failed: {e}")
|
| 36 |
+
return _embedder
|
| 37 |
+
|
| 38 |
+
def filter_results(search_results: List[str], question: str) -> List[str]:
|
| 39 |
+
try:
|
| 40 |
+
if not search_results or not question:
|
| 41 |
+
return search_results
|
| 42 |
+
|
| 43 |
+
embedder = get_embedder()
|
| 44 |
+
question_embedding = embedder.encode([question], convert_to_numpy=True)
|
| 45 |
+
result_embeddings = embedder.encode(search_results, convert_to_numpy=True)
|
| 46 |
+
|
| 47 |
+
similarities = np.dot(result_embeddings, question_embedding.T).flatten()
|
| 48 |
+
filtered_results = [
|
| 49 |
+
search_results[i] for i in range(len(search_results))
|
| 50 |
+
if similarities[i] > 0.5 and search_results[i].strip()
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
return filtered_results if filtered_results else search_results[:3]
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.warning(f"Result filtering failed: {e}")
|
| 56 |
+
return search_results[:3]
|
| 57 |
+
|
| 58 |
+
async def preprocess_question(question: str) -> str:
|
| 59 |
+
"""Preprocess the question to clean and standardize it."""
|
| 60 |
+
try:
|
| 61 |
+
question = question.strip().lower()
|
| 62 |
+
if not question.endswith("?"):
|
| 63 |
+
question += "?"
|
| 64 |
+
logger.debug(f"Preprocessed question: {question}")
|
| 65 |
+
return question
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Error preprocessing question: {e}")
|
| 68 |
+
return question
|
| 69 |
+
|
| 70 |
+
async def generate_answer(
|
| 71 |
+
task_id: str,
|
| 72 |
+
question: str,
|
| 73 |
+
search_results: List[str],
|
| 74 |
+
file_results: str,
|
| 75 |
+
llm_client: Any
|
| 76 |
+
) -> str:
|
| 77 |
+
"""Generate an answer using LLM with search and file results."""
|
| 78 |
+
try:
|
| 79 |
+
if not search_results:
|
| 80 |
+
search_results = ["No search results available."]
|
| 81 |
+
if not file_results:
|
| 82 |
+
file_results = "No file results available."
|
| 83 |
+
|
| 84 |
+
context = "\n".join([str(r) for r in search_results]) + "\n" + file_results
|
| 85 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 86 |
+
SystemMessage(content="""You are an assistant answering questions using provided context.
|
| 87 |
+
- Use ONLY the context to formulate a concise, accurate answer.
|
| 88 |
+
- If the context is insufficient, state: 'Insufficient information to answer.'
|
| 89 |
+
- Do NOT generate or assume information beyond the context.
|
| 90 |
+
- Return a single, clear sentence or phrase as the answer."""),
|
| 91 |
+
HumanMessage(content=f"Context: {context}\nQuestion: {question}")
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
messages = [
|
| 95 |
+
{"role": "system", "content": prompt[0].content},
|
| 96 |
+
{"role": "user", "content": prompt[1].content}
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
if isinstance(llm_client, tuple): # hf_local
|
| 100 |
+
model, tokenizer = llm_client
|
| 101 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
| 102 |
+
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
|
| 103 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 104 |
+
elif hasattr(llm_client, "chat"): # together
|
| 105 |
+
response = llm_client.chat.completions.create(
|
| 106 |
+
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 107 |
+
messages=messages,
|
| 108 |
+
max_tokens=100,
|
| 109 |
+
temperature=0.7,
|
| 110 |
+
top_p=0.9,
|
| 111 |
+
frequency_penalty=0.5
|
| 112 |
+
)
|
| 113 |
+
response = response.choices[0].message.content.strip()
|
| 114 |
+
else: # hf_api
|
| 115 |
+
response = llm_client.chat.completions.create(
|
| 116 |
+
messages=messages,
|
| 117 |
+
max_tokens=100,
|
| 118 |
+
temperature=0.7
|
| 119 |
+
)
|
| 120 |
+
response = response.choices[0].message.content.strip()
|
| 121 |
+
|
| 122 |
+
answer = response.strip()
|
| 123 |
+
if not answer or answer.lower() == "none":
|
| 124 |
+
answer = "Insufficient information to answer."
|
| 125 |
+
logger.info(f"Task {task_id}: Generated answer: {answer}")
|
| 126 |
+
return answer
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Task {task_id}: Answer generation failed: {e}")
|
| 129 |
+
return "Error generating answer."
|
tools/calculator.py
CHANGED
|
@@ -1,15 +1,35 @@
|
|
| 1 |
-
from langchain_core.tools import tool
|
| 2 |
-
from sympy import sympify
|
| 3 |
import logging
|
|
|
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
try:
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
return str(result)
|
| 13 |
except Exception as e:
|
| 14 |
-
logger.error(f"
|
| 15 |
-
return f"Error: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from langchain_core.tools import StructuredTool
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
+
class CalculatorInput(BaseModel):
|
| 8 |
+
expression: str = Field(description="Mathematical expression to evaluate")
|
| 9 |
+
|
| 10 |
+
async def calculator_func(expression: str) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Evaluate a mathematical expression and return the result as a string.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
expression (str): Mathematical expression (e.g., '2 + 2').
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
str: Result of the expression.
|
| 19 |
+
"""
|
| 20 |
try:
|
| 21 |
+
logger.info(f"Evaluating expression: {expression}")
|
| 22 |
+
result = eval(expression, {"__builtins__": {}}, {}) # Safe eval
|
| 23 |
+
if isinstance(result, float):
|
| 24 |
+
return f"{result:.2f}" if "USD" in expression else str(result)
|
| 25 |
return str(result)
|
| 26 |
except Exception as e:
|
| 27 |
+
logger.error(f"Calculator error: {e}")
|
| 28 |
+
return f"Error: {e}"
|
| 29 |
+
|
| 30 |
+
calculator_tool = StructuredTool.from_function(
|
| 31 |
+
func=calculator_func,
|
| 32 |
+
name="calculator_tool",
|
| 33 |
+
args_schema=CalculatorInput,
|
| 34 |
+
coroutine=calculator_func
|
| 35 |
+
)
|
tools/document_retriever.py
CHANGED
|
@@ -1,30 +1,47 @@
|
|
| 1 |
-
from langchain_core.tools import tool
|
| 2 |
-
from langchain_community.document_loaders import TextLoader, CSVLoader, PyPDFLoader
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
try:
|
| 12 |
-
file_path
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
else:
|
| 24 |
-
return f"Unsupported file type: {file_type}"
|
| 25 |
-
|
| 26 |
-
docs = loader.load()
|
| 27 |
-
return "\n".join(doc.page_content for doc in docs)
|
| 28 |
except Exception as e:
|
| 29 |
logger.error(f"Error retrieving document for task {task_id}: {e}")
|
| 30 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
from langchain_core.tools import StructuredTool
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from typing import Optional
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
+
class DocumentRetrieverInput(BaseModel):
|
| 10 |
+
task_id: str = Field(description="Task identifier")
|
| 11 |
+
query: str = Field(description="Search query")
|
| 12 |
+
file_path: Optional[str] = Field(description="Path to document file", default=None)
|
| 13 |
+
|
| 14 |
+
async def document_retriever_func(task_id: str, query: str, file_path: Optional[str] = None) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Retrieve content from documents for a given task and query.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
task_id (str): Task identifier.
|
| 20 |
+
query (str): Search query.
|
| 21 |
+
file_path (Optional[str]): Path to document file.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
str: Retrieved document content or error message.
|
| 25 |
+
"""
|
| 26 |
try:
|
| 27 |
+
if file_path and os.path.exists(file_path):
|
| 28 |
+
logger.info(f"Retrieving document from {file_path} for task {task_id}")
|
| 29 |
+
if file_path.endswith('.pdf'):
|
| 30 |
+
import PyPDF2
|
| 31 |
+
with open(file_path, 'rb') as f:
|
| 32 |
+
reader = PyPDF2.PdfReader(f)
|
| 33 |
+
text = "".join(page.extract_text() or "" for page in reader.pages)
|
| 34 |
+
return text[:500] if text else "No text extracted"
|
| 35 |
+
return "Unsupported file format"
|
| 36 |
+
logger.warning(f"No valid documents found for task {task_id}")
|
| 37 |
+
return "Document not found"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
except Exception as e:
|
| 39 |
logger.error(f"Error retrieving document for task {task_id}: {e}")
|
| 40 |
+
return f"Error: {str(e)}"
|
| 41 |
+
|
| 42 |
+
document_retriever_tool = StructuredTool.from_function(
|
| 43 |
+
func=document_retriever_func,
|
| 44 |
+
name="document_retriever_tool",
|
| 45 |
+
args_schema=DocumentRetrieverInput,
|
| 46 |
+
coroutine=document_retriever_func
|
| 47 |
+
)
|
tools/duckduckgo_search.py
CHANGED
|
@@ -1,6 +1,99 @@
|
|
| 1 |
-
from smolagents import Tool, DuckDuckGoSearchTool
|
| 2 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
logger = logging.getLogger(__name__)
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
+
import asyncio
|
| 4 |
+
from langchain_core.tools import StructuredTool
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
from duckduckgo_search import DDGS
|
| 8 |
+
from serpapi import GoogleSearch
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
+
class DuckDuckGoSearchInput(BaseModel):
|
| 13 |
+
query: str = Field(description="Search query")
|
| 14 |
+
original_query: str = Field(description="Original query for context")
|
| 15 |
+
embedder: Optional[object] = Field(description="SentenceTransformer embedder", default=None)
|
| 16 |
+
|
| 17 |
+
async def duckduckgo_search_func(query: str, original_query: str, embedder: Optional[object] = None) -> List[str]:
|
| 18 |
+
"""
|
| 19 |
+
Perform a DuckDuckGo search with retries and fall back to SerpAPI if needed.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
query (str): Search query.
|
| 23 |
+
original_query (str): Original query for context.
|
| 24 |
+
embedder (Optional[object]): SentenceTransformer for result filtering.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List[str]: List of search result snippets.
|
| 28 |
+
"""
|
| 29 |
+
async def try_duckduckgo(query: str, max_retries: int = 3) -> List[str]:
|
| 30 |
+
for attempt in range(max_retries):
|
| 31 |
+
try:
|
| 32 |
+
logger.info(f"DuckDuckGo search attempt {attempt + 1} for query: {query}")
|
| 33 |
+
with DDGS() as ddgs:
|
| 34 |
+
results = [r['body'] for r in ddgs.text(query, max_results=5)]
|
| 35 |
+
return results
|
| 36 |
+
except Exception as e:
|
| 37 |
+
if "Ratelimit" in str(e) and attempt < max_retries - 1:
|
| 38 |
+
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
| 39 |
+
logger.warning(f"DuckDuckGo rate limit hit, retrying in {wait_time}s: {e}")
|
| 40 |
+
await asyncio.sleep(wait_time)
|
| 41 |
+
else:
|
| 42 |
+
logger.error(f"DuckDuckGo search failed for query '{query}': {e}")
|
| 43 |
+
raise e
|
| 44 |
+
return []
|
| 45 |
+
|
| 46 |
+
async def try_serpapi(query: str, max_retries: int = 3) -> List[str]:
|
| 47 |
+
if not os.getenv("SERPAPI_API_KEY"):
|
| 48 |
+
logger.warning("SERPAPI_API_KEY not set, cannot use SerpAPI fallback")
|
| 49 |
+
return []
|
| 50 |
+
for attempt in range(max_retries):
|
| 51 |
+
try:
|
| 52 |
+
logger.info(f"SerpAPI search attempt {attempt + 1} for query: {query}")
|
| 53 |
+
params = {
|
| 54 |
+
"q": query,
|
| 55 |
+
"api_key": os.getenv("SERPAPI_API_KEY"),
|
| 56 |
+
"num": 5
|
| 57 |
+
}
|
| 58 |
+
search = GoogleSearch(params)
|
| 59 |
+
results = search.get_dict().get("organic_results", [])
|
| 60 |
+
return [result.get("snippet", "") for result in results if "snippet" in result]
|
| 61 |
+
except Exception as e:
|
| 62 |
+
if attempt < max_retries - 1:
|
| 63 |
+
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
| 64 |
+
logger.warning(f"SerpAPI search failed, retrying in {wait_time}s: {e}")
|
| 65 |
+
await asyncio.sleep(wait_time)
|
| 66 |
+
else:
|
| 67 |
+
logger.error(f"SerpAPI search failed for query '{query}': {e}")
|
| 68 |
+
return []
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Try DuckDuckGo with retries
|
| 72 |
+
logger.info(f"Executing DuckDuckGo search for query: {query}")
|
| 73 |
+
results = await try_duckduckgo(query)
|
| 74 |
+
|
| 75 |
+
# Fall back to SerpAPI if DuckDuckGo fails
|
| 76 |
+
if not results:
|
| 77 |
+
logger.info(f"DuckDuckGo returned no results, falling back to SerpAPI for query: {query}")
|
| 78 |
+
results = await try_serpapi(query)
|
| 79 |
+
|
| 80 |
+
# Rank results if embedder is provided
|
| 81 |
+
if embedder and results:
|
| 82 |
+
from sentence_transformers import util
|
| 83 |
+
query_embedding = embedder.encode(original_query, convert_to_tensor=True)
|
| 84 |
+
result_embeddings = embedder.encode(results, convert_to_tensor=True)
|
| 85 |
+
scores = util.cos_sim(query_embedding, result_embeddings)[0]
|
| 86 |
+
ranked_results = [results[i] for i in scores.argsort(descending=True)]
|
| 87 |
+
return ranked_results[:3]
|
| 88 |
+
|
| 89 |
+
return results[:3] if results else []
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Search failed for query '{query}': {e}")
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
duckduckgo_search_tool = StructuredTool.from_function(
|
| 95 |
+
func=duckduckgo_search_func,
|
| 96 |
+
name="duckduckgo_search_tool",
|
| 97 |
+
args_schema=DuckDuckGoSearchInput,
|
| 98 |
+
coroutine=duckduckgo_search_func
|
| 99 |
+
)
|
tools/file_fetcher.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import ssl
|
| 3 |
+
import aiohttp
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from urllib.parse import urljoin
|
| 7 |
+
|
| 8 |
+
# Setup logging
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
async def fetch_task_file(task_id: str, question: str) -> Dict[str, bytes]:
|
| 13 |
+
"""
|
| 14 |
+
Fetch a file associated with a task from the GAIA API.
|
| 15 |
+
Returns a dictionary of file extensions to content.
|
| 16 |
+
"""
|
| 17 |
+
results = {}
|
| 18 |
+
base_url = "https://gaia-benchmark-api.hf.space/files/" # Updated URL
|
| 19 |
+
extensions = ["xlsx", "csv", "pdf", "txt", "mp3", "jpg", "png"]
|
| 20 |
+
|
| 21 |
+
ssl_context = ssl.create_default_context()
|
| 22 |
+
ssl_context.check_hostname = False
|
| 23 |
+
ssl_context.verify_mode = ssl.CERT_NONE
|
| 24 |
+
|
| 25 |
+
async with aiohttp.ClientSession(
|
| 26 |
+
connector=aiohttp.TCPConnector(ssl_context=ssl_context),
|
| 27 |
+
timeout=aiohttp.ClientTimeout(total=30)
|
| 28 |
+
) as session:
|
| 29 |
+
for ext in extensions:
|
| 30 |
+
file_url = urljoin(base_url, f"{task_id}/{task_id}.{ext}")
|
| 31 |
+
try:
|
| 32 |
+
async with session.get(file_url) as response:
|
| 33 |
+
if response.status == 200:
|
| 34 |
+
content = await response.read()
|
| 35 |
+
results[ext] = content
|
| 36 |
+
logger.info(f"Fetched {ext} for task {task_id}")
|
| 37 |
+
else:
|
| 38 |
+
logger.warning(f"No {ext} for task {task_id}: HTTP {response.status}")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.warning(f"Error fetching {ext} for task {task_id}: {str(e)}")
|
| 41 |
+
|
| 42 |
+
return results
|
tools/file_parser.py
CHANGED
|
@@ -1,36 +1,112 @@
|
|
| 1 |
-
from langchain_core.tools import tool
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import PyPDF2
|
| 4 |
import logging
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
try:
|
| 13 |
-
file_path = f"temp_{task_id}.{file_type}"
|
| 14 |
if not os.path.exists(file_path):
|
| 15 |
logger.warning(f"File not found: {file_path}")
|
| 16 |
return "File not found"
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
df = pd.read_csv(file_path)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
elif file_type == "pdf":
|
| 25 |
with open(file_path, "rb") as f:
|
| 26 |
reader = PyPDF2.PdfReader(f)
|
| 27 |
-
text = "".join(page.extract_text() for page in reader.pages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
return text
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
else:
|
|
|
|
| 33 |
return f"Unsupported file type: {file_type}"
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
logger.error(f"Error parsing file for task {task_id}: {e}")
|
| 36 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import PyPDF2
|
| 5 |
+
import speech_recognition as sr
|
| 6 |
+
import re
|
| 7 |
+
from langchain_core.tools import StructuredTool
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
from typing import Optional
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
class FileParserInput(BaseModel):
|
| 14 |
+
task_id: str = Field(description="Task identifier")
|
| 15 |
+
file_type: str = Field(description="File extension (e.g., pdf, csv)")
|
| 16 |
+
file_path: str = Field(description="Path to the file")
|
| 17 |
+
query: Optional[str] = Field(description="Query related to the file", default=None)
|
| 18 |
+
|
| 19 |
+
async def file_parser_func(task_id: str, file_type: str, file_path: str, query: Optional[str] = None) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Parse a file based on task_id, file_type, file_path, and query context.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
task_id (str): Task identifier.
|
| 25 |
+
file_type (str): File extension (e.g., 'xlsx', 'mp3', 'pdf').
|
| 26 |
+
file_path (str): Path to the file.
|
| 27 |
+
query (Optional[str]): Question context to guide parsing (e.g., for specific data extraction).
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
str: Parsed content or error message.
|
| 31 |
+
"""
|
| 32 |
try:
|
|
|
|
| 33 |
if not os.path.exists(file_path):
|
| 34 |
logger.warning(f"File not found: {file_path}")
|
| 35 |
return "File not found"
|
| 36 |
|
| 37 |
+
logger.info(f"Parsing file: {file_path} for task {task_id}")
|
| 38 |
+
|
| 39 |
+
if file_type in ["xlsx", "xls"]:
|
| 40 |
+
df = pd.read_excel(file_path, engine="openpyxl")
|
| 41 |
+
if query and ("sum" in query.lower() or "total" in query.lower()):
|
| 42 |
+
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
|
| 43 |
+
if numerical_cols.empty:
|
| 44 |
+
return "No numerical data found"
|
| 45 |
+
if "food" in query.lower():
|
| 46 |
+
food_rows = df[df.apply(lambda x: "food" in str(x).lower(), axis=1)]
|
| 47 |
+
if not food_rows.empty and numerical_cols[0] in food_rows:
|
| 48 |
+
total = food_rows[numerical_cols[0]].sum()
|
| 49 |
+
return f"{total:.2f}"
|
| 50 |
+
total = df[numerical_cols[0]].sum()
|
| 51 |
+
return f"{total:.2f}"
|
| 52 |
+
return df.to_string(index=False)
|
| 53 |
+
|
| 54 |
+
elif file_type == "csv":
|
| 55 |
df = pd.read_csv(file_path)
|
| 56 |
+
if query and ("sum" in query.lower() or "total" in query.lower()):
|
| 57 |
+
numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
|
| 58 |
+
if numerical_cols.empty:
|
| 59 |
+
return "No numerical data found"
|
| 60 |
+
total = df[numerical_cols[0]].sum()
|
| 61 |
+
return f"{total:.2f}"
|
| 62 |
+
return df.to_string(index=False)
|
| 63 |
+
|
| 64 |
elif file_type == "pdf":
|
| 65 |
with open(file_path, "rb") as f:
|
| 66 |
reader = PyPDF2.PdfReader(f)
|
| 67 |
+
text = "".join(page.extract_text() or "" for page in reader.pages)
|
| 68 |
+
if query and "page number" in query.lower():
|
| 69 |
+
pages = re.findall(r'\b\d+\b', text)
|
| 70 |
+
return ", ".join(sorted(pages, key=int)) if pages else "No page numbers found"
|
| 71 |
+
return text.strip() or "No text extracted"
|
| 72 |
+
|
| 73 |
+
elif file_type == "txt":
|
| 74 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 75 |
+
text = f.read()
|
| 76 |
+
if query and "page number" in query.lower():
|
| 77 |
+
pages = re.findall(r'\b\d+\b', text)
|
| 78 |
+
return ", ".join(sorted(pages, key=int)) if pages else "No page numbers found"
|
| 79 |
+
return text.strip()
|
| 80 |
+
|
| 81 |
+
elif file_type == "mp3":
|
| 82 |
+
recognizer = sr.Recognizer()
|
| 83 |
+
with sr.AudioFile(file_path) as source:
|
| 84 |
+
audio = recognizer.record(source)
|
| 85 |
+
try:
|
| 86 |
+
text = recognizer.recognize_google(audio)
|
| 87 |
+
logger.debug(f"Transcribed audio: {text}")
|
| 88 |
+
if query and "page number" in query.lower():
|
| 89 |
+
pages = re.findall(r'\b\d+\b', text)
|
| 90 |
+
return ", ".join(sorted(pages, key=int)) if pages else "No page numbers provided"
|
| 91 |
return text
|
| 92 |
+
except sr.UnknownValueError:
|
| 93 |
+
logger.error("Could not understand audio")
|
| 94 |
+
return "No text transcribed from audio"
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Audio parsing failed: {e}")
|
| 97 |
+
return "Error transcribing audio"
|
| 98 |
+
|
| 99 |
else:
|
| 100 |
+
logger.warning(f"Unsupported file type: {file_type}")
|
| 101 |
return f"Unsupported file type: {file_type}"
|
| 102 |
+
|
| 103 |
except Exception as e:
|
| 104 |
logger.error(f"Error parsing file for task {task_id}: {e}")
|
| 105 |
+
return f"Error: {str(e)}"
|
| 106 |
+
|
| 107 |
+
file_parser_tool = StructuredTool.from_function(
|
| 108 |
+
func=file_parser_func,
|
| 109 |
+
name="file_parser_tool",
|
| 110 |
+
args_schema=FileParserInput,
|
| 111 |
+
coroutine=file_parser_func
|
| 112 |
+
)
|
tools/guest_info.py
CHANGED
|
@@ -1,20 +1,47 @@
|
|
| 1 |
-
from langchain_core.tools import tool
|
| 2 |
-
from retriever import load_guest_dataset
|
| 3 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
try:
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
except Exception as e:
|
| 19 |
logger.error(f"Error retrieving guest info for query '{query}': {e}")
|
| 20 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from langchain_core.tools import StructuredTool
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from rank_bm25 import BM25Okapi
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
+
class GuestInfoInput(BaseModel):
|
| 10 |
+
query: str = Field(description="Query about guest information")
|
| 11 |
+
|
| 12 |
+
async def guest_info_func(query: str) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Retrieve guest information based on a query.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
query (str): Query about guest information.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Guest information or error message.
|
| 21 |
+
"""
|
| 22 |
try:
|
| 23 |
+
logger.info(f"Retrieving guest info for query: {query}")
|
| 24 |
+
dataset = load_dataset("agents-course/unit3-invitees", split="train")
|
| 25 |
+
logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset")
|
| 26 |
+
|
| 27 |
+
documents = [f"{row['name']} {row['relation']}" for row in dataset]
|
| 28 |
+
tokenized_docs = [doc.lower().split() for doc in documents]
|
| 29 |
+
bm25 = BM25Okapi(tokenized_docs)
|
| 30 |
+
|
| 31 |
+
tokenized_query = query.lower().split()
|
| 32 |
+
scores = bm25.get_scores(tokenized_query)
|
| 33 |
+
best_idx = scores.argmax()
|
| 34 |
+
|
| 35 |
+
if scores[best_idx] > 0:
|
| 36 |
+
return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}"
|
| 37 |
+
return "No matching guest found"
|
| 38 |
except Exception as e:
|
| 39 |
logger.error(f"Error retrieving guest info for query '{query}': {e}")
|
| 40 |
+
return f"Error: {str(e)}"
|
| 41 |
+
|
| 42 |
+
guest_info_retriever_tool = StructuredTool.from_function(
|
| 43 |
+
func=guest_info_func,
|
| 44 |
+
name="guest_info_retriever_tool",
|
| 45 |
+
args_schema=GuestInfoInput,
|
| 46 |
+
coroutine=guest_info_func
|
| 47 |
+
)
|
tools/hub_stats.py
CHANGED
|
@@ -1,17 +1,54 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
@tool
|
| 8 |
async def hub_stats_tool(author: str) -> str:
|
| 9 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
try:
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
model = models[0]
|
| 14 |
-
return f"The most downloaded model by {author} is {model
|
| 15 |
return f"No models found for author {author}."
|
| 16 |
except Exception as e:
|
| 17 |
logger.error(f"Error fetching models for {author}: {e}")
|
|
|
|
| 1 |
+
import aiohttp
|
| 2 |
+
import ssl
|
| 3 |
import logging
|
| 4 |
+
from langchain_core.tools import tool
|
| 5 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
|
| 13 |
+
async def fetch_hf_models(author: str) -> Optional[dict]:
|
| 14 |
+
url = f"https://huggingface.co/api/models?author={author}&sort=downloads&direction=-1&limit=1"
|
| 15 |
+
ssl_context = ssl.create_default_context()
|
| 16 |
+
try:
|
| 17 |
+
async with aiohttp.ClientSession() as session:
|
| 18 |
+
async with session.get(url, ssl=ssl_context) as response:
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return await response.json()
|
| 21 |
+
except aiohttp.ClientError as e:
|
| 22 |
+
logger.error(f"Failed to fetch models for {author}: {e}")
|
| 23 |
+
raise
|
| 24 |
+
|
| 25 |
@tool
|
| 26 |
async def hub_stats_tool(author: str) -> str:
|
| 27 |
+
"""
|
| 28 |
+
Fetch the most downloaded model from a specific author on Hugging Face Hub.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
author (str): Hugging Face author username.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
str: Model information or error message.
|
| 35 |
+
"""
|
| 36 |
try:
|
| 37 |
+
# Check local cache
|
| 38 |
+
cache_file = f"temp/hf_cache_{author}.json"
|
| 39 |
+
if os.path.exists(cache_file):
|
| 40 |
+
with open(cache_file, "r") as f:
|
| 41 |
+
models = json.load(f)
|
| 42 |
+
logger.debug(f"Loaded cached models for {author}")
|
| 43 |
+
else:
|
| 44 |
+
models = await fetch_hf_models(author)
|
| 45 |
+
os.makedirs("temp", exist_ok=True)
|
| 46 |
+
with open(cache_file, "w") as f:
|
| 47 |
+
json.dump(models, f)
|
| 48 |
+
|
| 49 |
+
if models and isinstance(models, list) and models:
|
| 50 |
model = models[0]
|
| 51 |
+
return f"The most downloaded model by {author} is {model['id']} with {model.get('downloads', 0):,} downloads."
|
| 52 |
return f"No models found for author {author}."
|
| 53 |
except Exception as e:
|
| 54 |
logger.error(f"Error fetching models for {author}: {e}")
|
tools/image_parser.py
CHANGED
|
@@ -1,25 +1,43 @@
|
|
| 1 |
-
from langchain_core.tools import tool
|
| 2 |
-
import easyocr
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
-
reader = easyocr.Reader(['en'])
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
try:
|
| 13 |
if not os.path.exists(file_path):
|
| 14 |
-
logger.warning(f"Image not found: {file_path}")
|
| 15 |
-
return "Image not found"
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
return text
|
| 23 |
except Exception as e:
|
| 24 |
-
logger.error(f"Error parsing image {
|
| 25 |
-
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
from langchain_core.tools import StructuredTool
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
import easyocr
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
|
|
|
| 8 |
|
| 9 |
+
class ImageParserInput(BaseModel):
|
| 10 |
+
task_id: str = Field(description="Task identifier")
|
| 11 |
+
file_path: str = Field(description="Path to the image file")
|
| 12 |
+
|
| 13 |
+
async def image_parser_func(task_id: str, file_path: str) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Parse text from an image file using OCR.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
task_id (str): Task identifier.
|
| 19 |
+
file_path (str): Path to the image file.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str: Extracted text or error message.
|
| 23 |
+
"""
|
| 24 |
try:
|
| 25 |
if not os.path.exists(file_path):
|
| 26 |
+
logger.warning(f"Image file not found: {file_path}")
|
| 27 |
+
return "Image file not found"
|
| 28 |
|
| 29 |
+
logger.info(f"Parsing image: {file_path} for task {task_id}")
|
| 30 |
+
reader = easyocr.Reader(['en'], model_storage_directory='./cache')
|
| 31 |
+
result = reader.readtext(file_path, detail=0)
|
| 32 |
+
text = " ".join(result).strip()
|
| 33 |
+
return text if text else "No text extracted from image"
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
+
logger.error(f"Error parsing image for task {task_id}: {e}")
|
| 36 |
+
return f"Error: {str(e)}"
|
| 37 |
+
|
| 38 |
+
image_parser_tool = StructuredTool.from_function(
|
| 39 |
+
func=image_parser_func,
|
| 40 |
+
name="image_parser_tool",
|
| 41 |
+
args_schema=ImageParserInput,
|
| 42 |
+
coroutine=image_parser_func
|
| 43 |
+
)
|
tools/search.py
CHANGED
|
@@ -1,106 +1,103 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
from
|
| 6 |
-
from langchain.tools import Tool
|
| 7 |
-
from typing import List, Dict, Any
|
| 8 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
-
Perform a web search using
|
| 14 |
|
| 15 |
Args:
|
| 16 |
-
query:
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
-
List of search result snippets.
|
| 20 |
-
|
| 21 |
-
Raises:
|
| 22 |
-
Exception: If search fails after retries.
|
| 23 |
"""
|
| 24 |
-
|
| 25 |
-
"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
print(f"INFO - SERPAPI retry {attempt + 1}/3 due to: {e}")
|
| 38 |
-
asyncio.sleep(2)
|
| 39 |
-
|
| 40 |
-
raise Exception("SERPAPI failed after retries")
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
-
Perform
|
| 45 |
|
| 46 |
Args:
|
| 47 |
-
query:
|
| 48 |
-
steps: Number of search
|
| 49 |
-
llm_client: LLM client for query refinement.
|
| 50 |
-
llm_type: Type of LLM
|
|
|
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
-
List
|
| 54 |
"""
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
prompt = ChatPromptTemplate.from_messages([
|
| 67 |
-
SystemMessage(content="""Refine the following query to dig deeper into the topic, focusing on missing details or related aspects. Return ONLY the refined query as plain text, no explanations."""),
|
| 68 |
-
HumanMessage(content=f"Original query: {current_query}\nPrevious results: {json.dumps(search_results[:2], indent=2)}")
|
| 69 |
-
])
|
| 70 |
messages = [
|
| 71 |
-
{"role": "system", "content":
|
| 72 |
-
{"role": "user", "content": prompt
|
| 73 |
]
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
current_query = refined_query if refined_query else f"more details on {current_query}"
|
| 91 |
-
except Exception as e:
|
| 92 |
-
print(f"INFO - Query refinement failed at step {step + 1}: {e}")
|
| 93 |
-
current_query = f"more details on {current_query}"
|
| 94 |
-
|
| 95 |
-
await asyncio.sleep(1) # Rate limit
|
| 96 |
-
except Exception as e:
|
| 97 |
-
print(f"INFO - Multi-hop search step {step + 1} failed: {e}")
|
| 98 |
-
break
|
| 99 |
-
|
| 100 |
-
return results
|
| 101 |
|
| 102 |
-
multi_hop_search_tool =
|
| 103 |
-
func=
|
| 104 |
name="multi_hop_search_tool",
|
| 105 |
-
|
|
|
|
| 106 |
)
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
+
from langchain_core.tools import StructuredTool
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
from serpapi import GoogleSearch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class SearchInput(BaseModel):
|
| 11 |
+
query: str = Field(description="Search query")
|
| 12 |
+
|
| 13 |
+
async def search_func(query: str) -> List[str]:
|
| 14 |
"""
|
| 15 |
+
Perform a web search using SerpAPI and return relevant snippets.
|
| 16 |
|
| 17 |
Args:
|
| 18 |
+
query (str): The search query to execute.
|
| 19 |
|
| 20 |
Returns:
|
| 21 |
+
List[str]: A list of search result snippets.
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
+
try:
|
| 24 |
+
logger.info(f"Executing SerpAPI search for query: {query}")
|
| 25 |
+
params = {
|
| 26 |
+
"q": query,
|
| 27 |
+
"api_key": os.getenv("SERPAPI_API_KEY"),
|
| 28 |
+
"num": 10
|
| 29 |
+
}
|
| 30 |
+
search = GoogleSearch(params)
|
| 31 |
+
results = search.get_dict().get("organic_results", [])
|
| 32 |
+
return [result.get("snippet", "") for result in results if "snippet" in result]
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error(f"SerpAPI search failed for query '{query}': {e}")
|
| 35 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
search_tool = StructuredTool.from_function(
|
| 38 |
+
func=search_func,
|
| 39 |
+
name="search_tool",
|
| 40 |
+
args_schema=SearchInput,
|
| 41 |
+
coroutine=search_func
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
class MultiHopSearchInput(BaseModel):
|
| 45 |
+
query: str = Field(description="Multi-hop search query")
|
| 46 |
+
steps: int = Field(description="Number of search steps", ge=1, le=3)
|
| 47 |
+
llm_client: Optional[object] = Field(description="LLM client", default=None)
|
| 48 |
+
llm_type: Optional[str] = Field(description="LLM type", default="together")
|
| 49 |
+
llm_model: Optional[str] = Field(description="LLM model", default="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free")
|
| 50 |
+
|
| 51 |
+
async def multi_hop_search_func(query: str, steps: int, llm_client: Optional[object] = None, llm_type: Optional[str] = "together", llm_model: Optional[str] = "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free") -> List[str]:
|
| 52 |
"""
|
| 53 |
+
Perform a multi-hop web search using SerpAPI with iterative query refinement.
|
| 54 |
|
| 55 |
Args:
|
| 56 |
+
query (str): The initial multi-hop search query.
|
| 57 |
+
steps (int): Number of search steps to perform (1 to 3).
|
| 58 |
+
llm_client (Optional[object]): LLM client for query refinement.
|
| 59 |
+
llm_type (Optional[str]): Type of LLM (e.g., 'together').
|
| 60 |
+
llm_model (Optional[str]): LLM model name.
|
| 61 |
|
| 62 |
Returns:
|
| 63 |
+
List[str]: A list of search result snippets from all steps.
|
| 64 |
"""
|
| 65 |
+
try:
|
| 66 |
+
logger.info(f"Executing multi-hop search for query: {query}, steps: {steps}")
|
| 67 |
+
results = []
|
| 68 |
+
current_query = query
|
| 69 |
+
|
| 70 |
+
for step in range(steps):
|
| 71 |
+
logger.info(f"Multi-hop step {step + 1}: {current_query}")
|
| 72 |
+
step_results = await search_func(current_query)
|
| 73 |
+
results.extend(step_results)
|
| 74 |
|
| 75 |
+
if step < steps - 1 and llm_client:
|
| 76 |
+
prompt = f"Given the query '{current_query}' and results: {step_results[:3]}, generate a follow-up search query to refine or expand the search."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
messages = [
|
| 78 |
+
{"role": "system", "content": "Generate a single search query as a string."},
|
| 79 |
+
{"role": "user", "content": prompt}
|
| 80 |
]
|
| 81 |
+
if llm_type == "together":
|
| 82 |
+
response = llm_client.chat.completions.create(
|
| 83 |
+
model=llm_model,
|
| 84 |
+
messages=messages,
|
| 85 |
+
max_tokens=50,
|
| 86 |
+
temperature=0.7
|
| 87 |
+
)
|
| 88 |
+
current_query = response.choices[0].message.content.strip()
|
| 89 |
+
else:
|
| 90 |
+
logger.warning("LLM not configured for multi-hop refinement")
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
return results[:5] if results else ["No results found"]
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Multi-hop search failed for query '{query}': {e}")
|
| 96 |
+
return [f"Error: {str(e)}"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
multi_hop_search_tool = StructuredTool.from_function(
|
| 99 |
+
func=multi_hop_search_func,
|
| 100 |
name="multi_hop_search_tool",
|
| 101 |
+
args_schema=MultiHopSearchInput,
|
| 102 |
+
coroutine=multi_hop_search_func
|
| 103 |
)
|
tools/weather_info.py
CHANGED
|
@@ -1,23 +1,50 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
load_dotenv()
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
@tool
|
| 11 |
-
async def weather_info_tool(location: str) -> str:
|
| 12 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
api_key = os.getenv("OPENWEATHERMAP_API_KEY")
|
| 15 |
if not api_key:
|
| 16 |
logger.error("OPENWEATHERMAP_API_KEY not set")
|
| 17 |
return "Weather unavailable: API key missing"
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
if response.get("cod") == 200:
|
| 22 |
condition = response["weather"][0]["description"]
|
| 23 |
temp = response["main"]["temp"]
|
|
|
|
| 1 |
+
import aiohttp
|
| 2 |
+
import ssl
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
from langchain_core.tools import tool
|
| 6 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
|
| 13 |
+
async def fetch_weather(location: str, api_key: str) -> dict:
|
| 14 |
+
url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={api_key}&units=metric"
|
| 15 |
+
ssl_context = ssl.create_default_context()
|
| 16 |
+
try:
|
| 17 |
+
async with aiohttp.ClientSession() as session:
|
| 18 |
+
async with session.get(url, ssl=ssl_context) as response:
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return await response.json()
|
| 21 |
+
except aiohttp.ClientError as e:
|
| 22 |
+
logger.error(f"Failed to fetch weather for {location}: {e}")
|
| 23 |
+
raise
|
| 24 |
+
|
| 25 |
@tool
|
| 26 |
+
async def weather_info_tool(location: str, query_type: str = "current") -> str:
|
| 27 |
+
"""
|
| 28 |
+
Fetch weather information for a given location.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
location (str): City or location name.
|
| 32 |
+
query_type (str): Type of weather query ('current', 'forecast'; default: 'current').
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
str: Weather information or error message.
|
| 36 |
+
"""
|
| 37 |
try:
|
| 38 |
api_key = os.getenv("OPENWEATHERMAP_API_KEY")
|
| 39 |
if not api_key:
|
| 40 |
logger.error("OPENWEATHERMAP_API_KEY not set")
|
| 41 |
return "Weather unavailable: API key missing"
|
| 42 |
|
| 43 |
+
if query_type != "current":
|
| 44 |
+
logger.warning(f"Query type '{query_type}' not supported; using current weather")
|
| 45 |
+
query_type = "current"
|
| 46 |
+
|
| 47 |
+
response = await fetch_weather(location, api_key)
|
| 48 |
if response.get("cod") == 200:
|
| 49 |
condition = response["weather"][0]["description"]
|
| 50 |
temp = response["main"]["temp"]
|