|
|
from llama_index.llms.gemini import Gemini |
|
|
from llama_index.tools.arxiv import ArxivToolSpec |
|
|
from llama_index.tools.wikipedia import WikipediaToolSpec |
|
|
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec |
|
|
from llama_index.core.tools import FunctionTool |
|
|
from llama_index.core.agent.workflow import AgentWorkflow |
|
|
from gradio import ChatMessage |
|
|
from llama_index.core.base.llms.types import ChatMessage as llama_index_chat_message |
|
|
|
|
|
from tools import interpret_python_math_code, image_understanding, convert_audio_to_text, video_understanding, read_csv_file, read_xlsx_file |
|
|
from gaia_system_prompt import CUSTOM_SYSTEM_PROMPT |
|
|
|
|
|
import os |
|
|
import asyncio |
|
|
|
|
|
TIMEOUT=180 |
|
|
GEMINI_API_KEY = os.getenv("GEMINI_TOKEN") |
|
|
GEMINI_MODEL_NAME = "gemini-2.5-flash-preview-04-17" |
|
|
|
|
|
|
|
|
class FinalAgent: |
|
|
def __init__(self): |
|
|
|
|
|
self.llm = Gemini(model=GEMINI_MODEL_NAME, api_key=GEMINI_API_KEY) |
|
|
|
|
|
|
|
|
self.tools = [ |
|
|
FunctionTool.from_defaults( |
|
|
fn=interpret_python_math_code, |
|
|
name="InterpretPythonMathCode", |
|
|
description=interpret_python_math_code.__doc__ |
|
|
), |
|
|
FunctionTool.from_defaults( |
|
|
fn=image_understanding, |
|
|
name="ImageUnderstanding", |
|
|
description=image_understanding.__doc__ |
|
|
), |
|
|
FunctionTool.from_defaults( |
|
|
fn=convert_audio_to_text, |
|
|
name="ConvertAudioToText", |
|
|
description= convert_audio_to_text.__doc__ |
|
|
), |
|
|
FunctionTool.from_defaults( |
|
|
fn=video_understanding, |
|
|
name="VideoUnderstanding", |
|
|
description= video_understanding.__doc__ |
|
|
), |
|
|
FunctionTool.from_defaults( |
|
|
fn=read_csv_file, |
|
|
name="ReadCSVFile", |
|
|
description=read_csv_file.__doc__ |
|
|
), |
|
|
FunctionTool.from_defaults( |
|
|
fn=read_xlsx_file, |
|
|
name="ReadXLSXFile", |
|
|
description= read_xlsx_file.__doc__ |
|
|
) |
|
|
] |
|
|
self.tools.extend( |
|
|
ArxivToolSpec().to_tool_list() |
|
|
) |
|
|
self.tools.extend( |
|
|
WikipediaToolSpec().to_tool_list() |
|
|
) |
|
|
self.tools.extend( |
|
|
DuckDuckGoSearchToolSpec().to_tool_list() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.agent = AgentWorkflow.from_tools_or_functions( |
|
|
tools_or_functions=self.tools, |
|
|
llm=self.llm, |
|
|
system_prompt=CUSTOM_SYSTEM_PROMPT, |
|
|
timeout=TIMEOUT |
|
|
) |
|
|
|
|
|
print("FinalAgent initialized.") |
|
|
|
|
|
async def __call__(self, question: str) -> str: |
|
|
print(f"Agent received question: {question}") |
|
|
|
|
|
response_str = "" |
|
|
try: |
|
|
|
|
|
agent_chat_response = await self.agent.run(question) |
|
|
print(agent_chat_response) |
|
|
|
|
|
potential_response_obj = agent_chat_response.response |
|
|
|
|
|
if isinstance(potential_response_obj, ChatMessage): |
|
|
|
|
|
print(f"DEBUG: Response object is ChatMessage. Role: {potential_response_obj.role}") |
|
|
response_str = potential_response_obj.content |
|
|
if response_str is None: |
|
|
print("DEBUG: ChatMessage content is None, defaulting to empty string.") |
|
|
response_str = "" |
|
|
elif isinstance(potential_response_obj, str): |
|
|
|
|
|
print("DEBUG: Response object is str.") |
|
|
response_str = potential_response_obj |
|
|
elif isinstance(potential_response_obj, llama_index_chat_message): |
|
|
|
|
|
print(f"DEBUG: Response object is llama_index ChatMessage. Role: {potential_response_obj.role}") |
|
|
response_str = potential_response_obj.content |
|
|
if response_str is None: |
|
|
print("DEBUG: llama_index ChatMessage content is None, defaulting to empty string.") |
|
|
response_str = "" |
|
|
else: |
|
|
|
|
|
print(f"Warning: Agent response was of unexpected type: {type(potential_response_obj)}. Converting to string.") |
|
|
response_str = str(potential_response_obj) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during agent execution with LLM {self.llm.__class__.__name__}: {e}") |
|
|
|
|
|
response_str = f"Agent error: {e}" |
|
|
|
|
|
|
|
|
if "<final_answer>" in response_str and "</final_answer>" in response_str: |
|
|
start_index = response_str.index("<final_answer>") + len("<final_answer>") |
|
|
end_index = response_str.index("</final_answer>") |
|
|
response_str = response_str[start_index:end_index].strip() |
|
|
else: |
|
|
print("Warning: No <final_answer> tags found in the response.") |
|
|
|
|
|
return response_str |
|
|
|
|
|
|
|
|
async def main(): |
|
|
|
|
|
agent = FinalAgent() |
|
|
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia." |
|
|
question2 = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?" |
|
|
answer = await agent(question) |
|
|
print(f"Final answer: {answer}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |