from typing import Any, Callable, List, Literal import yaml from langchain.agents.agent import AgentExecutor from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import create_react_agent from tools import ( add, arxiv_search, create_handoff_tool, div, internet_search, mod, mult, retriever_tool, sub, wiki_search, ) from utils import pretty_print_messages def load_prompt(name: str) -> str: with open("prompts.yaml", "r") as f: prompts = yaml.safe_load(f) return prompts[name] def create_llm( model: Literal["groq", "openai", "openai-nano"] = "openai", ) -> BaseChatModel: match (model): case "groq": return ChatGroq(model="qwen-qwq-32b", temperature=0) case "openai": return ChatOpenAI(model="gpt-4.1", temperature=0) case "openai-nano": return ChatOpenAI(model="gpt-4.1-nano", temperature=0) def create_agent( llm: BaseChatModel, tools: List[Any], prompt_name: str, name: str ) -> AgentExecutor: return create_react_agent( model=llm, tools=tools, prompt=load_prompt(prompt_name), name=name ) def create_supervisor_agent(llm: BaseChatModel) -> AgentExecutor: assign_to_retriever_agent = create_handoff_tool( agent_name="retriever_agent", description="Assign task to a retriever agent for searching through documents.", ) assign_to_research_agent = create_handoff_tool( agent_name="research_agent", description="Assign task to a researcher agent.", ) assign_to_math_agent = create_handoff_tool( agent_name="math_agent", description="Assign task to a math agent.", ) return create_agent( llm=llm, tools=[ assign_to_retriever_agent, assign_to_research_agent, assign_to_math_agent, ], prompt_name="supervisor_prompt", name="supervisor", ) def create_workflow() -> Callable: llm = create_llm() retriever_agent = create_agent( llm=create_llm("openai-nano"), tools=[retriever_tool], prompt_name="retriever_prompt", name="retriever_agent", ) research_agent = create_agent( llm=llm, tools=[internet_search, wiki_search, arxiv_search], prompt_name="web_research_prompt", name="research_agent", ) math_agent = create_agent( llm=llm, tools=[add, sub, mult, div, mod], prompt_name="math_prompt", name="math_agent", ) supervisor_agent = create_supervisor_agent(llm) workflow = StateGraph(MessagesState) workflow.add_node( supervisor_agent, destinations=("retriever_agent", "research_agent", "math_agent", END), ) workflow.add_node(retriever_agent) workflow.add_node(research_agent) workflow.add_node(math_agent) workflow.add_edge(START, "supervisor") workflow.add_edge("retriever_agent", "supervisor") workflow.add_edge("research_agent", "supervisor") workflow.add_edge("math_agent", "supervisor") return workflow.compile() class BasicAgent: def __init__(self) -> None: print("BasicAgent initialized.") self.graph = create_workflow() def __call__(self, question: str) -> str: print(f"Agent received question (first 50 chars): {question[:50]}...") initial_messages = [HumanMessage(content=question)] final_messages = None for chunk in self.graph.stream({"messages": initial_messages}): pretty_print_messages(chunk) final_messages = chunk if final_messages is None: raise RuntimeError("No messages were generated during processing") return final_messages["supervisor"]["messages"][-1].content