File size: 7,816 Bytes
2d02aba
5ab1e4f
88b51a3
daf4564
88b51a3
 
449333b
 
5ab1e4f
88b51a3
 
449333b
88b51a3
5ab1e4f
 
 
 
 
449333b
 
 
470fd47
5ab1e4f
 
88b51a3
 
e04414b
e777bbb
e04414b
30af36c
e04414b
30af36c
449333b
 
 
 
 
30af36c
 
 
 
 
 
449333b
3abb547
 
 
 
 
 
 
88b51a3
e04414b
88b51a3
 
 
 
 
 
 
 
 
 
3abb547
5ab1e4f
 
 
5336c20
88b51a3
3abb547
88b51a3
5ab1e4f
88b51a3
3abb547
88b51a3
 
 
 
 
 
 
5ab1e4f
9c4e981
 
 
5ab1e4f
 
 
5336c20
9c4e981
 
5ab1e4f
30af36c
2d02aba
5ab1e4f
 
 
8332d06
 
88b51a3
5ab1e4f
 
 
 
5336c20
 
5ab1e4f
 
 
5336c20
5ab1e4f
 
 
5336c20
5ab1e4f
 
 
5336c20
 
 
5ab1e4f
 
 
9c4e981
5ab1e4f
 
 
 
 
 
 
 
5336c20
5ab1e4f
9c4e981
 
 
5ab1e4f
 
 
88b51a3
5ab1e4f
9c4e981
 
5ab1e4f
9c4e981
 
 
5ab1e4f
 
 
 
 
 
 
 
 
 
88b51a3
 
 
 
daf4564
9c4e981
8332d06
daf4564
8332d06
daf4564
8332d06
5ab1e4f
88b51a3
 
5ab1e4f
e04414b
88b51a3
9c4e981
5ab1e4f
 
 
 
88b51a3
 
c4d1642
3abb547
449333b
5ab1e4f
c4d1642
470fd47
c4d1642
5ab1e4f
449333b
 
c4d1642
30af36c
c4d1642
 
3abb547
5ab1e4f
 
 
 
3abb547
5ab1e4f
449333b
 
5ab1e4f
3abb547
5ab1e4f
 
abc07a2
c4d1642
 
ea9851f
daf4564
ea9851f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from pathlib import Path
from typing import Literal

from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.prompts import RichPromptTemplate
from llama_index.llms.nebius import NebiusLLM
from llama_index.llms.mistralai import MistralAI
from llama_index.llms.openai import OpenAI
from workflows import Workflow, step, Context
from workflows.events import StartEvent, Event, StopEvent

from gaia_solving_agent import NEBIUS_API_KEY, MISTRAL_API_KEY, OPENAI_API_KEY
from gaia_solving_agent.prompts import PLANING_PROMPT, FORMAT_ANSWER
from gaia_solving_agent.tools import (
    tavily_search_web,
    simple_web_page_reader_toolspec,
    vllm_ask_image_tool,
    youtube_transcript_reader_toolspec,
    text_content_analysis,
    research_paper_reader_toolspec,
    get_text_representation_of_additional_file,
    wikipedia_toolspec,
)
from gaia_solving_agent.utils import extract_pattern

# Choice of the model
cheap_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
light_model_name = "Qwen/Qwen2.5-32B-Instruct"
balanced_model_name = "meta-llama/Meta-Llama-3.1-70B-Instruct"
reasoning_model_name = "Qwen/Qwen3-235B-A22B"
vlm_model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"  # For VLM needs
openai_reasoning = OpenAI(
    model="gpt-4.1",
    api_key=OPENAI_API_KEY,
    temperature=.1,
    max_retries=5,
)
openai_llm = OpenAI(
    model="gpt-4.1-nano",
    api_key=OPENAI_API_KEY,
    temperature=.1,
    max_retries=5,
)
mistral_llm = MistralAI(
    model="mistral-small-latest",
    api_key=MISTRAL_API_KEY,
    temperature=.1,
    max_retries=5,
    # is_function_calling_model=True,
)


def get_llm(model_name=cheap_model_name):
    return NebiusLLM(
        model=model_name,
        api_key=NEBIUS_API_KEY,
        is_function_calling_model=True,
        max_completion_tokens=10000,
        context_window=80000,  # max = 128000 for "meta-llama/Meta-Llama-3.1-8B-Instruct"
        temperature=.1,
        max_retries=5,
    )


class PlanEvent(Event):
    to_do: Literal["Initialize", "Format", "Replan"] = "Initialize"
    plan: str | None = None
    n_retries: int = 0


class QueryEvent(Event):
    pass


class AnswerEvent(Event):
    plan: str
    answer: str


class GaiaWorkflow(Workflow):
    @step
    async def setup(self, ctx: Context, ev: StartEvent) -> PlanEvent:
        await ctx.store.set("user_msg", ev.user_msg)
        await ctx.store.set("additional_file", ev.additional_file)
        await ctx.store.set("additional_file_path", ev.additional_file_path)
        return PlanEvent()

    @step
    async def make_plan(self, ctx: Context, ev: PlanEvent) -> PlanEvent | QueryEvent | StopEvent:
        additional_file_path = await ctx.store.get("additional_file_path")
        user_msg = await ctx.store.get("user_msg")

        llm = openai_reasoning
        prompt_template = RichPromptTemplate(PLANING_PROMPT)
        file_extension = Path(additional_file_path).suffix if additional_file_path else ""
        prompt = prompt_template.format(
            user_request=user_msg,
            additional_file_extension=file_extension,
        )

        if ev.to_do == "Replan":
            ...
            # TODO : Placeholder for future update
        elif ev.to_do == "Format":
            if ev.n_retries > 3:
                return StopEvent(result="Cannot provide a plan. Format may be wrong.", reasoning=ev.plan)
            prompt = f"""
The original plan is not in the correct format.

______________
There is the query and constraints you must respect :
{prompt}

______________
There is the original plan you must reformat :
{ev.plan}

______________
Ask yourself what you did wrong and fix it.
Stick strictly to the formatting constraints !
"""

        plan = llm.complete(prompt)
        await ctx.store.set("plan", plan.text)

        question = extract_pattern(pattern=r"<Question> :\s*([\s\S]*?)\s*</Question>", text=plan.text)
        known_facts = extract_pattern(pattern=r"<Known facts> :\s*([\s\S]*?)\s*</Known facts>", text=plan.text)
        sub_tasks = extract_pattern(pattern=r"<Sub-tasks> :\s*([\s\S]*?)\s*<\/Sub-tasks>", text=plan.text)
        if any(
            extracted is None
            for extracted in [question, known_facts, sub_tasks]
        ):
            return PlanEvent(to_do="Format", plan=plan.text, n_retries=ev.n_retries + 1)
        else:
            await ctx.store.set("question", question if question is not None else "")
            await ctx.store.set("known_facts", known_facts if known_facts is not None else "")
            await ctx.store.set("sub_tasks", sub_tasks if sub_tasks is not None else "")

            return QueryEvent()

    @step()
    async def multi_agent_process(self, ctx: Context, ev: QueryEvent) -> AnswerEvent:
        plan = await ctx.store.get("plan")
        additional_file = await ctx.store.get("additional_file")

        question = await ctx.store.get("question")
        known_facts = await ctx.store.get("known_facts")
        sub_tasks = await ctx.store.get("sub_tasks")
        prompt = f"""
The question is : {question}

The known facts are :
{known_facts}

The sub-tasks are :
{sub_tasks}
"""

        # Cheap trick to avoid Error 400 errors from OpenAPI
        from llama_index.core.memory import ChatMemoryBuffer
        memory = ChatMemoryBuffer.from_defaults(token_limit=100000)

        agent_ctx = Context(gaia_solving_agent)
        await agent_ctx.store.set("additional_file", additional_file)
        agent_output = await gaia_solving_agent.run(
            user_msg=prompt,
            memory=memory,
            ctx=agent_ctx,
        )
        return AnswerEvent(plan=plan, answer=str(agent_output))

    @step
    async def parse_answer(self, ctx: Context, ev: AnswerEvent) -> StopEvent:
        llm = get_llm(balanced_model_name)
        prompt_template = RichPromptTemplate(FORMAT_ANSWER)
        question = await ctx.store.get("question")
        prompt = prompt_template.format(question=question, answer=ev.answer)
        result = llm.complete(prompt)

        return StopEvent(result=result.text, reasoning=ev.plan)


gaia_solving_agent = FunctionAgent(
    tools=[
        get_text_representation_of_additional_file,
        vllm_ask_image_tool,
        tavily_search_web,
        *wikipedia_toolspec,
        *simple_web_page_reader_toolspec.to_tool_list(),
        *youtube_transcript_reader_toolspec.to_tool_list(),
        *research_paper_reader_toolspec.to_tool_list(),
        text_content_analysis,
    ],
    llm=get_llm(reasoning_model_name),
    system_prompt="""
    You are a helpful assistant that uses tools to browse additional information and resources on the web to answer questions.

    Tools you have are of three types:
    - External resources getter: get text, images, video, etc. from the internet
    - Resource querier and transformer: query, summarize or transform a resource into a more digestible format.
    - Analyse or compute : specialized tools to provide a specific analysis or computation.

    Try to get resources before querying them.
    If it is an additional file, you can access its content through the get_text_representation_of_additional_file tool.
    If you need the original Document, you can use the llamaindex context with ctx.store.get("additional_file").
    If the analysis require a new external resource get it first.(e.g. a set of rules or a process)

    You will be provided a question, some known facts summarizing the user provided context and some sub-tasks to complete.
    You should follow the order of the sub-tasks.
    If the tools provides facts that go against your knowledge, you should not use them.
    """,
    name="gaia_solving_agent",
    description="Agent that browse additional information and resources on the web.",
    allow_parallel_tool_calls=False,
)