lwant's picture
Add support for handling `additional_file` and `additional_file_path` in workflow agents, update `setup` method, and adapt prompt template accordingly
8332d06
raw
history blame
4.78 kB
import re
from pathlib import Path
from typing import Any
from llama_index.core.agent.workflow import FunctionAgent, AgentWorkflow
from llama_index.core.prompts import RichPromptTemplate
from llama_index.llms.nebius import NebiusLLM
from llama_index.tools.requests import RequestsToolSpec
from llama_index.tools.wikipedia import WikipediaToolSpec
from workflows import Workflow, step
from workflows.events import StartEvent, Event, StopEvent
from gaia_solving_agent import NEBIUS_API_KEY
from gaia_solving_agent.prompts import PLANING_PROMPT, FORMAT_ANSWER
from gaia_solving_agent.tools import tavily_search_web, wikipedia_tool_spec
# Choice of the model
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# model_name = "deepseek-ai/DeepSeek-R1-0528"
# model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" # For VLM needs
def get_llm(model_name=model_name):
return NebiusLLM(
model=model_name,
api_key=NEBIUS_API_KEY,
is_function_calling_model=True,
max_completion_tokens=10000,
context_window=80000, # max = 128000 for "meta-llama/Meta-Llama-3.1-8B-Instruct"
temperature=.1,
max_retries=5,
)
class QueryEvent(Event):
query: str
additional_file: Any | None
additional_file_path: str | Path | None = None
plan: str
class AnswerEvent(Event):
plan: str
answer: str
class GaiaWorkflow(Workflow):
@step
async def setup(self, ev: StartEvent) -> QueryEvent:
llm = get_llm()
prompt_template = RichPromptTemplate(PLANING_PROMPT)
file_extension = Path(ev.additional_file_path).suffix if ev.additional_file_path else ""
plan = llm.complete(prompt_template.format(
user_request=ev.user_msg,
additional_file_extension=file_extension,
))
return QueryEvent(
query=ev.user_msg,
additional_file=ev.additional_file,
additional_file_path=ev.additional_file_path,
plan=plan.text,
)
@step()
async def multi_agent_process(self, ev: QueryEvent) -> AnswerEvent:
# Cheap trick to avoid Error 400 errors from OpenAPI
from llama_index.core.memory import ChatMemoryBuffer
memory = ChatMemoryBuffer.from_defaults(token_limit=100000)
agent_output = await gaia_solving_agent.run(
user_msg=ev.plan,
memory=memory,
additional_file=ev.additional_file,
additional_file_path=ev.additional_file_path,
)
return AnswerEvent(plan=ev.plan, answer=str(agent_output))
@step
async def parse_answer(self, ev: AnswerEvent) -> StopEvent:
llm = get_llm()
prompt_template = RichPromptTemplate(FORMAT_ANSWER)
pattern = r"Question :\s*(.*)[\n$]"
search = re.search(pattern, ev.plan)
question = search.group(1) if search else ""
result = llm.complete(prompt_template.format(question=question))
return StopEvent(result=result)
tavily_search_engine = FunctionAgent(
tools=[tavily_search_web],
llm=get_llm(),
system_prompt="""
You are a helpful assistant that does web searches.
Convert the user need into one or multiple web searches.
Each web search should aim for one specific topic.
A topic is defined as one to few words.
If the user needs to search for multiple topics, make multiple searches.
""",
name="search_engine_agent",
can_handoff_to = ["visit_web_page_agent"],
description="Agent that makes web searches to answer questions."
)
visit_website = FunctionAgent(
tools=[
*RequestsToolSpec().to_tool_list(),
],
llm=get_llm(),
system_prompt="""
You are a helpful assistant that visit a website.
Given a url, you should visit the web page and return a summary of the page.
The summary should answer the concerns of the user.
If the url is invalid, return "Invalid URL".
If the url is not a web page, return "Not a web page".
If the url is not reachable, return "Not reachable".
""",
name="visit_web_page_agent",
description="Agent that visit a web page and return a summary of the page."
)
wikipedia_agent = FunctionAgent(
tools=[*WikipediaToolSpec().to_tool_list()],
llm=get_llm(),
system_prompt="""
You are a helpful assistant that searches Wikipedia and visit Wikipedia pages.
""",
name="wikipedia_agent",
description="Agent that searches Wikipedia and visit Wikipedia pages."
)
gaia_solving_agent = AgentWorkflow(
agents = [tavily_search_engine, visit_website, wikipedia_agent],
initial_state = dict(),
root_agent = tavily_search_engine.name,
handoff_prompt = None,
handoff_output_prompt = None,
state_prompt = None,
num_concurrent_runs=1,
)