lwant commited on
Commit
8332d06
Β·
1 Parent(s): 2d02aba

Add support for handling `additional_file` and `additional_file_path` in workflow agents, update `setup` method, and adapt prompt template accordingly

Browse files
src/gaia_solving_agent/agent.py CHANGED
@@ -34,7 +34,7 @@ def get_llm(model_name=model_name):
34
  class QueryEvent(Event):
35
  query: str
36
  additional_file: Any | None
37
- additional_file_path: str = ""
38
  plan: str
39
 
40
  class AnswerEvent(Event):
@@ -47,11 +47,17 @@ class GaiaWorkflow(Workflow):
47
  async def setup(self, ev: StartEvent) -> QueryEvent:
48
  llm = get_llm()
49
  prompt_template = RichPromptTemplate(PLANING_PROMPT)
 
50
  plan = llm.complete(prompt_template.format(
51
- user_request=ev.query,
52
- additional_file_extension=Path(ev.additional_file_name).suffix,
53
  ))
54
- return QueryEvent(query=ev.query, additional_file=ev.additional_file, plan=plan.text)
 
 
 
 
 
55
 
56
  @step()
57
  async def multi_agent_process(self, ev: QueryEvent) -> AnswerEvent:
@@ -59,7 +65,12 @@ class GaiaWorkflow(Workflow):
59
  from llama_index.core.memory import ChatMemoryBuffer
60
  memory = ChatMemoryBuffer.from_defaults(token_limit=100000)
61
 
62
- agent_output = await gaia_solving_agent.run(user_msg=ev.plan, memory=memory)
 
 
 
 
 
63
  return AnswerEvent(plan=ev.plan, answer=str(agent_output))
64
 
65
  @step
 
34
  class QueryEvent(Event):
35
  query: str
36
  additional_file: Any | None
37
+ additional_file_path: str | Path | None = None
38
  plan: str
39
 
40
  class AnswerEvent(Event):
 
47
  async def setup(self, ev: StartEvent) -> QueryEvent:
48
  llm = get_llm()
49
  prompt_template = RichPromptTemplate(PLANING_PROMPT)
50
+ file_extension = Path(ev.additional_file_path).suffix if ev.additional_file_path else ""
51
  plan = llm.complete(prompt_template.format(
52
+ user_request=ev.user_msg,
53
+ additional_file_extension=file_extension,
54
  ))
55
+ return QueryEvent(
56
+ query=ev.user_msg,
57
+ additional_file=ev.additional_file,
58
+ additional_file_path=ev.additional_file_path,
59
+ plan=plan.text,
60
+ )
61
 
62
  @step()
63
  async def multi_agent_process(self, ev: QueryEvent) -> AnswerEvent:
 
65
  from llama_index.core.memory import ChatMemoryBuffer
66
  memory = ChatMemoryBuffer.from_defaults(token_limit=100000)
67
 
68
+ agent_output = await gaia_solving_agent.run(
69
+ user_msg=ev.plan,
70
+ memory=memory,
71
+ additional_file=ev.additional_file,
72
+ additional_file_path=ev.additional_file_path,
73
+ )
74
  return AnswerEvent(plan=ev.plan, answer=str(agent_output))
75
 
76
  @step
src/gaia_solving_agent/hf_submission_api.py CHANGED
@@ -82,9 +82,17 @@ async def run_agent(agent, questions_data):
82
  continue
83
  try:
84
  if isinstance(agent, BaseWorkflowAgent):
85
- submitted_answer = await agent.run(user_msg=question_text)
 
 
 
 
86
  elif isinstance(agent, Workflow):
87
- submitted_answer = await agent.run(query=question_text)
 
 
 
 
88
  else:
89
  raise NotImplementedError(f"Invalid agent type: {type(agent)}")
90
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
 
82
  continue
83
  try:
84
  if isinstance(agent, BaseWorkflowAgent):
85
+ submitted_answer = await agent.run(
86
+ user_msg=question_text,
87
+ additional_file=additional_file,
88
+ additional_file_path=additional_file_path,
89
+ )
90
  elif isinstance(agent, Workflow):
91
+ submitted_answer = await agent.run(
92
+ user_msg=question_text,
93
+ additional_file=additional_file,
94
+ additional_file_path=additional_file_path,
95
+ )
96
  else:
97
  raise NotImplementedError(f"Invalid agent type: {type(agent)}")
98
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
src/gaia_solving_agent/prompts.py CHANGED
@@ -32,7 +32,7 @@ Respond with exactly:
32
  3. Sub-tasks:
33
  {% if additional_file_extension|length -%}
34
  - There is an additional file to analyse. It's file extension is {{ additional_file_extension }}. You must analyse it to find out
35
- {%- else %}
36
  - xxxxxxx
37
  - xxxxxxx
38
  '''
 
32
  3. Sub-tasks:
33
  {% if additional_file_extension|length -%}
34
  - There is an additional file to analyse. It's file extension is {{ additional_file_extension }}. You must analyse it to find out
35
+ {%- endif %}
36
  - xxxxxxx
37
  - xxxxxxx
38
  '''