Spaces:
Sleeping
Sleeping
| """ | |
| Cortex RAG β FastAPI Application | |
| Endpoints | |
| βββββββββ | |
| GET /health β system health check | |
| POST /ingest β trigger ingestion pipeline | |
| POST /query β blocking query (JSON response) | |
| POST /query/stream β streaming query (Server-Sent Events) | |
| Phase 1 uses dense-only retrieval. | |
| Later phases will add routing, graph, BM25, and CRAG via the same endpoint. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator, List | |
| from fastapi import FastAPI, File, HTTPException, Request, UploadFile | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from api.schemas import ( | |
| HealthResponse, | |
| IngestRequest, | |
| IngestResponse, | |
| ModelInfo, | |
| ProviderInfo, | |
| ProvidersResponse, | |
| QueryRequest, | |
| QueryResponse, | |
| ChunkResponse, | |
| CitationResponse, | |
| ) | |
| from config import get_settings | |
| from generation.generator import PROVIDERS, Generator, GenerationRequest | |
| from generation.crag import CRAGGate | |
| from evaluation.store import EvalStore, QueryLogEntry | |
| from evaluation.ragas_eval import RAGASEvaluator, EvalInput | |
| from retrieval.cache import CachedRetriever | |
| from ingestion.pipeline import IngestionPipeline | |
| from retrieval.dense import MilvusStore | |
| from retrieval.embedder import Embedder | |
| from retrieval.bm25 import BM25Retriever | |
| from retrieval.orchestrator import MultiStrategyRetriever | |
| logger = logging.getLogger(__name__) | |
| # ββ Shared singletons ββββββββββββββββββββββββββββββββββββββββββ | |
| # Created once on startup, shared across requests | |
| _embedder: Embedder = None | |
| _store: MilvusStore = None | |
| _bm25: BM25Retriever = None | |
| _retriever: MultiStrategyRetriever = None | |
| _crag: CRAGGate = None | |
| _eval_store: EvalStore = None | |
| _evaluator: RAGASEvaluator = None | |
| _generator: Generator = None | |
| _pipeline: IngestionPipeline = None | |
| async def lifespan(app: FastAPI): | |
| """Initialise shared resources on startup, clean up on shutdown.""" | |
| global _embedder, _store, _bm25, _retriever, _crag, _generator, _pipeline, _eval_store, _evaluator | |
| logger.info("Cortex starting up...") | |
| cfg = get_settings() | |
| _embedder = Embedder() | |
| _store = MilvusStore(embedder=_embedder) | |
| _bm25 = BM25Retriever() | |
| _retriever = MultiStrategyRetriever(embedder=_embedder, store=_store, bm25=_bm25) | |
| _crag = CRAGGate() | |
| _eval_store = EvalStore(db_path=cfg.eval_db_path) | |
| _evaluator = RAGASEvaluator(store=_eval_store) | |
| _generator = Generator() | |
| # Wrap retriever with Redis cache (degrades gracefully if Redis is absent) | |
| _retriever = CachedRetriever(_retriever) | |
| _pipeline = IngestionPipeline(embedder=_embedder, store=_store, bm25=_bm25) | |
| # Warm up: trigger model load immediately so first request is fast | |
| _ = _embedder.model | |
| logger.info("Cortex ready.") | |
| yield | |
| logger.info("Cortex shutting down.") | |
| # ββ App factory ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def create_app() -> FastAPI: | |
| global cfg | |
| cfg = get_settings() | |
| app = FastAPI( | |
| title="Cortex RAG API", | |
| description=( | |
| "Production-grade Retrieval-Augmented Generation system " | |
| "with multi-strategy retrieval, CRAG, and RAGAS evaluation." | |
| ), | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| return app | |
| app = create_app() | |
| # Mount the SPA β served at / and all sub-paths not matched by API routes | |
| _STATIC_DIR = Path(__file__).parent.parent / "ui" / "static" | |
| if _STATIC_DIR.exists(): | |
| app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static") | |
| # Temporary directory for browser-uploaded files (auto-created) | |
| _UPLOAD_DIR = Path(cfg.upload_dir) | |
| _UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| async def serve_spa(): | |
| return FileResponse(str(_STATIC_DIR / "index.html")) | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health() -> HealthResponse: | |
| """ | |
| Returns the health of all system components. | |
| Use this to verify Milvus is reachable and the model is loaded. | |
| """ | |
| milvus_status = "ok" | |
| collection_stats = {} | |
| try: | |
| collection_stats = _store.collection_stats() | |
| except Exception as exc: | |
| milvus_status = f"error: {exc}" | |
| embedder_status = "loaded" if _embedder and _embedder._model else "not_loaded" | |
| graph_stats = {} | |
| try: | |
| graph_stats = _retriever.graph_builder.stats() | |
| except Exception: | |
| pass | |
| return HealthResponse( | |
| status="ok" if milvus_status == "ok" else "degraded", | |
| milvus=milvus_status, | |
| embedder=embedder_status, | |
| collection_stats=collection_stats, | |
| graph_stats=graph_stats, | |
| ) | |
| async def ingest(req: IngestRequest) -> IngestResponse: | |
| """ | |
| Trigger the ingestion pipeline for a file or directory. | |
| - Deduplicates by doc_id (SHA-256 of file path) | |
| - Returns counts for documents processed, chunks created, and errors | |
| """ | |
| import os | |
| path = req.path | |
| if not os.path.exists(path): | |
| raise HTTPException(status_code=404, detail=f"Path not found: {path}") | |
| try: | |
| if os.path.isfile(path): | |
| stats = _pipeline.ingest_file(path) | |
| else: | |
| stats = _pipeline.ingest_directory(path, recursive=req.recursive) | |
| except Exception as exc: | |
| logger.exception("Ingestion error") | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return IngestResponse(**stats) | |
| async def ingest_upload(files: List[UploadFile] = File(...)) -> IngestResponse: | |
| """ | |
| Upload files directly from the browser and ingest them. | |
| Accepts one or more files (PDF, HTML, TXT, Markdown). | |
| Files are saved to data/uploads/<original_filename> and then | |
| passed through the same ingestion pipeline as /ingest. | |
| Duplicate filenames are overwritten β re-uploading the same | |
| file will be deduplicated at the chunk level by doc_id. | |
| """ | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files provided.") | |
| saved_paths: list[Path] = [] | |
| save_errors: list[dict] = [] | |
| for upload in files: | |
| # Sanitise filename β strip any path components the browser may include | |
| safe_name = Path(upload.filename).name | |
| if not safe_name: | |
| continue | |
| dest = _UPLOAD_DIR / safe_name | |
| try: | |
| content_bytes = await upload.read() | |
| dest.write_bytes(content_bytes) | |
| saved_paths.append(dest) | |
| logger.info("Uploaded: %s (%d bytes)", safe_name, len(content_bytes)) | |
| except Exception as exc: | |
| logger.warning("Failed to save %s: %s", safe_name, exc) | |
| save_errors.append({"source": safe_name, "error": str(exc)}) | |
| finally: | |
| await upload.close() | |
| if not saved_paths: | |
| raise HTTPException(status_code=400, detail="No files could be saved.") | |
| # Run ingestion on each saved file | |
| merged: dict = { | |
| "documents_processed": 0, | |
| "documents_skipped": 0, | |
| "chunks_created": 0, | |
| "chunks_stored": 0, | |
| "bm25_indexed": 0, | |
| "graph_entities": 0, | |
| "graph_triples": 0, | |
| "errors": save_errors, | |
| } | |
| for path in saved_paths: | |
| try: | |
| stats = _pipeline.ingest_file(path) | |
| for key in ("documents_processed", "documents_skipped", | |
| "chunks_created", "chunks_stored", "bm25_indexed", | |
| "graph_entities", "graph_triples"): | |
| merged[key] += stats.get(key, 0) | |
| merged["errors"].extend(stats.get("errors", [])) | |
| except Exception as exc: | |
| logger.exception("Ingestion error for %s", path.name) | |
| merged["errors"].append({"source": path.name, "error": str(exc)}) | |
| return IngestResponse(**merged) | |
| async def get_metrics(limit: int = 100, days: int = 7): | |
| """ | |
| Query performance metrics and RAGAS scores for the dashboard. | |
| Returns summary stats, recent query logs, and hourly timeseries. | |
| """ | |
| return { | |
| "summary": _eval_store.get_summary_stats(), | |
| "recent": _eval_store.get_recent_queries(limit=limit), | |
| "timeseries": _eval_store.get_metric_timeseries(days=days), | |
| "cache": _retriever.cache_stats(), | |
| } | |
| async def flush_cache(): | |
| """Flush all Redis retrieval cache entries.""" | |
| deleted = _retriever.flush_all() | |
| return {"deleted": deleted} | |
| async def get_providers() -> ProvidersResponse: | |
| """ | |
| Returns the full provider/model catalogue and which providers are | |
| configured (i.e. have an API key in .env). | |
| """ | |
| cfg = get_settings() | |
| infos: list[ProviderInfo] = [] | |
| for pid, pdata in PROVIDERS.items(): | |
| env_key = pdata["env_key"] | |
| key_set = bool(getattr(cfg, env_key, "") or getattr(cfg, "groq_api_key", "")) | |
| infos.append(ProviderInfo( | |
| id=pid, | |
| label=pdata["label"], | |
| base_url=pdata["base_url"], | |
| models=[ModelInfo(id=m["id"], label=m["label"]) for m in pdata["models"]], | |
| configured=key_set, | |
| )) | |
| return ProvidersResponse( | |
| providers=infos, | |
| default_provider=getattr(cfg, "default_provider", "groq"), | |
| default_model=getattr(cfg, "groq_model", "llama-3.3-70b-versatile"), | |
| ) | |
| async def query(req: QueryRequest) -> QueryResponse: | |
| """ | |
| Blocking query endpoint. | |
| Retrieves top-k chunks and returns a complete cited answer. | |
| """ | |
| cfg = get_settings() | |
| k = req.top_k or cfg.retrieval_top_k | |
| import time as _time | |
| _t0 = _time.perf_counter() | |
| try: | |
| retrieval = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k) | |
| except Exception as exc: | |
| logger.exception("Retrieval error") | |
| raise HTTPException(status_code=500, detail=f"Retrieval failed: {exc}") | |
| if retrieval.empty: | |
| return QueryResponse( | |
| query=req.query, | |
| answer="No relevant documents found in the knowledge base.", | |
| citations=[], | |
| retrieved_chunks=[], | |
| model="", | |
| usage={}, | |
| ) | |
| final_chunks = retrieval.chunks | |
| # CRAG gate: grade, rewrite if POOR, web-search fallback if ABSENT | |
| crag_result = _crag.evaluate( | |
| query=req.query, | |
| chunks=final_chunks, | |
| retriever_fn=lambda q: _retriever.retrieve(q).chunks, | |
| ) | |
| final_chunks = crag_result.final_chunks | |
| llm = req.llm or {} | |
| llm_provider = getattr(llm, 'provider', None) if hasattr(llm, 'provider') else None | |
| llm_model = getattr(llm, 'model', None) if hasattr(llm, 'model') else None | |
| llm_api_key = getattr(llm, 'api_key', None) if hasattr(llm, 'api_key') else None | |
| llm_base_url = getattr(llm, 'base_url', None) if hasattr(llm, 'base_url') else None | |
| try: | |
| result = _generator.generate( | |
| GenerationRequest( | |
| query=req.query, chunks=final_chunks, | |
| provider=llm_provider, model=llm_model, | |
| api_key=llm_api_key, base_url=llm_base_url, | |
| ) | |
| ) | |
| except Exception as exc: | |
| logger.exception("Generation error") | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {exc}") | |
| latency_ms = (_time.perf_counter() - _t0) * 1000 | |
| log_id = _eval_store.log_query(QueryLogEntry( | |
| query=req.query, | |
| intent=retrieval.decision.intent.value, | |
| strategies=retrieval.decision.strategies, | |
| retriever_hits=retrieval.retriever_hits, | |
| crag_grade=crag_result.grade.value, | |
| crag_rewritten=bool(crag_result.rewritten_query), | |
| web_search_used=crag_result.web_search_used, | |
| num_chunks=len(final_chunks), | |
| top_chunk_score=final_chunks[0].score if final_chunks else 0.0, | |
| latency_ms=latency_ms, | |
| model=result.model, | |
| )) | |
| if cfg.eval_enabled: | |
| _evaluator.evaluate_async(EvalInput( | |
| query_log_id=log_id, | |
| query=req.query, | |
| answer=result.answer, | |
| chunks=final_chunks, | |
| )) | |
| return QueryResponse( | |
| query=req.query, | |
| answer=result.answer, | |
| citations=[ | |
| CitationResponse( | |
| number=c.number, | |
| title=c.title, | |
| source=c.source, | |
| chunk_id=c.chunk_id, | |
| score=c.score, | |
| ) | |
| for c in result.citations | |
| ], | |
| retrieved_chunks=[ | |
| ChunkResponse( | |
| chunk_id=ch.chunk_id, | |
| doc_id=ch.doc_id, | |
| source=ch.source, | |
| title=ch.title, | |
| text=ch.text, | |
| score=ch.score, | |
| ) | |
| for ch in final_chunks | |
| ], | |
| model=result.model, | |
| usage=result.usage, | |
| ) | |
| async def query_stream(req: QueryRequest): | |
| """ | |
| Streaming query endpoint using Server-Sent Events (SSE). | |
| Event types emitted: | |
| - data: {"type": "chunk_meta", "chunks": [...]} β retrieved chunks | |
| - data: {"type": "token", "text": "..."} β answer tokens | |
| - data: {"type": "sources", "text": "..."} β sources block | |
| - data: {"type": "done"} β stream complete | |
| - data: {"type": "error", "message": "..."} β error event | |
| """ | |
| cfg = get_settings() | |
| k = req.top_k or cfg.retrieval_top_k | |
| print(req) | |
| async def event_stream() -> AsyncGenerator[str, None]: | |
| try: | |
| # 1. Retrieve | |
| # 1. Multi-strategy retrieval: router β dense+BM25 β RRF β cross-encoder | |
| result = _retriever.retrieve(req.query, top_k_candidates=k, final_top_k=cfg.final_top_k) | |
| final_chunks = result.chunks | |
| # 2. Emit chunk metadata + routing decision so UI shows sources + strategy info immediately | |
| chunk_meta = [ | |
| { | |
| "chunk_id": c.chunk_id, | |
| "title": c.title, | |
| "source": c.source, | |
| "score": round(c.score, 4), | |
| "retriever": c.retriever, | |
| "text_snippet": c.text[:200], | |
| } | |
| for c in final_chunks | |
| ] | |
| yield _sse_event({ | |
| "type": "chunk_meta", | |
| "chunks": chunk_meta, | |
| "routing": { | |
| "intent": result.decision.intent.value, | |
| "strategies": result.decision.strategies, | |
| "retriever_hits": result.retriever_hits, | |
| "reasoning": result.decision.reasoning, | |
| }, | |
| }) | |
| if not final_chunks: | |
| yield _sse_event({ | |
| "type": "token", | |
| "text": "No relevant documents found in the knowledge base.", | |
| }) | |
| yield _sse_event({"type": "done"}) | |
| return | |
| # 3. CRAG gate β grade, optionally rewrite + re-retrieve | |
| crag_result = _crag.evaluate( | |
| query=req.query, | |
| chunks=final_chunks, | |
| retriever_fn=lambda q: _retriever.retrieve(q).chunks, | |
| ) | |
| final_chunks = crag_result.final_chunks | |
| # Emit CRAG event if something interesting happened | |
| if crag_result.grade.value != "GOOD" or crag_result.web_search_used: | |
| yield _sse_event({ | |
| "type": "crag_update", | |
| "grade": crag_result.grade.value, | |
| "rewritten_query": crag_result.rewritten_query, | |
| "web_search_used": crag_result.web_search_used, | |
| "reasoning": crag_result.reasoning, | |
| }) | |
| # 4. Stream answer tokens | |
| _llm = req.llm or {} | |
| gen_request = GenerationRequest( | |
| query=req.query, chunks=final_chunks, stream=True, | |
| provider=getattr(_llm, 'provider', None), | |
| model=getattr(_llm, 'model', None), | |
| api_key=getattr(_llm, 'api_key', None), | |
| base_url=getattr(_llm, 'base_url', None), | |
| ) | |
| for token in _generator.stream(gen_request): | |
| yield _sse_event({"type": "token", "text": token}) | |
| # 4. Emit sources block | |
| sources = _generator.build_sources_block(final_chunks) | |
| yield _sse_event({"type": "sources", "text": sources}) | |
| # 5. Signal completion | |
| yield _sse_event({"type": "done"}) | |
| except Exception as exc: | |
| logger.exception("Streaming error") | |
| yield _sse_event({"type": "error", "message": str(exc)}) | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", # disable nginx buffering | |
| }, | |
| ) | |
| # ββ SSE helper βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sse_event(data: dict) -> str: | |
| """Format a dict as a Server-Sent Event string.""" | |
| return f"data: {json.dumps(data)}\n\n" | |
| # ββ Dev server entry point βββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| cfg = get_settings() | |
| logging.basicConfig( | |
| level=getattr(logging, cfg.log_level), | |
| format="%(asctime)s %(levelname)-7s %(name)s β %(message)s", | |
| ) | |
| uvicorn.run( | |
| "api.main:app", | |
| host=cfg.api_host, | |
| port=cfg.api_port, | |
| reload=cfg.api_reload, | |
| ) | |