GodSpeed / agent /graph.py
AdithyaVardan's picture
feat: add confluence/slack search tools, chat history, cloud Qdrant support, sync trigger fixes
68af3c5
"""LangGraph graph definition — nodes, edges, and parallel execution."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from langgraph.graph import END, StateGraph
from agent.agents.guardrail import run_guardrail
from agent.agents.planner import run_planner
from agent.agents.synthesiser import stream_synthesis
from agent.models import AgentResult, KnowledgeGraphState, RetrievedChunk
from agent.tools.confluence_search import run_confluence_search
from agent.tools.doc_search import compute_retrieval_confidence, run_doc_search
from agent.tools.live_docs import run_live_docs
from agent.tools.slack_search import run_slack_search
from agent.tools.sql_query import run_sql_query
from agent.tools.ticket_lookup import run_ticket_lookup
logger = logging.getLogger(__name__)
async def _push_event(queue: asyncio.Queue, event: str, data: Any) -> None:
if queue is not None:
await queue.put({"event": event, "data": data})
async def planner_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
plan = await run_planner(state.query_input)
await _push_event(
queue,
"plan_ready",
{
"tasks": [t.model_dump() for t in plan.tasks],
"reasoning": plan.reasoning,
},
)
return {"execution_plan": plan}
async def doc_search_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "agent_started", {"agent": "doc_search"})
task_input = _find_task_input(state, "doc_search") or state.query_input.query
chunks: list[RetrievedChunk] = []
error: str | None = None
try:
chunks = await run_doc_search(
task_input,
state.query_input.team_id,
state.query_input.allowed_channel_ids or None,
)
except Exception as exc:
logger.exception("doc_search_node error")
error = str(exc)
confidence = compute_retrieval_confidence(chunks)
result = AgentResult(
agent="doc_search",
chunks=chunks,
retrieval_confidence=confidence,
error=error,
)
await _push_event(
queue,
"agent_done",
{"agent": "doc_search", "retrieval_confidence": confidence},
)
return {"agent_results": {**state.agent_results, "doc_search": result}}
async def ticket_lookup_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "agent_started", {"agent": "ticket_lookup"})
task_input = _find_task_input(state, "ticket_lookup") or state.query_input.query
chunks: list[RetrievedChunk] = []
error: str | None = None
try:
chunks = await run_ticket_lookup(task_input, state.query_input.team_id)
except Exception as exc:
logger.exception("ticket_lookup_node error")
error = str(exc)
confidence = compute_retrieval_confidence(chunks)
result = AgentResult(
agent="ticket_lookup",
chunks=chunks,
retrieval_confidence=confidence,
error=error,
)
await _push_event(
queue,
"agent_done",
{"agent": "ticket_lookup", "retrieval_confidence": confidence},
)
return {"agent_results": {**state.agent_results, "ticket_lookup": result}}
async def confluence_search_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "agent_started", {"agent": "confluence_search"})
task_input = _find_task_input(state, "confluence_search") or state.query_input.query
chunks: list[RetrievedChunk] = []
error: str | None = None
try:
chunks = await run_confluence_search(task_input, state.query_input.team_id)
except Exception as exc:
logger.exception("confluence_search_node error")
error = str(exc)
confidence = compute_retrieval_confidence(chunks)
result = AgentResult(
agent="confluence_search",
chunks=chunks,
retrieval_confidence=confidence,
error=error,
)
await _push_event(queue, "agent_done", {"agent": "confluence_search", "retrieval_confidence": confidence})
return {"agent_results": {**state.agent_results, "confluence_search": result}}
async def slack_search_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "agent_started", {"agent": "slack_search"})
task_input = _find_task_input(state, "slack_search") or state.query_input.query
chunks: list[RetrievedChunk] = []
error: str | None = None
try:
chunks = await run_slack_search(task_input, state.query_input.team_id)
except Exception as exc:
logger.exception("slack_search_node error")
error = str(exc)
confidence = compute_retrieval_confidence(chunks)
result = AgentResult(
agent="slack_search",
chunks=chunks,
retrieval_confidence=confidence,
error=error,
)
await _push_event(queue, "agent_done", {"agent": "slack_search", "retrieval_confidence": confidence})
return {"agent_results": {**state.agent_results, "slack_search": result}}
async def live_docs_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "agent_started", {"agent": "live_docs"})
task_input = _find_task_input(state, "live_docs") or state.query_input.query
chunks: list[RetrievedChunk] = []
error: str | None = None
try:
chunks = await run_live_docs(task_input, state.query_input.team_id)
except Exception as exc:
logger.exception("live_docs_node error")
error = str(exc)
confidence = compute_retrieval_confidence(chunks)
result = AgentResult(
agent="live_docs",
chunks=chunks,
retrieval_confidence=confidence,
error=error,
)
await _push_event(
queue,
"agent_done",
{"agent": "live_docs", "retrieval_confidence": confidence},
)
return {"agent_results": {**state.agent_results, "live_docs": result}}
async def sql_query_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "agent_started", {"agent": "sql_query"})
task_input = _find_task_input(state, "sql_query") or state.query_input.query
chunks: list[RetrievedChunk] = []
error: str | None = None
try:
chunks = await run_sql_query(task_input, state.query_input.team_id)
except Exception as exc:
logger.exception("sql_query_node error")
error = str(exc)
confidence = compute_retrieval_confidence(chunks)
result = AgentResult(
agent="sql_query",
chunks=chunks,
retrieval_confidence=confidence,
error=error,
)
await _push_event(
queue,
"agent_done",
{"agent": "sql_query", "retrieval_confidence": confidence},
)
return {"agent_results": {**state.agent_results, "sql_query": result}}
async def synthesiser_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
await _push_event(queue, "synthesis_started", {})
full_answer_parts: list[str] = []
async for token in stream_synthesis(state.query_input.query, state.agent_results):
full_answer_parts.append(token)
await _push_event(queue, "answer_chunk", {"chunk": token})
final_answer = "".join(full_answer_parts)
all_chunks: list[RetrievedChunk] = []
seen: set[str] = set()
for result in state.agent_results.values():
for chunk in result.chunks:
if chunk.chunk_id not in seen:
seen.add(chunk.chunk_id)
all_chunks.append(chunk)
await _push_event(queue, "citations", {"chunks": [c.model_dump() for c in all_chunks]})
return {"final_answer": final_answer, "citations": all_chunks}
async def join_node(state: KnowledgeGraphState) -> dict:
"""Fan-in synchronisation point — waits for all retrieval nodes, then hands off to synthesiser."""
await _push_event(state.sse_queue, "agent_started", {"agent": "synthesiser"})
return {}
async def guardrail_node(state: KnowledgeGraphState) -> dict:
queue = state.sse_queue
score, escalate = await run_guardrail(
state.final_answer or "",
state.citations,
)
await _push_event(
queue,
"guardrail_result",
{"score": score, "escalate": escalate},
)
return {
"guardrail_passed": not escalate,
"guardrail_score": score,
"escalate": escalate,
}
def _find_task_input(state: KnowledgeGraphState, agent: str) -> str | None:
if state.execution_plan is None:
return None
for task in state.execution_plan.tasks:
if task.agent == agent:
return task.input
return None
def _plan_includes(state: KnowledgeGraphState, agent: str) -> bool:
if state.execution_plan is None:
return False
return any(t.agent == agent for t in state.execution_plan.tasks)
def _route_after_planner(state: KnowledgeGraphState) -> list[str]:
if state.execution_plan is None:
return ["synthesiser_node"]
plan = state.execution_plan
immediate: list[str] = []
for task in plan.tasks:
if not task.depends_on:
immediate.append(f"{task.agent}_node")
# If nothing is immediate (shouldn't happen), fall back to synthesiser
return immediate or ["synthesiser_node"]
def _route_after_guardrail(state: KnowledgeGraphState) -> str:
return "escalate" if state.escalate else END
def build_graph() -> Any:
builder = StateGraph(KnowledgeGraphState)
builder.add_node("planner_node", planner_node)
builder.add_node("doc_search_node", doc_search_node)
builder.add_node("ticket_lookup_node", ticket_lookup_node)
builder.add_node("confluence_search_node", confluence_search_node)
builder.add_node("slack_search_node", slack_search_node)
builder.add_node("live_docs_node", live_docs_node)
builder.add_node("sql_query_node", sql_query_node)
builder.add_node("join_node", join_node)
builder.add_node("synthesiser_node", synthesiser_node)
builder.add_node("guardrail_node", guardrail_node)
builder.set_entry_point("planner_node")
builder.add_conditional_edges(
"planner_node",
_route_after_planner,
{
"doc_search_node": "doc_search_node",
"ticket_lookup_node": "ticket_lookup_node",
"confluence_search_node": "confluence_search_node",
"slack_search_node": "slack_search_node",
"live_docs_node": "live_docs_node",
"sql_query_node": "sql_query_node",
"summariser_node": "synthesiser_node",
"synthesiser_node": "synthesiser_node",
},
)
# Retrieval nodes all converge on join_node — LangGraph waits for every
# incoming edge to fire before executing join_node (fan-in).
builder.add_edge("doc_search_node", "join_node")
builder.add_edge("ticket_lookup_node", "join_node")
builder.add_edge("confluence_search_node", "join_node")
builder.add_edge("slack_search_node", "join_node")
builder.add_edge("live_docs_node", "join_node")
builder.add_edge("sql_query_node", "join_node")
builder.add_edge("join_node", "synthesiser_node")
builder.add_edge("synthesiser_node", "guardrail_node")
builder.add_conditional_edges(
"guardrail_node",
_route_after_guardrail,
{END: END, "escalate": END},
)
return builder.compile()
graph = build_graph()