Spaces:
Sleeping
Sleeping
| """ | |
| Production-Grade FastAPI Backend for easyResearch RAG System. | |
| Features: | |
| - Hybrid search with re-ranking | |
| - Qdrant connection retry with exponential backoff | |
| - Groq API rate limiting with retry | |
| - Batch processing endpoints | |
| - Full observability integration | |
| - OpenAPI documentation | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import functools | |
| import os | |
| import shutil | |
| import tempfile | |
| import time | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Any, Literal | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| from config import ( | |
| Config, | |
| QDRANT_HOST, | |
| QDRANT_PORT, | |
| API_HOST, | |
| API_PORT, | |
| DEVICE, | |
| ) | |
| # Import core modules | |
| from core.rag_engine import query_rag, RetrievalConfig | |
| from core.pipeline import run_pipeline, run_pipeline_async, get_pipeline_status, PipelineConfig | |
| from core.embedder import ( | |
| add_to_vector_db, | |
| get_all_notebooks, | |
| get_notebook_stats, | |
| delete_notebook, | |
| delete_file_from_notebook, | |
| get_total_db_size, | |
| check_qdrant_health, | |
| ) | |
| from core.loader import load_and_split_document | |
| from core.observability import ( | |
| rag_logger, | |
| get_current_metrics, | |
| get_recent_traces, | |
| clear_logs, | |
| MetricsCalculator, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration & Constants | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_VERSION = "0.1.0" | |
| API_TITLE = "easyResearch RAG API" | |
| API_DESCRIPTION = """ | |
| Production-grade RAG API with hybrid search, re-ranking, and Big Data pipelines. | |
| ## Features | |
| - **Hybrid Search**: Dense vectors + BM25 sparse retrieval | |
| - **Cross-Encoder Re-ranking**: ms-marco-MiniLM-L-6-v2 | |
| - **Metadata Enrichment**: LLM-extracted tags and summaries | |
| - **Full Observability**: Tracing, metrics, and logging | |
| ## Workspaces | |
| Each workspace has isolated documents and chat history. | |
| """ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Retry Decorators for Resilience | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class QdrantConnectionError(Exception): | |
| """Custom exception for Qdrant connection issues.""" | |
| pass | |
| class GroqRateLimitError(Exception): | |
| """Custom exception for Groq rate limiting.""" | |
| pass | |
| def with_qdrant_retry(func): | |
| """Decorator for Qdrant operations with exponential backoff.""" | |
| async def wrapper(*args, **kwargs): | |
| try: | |
| return await func(*args, **kwargs) | |
| except Exception as e: | |
| if "connection" in str(e).lower() or "timeout" in str(e).lower(): | |
| rag_logger.warning(f"Qdrant connection error, retrying: {e}") | |
| raise QdrantConnectionError(str(e)) from e | |
| raise | |
| return wrapper | |
| def with_groq_rate_limit(func): | |
| """Decorator for Groq API calls with rate limit handling.""" | |
| async def wrapper(*args, **kwargs): | |
| max_retries = 3 | |
| base_delay = 2.0 | |
| for attempt in range(max_retries): | |
| try: | |
| return await func(*args, **kwargs) | |
| except Exception as e: | |
| error_msg = str(e).lower() | |
| # Check for rate limit indicators | |
| if "rate" in error_msg or "429" in error_msg or "limit" in error_msg: | |
| delay = base_delay * (2 ** attempt) | |
| rag_logger.warning(f"Rate limited, waiting {delay}s before retry {attempt + 1}/{max_retries}") | |
| await asyncio.sleep(delay) | |
| if attempt == max_retries - 1: | |
| raise GroqRateLimitError(f"Rate limit exceeded after {max_retries} retries") from e | |
| else: | |
| raise | |
| raise GroqRateLimitError("Max retries exceeded") | |
| return wrapper | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Pydantic Models | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AskRequest(BaseModel): | |
| """Request model for RAG queries.""" | |
| question: str = Field(..., min_length=1, max_length=4000, description="User's question") | |
| collection_name: str = Field(default="Default_Project", description="Workspace/collection name") | |
| chat_history: list[dict] = Field(default_factory=list, description="Previous conversation messages") | |
| k_target: int = Field(default=10, ge=1, le=50, description="Number of documents to retrieve") | |
| format_filter: str | None = Field(default=None, description="Filter by document format") | |
| source_filter: str | None = Field(default=None, description="Filter by source filename") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "question": "What is RPC in distributed systems?", | |
| "collection_name": "network_programming", | |
| "k_target": 10, | |
| } | |
| } | |
| class AskResponse(BaseModel): | |
| """Response model for RAG queries.""" | |
| answer: str | |
| sources: list[str] | |
| standalone_question: str | None = None | |
| pipeline_info: dict = Field(default_factory=dict) | |
| raw_docs: list[dict] | None = None | |
| class UploadRequest(BaseModel): | |
| """Request model for file upload configuration.""" | |
| collection_name: str = Field(default="Default_Project") | |
| use_parent_retrieval: bool = Field(default=True, description="Enable parent-child chunking") | |
| class UploadResponse(BaseModel): | |
| """Response model for file upload.""" | |
| filename: str | |
| chunks: int | |
| collection: str | |
| metadata: dict | None = None | |
| class PipelineRequest(BaseModel): | |
| """Request model for ingestion pipeline.""" | |
| collection_name: str = Field(default="Default_Project") | |
| source_dir: str | None = Field(default=None, description="Source directory path") | |
| chunk_size: int = Field(default=400, ge=100, le=4000) | |
| chunk_overlap: int = Field(default=80, ge=0, le=500) | |
| batch_size: int = Field(default=32, ge=1, le=128) | |
| enable_enrichment: bool = Field(default=True, description="Enable LLM metadata enrichment") | |
| reset_db: bool = Field(default=True, description="Reset existing collection") | |
| class PipelineResponse(BaseModel): | |
| """Response model for pipeline status.""" | |
| stage: str | |
| progress: float | |
| message: str | |
| error: str | None = None | |
| docs_cleaned: int = 0 | |
| docs_enriched: int = 0 | |
| chunks_created: int = 0 | |
| chunks_embedded: int = 0 | |
| elapsed: float = 0.0 | |
| class WorkspaceStats(BaseModel): | |
| """Workspace statistics model.""" | |
| name: str | |
| chunks: int | |
| files: list[str] | |
| size_mb: float | |
| metadata: dict | None = None | |
| class HealthResponse(BaseModel): | |
| """Health check response model.""" | |
| status: Literal["ok", "degraded", "error"] | |
| version: str | |
| device: str | |
| gpu_name: str | None | |
| gpu_memory_mb: float | None | |
| qdrant: dict | |
| db_size_mb: float | |
| class MetricsResponse(BaseModel): | |
| """RAG metrics response model.""" | |
| hit_rate: float | |
| mrr: float | |
| total_queries: int | |
| successful_queries: int | |
| failed_queries: int | |
| avg_retrieval_time_ms: float | |
| avg_rerank_time_ms: float | |
| avg_generation_time_ms: float | |
| avg_total_time_ms: float | |
| p95_total_time_ms: float | |
| avg_docs_retrieved: float | |
| avg_context_length: float | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Application Lifecycle | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| """Application lifecycle manager.""" | |
| # Startup | |
| rag_logger.info(f"Starting {API_TITLE} v{API_VERSION}") | |
| rag_logger.info(f"Device: {DEVICE}") | |
| # Check Qdrant connection | |
| status = check_qdrant_health() | |
| if status.get("status") == "ok": | |
| rag_logger.info(f"β Qdrant connected: {QDRANT_HOST}:{QDRANT_PORT}") | |
| else: | |
| rag_logger.warning(f"β οΈ Qdrant health check failed: {status.get('error')}") | |
| # Log GPU info | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| rag_logger.info(f"π₯οΈ GPU: {gpu_name} ({gpu_mem:.1f}GB)") | |
| yield | |
| # Shutdown | |
| rag_logger.info("Shutting down API...") | |
| # Create FastAPI application | |
| app = FastAPI( | |
| title=API_TITLE, | |
| version=API_VERSION, | |
| description=API_DESCRIPTION, | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| openapi_url="/openapi.json", | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Error Handlers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def qdrant_error_handler(request, exc: QdrantConnectionError): | |
| rag_logger.error(f"Qdrant connection error: {exc}") | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "error": "Database connection error", | |
| "detail": str(exc), | |
| "retry_after": 5, | |
| }, | |
| ) | |
| async def rate_limit_handler(request, exc: GroqRateLimitError): | |
| rag_logger.warning(f"Rate limit error: {exc}") | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "error": "Rate limit exceeded", | |
| "detail": str(exc), | |
| "retry_after": 60, | |
| }, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Query Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def ask_question(req: AskRequest): | |
| """ | |
| Execute a RAG query with hybrid search and re-ranking. | |
| The query goes through: | |
| 1. Question contextualization (if chat history exists) | |
| 2. Dense vector search | |
| 3. BM25 sparse ranking | |
| 4. Reciprocal Rank Fusion | |
| 5. Cross-encoder re-ranking | |
| 6. LLM response generation | |
| """ | |
| try: | |
| config = RetrievalConfig(rerank_top_k=req.k_target) | |
| result = query_rag( | |
| question=req.question, | |
| collection_name=req.collection_name, | |
| chat_history=req.chat_history, | |
| k_target=req.k_target, | |
| format_filter=req.format_filter, | |
| source_filter=req.source_filter, | |
| retrieval_config=config, | |
| ) | |
| return AskResponse( | |
| answer=result["answer"], | |
| sources=result.get("sources", []), | |
| standalone_question=result.get("standalone_question"), | |
| pipeline_info=result.get("pipeline_info", {}), | |
| raw_docs=result.get("raw_docs"), | |
| ) | |
| except Exception as e: | |
| rag_logger.exception("Query failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def search_documents( | |
| collection_name: str, | |
| q: str = Query(..., min_length=1, description="Search query"), | |
| k: int = Query(default=10, ge=1, le=50, description="Number of results"), | |
| format_filter: str | None = Query(default=None, description="Filter by format"), | |
| ): | |
| """ | |
| Perform semantic search without LLM generation. | |
| Returns raw document matches with scores. | |
| """ | |
| from core.rag_engine import hybrid_search | |
| try: | |
| filter_dict = {"format": format_filter} if format_filter else None | |
| config = RetrievalConfig(rerank_top_k=k) | |
| results = hybrid_search(collection_name, q, config=config, filter_dict=filter_dict) | |
| return { | |
| "query": q, | |
| "count": len(results), | |
| "results": [ | |
| { | |
| "source": doc.metadata.get("source", "Unknown"), | |
| "score": round(score, 4), | |
| "content": doc.page_content[:500], | |
| "metadata": { | |
| k: v for k, v in doc.metadata.items() | |
| if k not in ["parent_content"] | |
| }, | |
| } | |
| for doc, score in results | |
| ], | |
| } | |
| except Exception as e: | |
| rag_logger.exception("Search failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Ingestion Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def upload_document( | |
| file: UploadFile = File(...), | |
| collection_name: str = Form(default="Default_Project"), | |
| use_parent_retrieval: bool = Form(default=True), | |
| ): | |
| """ | |
| Upload and index a single document. | |
| Supports: PDF, DOCX, TXT, PY, JS, JSON, CSV | |
| """ | |
| filename = file.filename or "unknown" | |
| suffix = Path(filename).suffix.lower() | |
| allowed_extensions = {".pdf", ".docx", ".txt", ".py", ".js", ".json", ".csv"} | |
| if suffix not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type: {suffix}. Allowed: {', '.join(allowed_extensions)}", | |
| ) | |
| upload_dir = Config.get_workspace_dir(collection_name) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=str(upload_dir)) as tmp: | |
| content = await file.read() | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| try: | |
| chunks = load_and_split_document(tmp_path, use_parent_retrieval=use_parent_retrieval) | |
| add_to_vector_db(chunks, collection_name=collection_name) | |
| return UploadResponse( | |
| filename=filename, | |
| chunks=len(chunks), | |
| collection=collection_name, | |
| metadata={"parent_retrieval": use_parent_retrieval}, | |
| ) | |
| except Exception as e: | |
| rag_logger.exception(f"Upload failed: {file.filename}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| async def start_pipeline( | |
| req: PipelineRequest, | |
| background_tasks: BackgroundTasks, | |
| ): | |
| """ | |
| Start the ingestion pipeline in the background. | |
| Stages: | |
| 1. Document cleaning and text extraction | |
| 2. LLM metadata enrichment (optional) | |
| 3. Chunking with deduplication | |
| 4. CUDA-accelerated embedding | |
| """ | |
| config = PipelineConfig( | |
| chunk_size=req.chunk_size, | |
| chunk_overlap=req.chunk_overlap, | |
| batch_size=req.batch_size, | |
| enable_llm_enrichment=req.enable_enrichment, | |
| reset_db=req.reset_db, | |
| ) | |
| source_dir = Path(req.source_dir) if req.source_dir else Config.get_workspace_dir(req.collection_name) | |
| # Start in background | |
| background_tasks.add_task( | |
| run_pipeline, | |
| source_dir=source_dir, | |
| collection_name=req.collection_name, | |
| config=config, | |
| ) | |
| return PipelineResponse( | |
| stage="starting", | |
| progress=0.0, | |
| message=f"Pipeline started for {req.collection_name}", | |
| ) | |
| async def pipeline_status(): | |
| """Get current pipeline execution status.""" | |
| status = get_pipeline_status() | |
| return PipelineResponse(**status) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Workspace Management Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def list_workspaces(): | |
| """List all available workspaces.""" | |
| workspaces = get_all_notebooks() | |
| return { | |
| "workspaces": workspaces, | |
| "count": len(workspaces), | |
| } | |
| async def get_workspace(workspace_name: str): | |
| """Get detailed statistics for a workspace.""" | |
| stats = get_notebook_stats(workspace_name) | |
| return WorkspaceStats( | |
| name=workspace_name, | |
| chunks=stats.get("chunks", 0), | |
| files=stats.get("files", []), | |
| size_mb=stats.get("size_mb", 0.0), | |
| ) | |
| async def remove_workspace(workspace_name: str): | |
| """Delete a workspace and all its data.""" | |
| success = delete_notebook(workspace_name) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Workspace not found") | |
| # Also clean up upload directory | |
| upload_dir = Config.get_workspace_dir(workspace_name) | |
| if upload_dir.exists(): | |
| shutil.rmtree(upload_dir, ignore_errors=True) | |
| return {"deleted": workspace_name, "status": "success"} | |
| async def remove_file_from_workspace(workspace_name: str, filename: str): | |
| """Remove a specific file from a workspace.""" | |
| deleted_count = delete_file_from_notebook(workspace_name, filename) | |
| if deleted_count == 0: | |
| raise HTTPException(status_code=404, detail="File not found in workspace") | |
| return { | |
| "deleted": filename, | |
| "chunks_removed": deleted_count, | |
| "workspace": workspace_name, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Observability Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health_check(): | |
| """ | |
| Comprehensive health check. | |
| Checks Qdrant connection, GPU availability, and system resources. | |
| """ | |
| qdrant_status = check_qdrant_health() | |
| gpu_name = None | |
| gpu_memory = None | |
| device = DEVICE | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**2) | |
| status = "ok" | |
| if qdrant_status.get("status") != "ok": | |
| status = "degraded" | |
| return HealthResponse( | |
| status=status, | |
| version=API_VERSION, | |
| device=device, | |
| gpu_name=gpu_name, | |
| gpu_memory_mb=gpu_memory, | |
| qdrant=qdrant_status, | |
| db_size_mb=get_total_db_size(), | |
| ) | |
| async def get_metrics(): | |
| """ | |
| Get RAG performance metrics. | |
| Includes hit rate, MRR, latency percentiles, and volume metrics. | |
| """ | |
| metrics = get_current_metrics() | |
| return MetricsResponse(**metrics) | |
| async def get_traces( | |
| limit: int = Query(default=50, ge=1, le=500, description="Number of traces to return"), | |
| ): | |
| """Get recent RAG pipeline traces.""" | |
| traces = get_recent_traces(limit=limit) | |
| return { | |
| "count": len(traces), | |
| "traces": traces, | |
| } | |
| async def clear_all_logs(): | |
| """Clear all log files (traces, metrics).""" | |
| clear_logs() | |
| return {"status": "cleared"} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Batch Processing Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BatchQueryRequest(BaseModel): | |
| """Request for batch query processing.""" | |
| questions: list[str] = Field(..., min_length=1, max_length=20) | |
| collection_name: str = Field(default="Default_Project") | |
| k_target: int = Field(default=5) | |
| async def batch_query(req: BatchQueryRequest): | |
| """ | |
| Process multiple queries in batch. | |
| Useful for evaluation and testing. | |
| Limited to 20 questions per batch. | |
| """ | |
| results = [] | |
| for i, question in enumerate(req.questions): | |
| try: | |
| result = query_rag( | |
| question=question, | |
| collection_name=req.collection_name, | |
| k_target=req.k_target, | |
| ) | |
| results.append({ | |
| "index": i, | |
| "question": question, | |
| "answer": result["answer"], | |
| "sources": result.get("sources", []), | |
| "success": True, | |
| }) | |
| # Rate limit between queries | |
| await asyncio.sleep(0.5) | |
| except Exception as e: | |
| results.append({ | |
| "index": i, | |
| "question": question, | |
| "error": str(e), | |
| "success": False, | |
| }) | |
| return { | |
| "total": len(req.questions), | |
| "successful": sum(1 for r in results if r["success"]), | |
| "results": results, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry Point | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "api_main:app", | |
| host=API_HOST, | |
| port=API_PORT, | |
| reload=True, | |
| log_level="info", | |
| ) | |