Buckets:
| import json | |
| import logging | |
| import os | |
| import time | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from threading import Lock | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from pydantic import BaseModel, Field | |
| from agent.engine import AgentEngine | |
| from model_server.llm_client import get_llm | |
| from rag.evaluator import evaluate_rag_json | |
| from rag.pipeline import query_rag | |
| from rag.reranker import rerank | |
| from rag.vector_store import build_index, index_exists, retrieve | |
| load_dotenv() | |
| logging.basicConfig(level="INFO", format="%(message)s") | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Insurance RAG API", version="1.0.0") | |
| BASE_DIR = Path(__file__).resolve().parent | |
| FRONTEND_DIR = BASE_DIR / "frontend" | |
| MAX_HISTORY = int(os.getenv("MAX_HISTORY", 12)) | |
| AUTO_BUILD_INDEX = os.getenv("AUTO_BUILD_INDEX", "false").lower() == "true" | |
| if FRONTEND_DIR.exists(): | |
| app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static") | |
| class QueryRequest(BaseModel): | |
| question: str = Field(..., min_length=3) | |
| session_id: str = "default" | |
| use_reranker: bool = True | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1) | |
| session_id: str = "default" | |
| class ContinuousChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1) | |
| session_id: str = "default" | |
| use_rag_context: bool = True | |
| use_reranker: bool = True | |
| class BenchmarkRequest(BaseModel): | |
| question: str = Field(..., min_length=3) | |
| runs: int = Field(default=5, ge=1, le=100) | |
| use_reranker: bool = True | |
| class JsonEvalRequest(BaseModel): | |
| dataset_path: str = "../rag.json" | |
| max_questions: int = Field(default=10, ge=1, le=200) | |
| use_reranker: bool = True | |
| class AgentChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1) | |
| session_id: str = "default-agent" | |
| use_reranker: bool = True | |
| max_turns: int = Field(default=6, ge=1, le=6) | |
| session_store: dict[str, list] = defaultdict(list) | |
| session_lock = Lock() | |
| metrics_store: dict[str, list[int]] = defaultdict(list) | |
| agent_engine = AgentEngine(max_history_pairs=MAX_HISTORY, default_top_k=int(os.getenv("TOP_K", 4))) | |
| def _record_metric(endpoint: str, latency_ms: int) -> None: | |
| metrics_store[endpoint].append(latency_ms) | |
| metrics_store[endpoint] = metrics_store[endpoint][-200:] | |
| def _require_llm_ready() -> None: | |
| model_id = os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct").strip() | |
| if not model_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="HF_LLM_MODEL is missing. Set it in .env before using LLM endpoints.", | |
| ) | |
| def _llm_configured() -> bool: | |
| model_id = os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct").strip() | |
| return bool(model_id) | |
| def _percentile(values: list[int], p: int) -> int: | |
| if not values: | |
| return 0 | |
| sorted_vals = sorted(values) | |
| idx = int((p / 100) * (len(sorted_vals) - 1)) | |
| return sorted_vals[idx] | |
| def _session_history(session_id: str) -> list: | |
| with session_lock: | |
| history = session_store.get(session_id, []) | |
| return history[-(MAX_HISTORY * 2) :] | |
| def _append_session(session_id: str, user_message: str, assistant_message: str) -> None: | |
| with session_lock: | |
| history = session_store.get(session_id, []) | |
| history.extend([HumanMessage(content=user_message), AIMessage(content=assistant_message)]) | |
| session_store[session_id] = history[-(MAX_HISTORY * 2) :] | |
| def startup_event() -> None: | |
| if AUTO_BUILD_INDEX and not index_exists(): | |
| logger.info("No FAISS index found. Building index automatically...") | |
| build_index() | |
| def serve_frontend(): | |
| page = FRONTEND_DIR / "index.html" | |
| if page.exists(): | |
| return FileResponse(page) | |
| raise HTTPException(status_code=404, detail="Frontend not found") | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "index_ready": index_exists(), | |
| "llm_configured": _llm_configured(), | |
| "llm_model": os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct"), | |
| } | |
| def index_build(): | |
| start = time.time() | |
| db = build_index() | |
| latency = int((time.time() - start) * 1000) | |
| return {"status": "rebuilt", "latency_ms": latency, "index_type": type(db).__name__} | |
| def rag_query(req: QueryRequest): | |
| """Main RAG endpoint — retrieve + answer with source attribution.""" | |
| if not index_exists(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="FAISS index not found. Call /index/build first.", | |
| ) | |
| _require_llm_ready() | |
| start = time.time() | |
| try: | |
| result = query_rag(req.question, use_reranker=req.use_reranker) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail=f"RAG generation failed: {exc}") from exc | |
| latency = int((time.time() - start) * 1000) | |
| _record_metric("/rag/query", latency) | |
| log_payload = { | |
| "endpoint": "/rag/query", | |
| "session_id": req.session_id, | |
| "latency_ms": latency, | |
| "chunks_used": result["chunks_used"], | |
| "sources": result["sources"], | |
| } | |
| logger.info(json.dumps(log_payload)) | |
| result["latency_ms"] = latency | |
| return result | |
| def chat(req: ChatRequest): | |
| """Single-turn direct LLM chat (no RAG).""" | |
| _require_llm_ready() | |
| start = time.time() | |
| try: | |
| response = get_llm().invoke([HumanMessage(content=req.message)]) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail=f"Chat call failed: {exc}") from exc | |
| latency = int((time.time() - start) * 1000) | |
| _record_metric("/chat", latency) | |
| return {"response": response.content, "latency_ms": latency} | |
| def chat_stream(req: ChatRequest): | |
| """SSE streaming endpoint — tokens arrive progressively.""" | |
| _require_llm_ready() | |
| llm = get_llm(streaming=True) | |
| async def generate(): | |
| try: | |
| async for chunk in llm.astream([HumanMessage(content=req.message)]): | |
| token = chunk.content or "" | |
| if token: | |
| yield f"data: {json.dumps({'token': token})}\n\n" | |
| except Exception as exc: | |
| yield f"data: {json.dumps({'error': f'Chat stream failed: {exc}'})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| def chat_continuous(req: ContinuousChatRequest): | |
| """Session-aware multi-turn chat with optional RAG grounding.""" | |
| _require_llm_ready() | |
| start = time.time() | |
| sources = [] | |
| context = "" | |
| if req.use_rag_context: | |
| if not index_exists(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="FAISS index not found. Call /index/build first.", | |
| ) | |
| chunks = retrieve(req.message) | |
| if req.use_reranker and len(chunks) > 1: | |
| chunks = rerank(req.message, chunks, top_k=min(3, len(chunks))) | |
| sources = [ | |
| { | |
| "file": c["source"], | |
| "section": c.get("section", "Unknown section"), | |
| "score": c["score"], | |
| "rerank_score": c.get("rerank_score"), | |
| } | |
| for c in chunks | |
| ] | |
| context = "\n\n---\n\n".join( | |
| [ | |
| f"[Document: {c['source']} | Section: {c.get('section', 'Unknown section')} | Similarity: {c['score']}]\n{c['content']}" | |
| for c in chunks | |
| ] | |
| ) | |
| system_prompt = ( | |
| "You are an insurance broker assistant for InsureCo. " | |
| "Use prior conversation context when relevant. " | |
| "If policy context is provided, stay grounded in it and cite document + section." | |
| ) | |
| user_prompt = req.message | |
| if context: | |
| user_prompt = ( | |
| f"POLICY EXCERPTS:\n{context}\n\n" | |
| f"BROKER MESSAGE: {req.message}\n\n" | |
| "Respond clearly and cite source document/section when you use the excerpts." | |
| ) | |
| messages = [SystemMessage(content=system_prompt)] | |
| messages.extend(_session_history(req.session_id)) | |
| messages.append(HumanMessage(content=user_prompt)) | |
| try: | |
| response = get_llm().invoke(messages) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail=f"Continuous chat failed: {exc}") from exc | |
| answer = response.content | |
| _append_session(req.session_id, req.message, answer) | |
| latency = int((time.time() - start) * 1000) | |
| _record_metric("/chat/continuous", latency) | |
| return { | |
| "response": answer, | |
| "session_id": req.session_id, | |
| "sources": sources, | |
| "latency_ms": latency, | |
| } | |
| def chat_continuous_stream(req: ContinuousChatRequest): | |
| """Streaming multi-turn chat endpoint with optional RAG grounding.""" | |
| _require_llm_ready() | |
| if req.use_rag_context and not index_exists(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="FAISS index not found. Call /index/build first.", | |
| ) | |
| sources = [] | |
| context = "" | |
| if req.use_rag_context: | |
| chunks = retrieve(req.message) | |
| if req.use_reranker and len(chunks) > 1: | |
| chunks = rerank(req.message, chunks, top_k=min(3, len(chunks))) | |
| sources = [ | |
| { | |
| "file": c["source"], | |
| "section": c.get("section", "Unknown section"), | |
| "score": c["score"], | |
| "rerank_score": c.get("rerank_score"), | |
| } | |
| for c in chunks | |
| ] | |
| context = "\n\n---\n\n".join( | |
| [ | |
| f"[Document: {c['source']} | Section: {c.get('section', 'Unknown section')} | Similarity: {c['score']}]\n{c['content']}" | |
| for c in chunks | |
| ] | |
| ) | |
| system_prompt = ( | |
| "You are an insurance broker assistant for InsureCo. " | |
| "Use prior conversation context when relevant. " | |
| "If policy context is provided, stay grounded in it and cite document + section." | |
| ) | |
| user_prompt = req.message | |
| if context: | |
| user_prompt = ( | |
| f"POLICY EXCERPTS:\n{context}\n\n" | |
| f"BROKER MESSAGE: {req.message}\n\n" | |
| "Respond clearly and cite source document/section when you use the excerpts." | |
| ) | |
| messages = [SystemMessage(content=system_prompt)] | |
| messages.extend(_session_history(req.session_id)) | |
| messages.append(HumanMessage(content=user_prompt)) | |
| llm = get_llm(streaming=True) | |
| async def generate(): | |
| start = time.time() | |
| collected = [] | |
| try: | |
| async for chunk in llm.astream(messages): | |
| token = chunk.content or "" | |
| if token: | |
| collected.append(token) | |
| yield f"data: {json.dumps({'token': token})}\n\n" | |
| except Exception as exc: | |
| yield f"data: {json.dumps({'error': f'Continuous stream failed: {exc}'})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| answer = "".join(collected) | |
| _append_session(req.session_id, req.message, answer) | |
| latency = int((time.time() - start) * 1000) | |
| _record_metric("/chat/continuous/stream", latency) | |
| yield f"data: {json.dumps({'meta': {'latency_ms': latency, 'sources': sources}})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| def agent_chat(req: AgentChatRequest): | |
| """Hand-rolled multi-turn tool-calling agent endpoint.""" | |
| _require_llm_ready() | |
| start = time.time() | |
| try: | |
| result = agent_engine.chat( | |
| session_id=req.session_id, | |
| user_message=req.message, | |
| use_reranker=req.use_reranker, | |
| max_turns=req.max_turns, | |
| ) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail=f"Agent chat failed: {exc}") from exc | |
| latency = int((time.time() - start) * 1000) | |
| _record_metric("/agent/chat", latency) | |
| logger.info( | |
| json.dumps( | |
| { | |
| "endpoint": "/agent/chat", | |
| "session_id": req.session_id, | |
| "latency_ms": latency, | |
| "tool_calls_made": result.get("tool_calls_made", 0), | |
| "turns_used": result.get("turns_used", 0), | |
| } | |
| ) | |
| ) | |
| result["latency_ms"] = latency | |
| return result | |
| def list_agent_sessions(): | |
| return {"sessions": agent_engine.list_sessions()} | |
| def clear_agent_session(session_id: str): | |
| agent_engine.clear_session(session_id) | |
| return {"status": "cleared", "session_id": session_id} | |
| def list_sessions(): | |
| with session_lock: | |
| return {"sessions": list(session_store.keys())} | |
| def clear_session(session_id: str): | |
| with session_lock: | |
| session_store.pop(session_id, None) | |
| return {"status": "cleared", "session_id": session_id} | |
| def rag_benchmark(req: BenchmarkRequest): | |
| if not index_exists(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="FAISS index not found. Call /index/build first.", | |
| ) | |
| _require_llm_ready() | |
| latencies = [] | |
| try: | |
| for _ in range(req.runs): | |
| start = time.time() | |
| query_rag(req.question, use_reranker=req.use_reranker) | |
| latencies.append(int((time.time() - start) * 1000)) | |
| except Exception as exc: | |
| raise HTTPException(status_code=502, detail=f"Benchmark failed: {exc}") from exc | |
| avg = int(sum(latencies) / len(latencies)) | |
| return { | |
| "question": req.question, | |
| "runs": req.runs, | |
| "latencies_ms": latencies, | |
| "avg_ms": avg, | |
| "p50_ms": _percentile(latencies, 50), | |
| "p95_ms": _percentile(latencies, 95), | |
| "min_ms": min(latencies), | |
| "max_ms": max(latencies), | |
| } | |
| def rag_evaluate_json(req: JsonEvalRequest): | |
| if not index_exists(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="FAISS index not found. Call /index/build first.", | |
| ) | |
| _require_llm_ready() | |
| start = time.time() | |
| try: | |
| report = evaluate_rag_json( | |
| dataset_path=req.dataset_path, | |
| base_dir=BASE_DIR, | |
| max_questions=req.max_questions, | |
| use_reranker=req.use_reranker, | |
| ) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"JSON evaluation failed: {exc}") from exc | |
| latency = int((time.time() - start) * 1000) | |
| _record_metric("/rag/evaluate/json", latency) | |
| report["latency_ms"] = latency | |
| return report | |
| def metrics(): | |
| payload = {} | |
| for endpoint, values in metrics_store.items(): | |
| payload[endpoint] = { | |
| "count": len(values), | |
| "avg_ms": int(sum(values) / len(values)) if values else 0, | |
| "p50_ms": _percentile(values, 50), | |
| "p95_ms": _percentile(values, 95), | |
| "latest_ms": values[-1] if values else 0, | |
| } | |
| return payload | |
| def reset_metrics(): | |
| metrics_store.clear() | |
| return {"status": "reset"} | |
Xet Storage Details
- Size:
- 15.6 kB
- Xet hash:
- aa63dc2a52076324bc82edde25afd6196e68b4cb60e1491e1d17c38491d610b8
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.