lwant commited on
Commit
e04414b
Β·
1 Parent(s): ea9851f

Refactor LLM selection to use specific models for reasoning, parsing, and balanced tasks; improve code clarity.

Browse files
Files changed (1) hide show
  1. src/gaia_solving_agent/agent.py +8 -7
src/gaia_solving_agent/agent.py CHANGED
@@ -15,11 +15,12 @@ from gaia_solving_agent.prompts import PLANING_PROMPT, FORMAT_ANSWER
15
  from gaia_solving_agent.tools import tavily_search_web, wikipedia_tool_spec
16
 
17
  # Choice of the model
18
- model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
19
- # model_name = "deepseek-ai/DeepSeek-R1-0528"
20
- # model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" # For VLM needs
 
21
 
22
- def get_llm(model_name=model_name):
23
  return NebiusLLM(
24
  model=model_name,
25
  api_key=NEBIUS_API_KEY,
@@ -45,7 +46,7 @@ class AnswerEvent(Event):
45
  class GaiaWorkflow(Workflow):
46
  @step
47
  async def setup(self, ev: StartEvent) -> QueryEvent:
48
- llm = get_llm()
49
  prompt_template = RichPromptTemplate(PLANING_PROMPT)
50
  file_extension = Path(ev.additional_file_path).suffix if ev.additional_file_path else ""
51
  plan = llm.complete(prompt_template.format(
@@ -75,7 +76,7 @@ class GaiaWorkflow(Workflow):
75
 
76
  @step
77
  async def parse_answer(self, ev: AnswerEvent) -> StopEvent:
78
- llm = get_llm()
79
  prompt_template = RichPromptTemplate(FORMAT_ANSWER)
80
  pattern = r"<Question> :\s*(.*)[\n$]"
81
  search = re.search(pattern, ev.plan)
@@ -145,7 +146,7 @@ gaia_solving_agent = FunctionAgent(
145
  *simple_web_page_reader_toolspec.to_tool_list(),
146
  *RequestsToolSpec().to_tool_list(),
147
  ],
148
- llm=get_llm(),
149
  system_prompt="""
150
  You are a helpful assistant that uses tools to browse additional information and resources on the web to answer questions.
151
  """,
 
15
  from gaia_solving_agent.tools import tavily_search_web, wikipedia_tool_spec
16
 
17
  # Choice of the model
18
+ cheap_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
19
+ balanced_model_name = "meta-llama/Meta-Llama-3.1-70B-Instruct"
20
+ reasoning_model_name = "deepseek-ai/DeepSeek-R1-0528"
21
+ vlm_model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" # For VLM needs
22
 
23
+ def get_llm(model_name=cheap_model_name):
24
  return NebiusLLM(
25
  model=model_name,
26
  api_key=NEBIUS_API_KEY,
 
46
  class GaiaWorkflow(Workflow):
47
  @step
48
  async def setup(self, ev: StartEvent) -> QueryEvent:
49
+ llm = get_llm(reasoning_model_name)
50
  prompt_template = RichPromptTemplate(PLANING_PROMPT)
51
  file_extension = Path(ev.additional_file_path).suffix if ev.additional_file_path else ""
52
  plan = llm.complete(prompt_template.format(
 
76
 
77
  @step
78
  async def parse_answer(self, ev: AnswerEvent) -> StopEvent:
79
+ llm = get_llm(balanced_model_name)
80
  prompt_template = RichPromptTemplate(FORMAT_ANSWER)
81
  pattern = r"<Question> :\s*(.*)[\n$]"
82
  search = re.search(pattern, ev.plan)
 
146
  *simple_web_page_reader_toolspec.to_tool_list(),
147
  *RequestsToolSpec().to_tool_list(),
148
  ],
149
+ llm=get_llm(balanced_model_name),
150
  system_prompt="""
151
  You are a helpful assistant that uses tools to browse additional information and resources on the web to answer questions.
152
  """,