Spaces:
Sleeping
Sleeping
| import time | |
| import math | |
| import json | |
| import asyncio | |
| from fastapi import APIRouter, HTTPException, Depends, Request | |
| from sse_starlette.sse import EventSourceResponse | |
| from pydantic import BaseModel | |
| from app.agents.graph import app_graph | |
| from app.agents.state import AgentState | |
| from app.core.auth import get_current_user | |
| from app.core.database import db | |
| from typing import Dict, Any, cast, List, AsyncGenerator | |
| router = APIRouter() | |
| # --- 1. SCHEMAS --- | |
| class VerificationRequest(BaseModel): | |
| question: str | |
| filenames: List[str] | |
| class VerificationResponse(BaseModel): | |
| answer: str | |
| status: str | |
| evidence_count: int | |
| metrics: Dict[str, float] | |
| def sanitize_float(val: Any) -> float: | |
| try: | |
| f_val = float(val) | |
| return f_val if math.isfinite(f_val) else 0.0 | |
| except (TypeError, ValueError): | |
| return 0.0 | |
| # --- 2. STREAMING ENDPOINT --- | |
| async def run_verification( | |
| payload: VerificationRequest, | |
| user_id: str = Depends(get_current_user) | |
| ): | |
| async def event_generator() -> AsyncGenerator[Dict[str, Any], None]: | |
| try: | |
| # FIX 1: Nginx/Vercel Buffer Flush | |
| # Cloud proxies trap small SSE events. We send a massive padding payload | |
| # to instantly overflow the buffer and force the stream to open in real-time. | |
| yield { | |
| "event": "connected", | |
| "data": json.dumps({"status": "established", "padding": " " * 2048}) | |
| } | |
| start_time = time.time() | |
| print(f"--- STREAM STARTED FOR: {payload.question[:30]}... ---") | |
| history_buffer: List[Dict[str, str]] =[] | |
| is_root_reset = payload.question.strip().startswith("/axm ..") | |
| if db and not is_root_reset: | |
| try: | |
| primary_file = payload.filenames[0] if payload.filenames else "vault" | |
| doc_res = db.table("documents").select("id").eq("filename", primary_file).eq("user_id", user_id).execute() | |
| doc_rows = cast(List[Dict[str, Any]], doc_res.data) | |
| if doc_rows: | |
| doc_id = doc_rows[0]['id'] | |
| hist_res = db.table("chat_messages").select("role, content").eq("document_id", doc_id).eq("user_id", user_id).order("created_at", desc=True).limit(5).execute() | |
| raw_hist = cast(List[Dict[str, str]], hist_res.data) | |
| history_buffer = raw_hist[::-1] | |
| except Exception as e: | |
| print(f"AXM-MEM: History hydration failed: {e}") | |
| initial_state: AgentState = { | |
| "question": payload.question, | |
| "user_id": user_id, | |
| "filenames": payload.filenames, | |
| "history": history_buffer, | |
| "command": None, | |
| "comparison_map": {}, | |
| "documents":[], | |
| "generation": "", | |
| "hallucination_score": 0.0, | |
| "metrics": {}, | |
| "status": "thinking", | |
| "retry_count": 0, | |
| "active_node": None | |
| } | |
| full_generation = "" | |
| final_metrics: Dict[str, float] = {} | |
| current_active_node = "System" | |
| ui_node_map = { | |
| "retrieve_node": "Librarian", "Librarian": "Librarian", | |
| "distill_node": "Editor", "Editor": "Editor", | |
| "strategist_node": "Strategist", "Strategist": "Strategist", | |
| "generate_node": "Architect", "Architect": "Architect", | |
| "grade_generation_node": "Prosecutor", "Prosecutor": "Prosecutor" | |
| } | |
| async for event in app_graph.astream_events(initial_state, version="v1"): | |
| kind = event["event"] | |
| name = event["name"] | |
| if kind == "on_chain_start" and name in ui_node_map: | |
| current_active_node = ui_node_map[name] | |
| if current_active_node in ["Architect", "Strategist"] and full_generation: | |
| full_generation = "" | |
| yield {"event": "clear", "data": json.dumps({"message": "retry_triggered"})} | |
| yield {"event": "node_update", "data": json.dumps({"node": current_active_node, "status": "active"})} | |
| elif kind == "on_chat_model_stream": | |
| if current_active_node in["Architect", "Strategist"]: | |
| chunk = event["data"].get("chunk") | |
| content = "" | |
| if chunk: | |
| if hasattr(chunk, "content"): content = str(chunk.content) | |
| elif isinstance(chunk, dict) and "content" in chunk: content = str(chunk["content"]) | |
| if content: | |
| full_generation += content | |
| yield {"event": "token", "data": json.dumps({"text": content})} | |
| elif kind == "on_chain_end" and name in["generate_node", "Architect"]: | |
| node_output: Dict[str, Any] = event["data"].get("output", {}) | |
| if not full_generation and "generation" in node_output: | |
| full_generation = str(node_output["generation"]) | |
| yield {"event": "token", "data": json.dumps({"text": full_generation})} | |
| elif kind == "on_chain_end" and name in["grade_generation_node", "Prosecutor"]: | |
| eval_output: Dict[str, Any] = event["data"].get("output", {}) | |
| final_metrics = eval_output.get("metrics", {}) | |
| if not full_generation.strip(): | |
| full_generation = "Verification Failed: Audit logic rejected the draft." | |
| yield {"event": "token", "data": json.dumps({"text": full_generation})} | |
| actual_latency = round(time.time() - start_time, 2) | |
| safe_metrics = {k: sanitize_float(v) for k, v in final_metrics.items()} | |
| if db: | |
| try: | |
| primary_file = payload.filenames[0] if payload.filenames else "vault" | |
| doc_res = db.table("documents").select("id").eq("filename", primary_file).eq("user_id", user_id).execute() | |
| doc_data = cast(List[Dict[str, Any]], doc_res.data) | |
| if doc_data: | |
| doc_id = doc_data[0]['id'] | |
| db.table("chat_messages").insert({"document_id": doc_id, "user_id": user_id, "role": "user", "content": payload.question}).execute() | |
| db.table("chat_messages").insert({"document_id": doc_id, "user_id": user_id, "role": "assistant", "content": full_generation, "metrics": safe_metrics}).execute() | |
| db.table("audit_logs").insert({"user_id": user_id, "question": payload.question, "faithfulness": safe_metrics.get("faithfulness", 0.0), "latency": actual_latency}).execute() | |
| except Exception as log_err: | |
| print(f"SSE DB ERROR: {log_err}") | |
| yield { | |
| "event": "audit_complete", | |
| "data": json.dumps({"answer": full_generation, "metrics": safe_metrics}) | |
| } | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"❌ MASTER STREAM CRASH: {error_msg}") | |
| yield {"event": "error", "data": json.dumps({"detail": f"Backend Engine Disconnected: {error_msg}"})} | |
| # stream small node_update events instead of holding them in a buffer bucket. | |
| return EventSourceResponse( | |
| event_generator(), | |
| ping=10, | |
| headers={ | |
| "X-Accel-Buffering": "no", | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive" | |
| } | |
| ) | |