Spaces:
Configuration error
Configuration error
Add .gitignore and clean tracked files
Browse files- .gitignore +74 -0
- app.py +96 -0
- dockerfile +24 -0
- graph.py +143 -0
- requirements.txt +97 -0
- state.py +14 -0
- tools/__init__.py +5 -0
- tools/calculator.py +20 -0
- tools/file_parser.py +38 -0
- tools/image_parser.py +66 -0
- tools/retriever.py +80 -0
- tools/search.py +68 -0
.gitignore
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python virtual environments
|
| 2 |
+
venv/
|
| 3 |
+
venv311/
|
| 4 |
+
*.venv/
|
| 5 |
+
|
| 6 |
+
# Python cache files
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*$py.class
|
| 10 |
+
|
| 11 |
+
# Environment variables
|
| 12 |
+
.env
|
| 13 |
+
*.env
|
| 14 |
+
|
| 15 |
+
# IDE and editor files
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
| 18 |
+
*.sublime-project
|
| 19 |
+
*.sublime-workspace
|
| 20 |
+
|
| 21 |
+
# macOS system files
|
| 22 |
+
.DS_Store
|
| 23 |
+
.AppleDouble
|
| 24 |
+
.LSOverride
|
| 25 |
+
|
| 26 |
+
# Jupyter Notebook checkpoints
|
| 27 |
+
.ipynb_checkpoints/
|
| 28 |
+
|
| 29 |
+
# Python package installation
|
| 30 |
+
*.egg
|
| 31 |
+
*.egg-info/
|
| 32 |
+
dist/
|
| 33 |
+
build/
|
| 34 |
+
eggs/
|
| 35 |
+
*.whl
|
| 36 |
+
|
| 37 |
+
# Testing and coverage
|
| 38 |
+
.coverage
|
| 39 |
+
coverage.xml
|
| 40 |
+
*.cover
|
| 41 |
+
*.py,cover
|
| 42 |
+
.tox/
|
| 43 |
+
.pytest_cache/
|
| 44 |
+
|
| 45 |
+
# Logs and temporary files
|
| 46 |
+
*.log
|
| 47 |
+
*.log.*
|
| 48 |
+
*.tmp
|
| 49 |
+
temp/
|
| 50 |
+
|
| 51 |
+
# Dependency directories
|
| 52 |
+
pip-wheel-metadata/
|
| 53 |
+
.pip_cache/
|
| 54 |
+
.wheels/
|
| 55 |
+
|
| 56 |
+
# Byte-compiled / optimized / DLL files
|
| 57 |
+
*.so
|
| 58 |
+
*.pyd
|
| 59 |
+
*.dll
|
| 60 |
+
|
| 61 |
+
# Hugging Face Space specific
|
| 62 |
+
*.ipynb
|
| 63 |
+
*.parquet
|
| 64 |
+
*.feather
|
| 65 |
+
*.pickle
|
| 66 |
+
*.pkl
|
| 67 |
+
*.h5
|
| 68 |
+
*.joblib
|
| 69 |
+
|
| 70 |
+
# Miscellaneous
|
| 71 |
+
*.swp
|
| 72 |
+
*~
|
| 73 |
+
*.bak
|
| 74 |
+
*.old
|
app.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import aiohttp
|
| 2 |
+
import asyncio
|
| 3 |
+
from graph import graph
|
| 4 |
+
from state import JARVISState
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import List
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables
|
| 12 |
+
load_dotenv()
|
| 13 |
+
# Debug: Verify environment variables
|
| 14 |
+
print(f"OPENAI_API_KEY loaded: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}")
|
| 15 |
+
print(f"LANGFUSE_PUBLIC_KEY loaded: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}")
|
| 16 |
+
|
| 17 |
+
# Verify critical environment variables
|
| 18 |
+
required_env_vars = ["OPENAI_API_KEY", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"]
|
| 19 |
+
for var in required_env_vars:
|
| 20 |
+
if not os.getenv(var):
|
| 21 |
+
raise ValueError(f"Environment variable {var} is not set")
|
| 22 |
+
|
| 23 |
+
# Pydantic Models for Submission
|
| 24 |
+
class Answer(BaseModel):
|
| 25 |
+
task_id: str
|
| 26 |
+
submitted_answer: str
|
| 27 |
+
|
| 28 |
+
class Submission(BaseModel):
|
| 29 |
+
username: str
|
| 30 |
+
agent_code: str
|
| 31 |
+
answers: List[Answer]
|
| 32 |
+
|
| 33 |
+
async def fetch_questions() -> List[dict]:
|
| 34 |
+
async with aiohttp.ClientSession() as session:
|
| 35 |
+
async with session.get("https://api.gaia-benchmark.com/questions") as resp:
|
| 36 |
+
return await resp.json()
|
| 37 |
+
|
| 38 |
+
async def download_file(task_id: str, file_path: str) -> bool:
|
| 39 |
+
async with aiohttp.ClientSession() as session:
|
| 40 |
+
async with session.get(f"https://api.gaia-benchmark.com/files/{task_id}") as resp:
|
| 41 |
+
if resp.status == 200:
|
| 42 |
+
with open(file_path, "wb") as f:
|
| 43 |
+
f.write(await resp.read())
|
| 44 |
+
return True
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
async def process_question(question: dict) -> Answer:
|
| 48 |
+
# Determine file type based on question context
|
| 49 |
+
file_type = "jpg" if "image" in question["question"].lower() else "txt"
|
| 50 |
+
if "menu" in question["question"].lower() or "report" in question["question"].lower() or "document" in question["question"].lower():
|
| 51 |
+
file_type = "pdf" # Prioritize PDF for reports/documents
|
| 52 |
+
elif "data" in question["question"].lower():
|
| 53 |
+
file_type = "csv"
|
| 54 |
+
|
| 55 |
+
file_path = f"temp_{question['task_id']}.{file_type}"
|
| 56 |
+
await download_file(question["task_id"], file_path)
|
| 57 |
+
|
| 58 |
+
state = JARVISState(
|
| 59 |
+
task_id=question["task_id"],
|
| 60 |
+
question=question["question"],
|
| 61 |
+
tools_needed=[],
|
| 62 |
+
web_results=[],
|
| 63 |
+
file_results="",
|
| 64 |
+
image_results="",
|
| 65 |
+
calculation_results="",
|
| 66 |
+
document_results="",
|
| 67 |
+
messages=[],
|
| 68 |
+
answer=""
|
| 69 |
+
)
|
| 70 |
+
# Use unique thread_id for memory
|
| 71 |
+
result = await graph.ainvoke(state, config={"thread_id": question["task_id"]})
|
| 72 |
+
return Answer(task_id=question["task_id"], submitted_answer=result["answer"])
|
| 73 |
+
|
| 74 |
+
async def submit_answers(answers: List[Answer], username: str, agent_code: str):
|
| 75 |
+
submission = Submission(
|
| 76 |
+
username=username,
|
| 77 |
+
agent_code=agent_code,
|
| 78 |
+
answers=answers
|
| 79 |
+
)
|
| 80 |
+
async with aiohttp.ClientSession() as session:
|
| 81 |
+
async with session.post("https://api.gaia-benchmark.com/submit", json=submission.dict()) as resp:
|
| 82 |
+
return await resp.json()
|
| 83 |
+
|
| 84 |
+
async def main():
|
| 85 |
+
username = "onisj" # Your Hugging Face username
|
| 86 |
+
agent_code = "https://huggingface.co/spaces/onisj/jarvis_gaia_agent/tree/main"
|
| 87 |
+
questions = await fetch_questions()
|
| 88 |
+
answers = []
|
| 89 |
+
for question in questions[:20]: # Process 20 questions
|
| 90 |
+
answer = await process_question(question)
|
| 91 |
+
answers.append(answer)
|
| 92 |
+
result = await submit_answers(answers, username, agent_code)
|
| 93 |
+
print("Submission result:", json.dumps(result, indent=2))
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
asyncio.run(main())
|
dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
libgl1-mesa-glx \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
tesseract-ocr \
|
| 10 |
+
libtesseract-dev \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy project files
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
COPY app.py .
|
| 16 |
+
COPY graph.py .
|
| 17 |
+
COPY state.py .
|
| 18 |
+
COPY tools/ tools/
|
| 19 |
+
|
| 20 |
+
# Install Python dependencies
|
| 21 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Run the application
|
| 24 |
+
CMD ["python", "app.py"]
|
graph.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, END
|
| 2 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 3 |
+
from state import JARVISState
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 6 |
+
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
|
| 7 |
+
from langfuse.callback import LangfuseCallbackHandler
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
# Load environment variables
|
| 13 |
+
load_dotenv()
|
| 14 |
+
# Debug: Verify environment variables
|
| 15 |
+
print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}")
|
| 16 |
+
print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}")
|
| 17 |
+
|
| 18 |
+
# Initialize LLM and Langfuse
|
| 19 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
if not api_key:
|
| 21 |
+
raise ValueError("OPENAI_API_KEY environment variable not set")
|
| 22 |
+
llm = ChatOpenAI(model="gpt-4o", api_key=api_key)
|
| 23 |
+
langfuse = LangfuseCallbackHandler(
|
| 24 |
+
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
|
| 25 |
+
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
|
| 26 |
+
host=os.getenv("LANGFUSE_HOST")
|
| 27 |
+
)
|
| 28 |
+
memory = MemorySaver()
|
| 29 |
+
|
| 30 |
+
# Question Parser Node
|
| 31 |
+
async def parse_question(state: JARVISState) -> JARVISState:
|
| 32 |
+
question = state["question"]
|
| 33 |
+
prompt = f"""Analyze this GAIA question: {question}
|
| 34 |
+
Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
|
| 35 |
+
Return a JSON list of tool names."""
|
| 36 |
+
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
|
| 37 |
+
tools_needed = json.loads(response.content)
|
| 38 |
+
return {"messages": state["messages"] + [response], "tools_needed": tools_needed}
|
| 39 |
+
|
| 40 |
+
# Web Search Agent Node
|
| 41 |
+
async def web_search_agent(state: JARVISState) -> JARVISState:
|
| 42 |
+
results = []
|
| 43 |
+
if "web_search" in state["tools_needed"]:
|
| 44 |
+
result = await search_tool.arun(state["question"])
|
| 45 |
+
results.append(result)
|
| 46 |
+
if "multi_hop_search" in state["tools_needed"]:
|
| 47 |
+
result = await multi_hop_search_tool.aparse(state["question"], steps=3)
|
| 48 |
+
results.append(result)
|
| 49 |
+
return {"web_results": results}
|
| 50 |
+
|
| 51 |
+
# File Parser Agent Node
|
| 52 |
+
async def file_parser_agent(state: JARVISState) -> JARVISState:
|
| 53 |
+
if "file_parser" in state["tools_needed"]:
|
| 54 |
+
result = await file_parser_tool.aparse(state["task_id"])
|
| 55 |
+
return {"file_results": result}
|
| 56 |
+
return {"file_results": ""}
|
| 57 |
+
|
| 58 |
+
# Image Parser Agent Node
|
| 59 |
+
async def image_parser_agent(state: JARVISState) -> JARVISState:
|
| 60 |
+
if "image_parser" in state["tools_needed"]:
|
| 61 |
+
task = "match" if "fruits" in state["question"].lower() else "describe"
|
| 62 |
+
match_query = "fruits" if task == "match" else ""
|
| 63 |
+
result = await image_parser_tool.aparse(
|
| 64 |
+
f"temp_{state['task_id']}.jpg", task=task, match_query=match_query
|
| 65 |
+
)
|
| 66 |
+
return {"image_results": result}
|
| 67 |
+
return {"image_results": ""}
|
| 68 |
+
|
| 69 |
+
# Calculator Agent Node
|
| 70 |
+
async def calculator_agent(state: JARVISState) -> JARVISState:
|
| 71 |
+
if "calculator" in state["tools_needed"]:
|
| 72 |
+
prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
|
| 73 |
+
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
|
| 74 |
+
expression = response.content
|
| 75 |
+
result = await calculator_tool.aparse(expression)
|
| 76 |
+
return {"calculation_results": result}
|
| 77 |
+
return {"calculation_results": ""}
|
| 78 |
+
|
| 79 |
+
# Document Retriever Agent Node
|
| 80 |
+
async def document_retriever_agent(state: JARVISState) -> JARVISState:
|
| 81 |
+
if "document_retriever" in state["tools_needed"]:
|
| 82 |
+
file_type = "txt" if "menu" in state["question"].lower() else "csv"
|
| 83 |
+
if "report" in state["question"].lower() or "document" in state["question"].lower():
|
| 84 |
+
file_type = "pdf"
|
| 85 |
+
result = await document_retriever_tool.aparse(
|
| 86 |
+
state["task_id"], state["question"], file_type=file_type
|
| 87 |
+
)
|
| 88 |
+
return {"document_results": result}
|
| 89 |
+
return {"document_results": ""}
|
| 90 |
+
|
| 91 |
+
# Reasoning Agent Node
|
| 92 |
+
async def reasoning_agent(state: JARVISState) -> JARVISState:
|
| 93 |
+
prompt = f"""Question: {state['question']}
|
| 94 |
+
Web Results: {state['web_results']}
|
| 95 |
+
File Results: {state['file_results']}
|
| 96 |
+
Image Results: {state['image_results']}
|
| 97 |
+
Calculation Results: {state['calculation_results']}
|
| 98 |
+
Document Results: {state['document_results']}
|
| 99 |
+
|
| 100 |
+
Synthesize an exact-match answer for the GAIA benchmark.
|
| 101 |
+
Output only the answer (e.g., '90', 'White;5876')."""
|
| 102 |
+
response = await llm.ainvoke(
|
| 103 |
+
[
|
| 104 |
+
SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
|
| 105 |
+
HumanMessage(content=prompt)
|
| 106 |
+
],
|
| 107 |
+
config={"callbacks": [langfuse]}
|
| 108 |
+
)
|
| 109 |
+
return {"answer": response.content, "messages": state["messages"] + [response]}
|
| 110 |
+
|
| 111 |
+
# Conditional Edge Router
|
| 112 |
+
def router(state: JARVISState) -> str:
|
| 113 |
+
if state["tools_needed"]:
|
| 114 |
+
return "tools"
|
| 115 |
+
return "reasoning"
|
| 116 |
+
|
| 117 |
+
# Build Graph
|
| 118 |
+
workflow = StateGraph(JARVISState)
|
| 119 |
+
workflow.add_node("parse", parse_question)
|
| 120 |
+
workflow.add_node("web_search", web_search_agent)
|
| 121 |
+
workflow.add_node("file_parser", file_parser_agent)
|
| 122 |
+
workflow.add_node("image_parser", image_parser_agent)
|
| 123 |
+
workflow.add_node("calculator", calculator_agent)
|
| 124 |
+
workflow.add_node("document_retriever", document_retriever_agent)
|
| 125 |
+
workflow.add_node("reasoning", reasoning_agent)
|
| 126 |
+
|
| 127 |
+
workflow.set_entry_point("parse")
|
| 128 |
+
workflow.add_conditional_edges(
|
| 129 |
+
"parse",
|
| 130 |
+
router,
|
| 131 |
+
{
|
| 132 |
+
"tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"],
|
| 133 |
+
"reasoning": "reasoning"
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
workflow.add_edge("web_search", "reasoning")
|
| 137 |
+
workflow.add_edge("file_parser", "reasoning")
|
| 138 |
+
workflow.add_edge("image_parser", "reasoning")
|
| 139 |
+
workflow.add_edge("calculator", "reasoning")
|
| 140 |
+
workflow.add_edge("document_retriever", "reasoning")
|
| 141 |
+
workflow.add_edge("reasoning", END)
|
| 142 |
+
|
| 143 |
+
graph = workflow.compile(checkpointer=memory)
|
requirements.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.12.2
|
| 3 |
+
aiosignal==1.3.2
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
anyio==4.9.0
|
| 6 |
+
attrs==25.3.0
|
| 7 |
+
backoff==2.2.1
|
| 8 |
+
certifi==2025.4.26
|
| 9 |
+
charset-normalizer==3.4.2
|
| 10 |
+
click==8.2.1
|
| 11 |
+
dataclasses-json==0.6.7
|
| 12 |
+
distro==1.9.0
|
| 13 |
+
duckduckgo_search==8.0.2
|
| 14 |
+
filelock==3.18.0
|
| 15 |
+
frozenlist==1.6.0
|
| 16 |
+
fsspec==2025.5.1
|
| 17 |
+
greenlet==3.2.2
|
| 18 |
+
h11==0.16.0
|
| 19 |
+
hf-xet==1.1.2
|
| 20 |
+
httpcore==1.0.9
|
| 21 |
+
httpx==0.28.1
|
| 22 |
+
httpx-sse==0.4.0
|
| 23 |
+
huggingface-hub==0.24.5
|
| 24 |
+
idna==3.10
|
| 25 |
+
Jinja2==3.1.6
|
| 26 |
+
jiter==0.10.0
|
| 27 |
+
joblib==1.5.1
|
| 28 |
+
jsonpatch==1.33
|
| 29 |
+
jsonpointer==3.0.0
|
| 30 |
+
langchain==0.3.25
|
| 31 |
+
langchain-community==0.3.24
|
| 32 |
+
langchain-core==0.3.62
|
| 33 |
+
langchain-openai==0.2.0
|
| 34 |
+
langchain-text-splitters==0.3.8
|
| 35 |
+
langfuse==2.44.0
|
| 36 |
+
langgraph==0.4.7
|
| 37 |
+
langgraph-checkpoint==2.0.26
|
| 38 |
+
langgraph-prebuilt==0.2.1
|
| 39 |
+
langgraph-sdk==0.1.70
|
| 40 |
+
langsmith==0.1.147
|
| 41 |
+
lxml==5.4.0
|
| 42 |
+
markdown-it-py==3.0.0
|
| 43 |
+
MarkupSafe==3.0.2
|
| 44 |
+
marshmallow==3.26.1
|
| 45 |
+
mdurl==0.1.2
|
| 46 |
+
mpmath==1.3.0
|
| 47 |
+
msgpack==1.1.0
|
| 48 |
+
multidict==6.4.4
|
| 49 |
+
mypy_extensions==1.1.0
|
| 50 |
+
networkx==3.4.2
|
| 51 |
+
numpy==1.26.4
|
| 52 |
+
openai==1.40.0
|
| 53 |
+
orjson==3.10.18
|
| 54 |
+
ormsgpack==1.10.0
|
| 55 |
+
packaging==23.2
|
| 56 |
+
pandas==2.2.3
|
| 57 |
+
pillow==11.0.0
|
| 58 |
+
primp==0.15.0
|
| 59 |
+
propcache==0.3.1
|
| 60 |
+
pydantic==2.8.2
|
| 61 |
+
pydantic-settings==2.9.1
|
| 62 |
+
pydantic_core==2.20.1
|
| 63 |
+
Pygments==2.19.1
|
| 64 |
+
PyPDF2==3.0.1
|
| 65 |
+
pytesseract==0.3.10
|
| 66 |
+
python-dateutil==2.9.0.post0
|
| 67 |
+
python-dotenv==1.0.1
|
| 68 |
+
pytz==2025.2
|
| 69 |
+
PyYAML==6.0.2
|
| 70 |
+
regex==2024.11.6
|
| 71 |
+
requests==2.32.3
|
| 72 |
+
requests-toolbelt==1.0.0
|
| 73 |
+
rich==14.0.0
|
| 74 |
+
safetensors==0.5.3
|
| 75 |
+
scikit-learn==1.6.1
|
| 76 |
+
scipy==1.15.3
|
| 77 |
+
sentence-transformers==3.0.1
|
| 78 |
+
six==1.17.0
|
| 79 |
+
smolagents==1.17.0
|
| 80 |
+
sniffio==1.3.1
|
| 81 |
+
SQLAlchemy==2.0.41
|
| 82 |
+
sympy==1.14.0
|
| 83 |
+
tenacity==8.5.0
|
| 84 |
+
threadpoolctl==3.6.0
|
| 85 |
+
tiktoken==0.9.0
|
| 86 |
+
tokenizers==0.19.1
|
| 87 |
+
torch==2.2.2
|
| 88 |
+
tqdm==4.67.1
|
| 89 |
+
transformers==4.42.4
|
| 90 |
+
typing-inspect==0.9.0
|
| 91 |
+
typing-inspection==0.4.1
|
| 92 |
+
typing_extensions==4.13.2
|
| 93 |
+
tzdata==2025.2
|
| 94 |
+
urllib3==2.4.0
|
| 95 |
+
wrapt==1.17.2
|
| 96 |
+
xxhash==3.5.0
|
| 97 |
+
yarl==1.20.0
|
state.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, List
|
| 2 |
+
from langchain_core.messages import AnyMessage
|
| 3 |
+
|
| 4 |
+
class JARVISState(TypedDict):
|
| 5 |
+
task_id: str
|
| 6 |
+
question: str
|
| 7 |
+
tools_needed: List[str]
|
| 8 |
+
web_results: List[str]
|
| 9 |
+
file_results: str
|
| 10 |
+
image_results: str
|
| 11 |
+
calculation_results: str
|
| 12 |
+
document_results: str
|
| 13 |
+
messages: List[AnyMessage]
|
| 14 |
+
answer: str
|
tools/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .search import search_tool, multi_hop_search_tool
|
| 2 |
+
from .file_parser import file_parser_tool
|
| 3 |
+
from .image_parser import image_parser_tool
|
| 4 |
+
from .calculator import calculator_tool
|
| 5 |
+
from .retriever import document_retriever_tool
|
tools/calculator.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
class CalculatorTool:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.name = "calculator"
|
| 7 |
+
self.description = "Evaluates mathematical expressions."
|
| 8 |
+
self.inputs = {
|
| 9 |
+
"expression": {"type": "string", "description": "Mathematical expression to evaluate"}
|
| 10 |
+
}
|
| 11 |
+
self.output_type = str
|
| 12 |
+
|
| 13 |
+
async def aparse(self, expression: str) -> str:
|
| 14 |
+
try:
|
| 15 |
+
result = eval(expression, {"__builtins__": {}}, {"abs": abs, "round": round})
|
| 16 |
+
return str(result)
|
| 17 |
+
except Exception as e:
|
| 18 |
+
return f"Error calculating expression: {str(e)}"
|
| 19 |
+
|
| 20 |
+
calculator_tool = CalculatorTool()
|
tools/file_parser.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import requests
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
class FileParserTool:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.name = "file_parser"
|
| 8 |
+
self.description = "Downloads and parses CSV or text files for GAIA tasks."
|
| 9 |
+
self.inputs = {
|
| 10 |
+
"task_id": {"type": "string", "description": "GAIA task ID"},
|
| 11 |
+
"file_type": {"type": "string", "description": "File type (csv, txt, default: csv)"}
|
| 12 |
+
}
|
| 13 |
+
self.output_type = str
|
| 14 |
+
|
| 15 |
+
async def aparse(self, task_id: str, file_type: str = "csv") -> str:
|
| 16 |
+
try:
|
| 17 |
+
url = f"https://api.gaia-benchmark.com/files/{task_id}"
|
| 18 |
+
response = await requests.get(url)
|
| 19 |
+
if response.status_code == 200:
|
| 20 |
+
file_path = f"temp_{task_id}.{file_type}"
|
| 21 |
+
with open(file_path, "wb") as f:
|
| 22 |
+
f.write(response.content)
|
| 23 |
+
if file_type == "csv":
|
| 24 |
+
df = pd.read_csv(file_path)
|
| 25 |
+
return df.to_string()
|
| 26 |
+
elif file_type == "txt":
|
| 27 |
+
with open(file_path, "r") as f:
|
| 28 |
+
return f.read()
|
| 29 |
+
else:
|
| 30 |
+
return f"Unsupported file type: {file_type}"
|
| 31 |
+
return f"Error downloading file for task ID {task_id}"
|
| 32 |
+
except Exception as e:
|
| 33 |
+
return f"Error: {str(e)}"
|
| 34 |
+
finally:
|
| 35 |
+
if os.path.exists(file_path):
|
| 36 |
+
os.remove(file_path)
|
| 37 |
+
|
| 38 |
+
file_parser_tool = FileParserTool()
|
tools/image_parser.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_openai import ChatOpenAI
|
| 2 |
+
from sentence_transformers import SentenceTransformer, util
|
| 3 |
+
import pytesseract
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import base64
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
load_dotenv()
|
| 11 |
+
# Debug: Verify OPENAI_API_KEY
|
| 12 |
+
if not os.getenv("OPENAI_API_KEY"):
|
| 13 |
+
print("Error: OPENAI_API_KEY not loaded in image_parser.py")
|
| 14 |
+
|
| 15 |
+
class ImageParserTool:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.name = "image_parser"
|
| 18 |
+
self.description = "Analyzes images to extract text, identify objects, or match descriptions."
|
| 19 |
+
self.inputs = {
|
| 20 |
+
"image_path": {"type": "string", "description": "Path to image file"},
|
| 21 |
+
"task": {"type": "string", "description": "Task type (ocr, describe, match)"},
|
| 22 |
+
"match_query": {"type": "string", "description": "Query for semantic matching (optional)"}
|
| 23 |
+
}
|
| 24 |
+
self.output_type = str
|
| 25 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 26 |
+
if not api_key:
|
| 27 |
+
raise ValueError("OPENAI_API_KEY environment variable not set")
|
| 28 |
+
self.vlm = ChatOpenAI(model="gpt-4o", api_key=api_key)
|
| 29 |
+
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 30 |
+
|
| 31 |
+
async def aparse(self, image_path: str, task: str = "describe", match_query: str = "") -> str:
|
| 32 |
+
try:
|
| 33 |
+
# Read image
|
| 34 |
+
with open(image_path, "rb") as f:
|
| 35 |
+
image_data = base64.b64encode(f.read()).decode()
|
| 36 |
+
img = Image.open(image_path)
|
| 37 |
+
|
| 38 |
+
if task == "ocr":
|
| 39 |
+
# Extract text with Tesseract
|
| 40 |
+
text = pytesseract.image_to_string(img)
|
| 41 |
+
return text if text.strip() else "No text found in image."
|
| 42 |
+
elif task == "describe":
|
| 43 |
+
# Describe image with VLM
|
| 44 |
+
response = await self.vlm.ainvoke([
|
| 45 |
+
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_data}"},
|
| 46 |
+
{"type": "text", "text": "Describe objects in the image in detail."}
|
| 47 |
+
])
|
| 48 |
+
return response.content
|
| 49 |
+
elif task == "match" and match_query:
|
| 50 |
+
# Semantic matching with sentence-transformers
|
| 51 |
+
description = await self.vlm.ainvoke([
|
| 52 |
+
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_data}"},
|
| 53 |
+
{"type": "text", "text": "List objects in the image."}
|
| 54 |
+
])
|
| 55 |
+
objects = description.content.split(", ")
|
| 56 |
+
query_embedding = self.embedder.encode(match_query, convert_to_tensor=True)
|
| 57 |
+
object_embeddings = self.embedder.encode(objects, convert_to_tensor=True)
|
| 58 |
+
similarities = util.cos_sim(query_embedding, object_embeddings)[0]
|
| 59 |
+
best_match = objects[similarities.argmax()]
|
| 60 |
+
return f"Best match for '{match_query}': {best_match}"
|
| 61 |
+
else:
|
| 62 |
+
return "Invalid task or missing match_query for matching."
|
| 63 |
+
except Exception as e:
|
| 64 |
+
return f"Error analyzing image: {str(e)}"
|
| 65 |
+
|
| 66 |
+
image_parser_tool = ImageParserTool()
|
tools/retriever.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import PyPDF2
|
| 6 |
+
import os
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
|
| 9 |
+
class DocumentRetrieverTool:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.name = "document_retriever"
|
| 12 |
+
self.description = "Retrieves relevant text from GAIA text-heavy files (CSV, TXT, PDF) using semantic search."
|
| 13 |
+
self.inputs = {
|
| 14 |
+
"task_id": {"type": "string", "description": "GAIA task ID for the file"},
|
| 15 |
+
"query": {"type": "string", "description": "Question or query to search for"},
|
| 16 |
+
"file_type": {"type": "string", "description": "File type (csv, txt, pdf, default: txt)"}
|
| 17 |
+
}
|
| 18 |
+
self.output_type = str
|
| 19 |
+
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 20 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 21 |
+
chunk_size=500,
|
| 22 |
+
chunk_overlap=50,
|
| 23 |
+
length_function=len
|
| 24 |
+
)
|
| 25 |
+
self.chunks: List[str] = []
|
| 26 |
+
self.embeddings: np.ndarray = None
|
| 27 |
+
|
| 28 |
+
async def aparse(self, task_id: str, query: str, file_type: str = "txt") -> str:
|
| 29 |
+
"""
|
| 30 |
+
Loads a GAIA file, splits it into chunks, embeds them, and retrieves relevant text for the query.
|
| 31 |
+
Supports CSV, TXT, and PDF files.
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
file_path = f"temp_{task_id}.{file_type}"
|
| 35 |
+
if not os.path.exists(file_path):
|
| 36 |
+
return f"File not found for task ID {task_id}"
|
| 37 |
+
|
| 38 |
+
# Load and preprocess file
|
| 39 |
+
text = ""
|
| 40 |
+
if file_type == "csv":
|
| 41 |
+
df = pd.read_csv(file_path)
|
| 42 |
+
text = df.to_string()
|
| 43 |
+
elif file_type == "txt":
|
| 44 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 45 |
+
text = f.read()
|
| 46 |
+
elif file_type == "pdf":
|
| 47 |
+
with open(file_path, "rb") as f:
|
| 48 |
+
pdf = PyPDF2.PdfReader(f)
|
| 49 |
+
text = "".join(page.extract_text() or "" for page in pdf.pages)
|
| 50 |
+
else:
|
| 51 |
+
return f"Unsupported file type: {file_type}"
|
| 52 |
+
|
| 53 |
+
# Check if text was extracted
|
| 54 |
+
if not text.strip():
|
| 55 |
+
return "No extractable text found in file."
|
| 56 |
+
|
| 57 |
+
# Split text into chunks
|
| 58 |
+
self.chunks = self.text_splitter.split_text(text)
|
| 59 |
+
if not self.chunks:
|
| 60 |
+
return "No content found in file."
|
| 61 |
+
|
| 62 |
+
# Embed chunks and query
|
| 63 |
+
self.embeddings = self.embedder.encode(self.chunks, convert_to_tensor=True)
|
| 64 |
+
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
|
| 65 |
+
|
| 66 |
+
# Compute cosine similarities
|
| 67 |
+
from sentence_transformers import util
|
| 68 |
+
similarities = util.cos_sim(query_embedding, self.embeddings)[0]
|
| 69 |
+
|
| 70 |
+
# Get top 3 most relevant chunks
|
| 71 |
+
top_k = min(3, len(self.chunks))
|
| 72 |
+
top_indices = similarities.argsort(descending=True)[:top_k]
|
| 73 |
+
relevant_chunks = [self.chunks[idx] for idx in top_indices]
|
| 74 |
+
|
| 75 |
+
# Combine results
|
| 76 |
+
return "\n\n".join(relevant_chunks)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
return f"Error retrieving documents: {str(e)}"
|
| 79 |
+
|
| 80 |
+
document_retriever_tool = DocumentRetrieverTool()
|
tools/search.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_openai import ChatOpenAI
|
| 2 |
+
from langchain_core.tools import tool
|
| 3 |
+
from duckduckgo_search import DDGS
|
| 4 |
+
import os
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Load environment variables
|
| 8 |
+
load_dotenv()
|
| 9 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 10 |
+
if not api_key:
|
| 11 |
+
raise ValueError("OPENAI_API_KEY environment variable not set")
|
| 12 |
+
|
| 13 |
+
@tool
|
| 14 |
+
async def web_search(query: str) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Performs a web search using DuckDuckGo and returns a string of results.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
query (str): The search query string.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str: A string containing the search results.
|
| 23 |
+
"""
|
| 24 |
+
try:
|
| 25 |
+
with DDGS() as ddgs:
|
| 26 |
+
results = await ddgs.atext(keywords=query, max_results=5)
|
| 27 |
+
return "\n".join([f"{r['title']}: {r['body']}" for r in results])
|
| 28 |
+
except Exception as e:
|
| 29 |
+
return f"Error performing web search: {str(e)}"
|
| 30 |
+
|
| 31 |
+
search_tool = web_search
|
| 32 |
+
|
| 33 |
+
class MultiHopSearchTool:
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.name = "multi_hop_search"
|
| 36 |
+
self.description = "Performs iterative web searches to refine results for complex queries."
|
| 37 |
+
self.inputs = {
|
| 38 |
+
"query": {"type": "string", "description": "Initial search query"},
|
| 39 |
+
"steps": {"type": "integer", "description": "Number of search iterations (default: 3)"}
|
| 40 |
+
}
|
| 41 |
+
self.output_type = str
|
| 42 |
+
self.llm = ChatOpenAI(
|
| 43 |
+
model="gpt-4o",
|
| 44 |
+
api_key=api_key,
|
| 45 |
+
temperature=0,
|
| 46 |
+
http_client=None # Explicitly disable custom HTTP client to avoid proxies
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
async def aparse(self, query: str, steps: int = 3) -> str:
|
| 50 |
+
try:
|
| 51 |
+
current_query = query
|
| 52 |
+
results = []
|
| 53 |
+
for _ in range(steps):
|
| 54 |
+
search_result = await web_search.invoke({"query": current_query})
|
| 55 |
+
results.append(search_result)
|
| 56 |
+
|
| 57 |
+
# Refine query using LLM
|
| 58 |
+
prompt = f"""Based on the query: {current_query}
|
| 59 |
+
And the search results: {search_result}
|
| 60 |
+
Generate a refined search query to get more precise results."""
|
| 61 |
+
response = await self.llm.ainvoke(prompt)
|
| 62 |
+
current_query = response.content
|
| 63 |
+
|
| 64 |
+
return "\n\n".join(results)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return f"Error in multi-hop search: {str(e)}"
|
| 67 |
+
|
| 68 |
+
multi_hop_search_tool = MultiHopSearchTool()
|