Lucas-C-R commited on
Commit
379f247
·
1 Parent(s): c65bd6f

feat: create agents

Browse files
Files changed (2) hide show
  1. services/__init__.py +3 -0
  2. services/agent_services.py +113 -0
services/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from services.agent_services import BasicAgent
2
+
3
+ __all__ = ["BasicAgent"]
services/agent_services.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Literal
3
+
4
+ from langchain_core.messages import HumanMessage
5
+ from langchain_openai.chat_models import ChatOpenAI
6
+ from langgraph.graph import END, START, MessagesState, StateGraph
7
+ from langgraph.prebuilt import create_react_agent
8
+ from langgraph.types import Command
9
+ from typing_extensions import TypedDict
10
+
11
+ from tools import add, arxiv_search, div, internet_search, mod, mult, sub, wiki_search
12
+
13
+ llm = ChatOpenAI(model="gpt-4.1-mini", api_key=os.getenv("OPENAI_API_KEY"))
14
+
15
+
16
+ class State(MessagesState):
17
+ next: str
18
+
19
+
20
+ def create_supervisor_node(
21
+ members: List[str],
22
+ ) -> Command[Literal["math", "web_search", "__end__"]]:
23
+ options = ["FINISH"] + members
24
+
25
+ class Router(TypedDict):
26
+ next: Literal[*options]
27
+
28
+ def supervisor_node(
29
+ state: State,
30
+ ) -> Command[Literal["math", "web_search", "__end__"]]:
31
+ prompt = f"""
32
+ You are a supervisor tasked with managing a conversation between the
33
+ following workers: {members}. Given the following user request, respond with the worker to act next.
34
+ Each worker will perform a task and respond with their results and status. When finished, respond with FINISH.
35
+
36
+ Guidelines:
37
+ - The final answer must be either a number, a single string, or a comma-separated list of numbers or strings.
38
+ - Do not include units (e.g. %, $, km) or commas inside numbers unless explicitly requested.
39
+ - If you use abbreviations in strings, write out the full expression in parentheses the first time the word appears.
40
+ - Write digits in full words only if asked.
41
+ """
42
+
43
+ messages = [{"role": "system", "content": prompt}] + state["messages"]
44
+
45
+ response = llm.with_structured_output(Router).invoke(messages)
46
+
47
+ goto = response["next"]
48
+ if goto == "FINISH":
49
+ goto = END
50
+
51
+ return Command(goto=goto, update={"next": goto})
52
+
53
+ return supervisor_node
54
+
55
+
56
+ def math_node(state: State) -> Command[Literal["supervisor"]]:
57
+ math_agent = create_react_agent(model=llm, tools=[add, sub, mult, div, mod])
58
+
59
+ result = math_agent.invoke(state)
60
+
61
+ return Command(
62
+ update={
63
+ "messages": [
64
+ HumanMessage(content=result["messages"][-1].content, name="math")
65
+ ]
66
+ },
67
+ goto="supervisor",
68
+ )
69
+
70
+
71
+ def web_search_node(state: State) -> Command[Literal["supervisor"]]:
72
+ search_agent = create_react_agent(
73
+ model=llm, tools=[internet_search, wiki_search, arxiv_search]
74
+ )
75
+
76
+ result = search_agent.invoke(state)
77
+
78
+ return Command(
79
+ update={
80
+ "messages": [
81
+ HumanMessage(content=result["messages"][-1].content, name="web_search")
82
+ ]
83
+ },
84
+ goto="supervisor",
85
+ )
86
+
87
+
88
+ def build_worflow() -> StateGraph:
89
+ workflow = StateGraph(State)
90
+
91
+ workflow.add_node(
92
+ "supervisor", create_supervisor_node(members=["math", "web_search"])
93
+ )
94
+ workflow.add_node("math", math_node)
95
+ workflow.add_node("web_search", web_search_node)
96
+
97
+ workflow.add_edge(START, "supervisor")
98
+
99
+ return workflow.compile()
100
+
101
+
102
+ class BasicAgent:
103
+ def __init__(self) -> None:
104
+ print("BasicAgent initialized.")
105
+ self.graph = build_worflow()
106
+
107
+ def __call__(self, question: str) -> str:
108
+ print(f"Agent received the question: {question[:50]}...")
109
+
110
+ messages = [HumanMessage(content=question)]
111
+ messages = self.graph.invoke({"messages": messages})
112
+
113
+ return messages["messages"][-1].content