| import os |
| import tempfile |
| from pathlib import Path |
| from contextlib import asynccontextmanager |
| from typing import Optional |
|
|
| from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| import chromadb |
|
|
| from utils import ( |
| get_gemini_client, |
| generate_query_embedding, |
| generate_answer |
| ) |
| from ingest import ( |
| get_chroma_client, |
| get_or_create_collection, |
| ingest_single_pdf, |
| COLLECTION_NAME, |
| DATA_DIR |
| ) |
|
|
|
|
| gemini_client = None |
| chroma_collection = None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global gemini_client, chroma_collection |
| |
| print("Starting Nigerian Tax Law RAG API...") |
| |
| try: |
| gemini_client = get_gemini_client() |
| print("Gemini client initialized") |
| except ValueError as e: |
| print(f"Warning: {e}") |
| print("The API will not work until GEMINI_API_KEY is set.") |
| |
| chroma_client = get_chroma_client() |
| chroma_collection = get_or_create_collection(chroma_client) |
| print(f"ChromaDB initialized ({chroma_collection.count()} chunks indexed)") |
| |
| yield |
| |
| print("Shutting down RAG API...") |
|
|
|
|
| app = FastAPI( |
| title="Nigerian Tax Law RAG API", |
| description="Query Nigerian tax laws and legal documents using AI-powered retrieval", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class AskRequest(BaseModel): |
| question: str = Field(..., min_length=3, max_length=2000) |
| top_k: int = Field(default=5, ge=1, le=20) |
| model: str = Field(default="gemini-2.0-flash") |
|
|
|
|
| class AskResponse(BaseModel): |
| answer: str |
| sources: list[dict] |
| chunks_used: int |
|
|
|
|
| class IngestResponse(BaseModel): |
| message: str |
| filename: str |
| chunks_added: int |
|
|
|
|
| class StatsResponse(BaseModel): |
| total_chunks: int |
| total_documents: int |
| documents: list[dict] |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| gemini_connected: bool |
| chroma_connected: bool |
| chunks_indexed: int |
|
|
|
|
| @app.get("/", response_model=dict) |
| async def root(): |
| return { |
| "name": "Nigerian Tax Law RAG API", |
| "version": "1.0.0", |
| "endpoints": { |
| "POST /ask": "Ask a question about Nigerian tax law", |
| "POST /ingest": "Upload and index a new PDF document", |
| "GET /stats": "Get database statistics", |
| "GET /health": "Health check" |
| } |
| } |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| gemini_ok = gemini_client is not None |
| chroma_ok = chroma_collection is not None |
| chunks = chroma_collection.count() if chroma_ok else 0 |
| |
| return HealthResponse( |
| status="healthy" if (gemini_ok and chroma_ok) else "degraded", |
| gemini_connected=gemini_ok, |
| chroma_connected=chroma_ok, |
| chunks_indexed=chunks |
| ) |
|
|
|
|
| @app.post("/ask", response_model=AskResponse) |
| async def ask_question(request: AskRequest): |
| if gemini_client is None: |
| raise HTTPException( |
| status_code=503, |
| detail="Gemini API not configured. Set GEMINI_API_KEY environment variable." |
| ) |
| |
| if chroma_collection is None: |
| raise HTTPException(status_code=503, detail="Vector database not initialized.") |
| |
| if chroma_collection.count() == 0: |
| raise HTTPException( |
| status_code=404, |
| detail="No documents indexed. Please ingest documents first using: python ingest.py" |
| ) |
| |
| try: |
| query_embedding = generate_query_embedding(gemini_client, request.question) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating query embedding: {str(e)}") |
| |
| try: |
| results = chroma_collection.query( |
| query_embeddings=[query_embedding], |
| n_results=request.top_k, |
| include=["documents", "metadatas", "distances"] |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error querying vector database: {str(e)}") |
| |
| documents = results["documents"][0] if results["documents"] else [] |
| metadatas = results["metadatas"][0] if results["metadatas"] else [] |
| distances = results["distances"][0] if results["distances"] else [] |
| |
| if not documents: |
| return AskResponse( |
| answer="I couldn't find any relevant information in the indexed documents.", |
| sources=[], |
| chunks_used=0 |
| ) |
| |
| context_parts = [] |
| sources = [] |
| |
| for i, (doc, meta, dist) in enumerate(zip(documents, metadatas, distances)): |
| source_name = meta.get("source", "Unknown") |
| chunk_idx = meta.get("chunk_index", 0) |
| |
| context_parts.append(f"[Source: {source_name}, Chunk {chunk_idx + 1}]\n{doc}") |
| sources.append({ |
| "document": source_name, |
| "chunk_index": chunk_idx, |
| "relevance_score": round(1 - dist, 4) |
| }) |
| |
| context = "\n\n---\n\n".join(context_parts) |
| |
| try: |
| answer = generate_answer( |
| gemini_client, |
| request.question, |
| context, |
| model=request.model |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}") |
| |
| return AskResponse( |
| answer=answer, |
| sources=sources, |
| chunks_used=len(documents) |
| ) |
|
|
|
|
| @app.post("/ingest", response_model=IngestResponse) |
| async def ingest_document(file: UploadFile = File(...), force: bool = False): |
| if gemini_client is None: |
| raise HTTPException( |
| status_code=503, |
| detail="Gemini API not configured. Set GEMINI_API_KEY environment variable." |
| ) |
| |
| if not file.filename.lower().endswith(".pdf"): |
| raise HTTPException(status_code=400, detail="Only PDF files are supported.") |
| |
| DATA_DIR.mkdir(parents=True, exist_ok=True) |
| file_path = DATA_DIR / file.filename |
| |
| try: |
| contents = await file.read() |
| with open(file_path, "wb") as f: |
| f.write(contents) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}") |
| |
| try: |
| chunks_added, _ = ingest_single_pdf( |
| file_path, |
| chroma_collection, |
| gemini_client, |
| force=force |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error ingesting document: {str(e)}") |
| |
| return IngestResponse( |
| message="Document ingested successfully" if chunks_added > 0 else "Document already exists", |
| filename=file.filename, |
| chunks_added=chunks_added |
| ) |
|
|
|
|
| @app.get("/stats", response_model=StatsResponse) |
| async def get_stats(): |
| if chroma_collection is None: |
| raise HTTPException(status_code=503, detail="Vector database not initialized.") |
| |
| count = chroma_collection.count() |
| |
| if count == 0: |
| return StatsResponse(total_chunks=0, total_documents=0, documents=[]) |
| |
| results = chroma_collection.get(limit=count, include=["metadatas"]) |
| |
| doc_chunks = {} |
| for meta in results["metadatas"]: |
| if meta: |
| source = meta.get("source", "Unknown") |
| doc_chunks[source] = doc_chunks.get(source, 0) + 1 |
| |
| documents = [ |
| {"name": name, "chunks": chunks} |
| for name, chunks in sorted(doc_chunks.items()) |
| ] |
| |
| return StatsResponse( |
| total_chunks=count, |
| total_documents=len(doc_chunks), |
| documents=documents |
| ) |
|
|
|
|
| @app.delete("/documents/{document_name}") |
| async def delete_document(document_name: str): |
| if chroma_collection is None: |
| raise HTTPException(status_code=503, detail="Vector database not initialized.") |
| |
| results = chroma_collection.get( |
| where={"source": document_name}, |
| include=["metadatas"] |
| ) |
| |
| if not results["ids"]: |
| raise HTTPException(status_code=404, detail=f"Document '{document_name}' not found in index.") |
| |
| chroma_collection.delete(ids=results["ids"]) |
| |
| return { |
| "message": f"Document '{document_name}' deleted successfully", |
| "chunks_deleted": len(results["ids"]) |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |
|
|