|
|
import os |
|
|
import time |
|
|
import hashlib |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
from contextlib import asynccontextmanager |
|
|
from collections import defaultdict |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Request, Depends, Form |
|
|
from typing import Optional, List |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.security import APIKeyHeader |
|
|
from pydantic import BaseModel, Field |
|
|
from pinecone import Pinecone |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv("rag/.env") |
|
|
|
|
|
from rag.utils import ( |
|
|
get_gemini_client, |
|
|
generate_query_embedding, |
|
|
generate_answer |
|
|
) |
|
|
from rag.ingest import ( |
|
|
get_pinecone_client, |
|
|
get_pinecone_index, |
|
|
ingest_single_pdf, |
|
|
PINECONE_INDEX, |
|
|
DATA_DIR |
|
|
) |
|
|
|
|
|
API_KEY = os.environ.get("API_KEY") |
|
|
RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "30")) |
|
|
RATE_LIMIT_WINDOW = int(os.environ.get("RATE_LIMIT_WINDOW", "60")) |
|
|
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",") |
|
|
|
|
|
gemini_client = None |
|
|
pinecone_index = None |
|
|
rate_limit_store = defaultdict(list) |
|
|
conversation_sessions = defaultdict(list) |
|
|
|
|
|
|
|
|
def get_client_ip(request: Request) -> str: |
|
|
forwarded = request.headers.get("X-Forwarded-For") |
|
|
if forwarded: |
|
|
return forwarded.split(",")[0].strip() |
|
|
return request.client.host if request.client else "unknown" |
|
|
|
|
|
|
|
|
def check_rate_limit(request: Request): |
|
|
client_ip = get_client_ip(request) |
|
|
now = time.time() |
|
|
|
|
|
rate_limit_store[client_ip] = [ |
|
|
t for t in rate_limit_store[client_ip] |
|
|
if now - t < RATE_LIMIT_WINDOW |
|
|
] |
|
|
|
|
|
if len(rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS: |
|
|
raise HTTPException( |
|
|
status_code=429, |
|
|
detail=f"Rate limit exceeded. Max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds." |
|
|
) |
|
|
|
|
|
rate_limit_store[client_ip].append(now) |
|
|
|
|
|
|
|
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
|
|
|
|
|
|
|
|
async def verify_api_key(api_key: str = Depends(api_key_header)): |
|
|
if API_KEY and api_key != API_KEY: |
|
|
raise HTTPException(status_code=403, detail="Invalid API key") |
|
|
return api_key |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
global gemini_client, pinecone_index |
|
|
|
|
|
print("Starting Nigerian Tax Law RAG API...") |
|
|
|
|
|
if API_KEY: |
|
|
print("API Key authentication enabled") |
|
|
else: |
|
|
print("Warning: No API_KEY set - API is unprotected") |
|
|
|
|
|
try: |
|
|
gemini_client = get_gemini_client() |
|
|
print("Gemini client initialized") |
|
|
except ValueError as e: |
|
|
print(f"Warning: {e}") |
|
|
|
|
|
try: |
|
|
pinecone_index = get_pinecone_index() |
|
|
stats = pinecone_index.describe_index_stats() |
|
|
print(f"Pinecone initialized ({stats.total_vector_count} vectors)") |
|
|
except Exception as e: |
|
|
print(f"Warning: Pinecone error: {e}") |
|
|
|
|
|
yield |
|
|
|
|
|
print("Shutting down RAG API...") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Nigerian Tax Law RAG API", |
|
|
description="Query Nigerian tax laws and legal documents using AI-powered retrieval", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=ALLOWED_ORIGINS, |
|
|
allow_credentials=True, |
|
|
allow_methods=["GET", "POST"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AskResponse(BaseModel): |
|
|
answer: str |
|
|
sources: list[dict] |
|
|
chunks_used: int |
|
|
session_id: str |
|
|
|
|
|
|
|
|
class IngestResponse(BaseModel): |
|
|
message: str |
|
|
filename: str |
|
|
chunks_added: int |
|
|
|
|
|
|
|
|
class StatsResponse(BaseModel): |
|
|
total_vectors: int |
|
|
dimension: int |
|
|
index_name: str |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
gemini_connected: bool |
|
|
pinecone_connected: bool |
|
|
vectors_indexed: int |
|
|
|
|
|
|
|
|
class YearlyWrapRequest(BaseModel): |
|
|
year: int = Field(default=2024, ge=2000, le=2030) |
|
|
|
|
|
|
|
|
class YearlyWrapResponse(BaseModel): |
|
|
analysis: dict |
|
|
video_script: Optional[dict] |
|
|
video_url: Optional[str] |
|
|
status: str |
|
|
message: str |
|
|
|
|
|
|
|
|
@app.get("/", response_model=dict) |
|
|
async def root(): |
|
|
return { |
|
|
"name": "Nigerian Tax Law RAG API", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"POST /ask": "Ask a question about Nigerian tax law", |
|
|
"POST /ingest": "Upload and index a new PDF document", |
|
|
"GET /stats": "Get database statistics", |
|
|
"GET /health": "Health check" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
gemini_ok = gemini_client is not None |
|
|
pinecone_ok = pinecone_index is not None |
|
|
vectors = 0 |
|
|
|
|
|
if pinecone_ok: |
|
|
try: |
|
|
stats = pinecone_index.describe_index_stats() |
|
|
vectors = stats.total_vector_count |
|
|
except: |
|
|
pinecone_ok = False |
|
|
|
|
|
return HealthResponse( |
|
|
status="healthy" if (gemini_ok and pinecone_ok) else "degraded", |
|
|
gemini_connected=gemini_ok, |
|
|
pinecone_connected=pinecone_ok, |
|
|
vectors_indexed=vectors |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/ask", response_model=AskResponse) |
|
|
async def ask_question( |
|
|
req: Request, |
|
|
question: str = Form(..., min_length=3, max_length=2000), |
|
|
top_k: int = Form(default=5, ge=1, le=20), |
|
|
model: str = Form(default="gemini-2.5-flash"), |
|
|
session_id: Optional[str] = Form(default=None), |
|
|
image: Optional[UploadFile] = File(default=None), |
|
|
document: Optional[UploadFile] = File(default=None), |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
check_rate_limit(req) |
|
|
|
|
|
if gemini_client is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable." |
|
|
) |
|
|
|
|
|
if pinecone_index is None: |
|
|
raise HTTPException(status_code=503, detail="Pinecone not initialized.") |
|
|
|
|
|
if not session_id: |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
image_data = None |
|
|
image_mime_type = None |
|
|
document_text = "" |
|
|
|
|
|
if image and image.filename: |
|
|
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"] |
|
|
if image.content_type not in allowed_types: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Invalid image type. Allowed: {', '.join(allowed_types)}" |
|
|
) |
|
|
if image.size and image.size > 10 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="Image too large. Max 10MB.") |
|
|
|
|
|
image_data = await image.read() |
|
|
image_mime_type = image.content_type |
|
|
|
|
|
if document and document.filename: |
|
|
allowed_exts = [".pdf", ".doc", ".docx", ".txt"] |
|
|
if not any(document.filename.lower().endswith(ext) for ext in allowed_exts): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Invalid document type. Allowed: {', '.join(allowed_exts)}" |
|
|
) |
|
|
if document.size and document.size > 50 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="Document too large. Max 50MB.") |
|
|
|
|
|
doc_content = await document.read() |
|
|
|
|
|
try: |
|
|
from rag.ingest import extract_text_from_file |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(document.filename)[1]) as tmp_file: |
|
|
tmp_file.write(doc_content) |
|
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
document_text = extract_text_from_file(Path(tmp_file_path)) |
|
|
if document_text: |
|
|
document_text = f"[User Uploaded Document: {document.filename}]\n{document_text}" |
|
|
finally: |
|
|
os.unlink(tmp_file_path) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") |
|
|
|
|
|
try: |
|
|
query_embedding = generate_query_embedding(gemini_client, question) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating query embedding: {str(e)}") |
|
|
|
|
|
try: |
|
|
results = pinecone_index.query( |
|
|
vector=query_embedding, |
|
|
top_k=top_k, |
|
|
include_metadata=True |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error querying Pinecone: {str(e)}") |
|
|
|
|
|
if not results.matches: |
|
|
conversation_sessions[session_id].append({"role": "user", "content": question}) |
|
|
conversation_sessions[session_id].append({"role": "assistant", "content": "I couldn't find any relevant information in the indexed documents."}) |
|
|
|
|
|
return AskResponse( |
|
|
answer="I couldn't find any relevant information in the indexed documents.", |
|
|
sources=[], |
|
|
chunks_used=0, |
|
|
session_id=session_id |
|
|
) |
|
|
|
|
|
context_parts = [] |
|
|
sources = [] |
|
|
|
|
|
for match in results.matches: |
|
|
meta = match.metadata |
|
|
source_name = meta.get("source", "Unknown") |
|
|
chunk_idx = meta.get("chunk_index", 0) |
|
|
text = meta.get("text", "") |
|
|
|
|
|
context_parts.append(f"[Source: {source_name}, Chunk {chunk_idx + 1}]\n{text}") |
|
|
sources.append({ |
|
|
"document": source_name, |
|
|
"chunk_index": chunk_idx, |
|
|
"relevance_score": round(match.score, 4) |
|
|
}) |
|
|
|
|
|
context = "\n\n---\n\n".join(context_parts) |
|
|
|
|
|
if document_text: |
|
|
context = f"""[Tax Document Analysis - User Uploaded File] |
|
|
{document_text} |
|
|
|
|
|
--- Tax Law Reference Context --- |
|
|
{context} |
|
|
|
|
|
[TAX ANALYSIS INSTRUCTIONS] |
|
|
- Analyze the uploaded document for tax-relevant information, forms, and declarations |
|
|
- Identify tax amounts, deadlines, compliance requirements, and filing obligations |
|
|
- Cross-reference with Nigerian tax laws from the retrieved context |
|
|
- Provide specific guidance on tax declarations, calculations, and compliance |
|
|
- Highlight any missing information or additional documents needed""" |
|
|
|
|
|
conversation_history = conversation_sessions.get(session_id, []) |
|
|
|
|
|
try: |
|
|
answer = generate_answer( |
|
|
gemini_client, |
|
|
question, |
|
|
context, |
|
|
model=model, |
|
|
image_data=image_data, |
|
|
image_mime_type=image_mime_type, |
|
|
conversation_history=conversation_history |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
if "overloaded" in error_msg.lower() or "503" in error_msg: |
|
|
raise HTTPException(status_code=503, detail=error_msg) |
|
|
raise HTTPException(status_code=500, detail=f"Error generating answer: {error_msg}") |
|
|
|
|
|
conversation_sessions[session_id].append({"role": "user", "content": question}) |
|
|
conversation_sessions[session_id].append({"role": "assistant", "content": answer}) |
|
|
|
|
|
if len(conversation_sessions[session_id]) > 20: |
|
|
conversation_sessions[session_id] = conversation_sessions[session_id][-20:] |
|
|
|
|
|
return AskResponse( |
|
|
answer=answer, |
|
|
sources=sources, |
|
|
chunks_used=len(results.matches), |
|
|
session_id=session_id |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/ingest", response_model=IngestResponse) |
|
|
async def ingest_document( |
|
|
req: Request, |
|
|
file: UploadFile = File(...), |
|
|
force: bool = False, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
check_rate_limit(req) |
|
|
|
|
|
if gemini_client is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable." |
|
|
) |
|
|
|
|
|
if pinecone_index is None: |
|
|
raise HTTPException(status_code=503, detail="Pinecone not initialized.") |
|
|
|
|
|
if not file.filename.lower().endswith(".pdf"): |
|
|
raise HTTPException(status_code=400, detail="Only PDF files are supported.") |
|
|
|
|
|
if file.size and file.size > 50 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="File too large. Max 50MB.") |
|
|
|
|
|
DATA_DIR.mkdir(parents=True, exist_ok=True) |
|
|
safe_filename = "".join(c for c in file.filename if c.isalnum() or c in "._- ") |
|
|
file_path = DATA_DIR / safe_filename |
|
|
|
|
|
try: |
|
|
contents = await file.read() |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(contents) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}") |
|
|
|
|
|
try: |
|
|
chunks_added, _ = ingest_single_pdf( |
|
|
file_path, |
|
|
pinecone_index, |
|
|
gemini_client, |
|
|
force=force |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error ingesting document: {str(e)}") |
|
|
|
|
|
return IngestResponse( |
|
|
message="Document ingested successfully" if chunks_added > 0 else "Document already exists", |
|
|
filename=safe_filename, |
|
|
chunks_added=chunks_added |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/stats", response_model=StatsResponse) |
|
|
async def get_stats(api_key: str = Depends(verify_api_key)): |
|
|
if pinecone_index is None: |
|
|
raise HTTPException(status_code=503, detail="Pinecone not initialized.") |
|
|
|
|
|
try: |
|
|
stats = pinecone_index.describe_index_stats() |
|
|
return StatsResponse( |
|
|
total_vectors=stats.total_vector_count, |
|
|
dimension=stats.dimension, |
|
|
index_name=PINECONE_INDEX |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error getting stats: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/yearly-wrap", response_model=YearlyWrapResponse) |
|
|
async def create_yearly_wrap( |
|
|
request: Request, |
|
|
year: int = Form(..., ge=2000, le=2030), |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
""" |
|
|
Create a yearly financial wrap video from account statements and financial documents. |
|
|
|
|
|
Upload your bank statements, investment reports, tax documents, and financial images |
|
|
to generate a personalized yearly financial summary video. |
|
|
""" |
|
|
|
|
|
form = await request.form() |
|
|
|
|
|
|
|
|
documents = [] |
|
|
images = [] |
|
|
|
|
|
for field_name, field_value in form.items(): |
|
|
if field_name.startswith("documents"): |
|
|
if hasattr(field_value, 'filename') and field_value.filename: |
|
|
documents.append(field_value) |
|
|
elif field_name.startswith("images"): |
|
|
if hasattr(field_value, 'filename') and field_value.filename: |
|
|
images.append(field_value) |
|
|
|
|
|
if not documents and not images: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Please upload at least one financial document or image" |
|
|
) |
|
|
|
|
|
|
|
|
document_texts = [] |
|
|
image_data_list = [] |
|
|
|
|
|
|
|
|
if documents: |
|
|
for doc in documents: |
|
|
if doc.filename: |
|
|
allowed_exts = [".pdf", ".doc", ".docx", ".txt"] |
|
|
if not any(doc.filename.lower().endswith(ext) for ext in allowed_exts): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Document type not supported: {doc.filename}" |
|
|
) |
|
|
if doc.size and doc.size > 50 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="Document too large. Max 50MB.") |
|
|
|
|
|
doc_content = await doc.read() |
|
|
|
|
|
try: |
|
|
from rag.ingest import extract_text_from_file |
|
|
import tempfile |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(doc.filename)[1]) as tmp_file: |
|
|
tmp_file.write(doc_content) |
|
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
text = extract_text_from_file(Path(tmp_file_path)) |
|
|
if text: |
|
|
document_texts.append(f"[Document: {doc.filename}]\n{text}") |
|
|
finally: |
|
|
os.unlink(tmp_file_path) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing document {doc.filename}: {str(e)}") |
|
|
|
|
|
|
|
|
if images: |
|
|
for img in images: |
|
|
if img.filename: |
|
|
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"] |
|
|
if img.content_type not in allowed_types: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Image type not supported: {img.filename}" |
|
|
) |
|
|
if img.size and img.size > 10 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="Image too large. Max 10MB.") |
|
|
|
|
|
img_data = await img.read() |
|
|
image_data_list.append(img_data) |
|
|
|
|
|
|
|
|
try: |
|
|
from rag.utils import analyze_financial_documents |
|
|
analysis = analyze_financial_documents(document_texts, image_data_list) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error analyzing financial data: {str(e)}") |
|
|
|
|
|
|
|
|
video_script = None |
|
|
video_url = None |
|
|
|
|
|
try: |
|
|
from rag.utils import create_video_script, generate_yearly_wrap_video |
|
|
|
|
|
|
|
|
video_script = create_video_script(analysis) |
|
|
|
|
|
|
|
|
video_url = generate_yearly_wrap_video(analysis) |
|
|
|
|
|
if video_url: |
|
|
status = "completed" |
|
|
message = f"Yearly financial wrap for {year} created successfully with professional video script and animation!" |
|
|
else: |
|
|
status = "script_only" |
|
|
message = f"Yearly financial wrap script created successfully! Video generation is temporarily unavailable due to service limitations, but you have a complete professional script ready." |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
if "Video generation" in error_msg: |
|
|
|
|
|
status = "script_only" |
|
|
message = f"Financial analysis and video script created, but video generation failed: {error_msg}" |
|
|
else: |
|
|
|
|
|
status = "analysis_only" |
|
|
message = f"Financial analysis completed, but script/video generation failed: {error_msg}" |
|
|
|
|
|
return YearlyWrapResponse( |
|
|
analysis=analysis, |
|
|
video_script=video_script, |
|
|
video_url=video_url, |
|
|
status=status, |
|
|
message=message |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |
|
|
|