Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Backend for Multi-PDF RAG System with Per-User Storage | |
| =============================================================== | |
| Secure multi-user API with: | |
| - API key authentication | |
| - Per-user storage isolation | |
| - PDF upload and management | |
| - RAG-based question answering | |
| - HF persistent storage | |
| """ | |
| import os | |
| import time | |
| import asyncio | |
| from typing import List, Optional, Dict | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, BackgroundTasks, Header, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from threading import Lock | |
| from rag_engine import RAGEngine | |
| from hf_storage import create_hf_storage_manager | |
| from auth import get_current_user | |
| from user_management import create_hf_user_manager | |
| from job_worker import job_manager | |
| from dotenv import load_dotenv | |
| # ============================================ | |
| # CONFIGURATION | |
| # ============================================ | |
| load_dotenv() | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_REPO = os.environ.get("HF_REPO", "Hamza4100/multi-pdf-storage") | |
| if not GEMINI_API_KEY: | |
| raise RuntimeError("β GEMINI_API_KEY not set") | |
| hf_storage = create_hf_storage_manager(hf_token=HF_TOKEN, hf_repo=HF_REPO) | |
| hf_user_manager = create_hf_user_manager(hf_token=HF_TOKEN, hf_repo=HF_REPO) | |
| app = FastAPI( | |
| title="Multi-PDF RAG System", | |
| description="Secure multi-user RAG API with persistent storage", | |
| version="2.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================ | |
| # PER-USER RAG ENGINE MANAGER | |
| # ============================================ | |
| class UserRAGManager: | |
| """Manages per-user RAG engine instances with lazy loading.""" | |
| def __init__(self): | |
| self.engines: Dict[str, RAGEngine] = {} | |
| self.locks: Dict[str, Lock] = {} | |
| self.global_lock = Lock() | |
| def get_user_lock(self, user_id: str) -> Lock: | |
| """Get or create lock for user.""" | |
| with self.global_lock: | |
| if user_id not in self.locks: | |
| self.locks[user_id] = Lock() | |
| return self.locks[user_id] | |
| async def get_engine(self, user_id: str) -> RAGEngine: | |
| """Get or create RAG engine for user (lazy loading).""" | |
| if user_id in self.engines: | |
| return self.engines[user_id] | |
| user_lock = self.get_user_lock(user_id) | |
| with user_lock: | |
| if user_id in self.engines: | |
| return self.engines[user_id] | |
| print(f"π§ Initializing RAG for user {user_id}...") | |
| # Sync from HF | |
| await asyncio.to_thread(hf_storage.sync_storage_from_hf, user_id) | |
| # User-specific paths | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| user_storage_dir = os.path.join(base_dir, "users", user_id, "storage") | |
| # Initialize engine | |
| engine = await asyncio.to_thread( | |
| RAGEngine, | |
| gemini_api_key=GEMINI_API_KEY, | |
| storage_dir=user_storage_dir | |
| ) | |
| self.engines[user_id] = engine | |
| print(f"β RAG ready for user {user_id}") | |
| return engine | |
| rag_manager = UserRAGManager() | |
| # ============================================ | |
| # MODELS | |
| # ============================================ | |
| class UploadResponse(BaseModel): | |
| document_id: str | |
| filename: str | |
| status: str | |
| message: str | |
| pages: Optional[int] = None | |
| chunks: Optional[int] = None | |
| class QueryRequest(BaseModel): | |
| question: str | |
| top_k: Optional[int] = 5 | |
| doc_id: Optional[str] = None | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: List[dict] | |
| class DocumentInfo(BaseModel): | |
| doc_id: str | |
| filename: str | |
| upload_timestamp: str | |
| num_chunks: int | |
| num_pages: int | |
| class StatsResponse(BaseModel): | |
| total_documents: int | |
| total_chunks: int | |
| index_size: int | |
| class DeleteResponse(BaseModel): | |
| status: str | |
| message: str | |
| # ============================================ | |
| # STARTUP | |
| # ============================================ | |
| async def startup_event(): | |
| print("π Multi-PDF RAG System v2.0") | |
| print(f"π¦ HF Storage: {'Enabled' if hf_storage.enabled else 'Disabled'}") | |
| # Load and display user count from HF | |
| if hf_user_manager.enabled: | |
| users_count = len(hf_user_manager.get_all_users()) | |
| print(f"π₯ Loaded {users_count} user(s) from HF repository") | |
| print(f"β User management: Enabled (HF-based)") | |
| else: | |
| print(f"β οΈ User management: Disabled (check HF_TOKEN and HF_REPO)") | |
| print("β Server ready (per-user lazy loading)") | |
| # Start persistent job manager | |
| try: | |
| job_manager.start() | |
| print("π Background job manager started") | |
| except Exception as e: | |
| print(f"β οΈ Failed to start job manager: {e}") | |
| # ============================================ | |
| # ENDPOINTS | |
| # ============================================ | |
| async def health_check(): | |
| """Health check (no auth required).""" | |
| return {"status": "ok"} | |
| async def upload_pdf( | |
| request: Request, | |
| file: UploadFile = File(...), | |
| user_id: str = Depends(get_current_user), | |
| x_api_key: Optional[str] = Header(None, alias="X-API-KEY"), | |
| background_tasks: BackgroundTasks = None | |
| ): | |
| """ | |
| Upload PDF for authenticated user. | |
| Requires: X-API-KEY header | |
| """ | |
| # Validate PDF | |
| if not file.filename.lower().endswith('.pdf'): | |
| raise HTTPException(400, "Only PDF files allowed") | |
| if file.content_type and file.content_type not in ['application/pdf']: | |
| raise HTTPException(400, "Invalid MIME type") | |
| # Read content quickly and persist to a temp file, then enqueue a job for background processing | |
| content = await file.read() | |
| # Size limit (10MB) | |
| if len(content) > 10 * 1024 * 1024: | |
| raise HTTPException(413, "File too large (max 10MB)") | |
| try: | |
| # Debug: log a subset of incoming headers to help diagnose authentication issues | |
| try: | |
| hdrs = dict(request.headers) | |
| # Only print a few header keys to avoid leaking secrets in logs | |
| interesting = {k: hdrs.get(k) for k in ["x-api-key", "authorization", "host", "user-agent"] if hdrs.get(k)} | |
| print(f"π Incoming upload headers (filtered): {interesting}") | |
| except Exception: | |
| pass | |
| # Save to a temp file so background worker can access it | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| tmp_dir = os.path.join(base_dir, 'tmp_uploads') | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| tmp_path = os.path.join(tmp_dir, f"{int(time.time())}_{file.filename}") | |
| with open(tmp_path, 'wb') as f: | |
| f.write(content) | |
| # Try to resolve username from HF user manager (if available) | |
| username = None | |
| try: | |
| if hf_user_manager and hf_user_manager.enabled and x_api_key: | |
| found = hf_user_manager.get_user_by_api_key(x_api_key) | |
| if found: | |
| username = found[0] | |
| except Exception: | |
| username = None | |
| # Create persistent job and return immediately | |
| job = job_manager.create_job(user_id=user_id, filename=file.filename, file_path=tmp_path, username=username) | |
| return UploadResponse( | |
| document_id=job['id'], | |
| filename=file.filename, | |
| status="processing", | |
| message="Upload accepted and is being processed in background" | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"β Upload enqueue error (user {user_id}): {e}") | |
| raise HTTPException(500, "Failed to enqueue upload") | |
| async def upload_status(job_id: str): | |
| job = job_manager.get_job(job_id) | |
| if not job: | |
| raise HTTPException(404, "Job not found") | |
| return job | |
| async def query_documents( | |
| request: QueryRequest, | |
| user_id: str = Depends(get_current_user) | |
| ): | |
| """ | |
| Query user's documents using RAG. | |
| Requires: X-API-KEY header | |
| """ | |
| try: | |
| engine = await rag_manager.get_engine(user_id) | |
| result = await asyncio.to_thread( | |
| engine.ask, | |
| query=request.question, | |
| top_k=request.top_k, | |
| doc_id=request.doc_id | |
| ) | |
| print(f"β Query success for user {user_id}") | |
| return QueryResponse( | |
| answer=result["answer"], | |
| sources=result.get("sources", []) | |
| ) | |
| except Exception as e: | |
| print(f"β Query error (user {user_id}): {e}") | |
| raise HTTPException(500, "Query failed") | |
| async def get_documents(user_id: str = Depends(get_current_user)): | |
| """ | |
| Get all documents for authenticated user. | |
| Requires: X-API-KEY header | |
| """ | |
| try: | |
| engine = await rag_manager.get_engine(user_id) | |
| documents = await asyncio.to_thread(engine.get_all_documents) | |
| return [ | |
| DocumentInfo( | |
| doc_id=doc["doc_id"], | |
| filename=doc["filename"], | |
| upload_timestamp=doc["upload_timestamp"], | |
| num_chunks=doc["num_chunks"], | |
| num_pages=doc["num_pages"] | |
| ) | |
| for doc in documents | |
| ] | |
| except Exception as e: | |
| print(f"β Get documents error (user {user_id}): {e}") | |
| raise HTTPException(500, "Failed to retrieve documents") | |
| async def delete_document( | |
| doc_id: str, | |
| user_id: str = Depends(get_current_user) | |
| ): | |
| """ | |
| Delete document for authenticated user. | |
| Requires: X-API-KEY header | |
| """ | |
| try: | |
| engine = await rag_manager.get_engine(user_id) | |
| user_lock = rag_manager.get_user_lock(user_id) | |
| with user_lock: | |
| result = await asyncio.to_thread(engine.delete_document, doc_id) | |
| if result["status"] == "success": | |
| await asyncio.to_thread( | |
| hf_storage.push_storage_to_hf, | |
| user_id, | |
| f"Delete {doc_id}" | |
| ) | |
| print(f"β Delete success for user {user_id}: {doc_id}") | |
| return DeleteResponse( | |
| status="success", | |
| message=result["message"] | |
| ) | |
| else: | |
| raise HTTPException(404, result["message"]) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"β Delete error (user {user_id}): {e}") | |
| raise HTTPException(500, "Deletion failed") | |
| async def get_stats(user_id: str = Depends(get_current_user)): | |
| """ | |
| Get stats for authenticated user. | |
| Requires: X-API-KEY header | |
| """ | |
| try: | |
| engine = await rag_manager.get_engine(user_id) | |
| stats = await asyncio.to_thread(engine.get_stats) | |
| return StatsResponse( | |
| total_documents=stats["total_documents"], | |
| total_chunks=stats["total_chunks"], | |
| index_size=stats["index_size"] | |
| ) | |
| except Exception as e: | |
| print(f"β Stats error (user {user_id}): {e}") | |
| raise HTTPException(500, "Failed to retrieve stats") | |
| # ============================================ | |
| # MAIN | |
| # ============================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True) |