Cortex / api /main.py
aditya-joshi-05's picture
Updates and fixes
27edbb8
"""
Cortex RAG β€” FastAPI Application
Endpoints
─────────
GET /health β†’ system health check
POST /ingest β†’ trigger ingestion pipeline
POST /query β†’ blocking query (JSON response)
POST /query/stream β†’ streaming query (Server-Sent Events)
Phase 1 uses dense-only retrieval.
Later phases will add routing, graph, BM25, and CRAG via the same endpoint.
"""
from __future__ import annotations
import json
import logging
import sys
import os
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List
from fastapi import FastAPI, File, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from api.schemas import (
HealthResponse,
IngestRequest,
IngestResponse,
ModelInfo,
ProviderInfo,
ProvidersResponse,
QueryRequest,
QueryResponse,
ChunkResponse,
CitationResponse,
)
from config import get_settings
from generation.generator import PROVIDERS, Generator, GenerationRequest
from generation.crag import CRAGGate
from evaluation.store import EvalStore, QueryLogEntry
from evaluation.ragas_eval import RAGASEvaluator, EvalInput
from retrieval.cache import CachedRetriever
from ingestion.pipeline import IngestionPipeline
from retrieval.dense import MilvusStore
from retrieval.embedder import Embedder
from retrieval.bm25 import BM25Retriever
from retrieval.orchestrator import MultiStrategyRetriever
logger = logging.getLogger(__name__)
# ── Shared singletons ──────────────────────────────────────────
# Created once on startup, shared across requests
_embedder: Embedder = None
_store: MilvusStore = None
_bm25: BM25Retriever = None
_retriever: MultiStrategyRetriever = None
_crag: CRAGGate = None
_eval_store: EvalStore = None
_evaluator: RAGASEvaluator = None
_generator: Generator = None
_pipeline: IngestionPipeline = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialise shared resources on startup, clean up on shutdown."""
global _embedder, _store, _bm25, _retriever, _crag, _generator, _pipeline, _eval_store, _evaluator
logger.info("Cortex starting up...")
cfg = get_settings()
_embedder = Embedder()
_store = MilvusStore(embedder=_embedder)
_bm25 = BM25Retriever()
_retriever = MultiStrategyRetriever(embedder=_embedder, store=_store, bm25=_bm25)
_crag = CRAGGate()
_eval_store = EvalStore(db_path=cfg.eval_db_path)
_evaluator = RAGASEvaluator(store=_eval_store)
_generator = Generator()
# Wrap retriever with Redis cache (degrades gracefully if Redis is absent)
_retriever = CachedRetriever(_retriever)
_pipeline = IngestionPipeline(embedder=_embedder, store=_store, bm25=_bm25)
# Warm up: trigger model load immediately so first request is fast
_ = _embedder.model
logger.info("Cortex ready.")
yield
logger.info("Cortex shutting down.")
# ── App factory ────────────────────────────────────────────────
def create_app() -> FastAPI:
global cfg
cfg = get_settings()
app = FastAPI(
title="Cortex RAG API",
description=(
"Production-grade Retrieval-Augmented Generation system "
"with multi-strategy retrieval, CRAG, and RAGAS evaluation."
),
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return app
app = create_app()
# Mount the SPA β€” served at / and all sub-paths not matched by API routes
_STATIC_DIR = Path(__file__).parent.parent / "ui" / "static"
if _STATIC_DIR.exists():
app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static")
# Temporary directory for browser-uploaded files (auto-created)
_UPLOAD_DIR = Path(cfg.upload_dir)
_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@app.get("/", include_in_schema=False)
async def serve_spa():
return FileResponse(str(_STATIC_DIR / "index.html"))
# ── Routes ─────────────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse, tags=["system"])
async def health() -> HealthResponse:
"""
Returns the health of all system components.
Use this to verify Milvus is reachable and the model is loaded.
"""
milvus_status = "ok"
collection_stats = {}
try:
collection_stats = _store.collection_stats()
except Exception as exc:
milvus_status = f"error: {exc}"
embedder_status = "loaded" if _embedder and _embedder._model else "not_loaded"
graph_stats = {}
try:
graph_stats = _retriever.graph_builder.stats()
except Exception:
pass
return HealthResponse(
status="ok" if milvus_status == "ok" else "degraded",
milvus=milvus_status,
embedder=embedder_status,
collection_stats=collection_stats,
graph_stats=graph_stats,
)
@app.post("/ingest", response_model=IngestResponse, tags=["ingestion"])
async def ingest(req: IngestRequest) -> IngestResponse:
"""
Trigger the ingestion pipeline for a file or directory.
- Deduplicates by doc_id (SHA-256 of file path)
- Returns counts for documents processed, chunks created, and errors
"""
import os
path = req.path
if not os.path.exists(path):
raise HTTPException(status_code=404, detail=f"Path not found: {path}")
try:
if os.path.isfile(path):
stats = _pipeline.ingest_file(path)
else:
stats = _pipeline.ingest_directory(path, recursive=req.recursive)
except Exception as exc:
logger.exception("Ingestion error")
raise HTTPException(status_code=500, detail=str(exc)) from exc
return IngestResponse(**stats)
@app.post("/ingest/upload", response_model=IngestResponse, tags=["ingestion"])
async def ingest_upload(files: List[UploadFile] = File(...)) -> IngestResponse:
"""
Upload files directly from the browser and ingest them.
Accepts one or more files (PDF, HTML, TXT, Markdown).
Files are saved to data/uploads/<original_filename> and then
passed through the same ingestion pipeline as /ingest.
Duplicate filenames are overwritten β€” re-uploading the same
file will be deduplicated at the chunk level by doc_id.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided.")
saved_paths: list[Path] = []
save_errors: list[dict] = []
for upload in files:
# Sanitise filename β€” strip any path components the browser may include
safe_name = Path(upload.filename).name
if not safe_name:
continue
dest = _UPLOAD_DIR / safe_name
try:
content_bytes = await upload.read()
dest.write_bytes(content_bytes)
saved_paths.append(dest)
logger.info("Uploaded: %s (%d bytes)", safe_name, len(content_bytes))
except Exception as exc:
logger.warning("Failed to save %s: %s", safe_name, exc)
save_errors.append({"source": safe_name, "error": str(exc)})
finally:
await upload.close()
if not saved_paths:
raise HTTPException(status_code=400, detail="No files could be saved.")
# Run ingestion on each saved file
merged: dict = {
"documents_processed": 0,
"documents_skipped": 0,
"chunks_created": 0,
"chunks_stored": 0,
"bm25_indexed": 0,
"graph_entities": 0,
"graph_triples": 0,
"errors": save_errors,
}
for path in saved_paths:
try:
stats = _pipeline.ingest_file(path)
for key in ("documents_processed", "documents_skipped",
"chunks_created", "chunks_stored", "bm25_indexed",
"graph_entities", "graph_triples"):
merged[key] += stats.get(key, 0)
merged["errors"].extend(stats.get("errors", []))
except Exception as exc:
logger.exception("Ingestion error for %s", path.name)
merged["errors"].append({"source": path.name, "error": str(exc)})
return IngestResponse(**merged)
@app.get("/metrics", tags=["evaluation"])
async def get_metrics(limit: int = 100, days: int = 7):
"""
Query performance metrics and RAGAS scores for the dashboard.
Returns summary stats, recent query logs, and hourly timeseries.
"""
return {
"summary": _eval_store.get_summary_stats(),
"recent": _eval_store.get_recent_queries(limit=limit),
"timeseries": _eval_store.get_metric_timeseries(days=days),
"cache": _retriever.cache_stats(),
}
@app.post("/cache/flush", tags=["system"])
async def flush_cache():
"""Flush all Redis retrieval cache entries."""
deleted = _retriever.flush_all()
return {"deleted": deleted}
@app.get("/providers", response_model=ProvidersResponse, tags=["system"])
async def get_providers() -> ProvidersResponse:
"""
Returns the full provider/model catalogue and which providers are
configured (i.e. have an API key in .env).
"""
cfg = get_settings()
infos: list[ProviderInfo] = []
for pid, pdata in PROVIDERS.items():
env_key = pdata["env_key"]
key_set = bool(getattr(cfg, env_key, "") or getattr(cfg, "groq_api_key", ""))
infos.append(ProviderInfo(
id=pid,
label=pdata["label"],
base_url=pdata["base_url"],
models=[ModelInfo(id=m["id"], label=m["label"]) for m in pdata["models"]],
configured=key_set,
))
return ProvidersResponse(
providers=infos,
default_provider=getattr(cfg, "default_provider", "groq"),
default_model=getattr(cfg, "groq_model", "llama-3.3-70b-versatile"),
)
@app.post("/query", response_model=QueryResponse, tags=["retrieval"])
async def query(req: QueryRequest) -> QueryResponse:
"""
Blocking query endpoint.
Retrieves top-k chunks and returns a complete cited answer.
"""
cfg = get_settings()
k = req.top_k or cfg.retrieval_top_k
import time as _time
_t0 = _time.perf_counter()
try:
retrieval = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k)
except Exception as exc:
logger.exception("Retrieval error")
raise HTTPException(status_code=500, detail=f"Retrieval failed: {exc}")
if retrieval.empty:
return QueryResponse(
query=req.query,
answer="No relevant documents found in the knowledge base.",
citations=[],
retrieved_chunks=[],
model="",
usage={},
)
final_chunks = retrieval.chunks
# CRAG gate: grade, rewrite if POOR, web-search fallback if ABSENT
crag_result = _crag.evaluate(
query=req.query,
chunks=final_chunks,
retriever_fn=lambda q: _retriever.retrieve(q).chunks,
)
final_chunks = crag_result.final_chunks
llm = req.llm or {}
llm_provider = getattr(llm, 'provider', None) if hasattr(llm, 'provider') else None
llm_model = getattr(llm, 'model', None) if hasattr(llm, 'model') else None
llm_api_key = getattr(llm, 'api_key', None) if hasattr(llm, 'api_key') else None
llm_base_url = getattr(llm, 'base_url', None) if hasattr(llm, 'base_url') else None
try:
result = _generator.generate(
GenerationRequest(
query=req.query, chunks=final_chunks,
provider=llm_provider, model=llm_model,
api_key=llm_api_key, base_url=llm_base_url,
)
)
except Exception as exc:
logger.exception("Generation error")
raise HTTPException(status_code=500, detail=f"Generation failed: {exc}")
latency_ms = (_time.perf_counter() - _t0) * 1000
log_id = _eval_store.log_query(QueryLogEntry(
query=req.query,
intent=retrieval.decision.intent.value,
strategies=retrieval.decision.strategies,
retriever_hits=retrieval.retriever_hits,
crag_grade=crag_result.grade.value,
crag_rewritten=bool(crag_result.rewritten_query),
web_search_used=crag_result.web_search_used,
num_chunks=len(final_chunks),
top_chunk_score=final_chunks[0].score if final_chunks else 0.0,
latency_ms=latency_ms,
model=result.model,
))
if cfg.eval_enabled:
_evaluator.evaluate_async(EvalInput(
query_log_id=log_id,
query=req.query,
answer=result.answer,
chunks=final_chunks,
))
return QueryResponse(
query=req.query,
answer=result.answer,
citations=[
CitationResponse(
number=c.number,
title=c.title,
source=c.source,
chunk_id=c.chunk_id,
score=c.score,
)
for c in result.citations
],
retrieved_chunks=[
ChunkResponse(
chunk_id=ch.chunk_id,
doc_id=ch.doc_id,
source=ch.source,
title=ch.title,
text=ch.text,
score=ch.score,
)
for ch in final_chunks
],
model=result.model,
usage=result.usage,
)
@app.post("/query/stream", tags=["retrieval"])
async def query_stream(req: QueryRequest):
"""
Streaming query endpoint using Server-Sent Events (SSE).
Event types emitted:
- data: {"type": "chunk_meta", "chunks": [...]} β€” retrieved chunks
- data: {"type": "token", "text": "..."} β€” answer tokens
- data: {"type": "sources", "text": "..."} β€” sources block
- data: {"type": "done"} β€” stream complete
- data: {"type": "error", "message": "..."} β€” error event
"""
cfg = get_settings()
k = req.top_k or cfg.retrieval_top_k
print(req)
async def event_stream() -> AsyncGenerator[str, None]:
try:
# 1. Retrieve
# 1. Multi-strategy retrieval: router β†’ dense+BM25 β†’ RRF β†’ cross-encoder
result = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k)
final_chunks = result.chunks
# 2. Emit chunk metadata + routing decision so UI shows sources + strategy info immediately
chunk_meta = [
{
"chunk_id": c.chunk_id,
"title": c.title,
"source": c.source,
"score": round(c.score, 4),
"retriever": c.retriever,
"text_snippet": c.text[:200],
}
for c in final_chunks
]
yield _sse_event({
"type": "chunk_meta",
"chunks": chunk_meta,
"routing": {
"intent": result.decision.intent.value,
"strategies": result.decision.strategies,
"retriever_hits": result.retriever_hits,
"reasoning": result.decision.reasoning,
},
})
if not final_chunks:
yield _sse_event({
"type": "token",
"text": "No relevant documents found in the knowledge base.",
})
yield _sse_event({"type": "done"})
return
# 3. CRAG gate β€” grade, optionally rewrite + re-retrieve
crag_result = _crag.evaluate(
query=req.query,
chunks=final_chunks,
retriever_fn=lambda q: _retriever.retrieve(q).chunks,
)
final_chunks = crag_result.final_chunks
# Emit CRAG event if something interesting happened
if crag_result.grade.value != "GOOD" or crag_result.web_search_used:
yield _sse_event({
"type": "crag_update",
"grade": crag_result.grade.value,
"rewritten_query": crag_result.rewritten_query,
"web_search_used": crag_result.web_search_used,
"reasoning": crag_result.reasoning,
})
# 4. Stream answer tokens
_llm = req.llm or {}
gen_request = GenerationRequest(
query=req.query, chunks=final_chunks, stream=True,
provider=getattr(_llm, 'provider', None),
model=getattr(_llm, 'model', None),
api_key=getattr(_llm, 'api_key', None),
base_url=getattr(_llm, 'base_url', None),
)
for token in _generator.stream(gen_request):
yield _sse_event({"type": "token", "text": token})
# 4. Emit sources block
sources = _generator.build_sources_block(final_chunks)
yield _sse_event({"type": "sources", "text": sources})
# 5. Signal completion
yield _sse_event({"type": "done"})
except Exception as exc:
logger.exception("Streaming error")
yield _sse_event({"type": "error", "message": str(exc)})
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # disable nginx buffering
},
)
# ── SSE helper ─────────────────────────────────────────────────
def _sse_event(data: dict) -> str:
"""Format a dict as a Server-Sent Event string."""
return f"data: {json.dumps(data)}\n\n"
# ── Dev server entry point ─────────────────────────────────────
if __name__ == "__main__":
import uvicorn
cfg = get_settings()
logging.basicConfig(
level=getattr(logging, cfg.log_level),
format="%(asctime)s %(levelname)-7s %(name)s β€” %(message)s",
)
uvicorn.run(
"api.main:app",
host=cfg.api_host,
port=cfg.api_port,
reload=cfg.api_reload,
)