LeomordKaly's picture
deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)
2a83c3b verified
"""LangGraph graph compilation and execution."""
from __future__ import annotations
import asyncio
import contextlib
import sys
import time
from typing import TYPE_CHECKING, Any
# psycopg's async driver does not support the Proactor event loop (Windows
# default). Switch to the Selector policy at import time so every asyncio.run
# the process spawns picks it up. No-op on POSIX. Must run before any other
# code in this project calls asyncio.run / asyncio.new_event_loop.
if sys.platform == "win32":
with contextlib.suppress(Exception):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from config.settings import settings
from core.agents.evaluator import evaluate_response
from core.agents.faithfulness import check_faithfulness
from core.agents.guardrails import guardrails_check, guardrails_gate
from core.agents.retriever import grade_documents, retrieve_documents, should_retry
from core.agents.router import rewrite_query, route_query
from core.agents.security import check_security, security_gate
from core.agents.synthesizer import synthesize_answer
from core.state import GraphState
from utils.logging import get_logger
from utils.metrics import record_pipeline_run
from utils.observability import trace_graph_execution
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from ingestion.metadata import UserContext
logger = get_logger(__name__)
# Module-level checkpointer cache
_checkpointer: MemorySaver | None = None
def _running_inside_event_loop() -> bool:
"""Return True if we are already inside an active asyncio loop.
Async checkpointers (aiosqlite, psycopg async) bind their connection to
the loop that opened it. Constructing one with ``asyncio.run`` while
another loop is already running raises RuntimeError. We detect that
condition and fall back to MemorySaver so tests / nest_asyncio harnesses
don't fail; production startup paths create the graph from a fresh
synchronous context and get the real persistent saver.
"""
try:
asyncio.get_running_loop()
except RuntimeError:
return False
return True
def _try_async_postgres_saver():
"""Build an ``AsyncPostgresSaver`` bound to the current connection.
Returns the saver on success, or ``None`` if the extras are not
installed, we're inside a running loop, or the connection fails.
"""
if _running_inside_event_loop():
logger.info("postgres_checkpointer_skipped", reason="inside_running_loop")
return None
try:
from langgraph.checkpoint.postgres.aio import ( # type: ignore[import-not-found]
AsyncPostgresSaver,
)
from psycopg_pool import AsyncConnectionPool # type: ignore[import-not-found]
except ImportError:
logger.warning(
"postgres_checkpointer_not_available",
hint="pip install langgraph-checkpoint-postgres 'psycopg[binary,pool]'",
)
return None
async def _open() -> Any:
pool = AsyncConnectionPool(
settings.postgres_url,
min_size=1,
max_size=5,
kwargs={"autocommit": True, "prepare_threshold": 0},
)
await pool.open()
saver = AsyncPostgresSaver(pool)
await saver.setup()
return saver
# Windows event-loop policy is already pinned at module import time
# so a fresh `asyncio.run(_open())` here gets the selector loop.
try:
saver = asyncio.run(_open())
logger.info(
"postgres_checkpointer_initialized",
db=settings.postgres_url.rsplit("/", 1)[-1],
)
return saver
except Exception as exc:
logger.error("postgres_checkpointer_failed", error=str(exc))
return None
def _try_async_sqlite_saver():
"""Build an ``AsyncSqliteSaver`` for local persistent checkpointing.
Returns the saver on success or ``None`` on any failure (missing deps,
inside a running loop, I/O error, etc.).
"""
if _running_inside_event_loop():
logger.info("sqlite_checkpointer_skipped", reason="inside_running_loop")
return None
try:
import pathlib
import aiosqlite
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError:
logger.warning(
"sqlite_checkpointer_not_available",
hint="pip install langgraph-checkpoint-sqlite aiosqlite",
)
return None
db_path = pathlib.Path(settings.checkpoint_db_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
async def _open() -> Any:
conn = await aiosqlite.connect(str(db_path), check_same_thread=False)
saver = AsyncSqliteSaver(conn)
await saver.setup()
return saver
try:
saver = asyncio.run(_open())
logger.info("sqlite_checkpointer_initialized", path=str(db_path))
return saver
except Exception as exc:
logger.error("sqlite_checkpointer_failed", error=str(exc))
return None
def _get_checkpointer():
"""Get or create the LangGraph checkpointer.
Priority (when ``use_persistent_checkpointer`` is True):
1. ``AsyncPostgresSaver`` if ``postgres_url`` is set AND the
``[persistence]`` extras are installed.
2. ``AsyncSqliteSaver`` against ``settings.checkpoint_db_path``.
3. ``MemorySaver`` (conversations lost on restart).
Both async savers refuse to construct from within a running event loop
to avoid cross-loop binding bugs in pytest-asyncio / nest_asyncio
contexts; in those cases we fall back to ``MemorySaver``.
Returns:
Configured checkpointer instance.
"""
global _checkpointer
if _checkpointer is not None:
return _checkpointer
# Persistent checkpointing is opt-in. Default to MemorySaver so the
# graph compiles without external deps and pytest-asyncio's per-test
# event loops don't collide with the async saver's loop-bound state.
if not settings.use_persistent_checkpointer:
_checkpointer = MemorySaver()
logger.info("memory_checkpointer_initialized", reason="persistence_opt_in_disabled")
return _checkpointer
if settings.postgres_url:
saver = _try_async_postgres_saver()
if saver is not None:
_checkpointer = saver
return _checkpointer
saver = _try_async_sqlite_saver()
if saver is not None:
_checkpointer = saver
return _checkpointer
# Final fallback: in-memory (conversations lost on restart)
_checkpointer = MemorySaver()
logger.info("memory_checkpointer_initialized", reason="all_persistent_paths_failed")
return _checkpointer
async def _get_async_checkpointer():
"""Async variant of ``_get_checkpointer`` — safe to call from inside a
running event loop.
The async ``AsyncPostgresSaver`` / ``AsyncSqliteSaver`` cannot be opened
via ``asyncio.run()`` from within another loop. When the pipeline is
invoked from within an already-running loop (Streamlit, FastAPI,
user-supplied ``asyncio.run`` wrappers) we open the saver natively
here and cache it.
"""
global _checkpointer
if _checkpointer is not None and not isinstance(_checkpointer, MemorySaver):
return _checkpointer
if not settings.use_persistent_checkpointer:
_checkpointer = MemorySaver()
return _checkpointer
if settings.postgres_url:
try:
from langgraph.checkpoint.postgres.aio import ( # type: ignore[import-not-found]
AsyncPostgresSaver,
)
from psycopg_pool import AsyncConnectionPool # type: ignore[import-not-found]
pool = AsyncConnectionPool(
settings.postgres_url,
min_size=1,
max_size=5,
kwargs={"autocommit": True, "prepare_threshold": 0},
open=False,
)
await pool.open()
saver = AsyncPostgresSaver(pool)
await saver.setup()
_checkpointer = saver
logger.info(
"postgres_checkpointer_initialized_async",
db=settings.postgres_url.rsplit("/", 1)[-1],
)
return _checkpointer
except ImportError:
logger.warning(
"postgres_checkpointer_not_available",
hint="pip install langgraph-checkpoint-postgres 'psycopg[binary,pool]'",
)
except Exception as exc:
logger.error("postgres_checkpointer_failed_async", error=str(exc))
try:
import pathlib
import aiosqlite
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
db_path = pathlib.Path(settings.checkpoint_db_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = await aiosqlite.connect(str(db_path), check_same_thread=False)
saver = AsyncSqliteSaver(conn)
await saver.setup()
_checkpointer = saver
logger.info("sqlite_checkpointer_initialized_async", path=str(db_path))
return _checkpointer
except ImportError:
logger.warning(
"sqlite_checkpointer_not_available",
hint="pip install langgraph-checkpoint-sqlite aiosqlite",
)
except Exception as exc:
logger.error("sqlite_checkpointer_failed_async", error=str(exc))
_checkpointer = MemorySaver()
return _checkpointer
async def build_rag_graph_async() -> StateGraph:
"""Build the LangGraph workflow with an async-resolved checkpointer.
Equivalent to :func:`build_rag_graph` but suitable for callers that are
already inside an event loop and want a persistent (Postgres / aiosqlite)
saver instead of the MemorySaver fallback ``build_rag_graph`` returns
in that situation.
"""
workflow = _compose_workflow()
checkpointer = await _get_async_checkpointer()
compiled = workflow.compile(checkpointer=checkpointer)
logger.info("rag_graph_compiled_async", nodes=list(workflow.nodes.keys()))
return compiled
def _compose_workflow() -> StateGraph:
"""Build the agent graph structure (no checkpointer attached)."""
workflow = StateGraph(GraphState)
workflow.add_node("router", route_query)
workflow.add_node("guardrails", guardrails_check)
workflow.add_node("security", check_security)
workflow.add_node("retriever", retrieve_documents)
workflow.add_node("grader", grade_documents)
workflow.add_node("rewriter", rewrite_query)
workflow.add_node("synthesizer", synthesize_answer)
workflow.add_node("faithfulness", check_faithfulness)
workflow.add_node("evaluator", evaluate_response)
workflow.add_edge(START, "router")
workflow.add_edge("router", "guardrails")
workflow.add_conditional_edges(
"guardrails",
guardrails_gate,
{"proceed": "security", "blocked": END},
)
workflow.add_conditional_edges(
"security",
security_gate,
{"proceed": "retriever", "blocked": END},
)
workflow.add_edge("retriever", "grader")
workflow.add_conditional_edges(
"grader",
should_retry,
{"rewrite": "rewriter", "generate": "synthesizer"},
)
workflow.add_edge("rewriter", "retriever")
# Faithfulness sits between synth and evaluator so the evaluator's
# confidence math can read faithfulness_ratio directly. When the gate
# is disabled the node is a no-op pass-through.
workflow.add_edge("synthesizer", "faithfulness")
workflow.add_edge("faithfulness", "evaluator")
workflow.add_edge("evaluator", END)
return workflow
def build_rag_graph() -> StateGraph:
"""Build and compile the multi-agent RAG workflow graph.
Creates a StateGraph with the following flow:
START -> router -> guardrails -> security -> [proceed: retriever | blocked: END]
retriever -> grader -> [rewrite: rewriter -> retriever | generate: synthesizer]
synthesizer -> evaluator -> END
Uses the sync checkpointer resolver, which falls back to MemorySaver
when called from inside a running event loop. Production async paths
should use :func:`build_rag_graph_async` instead so the persistent
Postgres / aiosqlite saver can be opened natively in the running loop.
Returns:
Compiled LangGraph StateGraph ready for invocation.
"""
workflow = _compose_workflow()
checkpointer = _get_checkpointer()
compiled = workflow.compile(checkpointer=checkpointer)
logger.info("rag_graph_compiled", nodes=list(workflow.nodes.keys()))
return compiled
def create_initial_state(
query: str,
user_context: UserContext,
prefer_cloud: bool = False,
override_provider: str = "",
persona_style: str = "",
byok_session_id: str = "",
) -> GraphState:
"""Create the proper initial state dict for graph invocation.
Args:
query: The user's natural language query.
user_context: Authenticated user context for RBAC.
prefer_cloud: Whether the caller is willing to route LOW/MEDIUM
sensitivity work to cloud providers. HIGH sensitivity always
stays local regardless.
override_provider: Explicit provider override ("ollama" / "groq" /
"openai" / "anthropic"). Bypasses the sensitivity routing —
intended for admin/debug. Empty string means no override.
Returns:
GraphState dict ready to pass to graph.invoke() or graph.ainvoke().
"""
return {
"query": query,
"user_context": user_context.model_dump(),
"prefer_cloud": prefer_cloud,
"override_provider": override_provider,
"persona_style": persona_style,
"byok_session_id": byok_session_id,
"_stream": False,
"query_type": "",
"rewritten_query": "",
"query_sensitivity": "low",
"guardrails_passed": False,
"guardrails_reason": "",
"security_passed": False,
"security_message": "",
"documents": [],
"relevant_documents": [],
"relevance_ratio": 0.0,
"retry_count": 0,
"max_retries": settings.max_retries,
"generation": "",
"citations": [],
"confidence_score": 0.0,
"synth_provider": "",
"synth_model": "",
"synth_usage": {},
"synth_latency_ms": 0.0,
"needs_human_review": False,
"evaluation_notes": "",
"faithfulness_ratio": 1.0,
"faithfulness_unsupported": [],
"audit_trail": [],
}
def _build_timeout_state(
query: str,
user_context: UserContext,
elapsed_ms: float,
prefer_cloud: bool,
override_provider: str,
) -> GraphState:
"""Synthesize a final-state dict for a request that hit the SLO deadline.
Mirrors the shape of a normal final state so downstream code (UI rendering,
cost dashboard, audit logger) treats it the same as a synthesized answer.
"""
state = create_initial_state(
query, user_context, prefer_cloud=prefer_cloud, override_provider=override_provider
)
state["generation"] = (
"Request exceeded the configured wall-clock budget and was cancelled. "
"Try a shorter query, disable streaming, or raise SAR_REQUEST_TIMEOUT_S."
)
state["citations"] = []
state["confidence_score"] = 0.0
state["needs_human_review"] = True
state["evaluation_notes"] = "request_timeout"
state["audit_trail"] = [
{
"node": "deadline",
"action": "timeout",
"elapsed_ms": elapsed_ms,
"budget_s": settings.request_timeout_s,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
]
return state
async def run_rag_pipeline(
query: str,
user_context: UserContext,
thread_id: str = "default",
prefer_cloud: bool = False,
override_provider: str = "",
persona_style: str = "",
byok_session_id: str = "",
) -> GraphState:
"""Execute the full RAG pipeline and return the final state.
High-level async function that builds the graph, creates initial state,
and invokes the workflow with checkpointing enabled. Bounded by
``settings.request_timeout_s``: on deadline, returns a graceful timeout
state with ``needs_human_review=True`` rather than blocking indefinitely.
Args:
query: The user's natural language query.
user_context: Authenticated user context for RBAC filtering.
thread_id: Thread identifier for checkpointing/session tracking.
Returns:
Final GraphState dict with generation, citations, confidence, etc.
"""
logger.info(
"running_rag_pipeline",
query_len=len(query),
user_id=user_context.user_id,
thread_id=thread_id,
)
start_time = time.perf_counter()
graph = await build_rag_graph_async()
initial_state = create_initial_state(
query,
user_context,
prefer_cloud=prefer_cloud,
override_provider=override_provider,
persona_style=persona_style,
byok_session_id=byok_session_id,
)
config = {"configurable": {"thread_id": thread_id}}
budget = settings.request_timeout_s
try:
if budget and budget > 0:
async with asyncio.timeout(budget):
final_state = await graph.ainvoke(initial_state, config=config)
else:
final_state = await graph.ainvoke(initial_state, config=config)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start_time) * 1000
logger.error(
"rag_pipeline_timeout",
budget_s=budget,
elapsed_ms=elapsed_ms,
user_id=user_context.user_id,
thread_id=thread_id,
)
timeout_state = _build_timeout_state(
query, user_context, elapsed_ms, prefer_cloud, override_provider
)
record_pipeline_run(timeout_state, elapsed_ms)
return timeout_state
elapsed_ms = (time.perf_counter() - start_time) * 1000
# Extract executed nodes from audit trail
nodes_executed = [
entry["node"] for entry in final_state.get("audit_trail", []) if "node" in entry
]
trace_graph_execution(
query=query,
nodes_executed=nodes_executed,
total_latency_ms=elapsed_ms,
final_confidence=final_state.get("confidence_score", 0.0),
retries=final_state.get("retry_count", 0),
)
record_pipeline_run(final_state, elapsed_ms)
logger.info(
"rag_pipeline_completed",
confidence_score=final_state.get("confidence_score", 0.0),
needs_review=final_state.get("needs_human_review", False),
generation_len=len(final_state.get("generation", "")),
latency_ms=elapsed_ms,
)
return final_state
def _apply_audit(state: dict, entries: list[dict] | None) -> None:
"""Append audit entries to mutable state['audit_trail'] in place."""
if not entries:
return
state.setdefault("audit_trail", []).extend(entries)
def _merge_update(state: dict, update: dict) -> None:
"""Merge a node's partial update into state.
Mirrors LangGraph's reducer semantics: audit_trail is appended,
every other field is overwritten.
"""
if not update:
return
audit_extra = update.pop("audit_trail", None)
state.update(update)
if audit_extra:
_apply_audit(state, audit_extra)
async def run_rag_pipeline_stream(
query: str,
user_context: UserContext,
thread_id: str = "default",
prefer_cloud: bool = False,
override_provider: str = "",
persona_style: str = "",
byok_session_id: str = "",
) -> AsyncGenerator[dict, None]:
"""Execute the full RAG pipeline with real token-by-token streaming.
Single source of truth: runs the same compiled LangGraph workflow the
non-streaming path uses via ``graph.astream(stream_mode=["updates",
"custom"])``. Node updates become ``phase`` events; the synthesizer's
``get_stream_writer()`` calls surface as ``token`` events. Blocked
gates and timeouts are detected from the merged state — no parallel
hand-walked graph.
Event types yielded:
{"type": "phase", "name": str, "state": dict} — after each node
{"type": "blocked", "message": str, "state": dict, "latency_ms": float}
{"type": "token", "text": str} — synthesis token
{"type": "final", "state": dict, "latency_ms": float}
Args:
query: Natural language query.
user_context: Authenticated user context for RBAC.
thread_id: Thread identifier for audit/log correlation.
prefer_cloud: Caller opts into cloud providers for LOW/MEDIUM.
override_provider: Admin-only provider pin.
Yields:
Event dicts as described above.
"""
logger.info(
"running_rag_pipeline_stream",
query_len=len(query),
user_id=user_context.user_id,
thread_id=thread_id,
)
start_time = time.perf_counter()
budget = settings.request_timeout_s
graph = await build_rag_graph_async()
initial_state = create_initial_state(
query,
user_context,
prefer_cloud=prefer_cloud,
override_provider=override_provider,
persona_style=persona_style,
byok_session_id=byok_session_id,
)
# Opt the synthesizer into the streaming dispatch path. The flag is
# local to this run and is not part of the public state contract — it
# exists so the synthesizer can deterministically choose call_llm_stream
# over call_llm_with_decision without sniffing framework internals.
initial_state["_stream"] = True
config = {"configurable": {"thread_id": thread_id}}
# Track the merged state as it grows. LangGraph's "updates" stream
# yields one partial dict per node; we apply them locally so we can
# detect blocked gates without waiting for the entire graph.
state: dict = dict(initial_state)
emitted_blocked = False
async def _astream():
async for chunk in graph.astream(
initial_state, config=config, stream_mode=["updates", "custom"]
):
yield chunk
try:
stream_ctx = asyncio.timeout(budget) if budget and budget > 0 else contextlib.nullcontext()
async with stream_ctx:
async for chunk in _astream():
# LangGraph yields (mode, payload) tuples when stream_mode
# is a list.
if not isinstance(chunk, tuple) or len(chunk) != 2:
continue
mode, payload = chunk
if mode == "custom":
# Synthesizer pushes {"type": "token", "text": ...}
# through the writer; relay verbatim.
if isinstance(payload, dict):
yield payload
continue
if mode != "updates":
continue
# `updates` payload is {node_name: partial_state}. Apply
# the partial to our local state and emit a phase event.
if not isinstance(payload, dict):
continue
for node_name, partial in payload.items():
if isinstance(partial, dict):
_merge_update(state, dict(partial))
yield {"type": "phase", "name": node_name, "state": dict(state)}
# Detect blocked gates as soon as they fire.
if (
node_name == "guardrails"
and state.get("guardrails_passed") is False
and not emitted_blocked
):
emitted_blocked = True
yield {
"type": "blocked",
"message": (
"Blocked by guardrails: "
f"{state.get('guardrails_reason', 'prompt_injection')}"
),
"state": dict(state),
"latency_ms": (time.perf_counter() - start_time) * 1000,
}
elif (
node_name == "security"
and state.get("security_passed") is False
and not emitted_blocked
):
emitted_blocked = True
yield {
"type": "blocked",
"message": state.get("security_message", "Blocked by security policy."),
"state": dict(state),
"latency_ms": (time.perf_counter() - start_time) * 1000,
}
except TimeoutError:
elapsed_ms = (time.perf_counter() - start_time) * 1000
logger.error(
"rag_pipeline_stream_timeout",
budget_s=budget,
elapsed_ms=elapsed_ms,
user_id=user_context.user_id,
thread_id=thread_id,
)
_apply_audit(
state,
[
{
"node": "deadline",
"action": "timeout",
"elapsed_ms": elapsed_ms,
"budget_s": budget,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
],
)
state["needs_human_review"] = True
state["evaluation_notes"] = "request_timeout"
record_pipeline_run(state, elapsed_ms)
yield {
"type": "blocked",
"message": (
f"Request exceeded the configured wall-clock budget ({budget:.1f}s) "
"and was cancelled."
),
"state": dict(state),
"latency_ms": elapsed_ms,
}
return
elapsed_ms = (time.perf_counter() - start_time) * 1000
nodes_executed = [entry["node"] for entry in state.get("audit_trail", []) if "node" in entry]
trace_graph_execution(
query=query,
nodes_executed=nodes_executed,
total_latency_ms=elapsed_ms,
final_confidence=state.get("confidence_score", 0.0),
retries=state.get("retry_count", 0),
)
record_pipeline_run(state, elapsed_ms)
logger.info(
"rag_pipeline_stream_completed",
confidence_score=state.get("confidence_score", 0.0),
needs_review=state.get("needs_human_review", False),
generation_len=len(state.get("generation", "")),
latency_ms=elapsed_ms,
)
yield {"type": "final", "state": dict(state), "latency_ms": elapsed_ms}