meet4150's picture
download
raw
15.6 kB
import json
import logging
import os
import time
from collections import defaultdict
from pathlib import Path
from threading import Lock
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from agent.engine import AgentEngine
from model_server.llm_client import get_llm
from rag.evaluator import evaluate_rag_json
from rag.pipeline import query_rag
from rag.reranker import rerank
from rag.vector_store import build_index, index_exists, retrieve
load_dotenv()
logging.basicConfig(level="INFO", format="%(message)s")
logger = logging.getLogger(__name__)
app = FastAPI(title="Insurance RAG API", version="1.0.0")
BASE_DIR = Path(__file__).resolve().parent
FRONTEND_DIR = BASE_DIR / "frontend"
MAX_HISTORY = int(os.getenv("MAX_HISTORY", 12))
AUTO_BUILD_INDEX = os.getenv("AUTO_BUILD_INDEX", "false").lower() == "true"
if FRONTEND_DIR.exists():
app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static")
class QueryRequest(BaseModel):
question: str = Field(..., min_length=3)
session_id: str = "default"
use_reranker: bool = True
class ChatRequest(BaseModel):
message: str = Field(..., min_length=1)
session_id: str = "default"
class ContinuousChatRequest(BaseModel):
message: str = Field(..., min_length=1)
session_id: str = "default"
use_rag_context: bool = True
use_reranker: bool = True
class BenchmarkRequest(BaseModel):
question: str = Field(..., min_length=3)
runs: int = Field(default=5, ge=1, le=100)
use_reranker: bool = True
class JsonEvalRequest(BaseModel):
dataset_path: str = "../rag.json"
max_questions: int = Field(default=10, ge=1, le=200)
use_reranker: bool = True
class AgentChatRequest(BaseModel):
message: str = Field(..., min_length=1)
session_id: str = "default-agent"
use_reranker: bool = True
max_turns: int = Field(default=6, ge=1, le=6)
session_store: dict[str, list] = defaultdict(list)
session_lock = Lock()
metrics_store: dict[str, list[int]] = defaultdict(list)
agent_engine = AgentEngine(max_history_pairs=MAX_HISTORY, default_top_k=int(os.getenv("TOP_K", 4)))
def _record_metric(endpoint: str, latency_ms: int) -> None:
metrics_store[endpoint].append(latency_ms)
metrics_store[endpoint] = metrics_store[endpoint][-200:]
def _require_llm_ready() -> None:
model_id = os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct").strip()
if not model_id:
raise HTTPException(
status_code=400,
detail="HF_LLM_MODEL is missing. Set it in .env before using LLM endpoints.",
)
def _llm_configured() -> bool:
model_id = os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct").strip()
return bool(model_id)
def _percentile(values: list[int], p: int) -> int:
if not values:
return 0
sorted_vals = sorted(values)
idx = int((p / 100) * (len(sorted_vals) - 1))
return sorted_vals[idx]
def _session_history(session_id: str) -> list:
with session_lock:
history = session_store.get(session_id, [])
return history[-(MAX_HISTORY * 2) :]
def _append_session(session_id: str, user_message: str, assistant_message: str) -> None:
with session_lock:
history = session_store.get(session_id, [])
history.extend([HumanMessage(content=user_message), AIMessage(content=assistant_message)])
session_store[session_id] = history[-(MAX_HISTORY * 2) :]
@app.on_event("startup")
def startup_event() -> None:
if AUTO_BUILD_INDEX and not index_exists():
logger.info("No FAISS index found. Building index automatically...")
build_index()
@app.get("/", include_in_schema=False)
def serve_frontend():
page = FRONTEND_DIR / "index.html"
if page.exists():
return FileResponse(page)
raise HTTPException(status_code=404, detail="Frontend not found")
@app.get("/health")
def health():
return {
"status": "ok",
"index_ready": index_exists(),
"llm_configured": _llm_configured(),
"llm_model": os.getenv("HF_LLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct"),
}
@app.post("/index/build")
def index_build():
start = time.time()
db = build_index()
latency = int((time.time() - start) * 1000)
return {"status": "rebuilt", "latency_ms": latency, "index_type": type(db).__name__}
@app.post("/rag/query")
def rag_query(req: QueryRequest):
"""Main RAG endpoint — retrieve + answer with source attribution."""
if not index_exists():
raise HTTPException(
status_code=400,
detail="FAISS index not found. Call /index/build first.",
)
_require_llm_ready()
start = time.time()
try:
result = query_rag(req.question, use_reranker=req.use_reranker)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"RAG generation failed: {exc}") from exc
latency = int((time.time() - start) * 1000)
_record_metric("/rag/query", latency)
log_payload = {
"endpoint": "/rag/query",
"session_id": req.session_id,
"latency_ms": latency,
"chunks_used": result["chunks_used"],
"sources": result["sources"],
}
logger.info(json.dumps(log_payload))
result["latency_ms"] = latency
return result
@app.post("/chat")
def chat(req: ChatRequest):
"""Single-turn direct LLM chat (no RAG)."""
_require_llm_ready()
start = time.time()
try:
response = get_llm().invoke([HumanMessage(content=req.message)])
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Chat call failed: {exc}") from exc
latency = int((time.time() - start) * 1000)
_record_metric("/chat", latency)
return {"response": response.content, "latency_ms": latency}
@app.post("/chat/stream")
def chat_stream(req: ChatRequest):
"""SSE streaming endpoint — tokens arrive progressively."""
_require_llm_ready()
llm = get_llm(streaming=True)
async def generate():
try:
async for chunk in llm.astream([HumanMessage(content=req.message)]):
token = chunk.content or ""
if token:
yield f"data: {json.dumps({'token': token})}\n\n"
except Exception as exc:
yield f"data: {json.dumps({'error': f'Chat stream failed: {exc}'})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@app.post("/chat/continuous")
def chat_continuous(req: ContinuousChatRequest):
"""Session-aware multi-turn chat with optional RAG grounding."""
_require_llm_ready()
start = time.time()
sources = []
context = ""
if req.use_rag_context:
if not index_exists():
raise HTTPException(
status_code=400,
detail="FAISS index not found. Call /index/build first.",
)
chunks = retrieve(req.message)
if req.use_reranker and len(chunks) > 1:
chunks = rerank(req.message, chunks, top_k=min(3, len(chunks)))
sources = [
{
"file": c["source"],
"section": c.get("section", "Unknown section"),
"score": c["score"],
"rerank_score": c.get("rerank_score"),
}
for c in chunks
]
context = "\n\n---\n\n".join(
[
f"[Document: {c['source']} | Section: {c.get('section', 'Unknown section')} | Similarity: {c['score']}]\n{c['content']}"
for c in chunks
]
)
system_prompt = (
"You are an insurance broker assistant for InsureCo. "
"Use prior conversation context when relevant. "
"If policy context is provided, stay grounded in it and cite document + section."
)
user_prompt = req.message
if context:
user_prompt = (
f"POLICY EXCERPTS:\n{context}\n\n"
f"BROKER MESSAGE: {req.message}\n\n"
"Respond clearly and cite source document/section when you use the excerpts."
)
messages = [SystemMessage(content=system_prompt)]
messages.extend(_session_history(req.session_id))
messages.append(HumanMessage(content=user_prompt))
try:
response = get_llm().invoke(messages)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Continuous chat failed: {exc}") from exc
answer = response.content
_append_session(req.session_id, req.message, answer)
latency = int((time.time() - start) * 1000)
_record_metric("/chat/continuous", latency)
return {
"response": answer,
"session_id": req.session_id,
"sources": sources,
"latency_ms": latency,
}
@app.post("/chat/continuous/stream")
def chat_continuous_stream(req: ContinuousChatRequest):
"""Streaming multi-turn chat endpoint with optional RAG grounding."""
_require_llm_ready()
if req.use_rag_context and not index_exists():
raise HTTPException(
status_code=400,
detail="FAISS index not found. Call /index/build first.",
)
sources = []
context = ""
if req.use_rag_context:
chunks = retrieve(req.message)
if req.use_reranker and len(chunks) > 1:
chunks = rerank(req.message, chunks, top_k=min(3, len(chunks)))
sources = [
{
"file": c["source"],
"section": c.get("section", "Unknown section"),
"score": c["score"],
"rerank_score": c.get("rerank_score"),
}
for c in chunks
]
context = "\n\n---\n\n".join(
[
f"[Document: {c['source']} | Section: {c.get('section', 'Unknown section')} | Similarity: {c['score']}]\n{c['content']}"
for c in chunks
]
)
system_prompt = (
"You are an insurance broker assistant for InsureCo. "
"Use prior conversation context when relevant. "
"If policy context is provided, stay grounded in it and cite document + section."
)
user_prompt = req.message
if context:
user_prompt = (
f"POLICY EXCERPTS:\n{context}\n\n"
f"BROKER MESSAGE: {req.message}\n\n"
"Respond clearly and cite source document/section when you use the excerpts."
)
messages = [SystemMessage(content=system_prompt)]
messages.extend(_session_history(req.session_id))
messages.append(HumanMessage(content=user_prompt))
llm = get_llm(streaming=True)
async def generate():
start = time.time()
collected = []
try:
async for chunk in llm.astream(messages):
token = chunk.content or ""
if token:
collected.append(token)
yield f"data: {json.dumps({'token': token})}\n\n"
except Exception as exc:
yield f"data: {json.dumps({'error': f'Continuous stream failed: {exc}'})}\n\n"
yield "data: [DONE]\n\n"
return
answer = "".join(collected)
_append_session(req.session_id, req.message, answer)
latency = int((time.time() - start) * 1000)
_record_metric("/chat/continuous/stream", latency)
yield f"data: {json.dumps({'meta': {'latency_ms': latency, 'sources': sources}})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@app.post("/agent/chat")
def agent_chat(req: AgentChatRequest):
"""Hand-rolled multi-turn tool-calling agent endpoint."""
_require_llm_ready()
start = time.time()
try:
result = agent_engine.chat(
session_id=req.session_id,
user_message=req.message,
use_reranker=req.use_reranker,
max_turns=req.max_turns,
)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Agent chat failed: {exc}") from exc
latency = int((time.time() - start) * 1000)
_record_metric("/agent/chat", latency)
logger.info(
json.dumps(
{
"endpoint": "/agent/chat",
"session_id": req.session_id,
"latency_ms": latency,
"tool_calls_made": result.get("tool_calls_made", 0),
"turns_used": result.get("turns_used", 0),
}
)
)
result["latency_ms"] = latency
return result
@app.get("/agent/sessions")
def list_agent_sessions():
return {"sessions": agent_engine.list_sessions()}
@app.delete("/agent/session/{session_id}")
def clear_agent_session(session_id: str):
agent_engine.clear_session(session_id)
return {"status": "cleared", "session_id": session_id}
@app.get("/chat/sessions")
def list_sessions():
with session_lock:
return {"sessions": list(session_store.keys())}
@app.delete("/chat/session/{session_id}")
def clear_session(session_id: str):
with session_lock:
session_store.pop(session_id, None)
return {"status": "cleared", "session_id": session_id}
@app.post("/rag/benchmark")
def rag_benchmark(req: BenchmarkRequest):
if not index_exists():
raise HTTPException(
status_code=400,
detail="FAISS index not found. Call /index/build first.",
)
_require_llm_ready()
latencies = []
try:
for _ in range(req.runs):
start = time.time()
query_rag(req.question, use_reranker=req.use_reranker)
latencies.append(int((time.time() - start) * 1000))
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Benchmark failed: {exc}") from exc
avg = int(sum(latencies) / len(latencies))
return {
"question": req.question,
"runs": req.runs,
"latencies_ms": latencies,
"avg_ms": avg,
"p50_ms": _percentile(latencies, 50),
"p95_ms": _percentile(latencies, 95),
"min_ms": min(latencies),
"max_ms": max(latencies),
}
@app.post("/rag/evaluate/json")
def rag_evaluate_json(req: JsonEvalRequest):
if not index_exists():
raise HTTPException(
status_code=400,
detail="FAISS index not found. Call /index/build first.",
)
_require_llm_ready()
start = time.time()
try:
report = evaluate_rag_json(
dataset_path=req.dataset_path,
base_dir=BASE_DIR,
max_questions=req.max_questions,
use_reranker=req.use_reranker,
)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"JSON evaluation failed: {exc}") from exc
latency = int((time.time() - start) * 1000)
_record_metric("/rag/evaluate/json", latency)
report["latency_ms"] = latency
return report
@app.get("/metrics")
def metrics():
payload = {}
for endpoint, values in metrics_store.items():
payload[endpoint] = {
"count": len(values),
"avg_ms": int(sum(values) / len(values)) if values else 0,
"p50_ms": _percentile(values, 50),
"p95_ms": _percentile(values, 95),
"latest_ms": values[-1] if values else 0,
}
return payload
@app.post("/metrics/reset")
def reset_metrics():
metrics_store.clear()
return {"status": "reset"}

Xet Storage Details

Size:
15.6 kB
·
Xet hash:
aa63dc2a52076324bc82edde25afd6196e68b4cb60e1491e1d17c38491d610b8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.