NewFreeRag / api.py
nivakaran's picture
Upload folder using huggingface_hub
2cf4b57 verified
"""REST API endpoints for FreeRAG."""
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional, List
import tempfile
import os
import logging
logger = logging.getLogger(__name__)
# Request/Response Models
class QueryRequest(BaseModel):
"""Request model for querying the RAG system."""
question: str = Field(..., description="The question to ask", min_length=1)
top_k: int = Field(default=3, description="Number of documents to retrieve", ge=1, le=10)
session_id: Optional[str] = Field(default=None, description="Session ID for chat history tracking")
class QueryResponse(BaseModel):
"""Response model for RAG queries."""
question: str
answer: str
sources: List[dict]
cached: bool
match_type: str
session_id: Optional[str] = None
class UploadResponse(BaseModel):
"""Response model for file uploads."""
success: bool
message: str
files_processed: int
chunks_added: int
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str
documents_count: int
model: str
class StatsResponse(BaseModel):
"""Response model for stats."""
documents_count: int
collection_name: str
model: str
embedding_model: str
# Create FastAPI app
api = FastAPI(
title="FreeRAG API",
description="REST API for the FreeRAG Retrieval-Augmented Generation system",
version="1.0.0"
)
# Add CORS middleware
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_rag_pipeline():
"""Import and get the RAG pipeline from app.py."""
from app import get_pipeline
return get_pipeline()
@api.get("/", tags=["Health"])
async def root():
"""Root endpoint with API info."""
return {
"name": "FreeRAG API",
"version": "1.0.0",
"endpoints": {
"query": "POST /api/query",
"upload": "POST /api/upload",
"stats": "GET /api/stats",
"health": "GET /api/health"
}
}
@api.get("/api/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
"""Check if the system is healthy and ready."""
try:
pipe = get_rag_pipeline()
stats = pipe.get_stats()
return HealthResponse(
status="healthy",
documents_count=stats["documents_count"],
model=stats["model"]
)
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(status_code=503, detail="Service unavailable")
@api.get("/api/stats", response_model=StatsResponse, tags=["System"])
async def get_stats():
"""Get system statistics."""
try:
pipe = get_rag_pipeline()
stats = pipe.get_stats()
return StatsResponse(**stats)
except Exception as e:
logger.error(f"Failed to get stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@api.post("/api/query", response_model=QueryResponse, tags=["RAG"])
async def query(request: QueryRequest):
"""
Query the RAG system with a question.
The system will:
1. Check cache for exact match
2. Check for semantically similar questions
3. Search uploaded documents
4. Generate answer using AI if needed
"""
try:
pipe = get_rag_pipeline()
if pipe.vector_store.get_count() == 0:
return QueryResponse(
question=request.question,
answer="No documents uploaded yet. Please upload documents first using /api/upload",
sources=[],
cached=False,
match_type="error"
)
# Get or create session
session_id = request.session_id
if not session_id:
from src.session import get_session_manager
session_id = get_session_manager().create_session()
result = pipe.query(request.question, top_k=request.top_k, session_id=session_id)
return QueryResponse(
question=result["question"],
answer=result["answer"],
sources=result.get("sources", []),
cached=result.get("cached", False),
match_type=result.get("match_type", "generated"),
session_id=session_id
)
except Exception as e:
logger.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@api.post("/api/upload", response_model=UploadResponse, tags=["Documents"])
async def upload_files(files: List[UploadFile] = File(...)):
"""
Upload documents to the RAG system.
Supported formats: PDF, DOCX, TXT, MD
Maximum file size: 10MB per file
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
ALLOWED_EXTENSIONS = {'.pdf', '.docx', '.txt', '.md'}
pipe = get_rag_pipeline()
total_chunks = 0
processed_count = 0
errors = []
for file in files:
try:
# Check extension
ext = os.path.splitext(file.filename)[1].lower()
if ext not in ALLOWED_EXTENSIONS:
errors.append(f"{file.filename}: Unsupported format")
continue
# Save to temp file
content = await file.read()
# Check size
if len(content) > MAX_FILE_SIZE:
errors.append(f"{file.filename}: File too large (max 10MB)")
continue
# Write to temp file and process
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
chunks = pipe.ingest_file(tmp_path)
total_chunks += chunks
processed_count += 1
finally:
os.unlink(tmp_path)
except Exception as e:
errors.append(f"{file.filename}: {str(e)}")
message = f"Processed {processed_count} file(s), {total_chunks} chunks added"
if errors:
message += f". Errors: {'; '.join(errors)}"
return UploadResponse(
success=processed_count > 0,
message=message,
files_processed=processed_count,
chunks_added=total_chunks
)
@api.delete("/api/clear", tags=["Documents"])
async def clear_documents():
"""Clear all uploaded documents from the system."""
try:
pipe = get_rag_pipeline()
pipe.vector_store.clear()
return {"success": True, "message": "All documents cleared"}
except Exception as e:
logger.error(f"Failed to clear documents: {e}")
raise HTTPException(status_code=500, detail=str(e))