Katya Beresneva commited on
Commit ·
523e34e
1
Parent(s): b75609c
fix
Browse files- .gitattributes +0 -17
- agent.py +122 -111
- app.py +207 -44
- requirements.txt +12 -5
- tools.py +373 -170
- utils.py +6 -21
.gitattributes
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
asyncer
|
| 2 |
-
anyio
|
| 3 |
-
arxiv
|
| 4 |
-
gradio
|
| 5 |
-
httpx
|
| 6 |
-
requests
|
| 7 |
-
langgraph==0.0.12
|
| 8 |
-
langchain-google-genai
|
| 9 |
-
langchain-community
|
| 10 |
-
langchain-tavily
|
| 11 |
-
openpyxl
|
| 12 |
-
smolagents
|
| 13 |
-
tavily-python
|
| 14 |
-
wikipedia-api
|
| 15 |
-
wikipedia
|
| 16 |
-
duckduckgo-search
|
| 17 |
-
python-dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent.py
CHANGED
|
@@ -1,134 +1,145 @@
|
|
| 1 |
import os
|
| 2 |
-
from typing import Optional, List, Dict, Any
|
| 3 |
from langchain_core.messages import HumanMessage
|
| 4 |
-
from
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
from utils import get_llm
|
| 24 |
|
|
|
|
| 25 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
|
| 26 |
if not GOOGLE_API_KEY:
|
| 27 |
raise ValueError("GOOGLE_API_KEY environment variable is not set.")
|
| 28 |
|
| 29 |
AGENT_MODEL_NAME = os.getenv("AGENT_MODEL_NAME", "gemini-2.0-flash")
|
| 30 |
|
| 31 |
-
|
| 32 |
-
You are
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
-
|
| 43 |
-
-
|
| 44 |
-
-
|
| 45 |
-
- Code
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
llm_provider_api_key=GOOGLE_API_KEY,
|
| 57 |
-
model_name=model_name
|
| 58 |
)
|
| 59 |
-
self.tools = self._get_tools()
|
| 60 |
-
self.agent_executor = self._create_agent_executor()
|
| 61 |
|
| 62 |
-
def _get_tools(self)
|
| 63 |
-
"""Convert all tools to LangChain Tool format"""
|
| 64 |
tools = [
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
]
|
| 78 |
-
return tools
|
| 79 |
-
|
| 80 |
-
def _wrap_tool(self, tool: Any) -> Tool:
|
| 81 |
-
"""Convert any tool to LangChain Tool format"""
|
| 82 |
-
return Tool(
|
| 83 |
-
name=tool.name,
|
| 84 |
-
description=tool.description,
|
| 85 |
-
func=tool._run,
|
| 86 |
-
coroutine=tool._arun,
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
def _create_agent_executor(self) -> AgentExecutor:
|
| 90 |
-
"""Create the agent executor with React strategy"""
|
| 91 |
-
prompt = ChatPromptTemplate.from_template(KATE_AGENT_PROMPT)
|
| 92 |
-
agent = create_react_agent(self.llm, self.tools, prompt)
|
| 93 |
-
return AgentExecutor(
|
| 94 |
-
agent=agent,
|
| 95 |
-
tools=self.tools,
|
| 96 |
-
handle_parsing_errors=True,
|
| 97 |
-
max_iterations=10,
|
| 98 |
-
verbose=True
|
| 99 |
-
)
|
| 100 |
|
| 101 |
async def __call__(
|
| 102 |
-
self,
|
| 103 |
-
task_id: str,
|
| 104 |
-
question: str,
|
| 105 |
-
file_name: Optional[str] = None
|
| 106 |
) -> str:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
recursion_limit=64,
|
| 110 |
-
configurable={"thread_id":
|
| 111 |
)
|
| 112 |
|
| 113 |
-
if not
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
from langchain_core.messages import HumanMessage
|
| 3 |
+
from langchain_core.runnables.config import RunnableConfig
|
| 4 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 5 |
+
from langchain.globals import set_debug
|
| 6 |
+
from langchain.globals import set_verbose
|
| 7 |
+
from langgraph.prebuilt import create_react_agent
|
| 8 |
+
from langgraph.prebuilt import ToolNode
|
| 9 |
+
from langgraph.prebuilt.chat_agent_executor import AgentState
|
| 10 |
+
|
| 11 |
+
from smolagents import DuckDuckGoSearchTool
|
| 12 |
+
from smolagents import PythonInterpreterTool
|
| 13 |
+
from tools import analyze_audio
|
| 14 |
+
from tools import analyze_excel
|
| 15 |
+
from tools import analyze_image
|
| 16 |
+
from tools import analyze_video
|
| 17 |
+
from tools import download_file_for_task
|
| 18 |
+
from tools import read_file_contents
|
| 19 |
+
from tools import search_arxiv
|
| 20 |
+
from tools import search_tavily
|
| 21 |
+
from tools import search_wikipedia
|
| 22 |
+
from tools import SmolagentToolWrapper
|
| 23 |
+
from tools import tavily_extract_tool
|
| 24 |
from utils import get_llm
|
| 25 |
|
| 26 |
+
|
| 27 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
|
| 28 |
if not GOOGLE_API_KEY:
|
| 29 |
raise ValueError("GOOGLE_API_KEY environment variable is not set.")
|
| 30 |
|
| 31 |
AGENT_MODEL_NAME = os.getenv("AGENT_MODEL_NAME", "gemini-2.0-flash")
|
| 32 |
|
| 33 |
+
MULTIMODAL_TASK_SOLVER_PROMPT = """
|
| 34 |
+
You are a specialized multimodal task-solving AI assistant capable of handling complex data analysis and information retrieval tasks.
|
| 35 |
+
Core Operating Guidelines:
|
| 36 |
+
- Employ systematic analysis: Break down problems into logical steps
|
| 37 |
+
- Maintain brevity: Provide answers in the most concise format possible - raw numbers, single words, or comma-delimited lists
|
| 38 |
+
- Format compliance:
|
| 39 |
+
* Numbers: No commas, units, or currency symbols
|
| 40 |
+
* Lists: Pure comma-separated values without additional text
|
| 41 |
+
* Text: Bare minimum words, no sentences or explanations
|
| 42 |
+
- Tool utilization:
|
| 43 |
+
* For multimedia content (images, audio, video) - use dedicated analysis tools
|
| 44 |
+
* For data processing (Excel, structured data) - use appropriate parsers
|
| 45 |
+
* For information retrieval - leverage search tools
|
| 46 |
+
- Verification principle: Never guess - use available tools to verify information
|
| 47 |
+
- Code usage: Implement Python code for calculations and data transformations
|
| 48 |
+
- Answer format: Always prefix final answers with 'FINAL ANSWER: '
|
| 49 |
+
- Counting queries: Return only the numerical count
|
| 50 |
+
- Listing queries: Return only the comma-separated items
|
| 51 |
+
- Sorting queries: Return only the ordered list
|
| 52 |
+
|
| 53 |
+
Sample Responses:
|
| 54 |
+
Q: Current Bitcoin price in USD? A: 47392
|
| 55 |
+
Q: Sort these colors: blue, red, azure A: azure, blue, red
|
| 56 |
+
Q: Capital of France? A: Paris
|
| 57 |
+
Q: Count vowels in 'hello' A: 2
|
| 58 |
+
Q: Temperature scale used in USA? A: Fahrenheit
|
| 59 |
+
Q: List prime numbers under 10 A: 2, 3, 5, 7
|
| 60 |
+
Q: Most streamed artist 2023? A: Taylor Swift
|
| 61 |
"""
|
| 62 |
|
| 63 |
+
#set_debug(True)
|
| 64 |
+
#set_verbose(True)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MultiModalTaskState(AgentState):
|
| 68 |
+
task_identifier: str
|
| 69 |
+
query_text: str
|
| 70 |
+
input_file_path: str
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MultiModalAgent:
|
| 74 |
+
def __init__(self, model_name: str | None = None):
|
| 75 |
+
if model_name is None:
|
| 76 |
+
model_name = AGENT_MODEL_NAME
|
| 77 |
+
llm = self._get_llm(model_name)
|
| 78 |
+
tools = self._get_tools()
|
| 79 |
+
self.agent = create_react_agent(
|
| 80 |
+
llm,
|
| 81 |
+
tools=tools,
|
| 82 |
+
state_schema=MultiModalTaskState,
|
| 83 |
+
state_modifier=MULTIMODAL_TASK_SOLVER_PROMPT,
|
| 84 |
+
checkpointer = MemorySaver()
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _get_llm(self, model_name: str):
|
| 88 |
+
return get_llm(
|
| 89 |
llm_provider_api_key=GOOGLE_API_KEY,
|
| 90 |
+
model_name=model_name,
|
| 91 |
)
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
def _get_tools(self):
|
|
|
|
| 94 |
tools = [
|
| 95 |
+
SmolagentToolWrapper(DuckDuckGoSearchTool()),
|
| 96 |
+
SmolagentToolWrapper(PythonInterpreterTool()),
|
| 97 |
+
download_file_for_task,
|
| 98 |
+
read_file_contents,
|
| 99 |
+
analyze_audio,
|
| 100 |
+
analyze_image,
|
| 101 |
+
analyze_excel,
|
| 102 |
+
analyze_video,
|
| 103 |
+
search_arxiv,
|
| 104 |
+
search_tavily,
|
| 105 |
+
search_wikipedia,
|
| 106 |
+
tavily_extract_tool,
|
| 107 |
]
|
| 108 |
+
return ToolNode(tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
async def __call__(
|
| 111 |
+
self, task_identifier: str, query_text: str, input_file_path: str | None = None
|
|
|
|
|
|
|
|
|
|
| 112 |
) -> str:
|
| 113 |
+
|
| 114 |
+
execution_config = RunnableConfig(
|
| 115 |
recursion_limit=64,
|
| 116 |
+
configurable={ "thread_id": task_identifier }
|
| 117 |
)
|
| 118 |
|
| 119 |
+
if not input_file_path:
|
| 120 |
+
input_file_path = "None - no file present"
|
| 121 |
+
|
| 122 |
+
user_input = HumanMessage(
|
| 123 |
+
content=
|
| 124 |
+
[
|
| 125 |
+
{
|
| 126 |
+
"type": "text",
|
| 127 |
+
"text": f"Task Id: {task_identifier}, Question: {query_text}, Filename: {input_file_path}. If a filename is present (and is not 'None'), download the file for the task that's referenced in the question. If there isn't a filename present, please use tools where applicable."
|
| 128 |
+
}
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
response = await self.agent.ainvoke(
|
| 133 |
+
{
|
| 134 |
+
"messages": [user_input],
|
| 135 |
+
"question": query_text,
|
| 136 |
+
"task_id": task_identifier,
|
| 137 |
+
"file_name": input_file_path
|
| 138 |
+
}, execution_config)
|
| 139 |
+
|
| 140 |
+
final_response = response['messages'][-1].content
|
| 141 |
+
if "FINAL ANSWER: " in final_response:
|
| 142 |
+
return final_response.split("FINAL ANSWER: ", 1)[1].strip()
|
| 143 |
+
else:
|
| 144 |
+
return final_response
|
| 145 |
+
|
app.py
CHANGED
|
@@ -3,56 +3,219 @@ import os
|
|
| 3 |
import gradio as gr
|
| 4 |
import requests
|
| 5 |
import pandas as pd
|
| 6 |
-
from agent import
|
| 7 |
-
|
| 8 |
-
agent = KateMultiModalAgent()
|
| 9 |
|
|
|
|
|
|
|
| 10 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 11 |
-
AGENT_NAME = "Kate's Advanced Agent"
|
| 12 |
|
| 13 |
async def run_agent(
|
| 14 |
-
agt:
|
| 15 |
item: dict
|
| 16 |
) -> str | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
try:
|
| 18 |
-
|
| 19 |
-
question_text = item.get("question")
|
| 20 |
-
file_name = item.get("file_name", None)
|
| 21 |
-
|
| 22 |
-
if not task_id or question_text is None:
|
| 23 |
-
print(f"Skipping invalid item: {item}")
|
| 24 |
-
return None
|
| 25 |
-
|
| 26 |
-
print(f"Processing task {task_id}...")
|
| 27 |
-
submitted_answer = await agt(task_id, question_text, file_name)
|
| 28 |
-
return {
|
| 29 |
-
"task_id": task_id,
|
| 30 |
-
"question": question_text,
|
| 31 |
-
"submitted_answer": submitted_answer
|
| 32 |
-
}
|
| 33 |
except Exception as e:
|
| 34 |
-
print(f"Error
|
| 35 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
return "Please login with your Hugging Face account (Kate Berasneva)", None
|
| 40 |
-
|
| 41 |
-
username = profile.username
|
| 42 |
-
print(f"Kate's Agent running for user: {username}")
|
| 43 |
-
|
| 44 |
-
with gr.Blocks(title="Kate's Agent Evaluation Runner") as demo:
|
| 45 |
-
gr.Markdown("# Kate's Advanced Agent Evaluation")
|
| 46 |
-
gr.Markdown("""
|
| 47 |
-
**Welcome to Kate Berasneva's Agent Solution!**
|
| 48 |
-
|
| 49 |
-
This enhanced agent features:
|
| 50 |
-
- Improved error handling
|
| 51 |
-
- Better tool integration
|
| 52 |
-
- Custom prompt engineering
|
| 53 |
-
- Efficient task processing
|
| 54 |
-
|
| 55 |
-
1. Login with your HF account
|
| 56 |
-
2. Click Run Evaluation
|
| 57 |
-
3. View your results!
|
| 58 |
-
""")
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import requests
|
| 5 |
import pandas as pd
|
| 6 |
+
from agent import MultiModalAgent
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
# (Keep Constants as is)
|
| 9 |
+
# --- Constants ---
|
| 10 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
|
|
| 11 |
|
| 12 |
async def run_agent(
|
| 13 |
+
agt: MultiModalAgent,
|
| 14 |
item: dict
|
| 15 |
) -> str | None:
|
| 16 |
+
task_id = item.get("task_id")
|
| 17 |
+
question_text = item.get("question")
|
| 18 |
+
file_name = item.get("file_name", None)
|
| 19 |
+
|
| 20 |
+
if not task_id or question_text is None:
|
| 21 |
+
print(f"Skipping item with missing task_id or question: {item}")
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
submitted_answer = await agt(task_id, question_text, file_name)
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"task_id": task_id,
|
| 28 |
+
"question": question_text,
|
| 29 |
+
"submitted_answer": submitted_answer
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 34 |
+
"""
|
| 35 |
+
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 36 |
+
and displays the results.
|
| 37 |
+
"""
|
| 38 |
+
# --- Determine HF Space Runtime URL and Repo URL ---
|
| 39 |
+
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
| 40 |
+
|
| 41 |
+
if profile:
|
| 42 |
+
username= f"{profile.username}"
|
| 43 |
+
print(f"User logged in: {username}")
|
| 44 |
+
else:
|
| 45 |
+
print("User not logged in.")
|
| 46 |
+
return "Please Login to Hugging Face with the button.", None
|
| 47 |
+
|
| 48 |
+
api_url = DEFAULT_API_URL
|
| 49 |
+
questions_url = f"{api_url}/questions"
|
| 50 |
+
submit_url = f"{api_url}/submit"
|
| 51 |
+
|
| 52 |
+
# 1. Instantiate Agent ( modify this part to create your agent)
|
| 53 |
try:
|
| 54 |
+
agent = MultiModalAgent()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
except Exception as e:
|
| 56 |
+
print(f"Error instantiating agent: {e}")
|
| 57 |
+
return f"Error initializing agent: {e}", None
|
| 58 |
+
# In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
|
| 59 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
| 60 |
+
print(agent_code)
|
| 61 |
+
|
| 62 |
+
# 2. Fetch Questions
|
| 63 |
+
print(f"Fetching questions from: {questions_url}")
|
| 64 |
+
try:
|
| 65 |
+
response = requests.get(questions_url, timeout=15)
|
| 66 |
+
response.raise_for_status()
|
| 67 |
+
questions_data = response.json()
|
| 68 |
+
if not questions_data:
|
| 69 |
+
print("Fetched questions list is empty.")
|
| 70 |
+
return "Fetched questions list is empty or invalid format.", None
|
| 71 |
+
print(f"Fetched {len(questions_data)} questions.")
|
| 72 |
+
except requests.exceptions.RequestException as e:
|
| 73 |
+
print(f"Error fetching questions: {e}")
|
| 74 |
+
return f"Error fetching questions: {e}", None
|
| 75 |
+
except requests.exceptions.JSONDecodeError as e:
|
| 76 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
| 77 |
+
print(f"Response text: {response.text[:500]}")
|
| 78 |
+
return f"Error decoding server response for questions: {e}", None
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
| 81 |
+
return f"An unexpected error occurred fetching questions: {e}", None
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
#see if there is a loop already running. If there is, reuse it.
|
| 85 |
+
loop = asyncio.get_running_loop()
|
| 86 |
+
except RuntimeError:
|
| 87 |
+
# Create new event loop if one is not running
|
| 88 |
+
loop = asyncio.new_event_loop()
|
| 89 |
+
asyncio.set_event_loop(loop)
|
| 90 |
+
|
| 91 |
+
# 3. Run your Agent
|
| 92 |
+
results_log = []
|
| 93 |
+
answers_payload = []
|
| 94 |
+
print(f"Running agent on {len(questions_data)} questions...")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
results = loop.run_until_complete(
|
| 98 |
+
asyncio.gather(*(run_agent(agent, item) for item in questions_data))
|
| 99 |
+
)
|
| 100 |
+
answers_payload = [{key: value for key, value in item.items() if key != "question"}
|
| 101 |
+
for item in results]
|
| 102 |
+
|
| 103 |
+
for item in results:
|
| 104 |
+
results_log.append(
|
| 105 |
+
{
|
| 106 |
+
"Task ID": item['task_id'],
|
| 107 |
+
"Question": item['question'],
|
| 108 |
+
"Submitted Answer": item['submitted_answer']
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
finally:
|
| 112 |
+
# Clean up
|
| 113 |
+
loop.close()
|
| 114 |
+
|
| 115 |
+
if not answers_payload:
|
| 116 |
+
print("Agent did not produce any answers to submit.")
|
| 117 |
+
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 118 |
+
|
| 119 |
+
# 4. Prepare Submission
|
| 120 |
+
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 121 |
+
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 122 |
+
print(status_update)
|
| 123 |
+
|
| 124 |
+
# 5. Submit
|
| 125 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
| 126 |
+
try:
|
| 127 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
| 128 |
+
response.raise_for_status()
|
| 129 |
+
result_data = response.json()
|
| 130 |
+
final_status = (
|
| 131 |
+
f"Submission Successful!\n"
|
| 132 |
+
f"User: {result_data.get('username')}\n"
|
| 133 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 134 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 135 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
| 136 |
+
)
|
| 137 |
+
print("Submission successful.")
|
| 138 |
+
results_df = pd.DataFrame(results_log)
|
| 139 |
+
return final_status, results_df
|
| 140 |
+
except requests.exceptions.HTTPError as e:
|
| 141 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
| 142 |
+
try:
|
| 143 |
+
error_json = e.response.json()
|
| 144 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
| 145 |
+
except requests.exceptions.JSONDecodeError:
|
| 146 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
| 147 |
+
status_message = f"Submission Failed: {error_detail}"
|
| 148 |
+
print(status_message)
|
| 149 |
+
results_df = pd.DataFrame(results_log)
|
| 150 |
+
return status_message, results_df
|
| 151 |
+
except requests.exceptions.Timeout:
|
| 152 |
+
status_message = "Submission Failed: The request timed out."
|
| 153 |
+
print(status_message)
|
| 154 |
+
results_df = pd.DataFrame(results_log)
|
| 155 |
+
return status_message, results_df
|
| 156 |
+
except requests.exceptions.RequestException as e:
|
| 157 |
+
status_message = f"Submission Failed: Network error - {e}"
|
| 158 |
+
print(status_message)
|
| 159 |
+
results_df = pd.DataFrame(results_log)
|
| 160 |
+
return status_message, results_df
|
| 161 |
+
except Exception as e:
|
| 162 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
| 163 |
+
print(status_message)
|
| 164 |
+
results_df = pd.DataFrame(results_log)
|
| 165 |
+
return status_message, results_df
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# --- Build Gradio Interface using Blocks ---
|
| 169 |
+
with gr.Blocks() as demo:
|
| 170 |
+
gr.Markdown("# Basic Agent Evaluation Runner")
|
| 171 |
+
gr.Markdown(
|
| 172 |
+
"""
|
| 173 |
+
**Instructions:**
|
| 174 |
+
|
| 175 |
+
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
|
| 176 |
+
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
|
| 177 |
+
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
**Disclaimers:**
|
| 181 |
+
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
|
| 182 |
+
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
|
| 183 |
+
"""
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
gr.LoginButton()
|
| 187 |
+
|
| 188 |
+
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 189 |
+
|
| 190 |
+
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
| 191 |
+
# Removed max_rows=10 from DataFrame constructor
|
| 192 |
+
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
| 193 |
+
|
| 194 |
+
run_button.click(
|
| 195 |
+
fn=run_and_submit_all,
|
| 196 |
+
outputs=[status_output, results_table]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
| 201 |
+
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 202 |
+
space_host_startup = os.getenv("SPACE_HOST")
|
| 203 |
+
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
| 204 |
+
|
| 205 |
+
if space_host_startup:
|
| 206 |
+
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
| 207 |
+
print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
| 208 |
+
else:
|
| 209 |
+
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
| 210 |
+
|
| 211 |
+
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
| 212 |
+
print(f"✅ SPACE_ID found: {space_id_startup}")
|
| 213 |
+
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
| 214 |
+
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
|
| 215 |
+
else:
|
| 216 |
+
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
| 217 |
+
|
| 218 |
+
print("-"*(60 + len(" App Starting ")) + "\n")
|
| 219 |
|
| 220 |
+
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 221 |
+
demo.launch(debug=True, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,9 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gradio
|
|
|
|
| 2 |
requests
|
| 3 |
langgraph
|
| 4 |
-
langchain-
|
| 5 |
-
langchain-community
|
| 6 |
-
|
| 7 |
-
|
| 8 |
smolagents
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
asyncer
|
| 2 |
+
anyio
|
| 3 |
+
arxiv
|
| 4 |
gradio
|
| 5 |
+
httpx
|
| 6 |
requests
|
| 7 |
langgraph
|
| 8 |
+
langchain-google-genai
|
| 9 |
+
langchain-community
|
| 10 |
+
langchain-tavily
|
| 11 |
+
openpyxl
|
| 12 |
smolagents
|
| 13 |
+
tavily-python
|
| 14 |
+
wikipedia-api
|
| 15 |
+
wikipedia
|
| 16 |
+
duckduckgo-search
|
tools.py
CHANGED
|
@@ -1,219 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from langchain.tools import tool
|
| 2 |
-
from
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
query: str,
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
query: Search query
|
| 32 |
-
max_results: Max results to return
|
| 33 |
-
|
| 34 |
-
Returns:
|
| 35 |
-
dict: Combined results from Tavily, Wikipedia and Arxiv
|
| 36 |
-
"""
|
| 37 |
-
pass
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
"""
|
| 42 |
-
Kate's enhanced Excel analyzer with better data processing.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
state: Current state dictionary
|
| 46 |
-
file_path: Path to the Excel file
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
str: Analysis results or error message
|
| 50 |
-
|
| 51 |
-
Features:
|
| 52 |
-
- Improved data validation
|
| 53 |
-
- Support for larger files
|
| 54 |
-
- Better error messages
|
| 55 |
-
"""
|
| 56 |
-
try:
|
| 57 |
-
pass
|
| 58 |
-
except Exception as e:
|
| 59 |
-
return f"ERROR: {str(e)}"
|
| 60 |
-
|
| 61 |
-
@tool("analyze-audio")
|
| 62 |
-
async def analyze_audio(file_path: str) -> str:
|
| 63 |
-
"""
|
| 64 |
-
Analyze audio file content.
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
file_path: Path to the audio file
|
| 68 |
-
|
| 69 |
-
Returns:
|
| 70 |
-
str: Analysis results of the audio content
|
| 71 |
-
"""
|
| 72 |
-
try:
|
| 73 |
-
return "Audio analysis placeholder"
|
| 74 |
-
except Exception as e:
|
| 75 |
-
return f"ERROR: {str(e)}"
|
| 76 |
|
| 77 |
-
@tool("analyze-image")
|
| 78 |
-
async def analyze_image(file_path: str) -> str:
|
| 79 |
-
"""
|
| 80 |
-
Analyze image file content.
|
| 81 |
-
|
| 82 |
Args:
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
Returns:
|
| 86 |
-
str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
"""
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"""
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
Args:
|
| 99 |
-
|
| 100 |
-
|
|
|
|
| 101 |
Returns:
|
| 102 |
-
str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"""
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
@tool("
|
| 110 |
-
async def
|
| 111 |
"""
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
Args:
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
Returns:
|
| 118 |
-
str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
"""
|
| 128 |
-
Download file
|
| 129 |
-
|
| 130 |
Args:
|
| 131 |
-
|
| 132 |
-
|
|
|
|
| 133 |
Returns:
|
| 134 |
-
|
| 135 |
"""
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
async def read_file_contents(file_path: str) -> str:
|
| 143 |
"""
|
| 144 |
-
Read
|
| 145 |
-
|
| 146 |
Args:
|
| 147 |
-
file_path:
|
| 148 |
-
|
| 149 |
Returns:
|
| 150 |
-
|
| 151 |
"""
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
except Exception as e:
|
| 155 |
-
return f"ERROR: {str(e)}"
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
"""
|
| 160 |
-
|
| 161 |
-
|
| 162 |
Args:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
Returns:
|
| 166 |
-
|
| 167 |
"""
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
"""
|
| 176 |
-
|
| 177 |
-
|
| 178 |
Args:
|
| 179 |
-
|
| 180 |
-
|
| 181 |
Returns:
|
| 182 |
-
|
| 183 |
"""
|
| 184 |
-
try:
|
| 185 |
-
return {"results": "Tavily search placeholder"}
|
| 186 |
-
except Exception as e:
|
| 187 |
-
return {"error": str(e)}
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"""
|
| 192 |
-
|
| 193 |
-
|
| 194 |
Args:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
Returns:
|
| 198 |
-
|
| 199 |
"""
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
"""
|
| 208 |
-
|
| 209 |
-
|
| 210 |
Args:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
Returns:
|
| 214 |
-
|
| 215 |
"""
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import dotenv
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import tempfile
|
| 5 |
+
import typing
|
| 6 |
+
|
| 7 |
+
from base64 import b64encode
|
| 8 |
+
from io import StringIO
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
from anyio import Path
|
| 13 |
+
from asyncer import asyncify
|
| 14 |
+
from langchain_community.document_loaders import ArxivLoader
|
| 15 |
+
from langchain_community.document_loaders import WikipediaLoader
|
| 16 |
+
from langchain_core.messages import HumanMessage
|
| 17 |
+
from langchain_tavily import TavilyExtract
|
| 18 |
+
from langchain_tavily import TavilySearch
|
| 19 |
+
from langgraph.prebuilt import create_react_agent
|
| 20 |
+
from langgraph.prebuilt import InjectedState
|
| 21 |
+
from langchain.tools import BaseTool
|
| 22 |
from langchain.tools import tool
|
| 23 |
+
from pydantic import Field
|
| 24 |
+
from typing_extensions import Annotated
|
| 25 |
+
|
| 26 |
+
from utils import get_llm
|
| 27 |
+
|
| 28 |
+
dotenv.load_dotenv()
|
| 29 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
|
| 30 |
+
if not GOOGLE_API_KEY:
|
| 31 |
+
raise ValueError("GOOGLE_API_KEY environment variable is not set.")
|
| 32 |
+
|
| 33 |
+
AGENT_MODEL_NAME = os.getenv("AGENT_MODEL_NAME", "gemini-2.0-flash")
|
| 34 |
+
|
| 35 |
+
MULTIMODAL_FILE_ANALYZER_PROMPT = """
|
| 36 |
+
You are a specialized file analysis AI assistant focused on extracting information from various file formats including images, videos, audio, and structured data.
|
| 37 |
+
Core Analysis Guidelines:
|
| 38 |
+
- Systematic processing: Analyze file contents step by step
|
| 39 |
+
- Precise responses: Provide answers in the most concise format - raw numbers, single words, or comma-delimited lists
|
| 40 |
+
- Format requirements:
|
| 41 |
+
* Numbers: No formatting (no commas, units, or symbols)
|
| 42 |
+
* Lists: Pure comma-separated values
|
| 43 |
+
* Text: Minimal words, no explanations
|
| 44 |
+
- Analysis approach:
|
| 45 |
+
* Images: Focus on visual elements, objects, text, and scene composition
|
| 46 |
+
* Audio: Identify sounds, speech, music, and audio characteristics
|
| 47 |
+
* Video: Analyze visual content, motion, and temporal elements
|
| 48 |
+
* Excel/CSV: Extract relevant data points and patterns
|
| 49 |
+
- Verification focus: Base answers solely on file contents
|
| 50 |
+
- Answer format: Always prefix with 'FINAL ANSWER: '
|
| 51 |
+
- Counting tasks: Return only the count
|
| 52 |
+
- Listing tasks: Return only the items
|
| 53 |
+
- Sorting tasks: Return only the ordered list
|
| 54 |
+
|
| 55 |
+
Example Responses:
|
| 56 |
+
Q: Count people in image? A: 3
|
| 57 |
+
Q: List colors in logo? A: blue, red, white
|
| 58 |
+
Q: Main topic of audio? A: weather forecast
|
| 59 |
+
Q: Excel total sales? A: 15420
|
| 60 |
+
Q: Video duration? A: 45
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SmolagentToolWrapper(BaseTool):
|
| 65 |
+
"""Smol wrapper to allow Langchain/Graph to leverage smolagents tools"""
|
| 66 |
+
|
| 67 |
+
wrapped_tool: object = Field(description="Smolagents tool (wrapped)")
|
| 68 |
+
|
| 69 |
+
def __init__(self, tool):
|
| 70 |
+
super().__init__(
|
| 71 |
+
name=tool.name,
|
| 72 |
+
description=tool.description,
|
| 73 |
+
return_direct=False,
|
| 74 |
+
wrapped_tool=tool,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def _run(self, query: str) -> str:
|
| 78 |
+
try:
|
| 79 |
+
return self.wrapped_tool(query)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
return f"Error using SmolagentToolWrapper: {str(e)}"
|
| 82 |
+
|
| 83 |
+
def _arun(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
| 84 |
+
"""Async version of the tool"""
|
| 85 |
+
return asyncify(self._run, cancellable=True)(*args, **kwargs)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
tavily_extract_tool = TavilyExtract()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@tool("search-tavily-tool", parse_docstring=True)
|
| 92 |
+
async def search_tavily(
|
| 93 |
query: str,
|
| 94 |
+
state: Annotated[dict, InjectedState],
|
| 95 |
+
included_domains: list[str] = None,
|
| 96 |
+
max_results: int = 5,
|
| 97 |
+
) -> dict[str, str]:
|
| 98 |
"""
|
| 99 |
+
Search the web using Tavily API with optional domain filtering.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
This function performs a search using the Tavily search engine and returns formatted results.
|
| 102 |
+
You can specify domains to include in the search results for more targeted information.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
Args:
|
| 105 |
+
query (str): The search query to search the web for
|
| 106 |
+
included_domains (list[str], optional): List of domains to include in search results
|
| 107 |
+
(e.g., ["wikipedia.org", "cnn.com"]). Defaults to None.
|
| 108 |
+
max_results (int, optional): Maximum number of results to return. Defaults to 5.
|
| 109 |
+
|
| 110 |
Returns:
|
| 111 |
+
dict[str, str]: A dictionary with key 'tavily_results' containing formatted search results.
|
| 112 |
+
Each result includes document source, page information, and content.
|
| 113 |
+
|
| 114 |
+
Example:
|
| 115 |
+
results = await search_tavily("How many albums did Michael Jackson produce", included_domains=[], topic="general")
|
| 116 |
+
# Returns filtered results about Michael Jackson
|
| 117 |
"""
|
| 118 |
+
# Configure Tavily search with provided parameters
|
| 119 |
+
tavily_search_tool = TavilySearch(
|
| 120 |
+
max_results=max_results,
|
| 121 |
+
topic="general",
|
| 122 |
+
include_domains=included_domains if included_domains else None,
|
| 123 |
+
search_depth="advanced",
|
| 124 |
+
include_answer="advanced",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Execute search
|
| 128 |
+
search_docs = await tavily_search_tool.arun(state["question"])
|
| 129 |
+
|
| 130 |
+
# Format results
|
| 131 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 132 |
+
[
|
| 133 |
+
f'<Document source="{doc.get("url", "No URL")}"/>{doc.get("title", "No Title")}\n{doc.get("content", "")}\n</Document>'
|
| 134 |
+
for doc in search_docs.get("results", [])
|
| 135 |
+
]
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
results = {"tavily_results": formatted_search_docs}
|
| 139 |
|
| 140 |
+
answer = search_docs.get("answer", None)
|
| 141 |
+
|
| 142 |
+
if answer:
|
| 143 |
+
results["tavily_answer"] = answer
|
| 144 |
+
|
| 145 |
+
return results
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@tool("search-arxiv-tool", parse_docstring=True)
|
| 149 |
+
async def search_arxiv(query: str, max_num_result: int = 5) -> dict[str, str]:
|
| 150 |
"""
|
| 151 |
+
Search arXiv for academic papers matching the provided query.
|
| 152 |
+
This function queries the arXiv database for scholarly articles related to the
|
| 153 |
+
search query and returns a formatted collection of the results.
|
| 154 |
+
|
| 155 |
Args:
|
| 156 |
+
query (str): The search query to find relevant academic papers.
|
| 157 |
+
max_num_result (int, optional): Maximum number of results to return. Defaults to 5.
|
| 158 |
+
|
| 159 |
Returns:
|
| 160 |
+
dict[str, str]: A dictionary with key 'arxiv_results' containing formatted search results.
|
| 161 |
+
Each result includes document source, page information, and content.
|
| 162 |
+
|
| 163 |
+
Example:
|
| 164 |
+
results = await search_arxiv("quantum computing", 3)
|
| 165 |
+
# Returns dictionary with up to 3 formatted arXiv papers about quantum computing
|
| 166 |
"""
|
| 167 |
+
search_docs = await ArxivLoader(query=query, load_max_docs=max_num_result).aload()
|
| 168 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 169 |
+
[
|
| 170 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 171 |
+
for doc in search_docs
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
return {"arvix_results": formatted_search_docs}
|
| 175 |
+
|
| 176 |
|
| 177 |
+
@tool("search-wikipedia-tool", parse_docstring=True)
|
| 178 |
+
async def search_wikipedia(query: str, max_num_result: int = 5) -> dict[str, str]:
|
| 179 |
"""
|
| 180 |
+
Search Wikipedia for articles matching the provided query.
|
| 181 |
+
This function queries the Wikipedia database for articles related to the
|
| 182 |
+
search term and returns a formatted collection of the results.
|
| 183 |
+
|
| 184 |
Args:
|
| 185 |
+
query (str): The search query to find relevant Wikipedia articles.
|
| 186 |
+
max_num_result (int, optional): Maximum number of results to return. Defaults to 5.
|
| 187 |
+
|
| 188 |
Returns:
|
| 189 |
+
dict[str, str]: A dictionary with key 'wikipedia_results' containing formatted search results.
|
| 190 |
+
Each result includes document source, page information, and content.
|
| 191 |
+
|
| 192 |
+
Example:
|
| 193 |
+
results = await search_wikipedia("neural networks", 3)
|
| 194 |
+
# Returns dictionary with up to 3 formatted Wikipedia articles about neural networks
|
| 195 |
"""
|
| 196 |
+
search_docs = await WikipediaLoader(
|
| 197 |
+
query=query,
|
| 198 |
+
load_max_docs=max_num_result,
|
| 199 |
+
load_all_available_meta=True,
|
| 200 |
+
doc_content_chars_max=128000,
|
| 201 |
+
).aload()
|
| 202 |
|
| 203 |
+
#print(search_docs)
|
| 204 |
+
|
| 205 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 206 |
+
[
|
| 207 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 208 |
+
for doc in search_docs
|
| 209 |
+
]
|
| 210 |
+
)
|
| 211 |
+
return {"wikipedia_results": formatted_search_docs}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@tool("download-file-for-task-tool", parse_docstring=True)
|
| 215 |
+
async def download_file_for_task(task_id: str, filename: str | None = None) -> str:
|
| 216 |
"""
|
| 217 |
+
Download a file for task_id, save to a temporary file, and return path
|
| 218 |
+
|
| 219 |
Args:
|
| 220 |
+
task_id: The task id file to download
|
| 221 |
+
filename: Optional filename (will be generated if not provided)
|
| 222 |
+
|
| 223 |
Returns:
|
| 224 |
+
String path to the downloaded file
|
| 225 |
"""
|
| 226 |
+
if filename is None:
|
| 227 |
+
filename = task_id
|
| 228 |
+
|
| 229 |
+
temp_dir = Path(tempfile.gettempdir())
|
| 230 |
+
filepath = temp_dir / filename
|
| 231 |
|
| 232 |
+
url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
|
| 233 |
+
async with httpx.AsyncClient() as client:
|
| 234 |
+
async with client.stream("GET", url) as response:
|
| 235 |
+
response.raise_for_status()
|
| 236 |
+
async with await filepath.open("wb") as f:
|
| 237 |
+
async for chunk in response.aiter_bytes(chunk_size=4096):
|
| 238 |
+
await f.write(chunk)
|
| 239 |
+
|
| 240 |
+
return str(filepath)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@tool("read-file-contents-tool", parse_docstring=True)
|
| 244 |
async def read_file_contents(file_path: str) -> str:
|
| 245 |
"""
|
| 246 |
+
Read a file and return its contents
|
| 247 |
+
|
| 248 |
Args:
|
| 249 |
+
file_path: String path to file to read
|
| 250 |
+
|
| 251 |
Returns:
|
| 252 |
+
Contents of the file at file_path
|
| 253 |
"""
|
| 254 |
+
path = Path(file_path)
|
| 255 |
+
return await path.read_text()
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
|
| 258 |
+
@tool("analyze-image-tool", parse_docstring=True)
|
| 259 |
+
async def analyze_image(state: Annotated[dict, InjectedState], image_path: str) -> str:
|
| 260 |
"""
|
| 261 |
+
Analyze the image at image_path
|
| 262 |
+
|
| 263 |
Args:
|
| 264 |
+
image_path: String path where the image file is located on disk
|
| 265 |
+
|
| 266 |
Returns:
|
| 267 |
+
Answer to the question about the image file
|
| 268 |
"""
|
| 269 |
+
path = Path(image_path)
|
| 270 |
+
async with await path.open("rb") as rb:
|
| 271 |
+
img_base64 = b64encode(await rb.read()).decode("utf-8")
|
| 272 |
+
|
| 273 |
+
llm = get_llm(
|
| 274 |
+
llm_provider_api_key=GOOGLE_API_KEY,
|
| 275 |
+
model_name=AGENT_MODEL_NAME,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
file_agent = create_react_agent(
|
| 279 |
+
llm,
|
| 280 |
+
tools=[],
|
| 281 |
+
state_modifier=MULTIMODAL_FILE_ANALYZER_PROMPT,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
message = HumanMessage(
|
| 285 |
+
content=[
|
| 286 |
+
{"type": "text", "text": state["question"]},
|
| 287 |
+
{
|
| 288 |
+
"type": "image",
|
| 289 |
+
"source_type": "base64",
|
| 290 |
+
"mime_type": "image/png",
|
| 291 |
+
"data": img_base64,
|
| 292 |
+
},
|
| 293 |
+
]
|
| 294 |
+
)
|
| 295 |
|
| 296 |
+
messages = await file_agent.ainvoke({"messages": [message]})
|
| 297 |
+
return messages["messages"][-1].content
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@tool("analyze-excel-tool", parse_docstring=True)
|
| 301 |
+
async def analyze_excel(state: Annotated[dict, InjectedState], excel_path: str) -> str:
|
| 302 |
"""
|
| 303 |
+
Analyze the excel file at excel_path
|
| 304 |
+
|
| 305 |
Args:
|
| 306 |
+
excel_path: String path where the excel file is located on disk
|
| 307 |
+
|
| 308 |
Returns:
|
| 309 |
+
Answer to the question about the excel file
|
| 310 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
+
df = pd.read_excel(excel_path)
|
| 313 |
+
|
| 314 |
+
csv_buffer = StringIO()
|
| 315 |
+
df.to_csv(csv_buffer, index=False)
|
| 316 |
+
|
| 317 |
+
csv_contents = csv_buffer.getvalue()
|
| 318 |
+
csv_contents_bytes = csv_contents.encode("utf-8")
|
| 319 |
+
csv_contents_base64 = b64encode(csv_contents_bytes).decode("utf-8")
|
| 320 |
+
|
| 321 |
+
llm = get_llm(
|
| 322 |
+
llm_provider_api_key=GOOGLE_API_KEY,
|
| 323 |
+
model_name=AGENT_MODEL_NAME,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
file_agent = create_react_agent(
|
| 327 |
+
llm,
|
| 328 |
+
tools=[],
|
| 329 |
+
state_modifier=MULTIMODAL_FILE_ANALYZER_PROMPT,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
message = HumanMessage(
|
| 333 |
+
content=[
|
| 334 |
+
{"type": "text", "text": state["question"]},
|
| 335 |
+
{
|
| 336 |
+
"type": "file",
|
| 337 |
+
"source_type": "base64",
|
| 338 |
+
"mime_type": "text/csv",
|
| 339 |
+
"data": csv_contents_base64,
|
| 340 |
+
},
|
| 341 |
+
],
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
messages = await file_agent.ainvoke({"messages": [message]})
|
| 345 |
+
return messages["messages"][-1].content
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@tool("analyze-audio-tool", parse_docstring=True)
|
| 349 |
+
async def analyze_audio(state: Annotated[dict, InjectedState], audio_path: str) -> str:
|
| 350 |
"""
|
| 351 |
+
Analyze the audio at audio_path
|
| 352 |
+
|
| 353 |
Args:
|
| 354 |
+
audio_path: String path where the audio file is located on disk
|
| 355 |
+
|
| 356 |
Returns:
|
| 357 |
+
Answer to the question about the audio file
|
| 358 |
"""
|
| 359 |
+
audio_mime_type = "audio/mpeg"
|
| 360 |
+
|
| 361 |
+
path = Path(audio_path)
|
| 362 |
+
|
| 363 |
+
async with await path.open("rb") as rb:
|
| 364 |
+
encoded_audio = b64encode(await rb.read()).decode("utf-8")
|
| 365 |
|
| 366 |
+
llm = get_llm(
|
| 367 |
+
llm_provider_api_key=GOOGLE_API_KEY,
|
| 368 |
+
model_name=AGENT_MODEL_NAME,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
file_agent = create_react_agent(
|
| 372 |
+
llm,
|
| 373 |
+
tools=[],
|
| 374 |
+
state_modifier=MULTIMODAL_FILE_ANALYZER_PROMPT,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
message = HumanMessage(
|
| 378 |
+
content=[
|
| 379 |
+
{"type": "text", "text": state["question"]},
|
| 380 |
+
{"type": "media", "data": encoded_audio, "mime_type": audio_mime_type},
|
| 381 |
+
],
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
messages = await file_agent.ainvoke({"messages": [message]})
|
| 385 |
+
return messages["messages"][-1].content
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@tool("analyze-video-tool", parse_docstring=True)
|
| 389 |
+
async def analyze_video(state: Annotated[dict, InjectedState], video_url: str) -> str:
|
| 390 |
"""
|
| 391 |
+
Analyze the video at video_url
|
| 392 |
+
|
| 393 |
Args:
|
| 394 |
+
video_url: URL where the video is located
|
| 395 |
+
|
| 396 |
Returns:
|
| 397 |
+
Answer to the question about the video url
|
| 398 |
"""
|
| 399 |
+
llm = get_llm(
|
| 400 |
+
llm_provider_api_key=GOOGLE_API_KEY,
|
| 401 |
+
model_name=AGENT_MODEL_NAME,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
file_agent = create_react_agent(
|
| 405 |
+
llm,
|
| 406 |
+
tools=[],
|
| 407 |
+
state_modifier=MULTIMODAL_FILE_ANALYZER_PROMPT,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
message = HumanMessage(
|
| 411 |
+
content=[
|
| 412 |
+
{"type": "text", "text": state["question"]},
|
| 413 |
+
{
|
| 414 |
+
"type": "media",
|
| 415 |
+
"mime_type": "video/mp4",
|
| 416 |
+
"file_uri": video_url,
|
| 417 |
+
},
|
| 418 |
+
],
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
messages = await file_agent.ainvoke({"messages": [message]})
|
| 422 |
+
return messages["messages"][-1].content
|
utils.py
CHANGED
|
@@ -1,28 +1,13 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
|
|
|
|
| 4 |
def get_llm(
|
| 5 |
llm_provider_api_key: str,
|
| 6 |
-
model_name:
|
| 7 |
-
|
| 8 |
-
max_tokens: Optional[int] = None
|
| 9 |
-
) -> ChatGoogleGenerativeAI:
|
| 10 |
-
"""
|
| 11 |
-
Initialize and return a Google Generative AI language model.
|
| 12 |
-
|
| 13 |
-
Args:
|
| 14 |
-
llm_provider_api_key: Google API key
|
| 15 |
-
model_name: Name of the model to use (default: None)
|
| 16 |
-
temperature: Sampling temperature (default: 0.7)
|
| 17 |
-
max_tokens: Maximum number of tokens to generate (default: None)
|
| 18 |
-
|
| 19 |
-
Returns:
|
| 20 |
-
ChatGoogleGenerativeAI: Initialized language model
|
| 21 |
-
"""
|
| 22 |
return ChatGoogleGenerativeAI(
|
| 23 |
google_api_key=llm_provider_api_key,
|
|
|
|
|
|
|
| 24 |
model=model_name,
|
| 25 |
-
|
| 26 |
-
max_output_tokens=max_tokens,
|
| 27 |
-
convert_system_message_to_human=True
|
| 28 |
-
)
|
|
|
|
|
|
|
| 1 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
|
| 3 |
+
|
| 4 |
def get_llm(
|
| 5 |
llm_provider_api_key: str,
|
| 6 |
+
model_name: str = "gemini-2.0-flash", # Default model aligned with AGENT_MODEL_NAME
|
| 7 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
return ChatGoogleGenerativeAI(
|
| 9 |
google_api_key=llm_provider_api_key,
|
| 10 |
+
temperature=0.7,
|
| 11 |
+
max_retries=5,
|
| 12 |
model=model_name,
|
| 13 |
+
)
|
|
|
|
|
|
|
|
|