sabitax / rag /main.py
nexusbert's picture
Upload 14 files
4b1d477 verified
import os
import tempfile
from pathlib import Path
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import chromadb
from utils import (
get_gemini_client,
generate_query_embedding,
generate_answer
)
from ingest import (
get_chroma_client,
get_or_create_collection,
ingest_single_pdf,
COLLECTION_NAME,
DATA_DIR
)
gemini_client = None
chroma_collection = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global gemini_client, chroma_collection
print("Starting Nigerian Tax Law RAG API...")
try:
gemini_client = get_gemini_client()
print("Gemini client initialized")
except ValueError as e:
print(f"Warning: {e}")
print("The API will not work until GEMINI_API_KEY is set.")
chroma_client = get_chroma_client()
chroma_collection = get_or_create_collection(chroma_client)
print(f"ChromaDB initialized ({chroma_collection.count()} chunks indexed)")
yield
print("Shutting down RAG API...")
app = FastAPI(
title="Nigerian Tax Law RAG API",
description="Query Nigerian tax laws and legal documents using AI-powered retrieval",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AskRequest(BaseModel):
question: str = Field(..., min_length=3, max_length=2000)
top_k: int = Field(default=5, ge=1, le=20)
model: str = Field(default="gemini-2.0-flash")
class AskResponse(BaseModel):
answer: str
sources: list[dict]
chunks_used: int
class IngestResponse(BaseModel):
message: str
filename: str
chunks_added: int
class StatsResponse(BaseModel):
total_chunks: int
total_documents: int
documents: list[dict]
class HealthResponse(BaseModel):
status: str
gemini_connected: bool
chroma_connected: bool
chunks_indexed: int
@app.get("/", response_model=dict)
async def root():
return {
"name": "Nigerian Tax Law RAG API",
"version": "1.0.0",
"endpoints": {
"POST /ask": "Ask a question about Nigerian tax law",
"POST /ingest": "Upload and index a new PDF document",
"GET /stats": "Get database statistics",
"GET /health": "Health check"
}
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
gemini_ok = gemini_client is not None
chroma_ok = chroma_collection is not None
chunks = chroma_collection.count() if chroma_ok else 0
return HealthResponse(
status="healthy" if (gemini_ok and chroma_ok) else "degraded",
gemini_connected=gemini_ok,
chroma_connected=chroma_ok,
chunks_indexed=chunks
)
@app.post("/ask", response_model=AskResponse)
async def ask_question(request: AskRequest):
if gemini_client is None:
raise HTTPException(
status_code=503,
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
)
if chroma_collection is None:
raise HTTPException(status_code=503, detail="Vector database not initialized.")
if chroma_collection.count() == 0:
raise HTTPException(
status_code=404,
detail="No documents indexed. Please ingest documents first using: python ingest.py"
)
try:
query_embedding = generate_query_embedding(gemini_client, request.question)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating query embedding: {str(e)}")
try:
results = chroma_collection.query(
query_embeddings=[query_embedding],
n_results=request.top_k,
include=["documents", "metadatas", "distances"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error querying vector database: {str(e)}")
documents = results["documents"][0] if results["documents"] else []
metadatas = results["metadatas"][0] if results["metadatas"] else []
distances = results["distances"][0] if results["distances"] else []
if not documents:
return AskResponse(
answer="I couldn't find any relevant information in the indexed documents.",
sources=[],
chunks_used=0
)
context_parts = []
sources = []
for i, (doc, meta, dist) in enumerate(zip(documents, metadatas, distances)):
source_name = meta.get("source", "Unknown")
chunk_idx = meta.get("chunk_index", 0)
context_parts.append(f"[Source: {source_name}, Chunk {chunk_idx + 1}]\n{doc}")
sources.append({
"document": source_name,
"chunk_index": chunk_idx,
"relevance_score": round(1 - dist, 4)
})
context = "\n\n---\n\n".join(context_parts)
try:
answer = generate_answer(
gemini_client,
request.question,
context,
model=request.model
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
return AskResponse(
answer=answer,
sources=sources,
chunks_used=len(documents)
)
@app.post("/ingest", response_model=IngestResponse)
async def ingest_document(file: UploadFile = File(...), force: bool = False):
if gemini_client is None:
raise HTTPException(
status_code=503,
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
)
if not file.filename.lower().endswith(".pdf"):
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
DATA_DIR.mkdir(parents=True, exist_ok=True)
file_path = DATA_DIR / file.filename
try:
contents = await file.read()
with open(file_path, "wb") as f:
f.write(contents)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}")
try:
chunks_added, _ = ingest_single_pdf(
file_path,
chroma_collection,
gemini_client,
force=force
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error ingesting document: {str(e)}")
return IngestResponse(
message="Document ingested successfully" if chunks_added > 0 else "Document already exists",
filename=file.filename,
chunks_added=chunks_added
)
@app.get("/stats", response_model=StatsResponse)
async def get_stats():
if chroma_collection is None:
raise HTTPException(status_code=503, detail="Vector database not initialized.")
count = chroma_collection.count()
if count == 0:
return StatsResponse(total_chunks=0, total_documents=0, documents=[])
results = chroma_collection.get(limit=count, include=["metadatas"])
doc_chunks = {}
for meta in results["metadatas"]:
if meta:
source = meta.get("source", "Unknown")
doc_chunks[source] = doc_chunks.get(source, 0) + 1
documents = [
{"name": name, "chunks": chunks}
for name, chunks in sorted(doc_chunks.items())
]
return StatsResponse(
total_chunks=count,
total_documents=len(doc_chunks),
documents=documents
)
@app.delete("/documents/{document_name}")
async def delete_document(document_name: str):
if chroma_collection is None:
raise HTTPException(status_code=503, detail="Vector database not initialized.")
results = chroma_collection.get(
where={"source": document_name},
include=["metadatas"]
)
if not results["ids"]:
raise HTTPException(status_code=404, detail=f"Document '{document_name}' not found in index.")
chroma_collection.delete(ids=results["ids"])
return {
"message": f"Document '{document_name}' deleted successfully",
"chunks_deleted": len(results["ids"])
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)