File size: 4,023 Bytes
c908185
379f247
c908185
 
 
379f247
c908185
 
379f247
 
 
c908185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f77561
 
 
 
 
 
 
 
 
 
379f247
 
c908185
 
 
 
 
 
379f247
 
c908185
4c94f7a
 
 
 
 
c908185
 
 
 
379f247
c908185
 
 
 
379f247
c908185
 
4c94f7a
 
 
 
 
c908185
 
379f247
 
 
c908185
 
379f247
4c94f7a
3f77561
4c94f7a
 
 
 
 
c908185
 
4c94f7a
c908185
 
 
379f247
c908185
 
 
 
 
379f247
 
c908185
379f247
c908185
379f247
 
4c94f7a
 
379f247
4c94f7a
c908185
 
379f247
4c94f7a
c908185
 
379f247
 
 
 
 
 
 
c908185
379f247
 
c908185
 
 
 
 
 
 
 
379f247
c908185
 
379f247
c908185
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import Any, Callable, List, Literal

import yaml
from langchain.agents.agent import AgentExecutor
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import create_react_agent

from tools import (
    add,
    arxiv_search,
    create_handoff_tool,
    div,
    internet_search,
    mod,
    mult,
    retriever_tool,
    sub,
    wiki_search,
)
from utils import pretty_print_messages


def load_prompt(name: str) -> str:
    with open("prompts.yaml", "r") as f:
        prompts = yaml.safe_load(f)
    return prompts[name]


def create_llm(
    model: Literal["groq", "openai", "openai-nano"] = "openai",
) -> BaseChatModel:
    match (model):
        case "groq":
            return ChatGroq(model="qwen-qwq-32b", temperature=0)
        case "openai":
            return ChatOpenAI(model="gpt-4.1", temperature=0)
        case "openai-nano":
            return ChatOpenAI(model="gpt-4.1-nano", temperature=0)


def create_agent(
    llm: BaseChatModel, tools: List[Any], prompt_name: str, name: str
) -> AgentExecutor:
    return create_react_agent(
        model=llm, tools=tools, prompt=load_prompt(prompt_name), name=name
    )


def create_supervisor_agent(llm: BaseChatModel) -> AgentExecutor:
    assign_to_retriever_agent = create_handoff_tool(
        agent_name="retriever_agent",
        description="Assign task to a retriever agent for searching through documents.",
    )

    assign_to_research_agent = create_handoff_tool(
        agent_name="research_agent",
        description="Assign task to a researcher agent.",
    )

    assign_to_math_agent = create_handoff_tool(
        agent_name="math_agent",
        description="Assign task to a math agent.",
    )

    return create_agent(
        llm=llm,
        tools=[
            assign_to_retriever_agent,
            assign_to_research_agent,
            assign_to_math_agent,
        ],
        prompt_name="supervisor_prompt",
        name="supervisor",
    )


def create_workflow() -> Callable:
    llm = create_llm()

    retriever_agent = create_agent(
        llm=create_llm("openai-nano"),
        tools=[retriever_tool],
        prompt_name="retriever_prompt",
        name="retriever_agent",
    )

    research_agent = create_agent(
        llm=llm,
        tools=[internet_search, wiki_search, arxiv_search],
        prompt_name="web_research_prompt",
        name="research_agent",
    )

    math_agent = create_agent(
        llm=llm,
        tools=[add, sub, mult, div, mod],
        prompt_name="math_prompt",
        name="math_agent",
    )

    supervisor_agent = create_supervisor_agent(llm)

    workflow = StateGraph(MessagesState)

    workflow.add_node(
        supervisor_agent,
        destinations=("retriever_agent", "research_agent", "math_agent", END),
    )
    workflow.add_node(retriever_agent)
    workflow.add_node(research_agent)
    workflow.add_node(math_agent)
    workflow.add_edge(START, "supervisor")
    workflow.add_edge("retriever_agent", "supervisor")
    workflow.add_edge("research_agent", "supervisor")
    workflow.add_edge("math_agent", "supervisor")

    return workflow.compile()


class BasicAgent:
    def __init__(self) -> None:
        print("BasicAgent initialized.")
        self.graph = create_workflow()

    def __call__(self, question: str) -> str:
        print(f"Agent received question (first 50 chars): {question[:50]}...")

        initial_messages = [HumanMessage(content=question)]
        final_messages = None

        for chunk in self.graph.stream({"messages": initial_messages}):
            pretty_print_messages(chunk)
            final_messages = chunk

        if final_messages is None:
            raise RuntimeError("No messages were generated during processing")

        return final_messages["supervisor"]["messages"][-1].content