"""FastAPI router with SSE streaming endpoint for the knowledge copilot.""" from __future__ import annotations import asyncio import json from datetime import datetime from uuid import uuid4 from src.utils.logger import Timer, get_logger as _get_logger from typing import AsyncGenerator from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse from agent.graph import graph from agent.models import KnowledgeGraphState, QueryInput from src.auth.deps import get_current_user logger = _get_logger(__name__) router = APIRouter(prefix="/agent", tags=["agent"]) HISTORY_KEY = "gs:queries" TOPICS_KEY = "gs:topics" async def _store_query_event( query_input: QueryInput, duration_ms: int, success: bool, agent_results: dict | None = None, guardrail_score: float | None = None, escalated: bool = False, answer_text: str = "", ) -> None: """Persist query event to Redis for analytics and workspace history.""" try: import redis.asyncio as aioredis from src.config import settings brief = answer_text[:500].rstrip() if answer_text else "" if answer_text and len(answer_text) > 500: brief += "…" event = { "id": str(uuid4()), "query": query_input.query, "session_id": query_input.session_id, "team_id": query_input.team_id, "created_at": datetime.utcnow().isoformat(), "success": success, "duration_ms": duration_ms, "answer_brief": brief, # Per-agent retrieval metrics — populated from final graph state "agents": { agent: { "confidence": result.retrieval_confidence, "chunk_count": len(result.chunks), } for agent, result in (agent_results or {}).items() }, "guardrail_score": guardrail_score, "escalated": escalated, } r = aioredis.from_url(settings.redis_url, decode_responses=True) try: # Push to history list (newest first), keep last 1000 await r.lpush(HISTORY_KEY, json.dumps(event)) await r.ltrim(HISTORY_KEY, 0, 999) # Track topic words (naive: split query into words, skip short ones) for word in query_input.query.lower().split(): if len(word) > 4 and word.isalpha(): await r.zincrby(TOPICS_KEY, 1, word) # Write escalation record when the guardrail flagged the answer if escalated: escalation = { "id": event["id"], "query": query_input.query, "frequency": 1, "last_seen": event["created_at"], "teams": [query_input.team_id], "status": "open", "gap_type": "missing_knowledge", "guardrail_score": guardrail_score, } await r.lpush("gs:escalations", json.dumps(escalation)) await r.ltrim("gs:escalations", 0, 499) # Persist to Supabase for time-series anomaly detection. # Fire-and-forget: never allowed to fail the SSE stream. try: import asyncio as _asyncio from src.anomaly.db import async_upsert_query_event _asyncio.ensure_future(async_upsert_query_event(event)) except Exception: pass finally: await r.aclose() except Exception as exc: logger.warning("query_store_failed", extra={"error": str(exc)}) async def _event_generator( query_input: QueryInput, queue: asyncio.Queue, ) -> AsyncGenerator[str, None]: _SENTINEL = object() async def run_graph() -> None: initial_state = KnowledgeGraphState( query_input=query_input, sse_queue=queue, ) with Timer() as t: try: final_state = await graph.ainvoke(initial_state) logger.info( "query_complete", extra={"session_id": query_input.session_id, "duration_ms": t.ms}, ) await _store_query_event( query_input, t.ms, success=True, agent_results=final_state.get("agent_results", {}), guardrail_score=final_state.get("guardrail_score"), escalated=final_state.get("escalate", False), answer_text=final_state.get("final_answer") or "", ) except Exception as exc: logger.exception( "query_error", extra={"session_id": query_input.session_id, "duration_ms": t.ms, "error": str(exc)}, ) await _store_query_event(query_input, t.ms, success=False) await queue.put({"event": "error", "data": {"message": str(exc)}}) finally: await queue.put(_SENTINEL) task = asyncio.create_task(run_graph()) try: while True: item = await queue.get() if item is _SENTINEL: break event_name = item.get("event", "message") data_str = json.dumps(item.get("data", {})) yield f"event: {event_name}\ndata: {data_str}\n\n" yield "event: done\ndata: {}\n\n" except asyncio.CancelledError: logger.info("SSE stream cancelled for session=%s", query_input.session_id) task.cancel() raise finally: if not task.done(): task.cancel() @router.post("/query") async def query_endpoint( query_input: QueryInput, user: dict = Depends(get_current_user), ) -> StreamingResponse: # Enforce server-side team_id and channel IDs — never trust the client body. # Admins bypass RBAC channel filtering so they can search the full knowledge base. is_admin = user.get("role") in ("admin", "org_admin") query_input = query_input.model_copy(update={ "team_id": user.get("team_id", query_input.team_id), "allowed_channel_ids": [] if is_admin else user.get("allowed_channel_ids", []), }) queue: asyncio.Queue = asyncio.Queue() logger.info( "query_start", extra={ "session_id": query_input.session_id, "team_id": query_input.team_id, "channels": len(query_input.allowed_channel_ids), "query_len": len(query_input.query), }, ) return StreamingResponse( _event_generator(query_input, queue), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )