easyResearchBigData / api_main.py
hzjanuary's picture
Upload 21 files
fa89805 verified
"""
Production-Grade FastAPI Backend for easyResearch RAG System.
Features:
- Hybrid search with re-ranking
- Qdrant connection retry with exponential backoff
- Groq API rate limiting with retry
- Batch processing endpoints
- Full observability integration
- OpenAPI documentation
"""
from __future__ import annotations
import asyncio
import functools
import os
import shutil
import tempfile
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Literal
import torch
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from config import (
Config,
QDRANT_HOST,
QDRANT_PORT,
API_HOST,
API_PORT,
DEVICE,
)
# Import core modules
from core.rag_engine import query_rag, RetrievalConfig
from core.pipeline import run_pipeline, run_pipeline_async, get_pipeline_status, PipelineConfig
from core.embedder import (
add_to_vector_db,
get_all_notebooks,
get_notebook_stats,
delete_notebook,
delete_file_from_notebook,
get_total_db_size,
check_qdrant_health,
)
from core.loader import load_and_split_document
from core.observability import (
rag_logger,
get_current_metrics,
get_recent_traces,
clear_logs,
MetricsCalculator,
)
# ──────────────────────────────────────────────────────────────────────────────
# Configuration & Constants
# ──────────────────────────────────────────────────────────────────────────────
API_VERSION = "0.1.0"
API_TITLE = "easyResearch RAG API"
API_DESCRIPTION = """
Production-grade RAG API with hybrid search, re-ranking, and Big Data pipelines.
## Features
- **Hybrid Search**: Dense vectors + BM25 sparse retrieval
- **Cross-Encoder Re-ranking**: ms-marco-MiniLM-L-6-v2
- **Metadata Enrichment**: LLM-extracted tags and summaries
- **Full Observability**: Tracing, metrics, and logging
## Workspaces
Each workspace has isolated documents and chat history.
"""
# ──────────────────────────────────────────────────────────────────────────────
# Retry Decorators for Resilience
# ──────────────────────────────────────────────────────────────────────────────
class QdrantConnectionError(Exception):
"""Custom exception for Qdrant connection issues."""
pass
class GroqRateLimitError(Exception):
"""Custom exception for Groq rate limiting."""
pass
def with_qdrant_retry(func):
"""Decorator for Qdrant operations with exponential backoff."""
@functools.wraps(func)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError, QdrantConnectionError)),
reraise=True,
)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
if "connection" in str(e).lower() or "timeout" in str(e).lower():
rag_logger.warning(f"Qdrant connection error, retrying: {e}")
raise QdrantConnectionError(str(e)) from e
raise
return wrapper
def with_groq_rate_limit(func):
"""Decorator for Groq API calls with rate limit handling."""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
max_retries = 3
base_delay = 2.0
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
error_msg = str(e).lower()
# Check for rate limit indicators
if "rate" in error_msg or "429" in error_msg or "limit" in error_msg:
delay = base_delay * (2 ** attempt)
rag_logger.warning(f"Rate limited, waiting {delay}s before retry {attempt + 1}/{max_retries}")
await asyncio.sleep(delay)
if attempt == max_retries - 1:
raise GroqRateLimitError(f"Rate limit exceeded after {max_retries} retries") from e
else:
raise
raise GroqRateLimitError("Max retries exceeded")
return wrapper
# ──────────────────────────────────────────────────────────────────────────────
# Pydantic Models
# ──────────────────────────────────────────────────────────────────────────────
class AskRequest(BaseModel):
"""Request model for RAG queries."""
question: str = Field(..., min_length=1, max_length=4000, description="User's question")
collection_name: str = Field(default="Default_Project", description="Workspace/collection name")
chat_history: list[dict] = Field(default_factory=list, description="Previous conversation messages")
k_target: int = Field(default=10, ge=1, le=50, description="Number of documents to retrieve")
format_filter: str | None = Field(default=None, description="Filter by document format")
source_filter: str | None = Field(default=None, description="Filter by source filename")
class Config:
json_schema_extra = {
"example": {
"question": "What is RPC in distributed systems?",
"collection_name": "network_programming",
"k_target": 10,
}
}
class AskResponse(BaseModel):
"""Response model for RAG queries."""
answer: str
sources: list[str]
standalone_question: str | None = None
pipeline_info: dict = Field(default_factory=dict)
raw_docs: list[dict] | None = None
class UploadRequest(BaseModel):
"""Request model for file upload configuration."""
collection_name: str = Field(default="Default_Project")
use_parent_retrieval: bool = Field(default=True, description="Enable parent-child chunking")
class UploadResponse(BaseModel):
"""Response model for file upload."""
filename: str
chunks: int
collection: str
metadata: dict | None = None
class PipelineRequest(BaseModel):
"""Request model for ingestion pipeline."""
collection_name: str = Field(default="Default_Project")
source_dir: str | None = Field(default=None, description="Source directory path")
chunk_size: int = Field(default=400, ge=100, le=4000)
chunk_overlap: int = Field(default=80, ge=0, le=500)
batch_size: int = Field(default=32, ge=1, le=128)
enable_enrichment: bool = Field(default=True, description="Enable LLM metadata enrichment")
reset_db: bool = Field(default=True, description="Reset existing collection")
class PipelineResponse(BaseModel):
"""Response model for pipeline status."""
stage: str
progress: float
message: str
error: str | None = None
docs_cleaned: int = 0
docs_enriched: int = 0
chunks_created: int = 0
chunks_embedded: int = 0
elapsed: float = 0.0
class WorkspaceStats(BaseModel):
"""Workspace statistics model."""
name: str
chunks: int
files: list[str]
size_mb: float
metadata: dict | None = None
class HealthResponse(BaseModel):
"""Health check response model."""
status: Literal["ok", "degraded", "error"]
version: str
device: str
gpu_name: str | None
gpu_memory_mb: float | None
qdrant: dict
db_size_mb: float
class MetricsResponse(BaseModel):
"""RAG metrics response model."""
hit_rate: float
mrr: float
total_queries: int
successful_queries: int
failed_queries: int
avg_retrieval_time_ms: float
avg_rerank_time_ms: float
avg_generation_time_ms: float
avg_total_time_ms: float
p95_total_time_ms: float
avg_docs_retrieved: float
avg_context_length: float
# ──────────────────────────────────────────────────────────────────────────────
# Application Lifecycle
# ──────────────────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifecycle manager."""
# Startup
rag_logger.info(f"Starting {API_TITLE} v{API_VERSION}")
rag_logger.info(f"Device: {DEVICE}")
# Check Qdrant connection
status = check_qdrant_health()
if status.get("status") == "ok":
rag_logger.info(f"βœ… Qdrant connected: {QDRANT_HOST}:{QDRANT_PORT}")
else:
rag_logger.warning(f"⚠️ Qdrant health check failed: {status.get('error')}")
# Log GPU info
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
rag_logger.info(f"πŸ–₯️ GPU: {gpu_name} ({gpu_mem:.1f}GB)")
yield
# Shutdown
rag_logger.info("Shutting down API...")
# Create FastAPI application
app = FastAPI(
title=API_TITLE,
version=API_VERSION,
description=API_DESCRIPTION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ──────────────────────────────────────────────────────────────────────────────
# Error Handlers
# ──────────────────────────────────────────────────────────────────────────────
@app.exception_handler(QdrantConnectionError)
async def qdrant_error_handler(request, exc: QdrantConnectionError):
rag_logger.error(f"Qdrant connection error: {exc}")
return JSONResponse(
status_code=503,
content={
"error": "Database connection error",
"detail": str(exc),
"retry_after": 5,
},
)
@app.exception_handler(GroqRateLimitError)
async def rate_limit_handler(request, exc: GroqRateLimitError):
rag_logger.warning(f"Rate limit error: {exc}")
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"detail": str(exc),
"retry_after": 60,
},
)
# ──────────────────────────────────────────────────────────────────────────────
# Query Endpoints
# ──────────────────────────────────────────────────────────────────────────────
@app.post("/ask", response_model=AskResponse, tags=["Query"])
@with_groq_rate_limit
async def ask_question(req: AskRequest):
"""
Execute a RAG query with hybrid search and re-ranking.
The query goes through:
1. Question contextualization (if chat history exists)
2. Dense vector search
3. BM25 sparse ranking
4. Reciprocal Rank Fusion
5. Cross-encoder re-ranking
6. LLM response generation
"""
try:
config = RetrievalConfig(rerank_top_k=req.k_target)
result = query_rag(
question=req.question,
collection_name=req.collection_name,
chat_history=req.chat_history,
k_target=req.k_target,
format_filter=req.format_filter,
source_filter=req.source_filter,
retrieval_config=config,
)
return AskResponse(
answer=result["answer"],
sources=result.get("sources", []),
standalone_question=result.get("standalone_question"),
pipeline_info=result.get("pipeline_info", {}),
raw_docs=result.get("raw_docs"),
)
except Exception as e:
rag_logger.exception("Query failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/search/{collection_name}", tags=["Query"])
async def search_documents(
collection_name: str,
q: str = Query(..., min_length=1, description="Search query"),
k: int = Query(default=10, ge=1, le=50, description="Number of results"),
format_filter: str | None = Query(default=None, description="Filter by format"),
):
"""
Perform semantic search without LLM generation.
Returns raw document matches with scores.
"""
from core.rag_engine import hybrid_search
try:
filter_dict = {"format": format_filter} if format_filter else None
config = RetrievalConfig(rerank_top_k=k)
results = hybrid_search(collection_name, q, config=config, filter_dict=filter_dict)
return {
"query": q,
"count": len(results),
"results": [
{
"source": doc.metadata.get("source", "Unknown"),
"score": round(score, 4),
"content": doc.page_content[:500],
"metadata": {
k: v for k, v in doc.metadata.items()
if k not in ["parent_content"]
},
}
for doc, score in results
],
}
except Exception as e:
rag_logger.exception("Search failed")
raise HTTPException(status_code=500, detail=str(e))
# ──────────────────────────────────────────────────────────────────────────────
# Ingestion Endpoints
# ──────────────────────────────────────────────────────────────────────────────
@app.post("/upload", response_model=UploadResponse, tags=["Ingestion"])
@with_qdrant_retry
async def upload_document(
file: UploadFile = File(...),
collection_name: str = Form(default="Default_Project"),
use_parent_retrieval: bool = Form(default=True),
):
"""
Upload and index a single document.
Supports: PDF, DOCX, TXT, PY, JS, JSON, CSV
"""
filename = file.filename or "unknown"
suffix = Path(filename).suffix.lower()
allowed_extensions = {".pdf", ".docx", ".txt", ".py", ".js", ".json", ".csv"}
if suffix not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {suffix}. Allowed: {', '.join(allowed_extensions)}",
)
upload_dir = Config.get_workspace_dir(collection_name)
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=str(upload_dir)) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
chunks = load_and_split_document(tmp_path, use_parent_retrieval=use_parent_retrieval)
add_to_vector_db(chunks, collection_name=collection_name)
return UploadResponse(
filename=filename,
chunks=len(chunks),
collection=collection_name,
metadata={"parent_retrieval": use_parent_retrieval},
)
except Exception as e:
rag_logger.exception(f"Upload failed: {file.filename}")
raise HTTPException(status_code=500, detail=str(e))
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
@app.post("/pipeline/start", response_model=PipelineResponse, tags=["Ingestion"])
async def start_pipeline(
req: PipelineRequest,
background_tasks: BackgroundTasks,
):
"""
Start the ingestion pipeline in the background.
Stages:
1. Document cleaning and text extraction
2. LLM metadata enrichment (optional)
3. Chunking with deduplication
4. CUDA-accelerated embedding
"""
config = PipelineConfig(
chunk_size=req.chunk_size,
chunk_overlap=req.chunk_overlap,
batch_size=req.batch_size,
enable_llm_enrichment=req.enable_enrichment,
reset_db=req.reset_db,
)
source_dir = Path(req.source_dir) if req.source_dir else Config.get_workspace_dir(req.collection_name)
# Start in background
background_tasks.add_task(
run_pipeline,
source_dir=source_dir,
collection_name=req.collection_name,
config=config,
)
return PipelineResponse(
stage="starting",
progress=0.0,
message=f"Pipeline started for {req.collection_name}",
)
@app.get("/pipeline/status", response_model=PipelineResponse, tags=["Ingestion"])
async def pipeline_status():
"""Get current pipeline execution status."""
status = get_pipeline_status()
return PipelineResponse(**status)
# ──────────────────────────────────────────────────────────────────────────────
# Workspace Management Endpoints
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/workspaces", tags=["Workspace"])
async def list_workspaces():
"""List all available workspaces."""
workspaces = get_all_notebooks()
return {
"workspaces": workspaces,
"count": len(workspaces),
}
@app.get("/workspaces/{workspace_name}", response_model=WorkspaceStats, tags=["Workspace"])
async def get_workspace(workspace_name: str):
"""Get detailed statistics for a workspace."""
stats = get_notebook_stats(workspace_name)
return WorkspaceStats(
name=workspace_name,
chunks=stats.get("chunks", 0),
files=stats.get("files", []),
size_mb=stats.get("size_mb", 0.0),
)
@app.delete("/workspaces/{workspace_name}", tags=["Workspace"])
async def remove_workspace(workspace_name: str):
"""Delete a workspace and all its data."""
success = delete_notebook(workspace_name)
if not success:
raise HTTPException(status_code=404, detail="Workspace not found")
# Also clean up upload directory
upload_dir = Config.get_workspace_dir(workspace_name)
if upload_dir.exists():
shutil.rmtree(upload_dir, ignore_errors=True)
return {"deleted": workspace_name, "status": "success"}
@app.delete("/workspaces/{workspace_name}/files/{filename}", tags=["Workspace"])
async def remove_file_from_workspace(workspace_name: str, filename: str):
"""Remove a specific file from a workspace."""
deleted_count = delete_file_from_notebook(workspace_name, filename)
if deleted_count == 0:
raise HTTPException(status_code=404, detail="File not found in workspace")
return {
"deleted": filename,
"chunks_removed": deleted_count,
"workspace": workspace_name,
}
# ──────────────────────────────────────────────────────────────────────────────
# Observability Endpoints
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse, tags=["Observability"])
async def health_check():
"""
Comprehensive health check.
Checks Qdrant connection, GPU availability, and system resources.
"""
qdrant_status = check_qdrant_health()
gpu_name = None
gpu_memory = None
device = DEVICE
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**2)
status = "ok"
if qdrant_status.get("status") != "ok":
status = "degraded"
return HealthResponse(
status=status,
version=API_VERSION,
device=device,
gpu_name=gpu_name,
gpu_memory_mb=gpu_memory,
qdrant=qdrant_status,
db_size_mb=get_total_db_size(),
)
@app.get("/metrics", response_model=MetricsResponse, tags=["Observability"])
async def get_metrics():
"""
Get RAG performance metrics.
Includes hit rate, MRR, latency percentiles, and volume metrics.
"""
metrics = get_current_metrics()
return MetricsResponse(**metrics)
@app.get("/traces", tags=["Observability"])
async def get_traces(
limit: int = Query(default=50, ge=1, le=500, description="Number of traces to return"),
):
"""Get recent RAG pipeline traces."""
traces = get_recent_traces(limit=limit)
return {
"count": len(traces),
"traces": traces,
}
@app.delete("/logs", tags=["Observability"])
async def clear_all_logs():
"""Clear all log files (traces, metrics)."""
clear_logs()
return {"status": "cleared"}
# ──────────────────────────────────────────────────────────────────────────────
# Batch Processing Endpoints
# ──────────────────────────────────────────────────────────────────────────────
class BatchQueryRequest(BaseModel):
"""Request for batch query processing."""
questions: list[str] = Field(..., min_length=1, max_length=20)
collection_name: str = Field(default="Default_Project")
k_target: int = Field(default=5)
@app.post("/batch/query", tags=["Batch"])
async def batch_query(req: BatchQueryRequest):
"""
Process multiple queries in batch.
Useful for evaluation and testing.
Limited to 20 questions per batch.
"""
results = []
for i, question in enumerate(req.questions):
try:
result = query_rag(
question=question,
collection_name=req.collection_name,
k_target=req.k_target,
)
results.append({
"index": i,
"question": question,
"answer": result["answer"],
"sources": result.get("sources", []),
"success": True,
})
# Rate limit between queries
await asyncio.sleep(0.5)
except Exception as e:
results.append({
"index": i,
"question": question,
"error": str(e),
"success": False,
})
return {
"total": len(req.questions),
"successful": sum(1 for r in results if r["success"]),
"results": results,
}
# ──────────────────────────────────────────────────────────────────────────────
# Entry Point
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api_main:app",
host=API_HOST,
port=API_PORT,
reload=True,
log_level="info",
)