ChatWB / graph.py
Levin-Aleksey's picture
Initial commit
eecf451
import os
from typing import Annotated, Literal, TypedDict
from pydantic import BaseModel
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import create_react_agent
# Импорты из соседних файлов
from api.ads import ads_tools
from api.analytics import analytics_tools
from api.finance import finance_tools
from api.logistics import logistics_tools
from prompts import SUPERVISOR_PROMPT, ADS_ANALYST_PROMPT, ANALYTICS_ANALYST_PROMPT, FINANCE_ANALYST_PROMPT, LOGISTICS_ANALYST_PROMPT
# ==========================================
# 1. СТРУКТУРА СОСТОЯНИЯ (STATE)
# ==========================================
class AgentState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
next_node: str
agent_responded: bool
# ==========================================
# 2. ИНИЦИАЛИЗАЦИЯ LLM (OpenRouter)
# ==========================================
llm = ChatOpenAI(
model="google/gemini-2.0-flash-001",
openai_api_key=os.getenv("OPENROUTER_API_KEY"),
openai_api_base="https://openrouter.ai/api/v1",
temperature=0.2
)
# ==========================================
# 3. УЗЕЛ: АНАЛИТИК РЕКЛАМЫ (Суб-агент)
# ==========================================
# Оставляем ТОЛЬКО модель и инструменты. Никаких модификаторов промпта.
ads_agent_runnable = create_react_agent(
model=llm,
tools=ads_tools
)
async def ads_node(state: AgentState):
# Промпт склеиваем с сообщениями вручную перед вызовом агента
messages_with_prompt = [SystemMessage(content=ADS_ANALYST_PROMPT)] + state["messages"]
# Отдаем агенту готовый массив
result = await ads_agent_runnable.ainvoke({"messages": messages_with_prompt})
return {"messages": [result["messages"][-1]], "agent_responded": True}
# ==========================================
# 4. УЗЕЛ: АНАЛИТИК ПРОДАЖ (Суб-агент)
# ==========================================
analytics_agent_runnable = create_react_agent(
model=llm,
tools=analytics_tools
)
async def analytics_node(state: AgentState):
messages_with_prompt = [SystemMessage(content=ANALYTICS_ANALYST_PROMPT)] + state["messages"]
result = await analytics_agent_runnable.ainvoke({"messages": messages_with_prompt})
return {"messages": [result["messages"][-1]], "agent_responded": True}
# ==========================================
# 5. УЗЕЛ: ФИНАНСОВЫЙ АНАЛИТИК (Суб-агент)
# ==========================================
finance_agent_runnable = create_react_agent(
model=llm,
tools=finance_tools
)
async def finance_node(state: AgentState):
messages_with_prompt = [SystemMessage(content=FINANCE_ANALYST_PROMPT)] + state["messages"]
result = await finance_agent_runnable.ainvoke({"messages": messages_with_prompt})
return {"messages": [result["messages"][-1]], "agent_responded": True}
# ==========================================
# 6. УЗЕЛ: АНАЛИТИК ЛОГИСТИКИ (Суб-агент)
# ==========================================
logistics_agent_runnable = create_react_agent(
model=llm,
tools=logistics_tools
)
async def logistics_node(state: AgentState):
messages_with_prompt = [SystemMessage(content=LOGISTICS_ANALYST_PROMPT)] + state["messages"]
result = await logistics_agent_runnable.ainvoke({"messages": messages_with_prompt})
return {"messages": [result["messages"][-1]], "agent_responded": True}
# ==========================================
# 7. УЗЕЛ: СУПЕРВАЙЗЕР (Оркестратор)
# ==========================================
class RouterOutput(BaseModel):
next_node: Literal["Ads_Analyst", "Analytics_Analyst", "Finance_Analyst", "Logistics_Analyst", "FINISH"]
async def supervisor_node(state: AgentState):
# Если агент уже ответил — детерминированно завершаем, без вызова LLM
if state.get("agent_responded"):
return {"next_node": "FINISH", "agent_responded": False}
messages = [SystemMessage(content=SUPERVISOR_PROMPT)] + state["messages"]
response = await llm.with_structured_output(RouterOutput).ainvoke(messages)
return {"next_node": response.next_node, "agent_responded": False}
# ==========================================
# 8. СБОРКА ГРАФА СОСТОЯНИЙ
# ==========================================
workflow = StateGraph(AgentState)
workflow.add_node("Supervisor", supervisor_node)
workflow.add_node("Ads_Analyst", ads_node)
workflow.add_node("Analytics_Analyst", analytics_node)
workflow.add_node("Finance_Analyst", finance_node)
workflow.add_node("Logistics_Analyst", logistics_node)
def router(state: AgentState) -> str:
if state["next_node"] == "FINISH":
return END
return state["next_node"]
workflow.add_edge(START, "Supervisor")
workflow.add_conditional_edges("Supervisor", router)
workflow.add_edge("Ads_Analyst", "Supervisor")
workflow.add_edge("Analytics_Analyst", "Supervisor")
workflow.add_edge("Finance_Analyst", "Supervisor")
workflow.add_edge("Logistics_Analyst", "Supervisor")
# Граф компилируется в app.py после инициализации checkpointer