Spaces:
Sleeping
Sleeping
| """REST API endpoints for FreeRAG.""" | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List | |
| import tempfile | |
| import os | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Request/Response Models | |
| class QueryRequest(BaseModel): | |
| """Request model for querying the RAG system.""" | |
| question: str = Field(..., description="The question to ask", min_length=1) | |
| top_k: int = Field(default=3, description="Number of documents to retrieve", ge=1, le=10) | |
| session_id: Optional[str] = Field(default=None, description="Session ID for chat history tracking") | |
| class QueryResponse(BaseModel): | |
| """Response model for RAG queries.""" | |
| question: str | |
| answer: str | |
| sources: List[dict] | |
| cached: bool | |
| match_type: str | |
| session_id: Optional[str] = None | |
| class UploadResponse(BaseModel): | |
| """Response model for file uploads.""" | |
| success: bool | |
| message: str | |
| files_processed: int | |
| chunks_added: int | |
| class HealthResponse(BaseModel): | |
| """Response model for health check.""" | |
| status: str | |
| documents_count: int | |
| model: str | |
| class StatsResponse(BaseModel): | |
| """Response model for stats.""" | |
| documents_count: int | |
| collection_name: str | |
| model: str | |
| embedding_model: str | |
| # Create FastAPI app | |
| api = FastAPI( | |
| title="FreeRAG API", | |
| description="REST API for the FreeRAG Retrieval-Augmented Generation system", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| api.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_rag_pipeline(): | |
| """Import and get the RAG pipeline from app.py.""" | |
| from app import get_pipeline | |
| return get_pipeline() | |
| async def root(): | |
| """Root endpoint with API info.""" | |
| return { | |
| "name": "FreeRAG API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "query": "POST /api/query", | |
| "upload": "POST /api/upload", | |
| "stats": "GET /api/stats", | |
| "health": "GET /api/health" | |
| } | |
| } | |
| async def health_check(): | |
| """Check if the system is healthy and ready.""" | |
| try: | |
| pipe = get_rag_pipeline() | |
| stats = pipe.get_stats() | |
| return HealthResponse( | |
| status="healthy", | |
| documents_count=stats["documents_count"], | |
| model=stats["model"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| raise HTTPException(status_code=503, detail="Service unavailable") | |
| async def get_stats(): | |
| """Get system statistics.""" | |
| try: | |
| pipe = get_rag_pipeline() | |
| stats = pipe.get_stats() | |
| return StatsResponse(**stats) | |
| except Exception as e: | |
| logger.error(f"Failed to get stats: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def query(request: QueryRequest): | |
| """ | |
| Query the RAG system with a question. | |
| The system will: | |
| 1. Check cache for exact match | |
| 2. Check for semantically similar questions | |
| 3. Search uploaded documents | |
| 4. Generate answer using AI if needed | |
| """ | |
| try: | |
| pipe = get_rag_pipeline() | |
| if pipe.vector_store.get_count() == 0: | |
| return QueryResponse( | |
| question=request.question, | |
| answer="No documents uploaded yet. Please upload documents first using /api/upload", | |
| sources=[], | |
| cached=False, | |
| match_type="error" | |
| ) | |
| # Get or create session | |
| session_id = request.session_id | |
| if not session_id: | |
| from src.session import get_session_manager | |
| session_id = get_session_manager().create_session() | |
| result = pipe.query(request.question, top_k=request.top_k, session_id=session_id) | |
| return QueryResponse( | |
| question=result["question"], | |
| answer=result["answer"], | |
| sources=result.get("sources", []), | |
| cached=result.get("cached", False), | |
| match_type=result.get("match_type", "generated"), | |
| session_id=session_id | |
| ) | |
| except Exception as e: | |
| logger.error(f"Query failed: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_files(files: List[UploadFile] = File(...)): | |
| """ | |
| Upload documents to the RAG system. | |
| Supported formats: PDF, DOCX, TXT, MD | |
| Maximum file size: 10MB per file | |
| """ | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files provided") | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| ALLOWED_EXTENSIONS = {'.pdf', '.docx', '.txt', '.md'} | |
| pipe = get_rag_pipeline() | |
| total_chunks = 0 | |
| processed_count = 0 | |
| errors = [] | |
| for file in files: | |
| try: | |
| # Check extension | |
| ext = os.path.splitext(file.filename)[1].lower() | |
| if ext not in ALLOWED_EXTENSIONS: | |
| errors.append(f"{file.filename}: Unsupported format") | |
| continue | |
| # Save to temp file | |
| content = await file.read() | |
| # Check size | |
| if len(content) > MAX_FILE_SIZE: | |
| errors.append(f"{file.filename}: File too large (max 10MB)") | |
| continue | |
| # Write to temp file and process | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| try: | |
| chunks = pipe.ingest_file(tmp_path) | |
| total_chunks += chunks | |
| processed_count += 1 | |
| finally: | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| errors.append(f"{file.filename}: {str(e)}") | |
| message = f"Processed {processed_count} file(s), {total_chunks} chunks added" | |
| if errors: | |
| message += f". Errors: {'; '.join(errors)}" | |
| return UploadResponse( | |
| success=processed_count > 0, | |
| message=message, | |
| files_processed=processed_count, | |
| chunks_added=total_chunks | |
| ) | |
| async def clear_documents(): | |
| """Clear all uploaded documents from the system.""" | |
| try: | |
| pipe = get_rag_pipeline() | |
| pipe.vector_store.clear() | |
| return {"success": True, "message": "All documents cleared"} | |
| except Exception as e: | |
| logger.error(f"Failed to clear documents: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |