File size: 4,023 Bytes
c908185 379f247 c908185 379f247 c908185 379f247 c908185 3f77561 379f247 c908185 379f247 c908185 4c94f7a c908185 379f247 c908185 379f247 c908185 4c94f7a c908185 379f247 c908185 379f247 4c94f7a 3f77561 4c94f7a c908185 4c94f7a c908185 379f247 c908185 379f247 c908185 379f247 c908185 379f247 4c94f7a 379f247 4c94f7a c908185 379f247 4c94f7a c908185 379f247 c908185 379f247 c908185 379f247 c908185 379f247 c908185 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import Any, Callable, List, Literal
import yaml
from langchain.agents.agent import AgentExecutor
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import create_react_agent
from tools import (
add,
arxiv_search,
create_handoff_tool,
div,
internet_search,
mod,
mult,
retriever_tool,
sub,
wiki_search,
)
from utils import pretty_print_messages
def load_prompt(name: str) -> str:
with open("prompts.yaml", "r") as f:
prompts = yaml.safe_load(f)
return prompts[name]
def create_llm(
model: Literal["groq", "openai", "openai-nano"] = "openai",
) -> BaseChatModel:
match (model):
case "groq":
return ChatGroq(model="qwen-qwq-32b", temperature=0)
case "openai":
return ChatOpenAI(model="gpt-4.1", temperature=0)
case "openai-nano":
return ChatOpenAI(model="gpt-4.1-nano", temperature=0)
def create_agent(
llm: BaseChatModel, tools: List[Any], prompt_name: str, name: str
) -> AgentExecutor:
return create_react_agent(
model=llm, tools=tools, prompt=load_prompt(prompt_name), name=name
)
def create_supervisor_agent(llm: BaseChatModel) -> AgentExecutor:
assign_to_retriever_agent = create_handoff_tool(
agent_name="retriever_agent",
description="Assign task to a retriever agent for searching through documents.",
)
assign_to_research_agent = create_handoff_tool(
agent_name="research_agent",
description="Assign task to a researcher agent.",
)
assign_to_math_agent = create_handoff_tool(
agent_name="math_agent",
description="Assign task to a math agent.",
)
return create_agent(
llm=llm,
tools=[
assign_to_retriever_agent,
assign_to_research_agent,
assign_to_math_agent,
],
prompt_name="supervisor_prompt",
name="supervisor",
)
def create_workflow() -> Callable:
llm = create_llm()
retriever_agent = create_agent(
llm=create_llm("openai-nano"),
tools=[retriever_tool],
prompt_name="retriever_prompt",
name="retriever_agent",
)
research_agent = create_agent(
llm=llm,
tools=[internet_search, wiki_search, arxiv_search],
prompt_name="web_research_prompt",
name="research_agent",
)
math_agent = create_agent(
llm=llm,
tools=[add, sub, mult, div, mod],
prompt_name="math_prompt",
name="math_agent",
)
supervisor_agent = create_supervisor_agent(llm)
workflow = StateGraph(MessagesState)
workflow.add_node(
supervisor_agent,
destinations=("retriever_agent", "research_agent", "math_agent", END),
)
workflow.add_node(retriever_agent)
workflow.add_node(research_agent)
workflow.add_node(math_agent)
workflow.add_edge(START, "supervisor")
workflow.add_edge("retriever_agent", "supervisor")
workflow.add_edge("research_agent", "supervisor")
workflow.add_edge("math_agent", "supervisor")
return workflow.compile()
class BasicAgent:
def __init__(self) -> None:
print("BasicAgent initialized.")
self.graph = create_workflow()
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
initial_messages = [HumanMessage(content=question)]
final_messages = None
for chunk in self.graph.stream({"messages": initial_messages}):
pretty_print_messages(chunk)
final_messages = chunk
if final_messages is None:
raise RuntimeError("No messages were generated during processing")
return final_messages["supervisor"]["messages"][-1].content
|