Spaces:
Sleeping
Sleeping
| """FastAPI router with SSE streaming endpoint for the knowledge copilot.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| from datetime import datetime | |
| from uuid import uuid4 | |
| from src.utils.logger import Timer, get_logger as _get_logger | |
| from typing import AsyncGenerator | |
| from fastapi import APIRouter, Depends | |
| from fastapi.responses import StreamingResponse | |
| from agent.graph import graph | |
| from agent.models import KnowledgeGraphState, QueryInput | |
| from src.auth.deps import get_current_user | |
| logger = _get_logger(__name__) | |
| router = APIRouter(prefix="/agent", tags=["agent"]) | |
| HISTORY_KEY = "gs:queries" | |
| TOPICS_KEY = "gs:topics" | |
| async def _store_query_event( | |
| query_input: QueryInput, | |
| duration_ms: int, | |
| success: bool, | |
| agent_results: dict | None = None, | |
| guardrail_score: float | None = None, | |
| escalated: bool = False, | |
| answer_text: str = "", | |
| ) -> None: | |
| """Persist query event to Redis for analytics and workspace history.""" | |
| try: | |
| import redis.asyncio as aioredis | |
| from src.config import settings | |
| brief = answer_text[:500].rstrip() if answer_text else "" | |
| if answer_text and len(answer_text) > 500: | |
| brief += "…" | |
| event = { | |
| "id": str(uuid4()), | |
| "query": query_input.query, | |
| "session_id": query_input.session_id, | |
| "team_id": query_input.team_id, | |
| "created_at": datetime.utcnow().isoformat(), | |
| "success": success, | |
| "duration_ms": duration_ms, | |
| "answer_brief": brief, | |
| # Per-agent retrieval metrics — populated from final graph state | |
| "agents": { | |
| agent: { | |
| "confidence": result.retrieval_confidence, | |
| "chunk_count": len(result.chunks), | |
| } | |
| for agent, result in (agent_results or {}).items() | |
| }, | |
| "guardrail_score": guardrail_score, | |
| "escalated": escalated, | |
| } | |
| r = aioredis.from_url(settings.redis_url, decode_responses=True) | |
| try: | |
| # Push to history list (newest first), keep last 1000 | |
| await r.lpush(HISTORY_KEY, json.dumps(event)) | |
| await r.ltrim(HISTORY_KEY, 0, 999) | |
| # Track topic words (naive: split query into words, skip short ones) | |
| for word in query_input.query.lower().split(): | |
| if len(word) > 4 and word.isalpha(): | |
| await r.zincrby(TOPICS_KEY, 1, word) | |
| # Write escalation record when the guardrail flagged the answer | |
| if escalated: | |
| escalation = { | |
| "id": event["id"], | |
| "query": query_input.query, | |
| "frequency": 1, | |
| "last_seen": event["created_at"], | |
| "teams": [query_input.team_id], | |
| "status": "open", | |
| "gap_type": "missing_knowledge", | |
| "guardrail_score": guardrail_score, | |
| } | |
| await r.lpush("gs:escalations", json.dumps(escalation)) | |
| await r.ltrim("gs:escalations", 0, 499) | |
| # Persist to Supabase for time-series anomaly detection. | |
| # Fire-and-forget: never allowed to fail the SSE stream. | |
| try: | |
| import asyncio as _asyncio | |
| from src.anomaly.db import async_upsert_query_event | |
| _asyncio.ensure_future(async_upsert_query_event(event)) | |
| except Exception: | |
| pass | |
| finally: | |
| await r.aclose() | |
| except Exception as exc: | |
| logger.warning("query_store_failed", extra={"error": str(exc)}) | |
| async def _event_generator( | |
| query_input: QueryInput, | |
| queue: asyncio.Queue, | |
| ) -> AsyncGenerator[str, None]: | |
| _SENTINEL = object() | |
| async def run_graph() -> None: | |
| initial_state = KnowledgeGraphState( | |
| query_input=query_input, | |
| sse_queue=queue, | |
| ) | |
| with Timer() as t: | |
| try: | |
| final_state = await graph.ainvoke(initial_state) | |
| logger.info( | |
| "query_complete", | |
| extra={"session_id": query_input.session_id, "duration_ms": t.ms}, | |
| ) | |
| await _store_query_event( | |
| query_input, t.ms, success=True, | |
| agent_results=final_state.get("agent_results", {}), | |
| guardrail_score=final_state.get("guardrail_score"), | |
| escalated=final_state.get("escalate", False), | |
| answer_text=final_state.get("final_answer") or "", | |
| ) | |
| except Exception as exc: | |
| logger.exception( | |
| "query_error", | |
| extra={"session_id": query_input.session_id, "duration_ms": t.ms, "error": str(exc)}, | |
| ) | |
| await _store_query_event(query_input, t.ms, success=False) | |
| await queue.put({"event": "error", "data": {"message": str(exc)}}) | |
| finally: | |
| await queue.put(_SENTINEL) | |
| task = asyncio.create_task(run_graph()) | |
| try: | |
| while True: | |
| item = await queue.get() | |
| if item is _SENTINEL: | |
| break | |
| event_name = item.get("event", "message") | |
| data_str = json.dumps(item.get("data", {})) | |
| yield f"event: {event_name}\ndata: {data_str}\n\n" | |
| yield "event: done\ndata: {}\n\n" | |
| except asyncio.CancelledError: | |
| logger.info("SSE stream cancelled for session=%s", query_input.session_id) | |
| task.cancel() | |
| raise | |
| finally: | |
| if not task.done(): | |
| task.cancel() | |
| async def query_endpoint( | |
| query_input: QueryInput, | |
| user: dict = Depends(get_current_user), | |
| ) -> StreamingResponse: | |
| # Enforce server-side team_id and channel IDs — never trust the client body. | |
| # Admins bypass RBAC channel filtering so they can search the full knowledge base. | |
| is_admin = user.get("role") in ("admin", "org_admin") | |
| query_input = query_input.model_copy(update={ | |
| "team_id": user.get("team_id", query_input.team_id), | |
| "allowed_channel_ids": [] if is_admin else user.get("allowed_channel_ids", []), | |
| }) | |
| queue: asyncio.Queue = asyncio.Queue() | |
| logger.info( | |
| "query_start", | |
| extra={ | |
| "session_id": query_input.session_id, | |
| "team_id": query_input.team_id, | |
| "channels": len(query_input.allowed_channel_ids), | |
| "query_len": len(query_input.query), | |
| }, | |
| ) | |
| return StreamingResponse( | |
| _event_generator(query_input, queue), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |