lwant's picture
Update `setup` method to handle additional files with extensions, adapt prompt template for new variable
2d02aba
raw
history blame
4.44 kB
import re
from pathlib import Path
from typing import Any
from llama_index.core.agent.workflow import FunctionAgent, AgentWorkflow
from llama_index.core.prompts import RichPromptTemplate
from llama_index.llms.nebius import NebiusLLM
from llama_index.tools.requests import RequestsToolSpec
from llama_index.tools.wikipedia import WikipediaToolSpec
from workflows import Workflow, step
from workflows.events import StartEvent, Event, StopEvent
from gaia_solving_agent import NEBIUS_API_KEY
from gaia_solving_agent.prompts import PLANING_PROMPT, FORMAT_ANSWER
from gaia_solving_agent.tools import tavily_search_web, wikipedia_tool_spec
# Choice of the model
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# model_name = "deepseek-ai/DeepSeek-R1-0528"
# model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" # For VLM needs
def get_llm(model_name=model_name):
return NebiusLLM(
model=model_name,
api_key=NEBIUS_API_KEY,
is_function_calling_model=True,
max_completion_tokens=10000,
context_window=80000, # max = 128000 for "meta-llama/Meta-Llama-3.1-8B-Instruct"
temperature=.1,
max_retries=5,
)
class QueryEvent(Event):
query: str
additional_file: Any | None
additional_file_path: str = ""
plan: str
class AnswerEvent(Event):
plan: str
answer: str
class GaiaWorkflow(Workflow):
@step
async def setup(self, ev: StartEvent) -> QueryEvent:
llm = get_llm()
prompt_template = RichPromptTemplate(PLANING_PROMPT)
plan = llm.complete(prompt_template.format(
user_request=ev.query,
additional_file_extension=Path(ev.additional_file_name).suffix,
))
return QueryEvent(query=ev.query, additional_file=ev.additional_file, plan=plan.text)
@step()
async def multi_agent_process(self, ev: QueryEvent) -> AnswerEvent:
# Cheap trick to avoid Error 400 errors from OpenAPI
from llama_index.core.memory import ChatMemoryBuffer
memory = ChatMemoryBuffer.from_defaults(token_limit=100000)
agent_output = await gaia_solving_agent.run(user_msg=ev.plan, memory=memory)
return AnswerEvent(plan=ev.plan, answer=str(agent_output))
@step
async def parse_answer(self, ev: AnswerEvent) -> StopEvent:
llm = get_llm()
prompt_template = RichPromptTemplate(FORMAT_ANSWER)
pattern = r"Question :\s*(.*)[\n$]"
search = re.search(pattern, ev.plan)
question = search.group(1) if search else ""
result = llm.complete(prompt_template.format(question=question))
return StopEvent(result=result)
tavily_search_engine = FunctionAgent(
tools=[tavily_search_web],
llm=get_llm(),
system_prompt="""
You are a helpful assistant that does web searches.
Convert the user need into one or multiple web searches.
Each web search should aim for one specific topic.
A topic is defined as one to few words.
If the user needs to search for multiple topics, make multiple searches.
""",
name="search_engine_agent",
can_handoff_to = ["visit_web_page_agent"],
description="Agent that makes web searches to answer questions."
)
visit_website = FunctionAgent(
tools=[
*RequestsToolSpec().to_tool_list(),
],
llm=get_llm(),
system_prompt="""
You are a helpful assistant that visit a website.
Given a url, you should visit the web page and return a summary of the page.
The summary should answer the concerns of the user.
If the url is invalid, return "Invalid URL".
If the url is not a web page, return "Not a web page".
If the url is not reachable, return "Not reachable".
""",
name="visit_web_page_agent",
description="Agent that visit a web page and return a summary of the page."
)
wikipedia_agent = FunctionAgent(
tools=[*WikipediaToolSpec().to_tool_list()],
llm=get_llm(),
system_prompt="""
You are a helpful assistant that searches Wikipedia and visit Wikipedia pages.
""",
name="wikipedia_agent",
description="Agent that searches Wikipedia and visit Wikipedia pages."
)
gaia_solving_agent = AgentWorkflow(
agents = [tavily_search_engine, visit_website, wikipedia_agent],
initial_state = dict(),
root_agent = tavily_search_engine.name,
handoff_prompt = None,
handoff_output_prompt = None,
state_prompt = None,
num_concurrent_runs=1,
)