sabitax / app.py
nexusbert's picture
Upload 14 files
d43d504 verified
import os
import time
import hashlib
import uuid
from pathlib import Path
from contextlib import asynccontextmanager
from collections import defaultdict
from fastapi import FastAPI, HTTPException, UploadFile, File, Request, Depends, Form
from typing import Optional, List
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field
from pinecone import Pinecone
from dotenv import load_dotenv
load_dotenv("rag/.env")
from rag.utils import (
get_gemini_client,
generate_query_embedding,
generate_answer
)
from rag.ingest import (
get_pinecone_client,
get_pinecone_index,
ingest_single_pdf,
PINECONE_INDEX,
DATA_DIR
)
API_KEY = os.environ.get("API_KEY")
RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "30"))
RATE_LIMIT_WINDOW = int(os.environ.get("RATE_LIMIT_WINDOW", "60"))
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
gemini_client = None
pinecone_index = None
rate_limit_store = defaultdict(list)
conversation_sessions = defaultdict(list)
def get_client_ip(request: Request) -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def check_rate_limit(request: Request):
client_ip = get_client_ip(request)
now = time.time()
rate_limit_store[client_ip] = [
t for t in rate_limit_store[client_ip]
if now - t < RATE_LIMIT_WINDOW
]
if len(rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds."
)
rate_limit_store[client_ip].append(now)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def verify_api_key(api_key: str = Depends(api_key_header)):
if API_KEY and api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
@asynccontextmanager
async def lifespan(app: FastAPI):
global gemini_client, pinecone_index
print("Starting Nigerian Tax Law RAG API...")
if API_KEY:
print("API Key authentication enabled")
else:
print("Warning: No API_KEY set - API is unprotected")
try:
gemini_client = get_gemini_client()
print("Gemini client initialized")
except ValueError as e:
print(f"Warning: {e}")
try:
pinecone_index = get_pinecone_index()
stats = pinecone_index.describe_index_stats()
print(f"Pinecone initialized ({stats.total_vector_count} vectors)")
except Exception as e:
print(f"Warning: Pinecone error: {e}")
yield
print("Shutting down RAG API...")
app = FastAPI(
title="Nigerian Tax Law RAG API",
description="Query Nigerian tax laws and legal documents using AI-powered retrieval",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
class AskResponse(BaseModel):
answer: str
sources: list[dict]
chunks_used: int
session_id: str
class IngestResponse(BaseModel):
message: str
filename: str
chunks_added: int
class StatsResponse(BaseModel):
total_vectors: int
dimension: int
index_name: str
class HealthResponse(BaseModel):
status: str
gemini_connected: bool
pinecone_connected: bool
vectors_indexed: int
class YearlyWrapRequest(BaseModel):
year: int = Field(default=2024, ge=2000, le=2030)
class YearlyWrapResponse(BaseModel):
analysis: dict
video_script: Optional[dict]
video_url: Optional[str]
status: str
message: str
@app.get("/", response_model=dict)
async def root():
return {
"name": "Nigerian Tax Law RAG API",
"version": "1.0.0",
"endpoints": {
"POST /ask": "Ask a question about Nigerian tax law",
"POST /ingest": "Upload and index a new PDF document",
"GET /stats": "Get database statistics",
"GET /health": "Health check"
}
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
gemini_ok = gemini_client is not None
pinecone_ok = pinecone_index is not None
vectors = 0
if pinecone_ok:
try:
stats = pinecone_index.describe_index_stats()
vectors = stats.total_vector_count
except:
pinecone_ok = False
return HealthResponse(
status="healthy" if (gemini_ok and pinecone_ok) else "degraded",
gemini_connected=gemini_ok,
pinecone_connected=pinecone_ok,
vectors_indexed=vectors
)
@app.post("/ask", response_model=AskResponse)
async def ask_question(
req: Request,
question: str = Form(..., min_length=3, max_length=2000),
top_k: int = Form(default=5, ge=1, le=20),
model: str = Form(default="gemini-2.5-flash"),
session_id: Optional[str] = Form(default=None),
image: Optional[UploadFile] = File(default=None),
document: Optional[UploadFile] = File(default=None),
api_key: str = Depends(verify_api_key)
):
check_rate_limit(req)
if gemini_client is None:
raise HTTPException(
status_code=503,
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
)
if pinecone_index is None:
raise HTTPException(status_code=503, detail="Pinecone not initialized.")
if not session_id:
session_id = str(uuid.uuid4())
image_data = None
image_mime_type = None
document_text = ""
if image and image.filename:
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
if image.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"Invalid image type. Allowed: {', '.join(allowed_types)}"
)
if image.size and image.size > 10 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Image too large. Max 10MB.")
image_data = await image.read()
image_mime_type = image.content_type
if document and document.filename:
allowed_exts = [".pdf", ".doc", ".docx", ".txt"]
if not any(document.filename.lower().endswith(ext) for ext in allowed_exts):
raise HTTPException(
status_code=400,
detail=f"Invalid document type. Allowed: {', '.join(allowed_exts)}"
)
if document.size and document.size > 50 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Document too large. Max 50MB.")
doc_content = await document.read()
try:
from rag.ingest import extract_text_from_file
import tempfile
import os
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(document.filename)[1]) as tmp_file:
tmp_file.write(doc_content)
tmp_file_path = tmp_file.name
try:
document_text = extract_text_from_file(Path(tmp_file_path))
if document_text:
document_text = f"[User Uploaded Document: {document.filename}]\n{document_text}"
finally:
os.unlink(tmp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
try:
query_embedding = generate_query_embedding(gemini_client, question)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating query embedding: {str(e)}")
try:
results = pinecone_index.query(
vector=query_embedding,
top_k=top_k,
include_metadata=True
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error querying Pinecone: {str(e)}")
if not results.matches:
conversation_sessions[session_id].append({"role": "user", "content": question})
conversation_sessions[session_id].append({"role": "assistant", "content": "I couldn't find any relevant information in the indexed documents."})
return AskResponse(
answer="I couldn't find any relevant information in the indexed documents.",
sources=[],
chunks_used=0,
session_id=session_id
)
context_parts = []
sources = []
for match in results.matches:
meta = match.metadata
source_name = meta.get("source", "Unknown")
chunk_idx = meta.get("chunk_index", 0)
text = meta.get("text", "")
context_parts.append(f"[Source: {source_name}, Chunk {chunk_idx + 1}]\n{text}")
sources.append({
"document": source_name,
"chunk_index": chunk_idx,
"relevance_score": round(match.score, 4)
})
context = "\n\n---\n\n".join(context_parts)
if document_text:
context = f"""[Tax Document Analysis - User Uploaded File]
{document_text}
--- Tax Law Reference Context ---
{context}
[TAX ANALYSIS INSTRUCTIONS]
- Analyze the uploaded document for tax-relevant information, forms, and declarations
- Identify tax amounts, deadlines, compliance requirements, and filing obligations
- Cross-reference with Nigerian tax laws from the retrieved context
- Provide specific guidance on tax declarations, calculations, and compliance
- Highlight any missing information or additional documents needed"""
conversation_history = conversation_sessions.get(session_id, [])
try:
answer = generate_answer(
gemini_client,
question,
context,
model=model,
image_data=image_data,
image_mime_type=image_mime_type,
conversation_history=conversation_history
)
except Exception as e:
error_msg = str(e)
if "overloaded" in error_msg.lower() or "503" in error_msg:
raise HTTPException(status_code=503, detail=error_msg)
raise HTTPException(status_code=500, detail=f"Error generating answer: {error_msg}")
conversation_sessions[session_id].append({"role": "user", "content": question})
conversation_sessions[session_id].append({"role": "assistant", "content": answer})
if len(conversation_sessions[session_id]) > 20:
conversation_sessions[session_id] = conversation_sessions[session_id][-20:]
return AskResponse(
answer=answer,
sources=sources,
chunks_used=len(results.matches),
session_id=session_id
)
@app.post("/ingest", response_model=IngestResponse)
async def ingest_document(
req: Request,
file: UploadFile = File(...),
force: bool = False,
api_key: str = Depends(verify_api_key)
):
check_rate_limit(req)
if gemini_client is None:
raise HTTPException(
status_code=503,
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
)
if pinecone_index is None:
raise HTTPException(status_code=503, detail="Pinecone not initialized.")
if not file.filename.lower().endswith(".pdf"):
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
if file.size and file.size > 50 * 1024 * 1024:
raise HTTPException(status_code=400, detail="File too large. Max 50MB.")
DATA_DIR.mkdir(parents=True, exist_ok=True)
safe_filename = "".join(c for c in file.filename if c.isalnum() or c in "._- ")
file_path = DATA_DIR / safe_filename
try:
contents = await file.read()
with open(file_path, "wb") as f:
f.write(contents)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}")
try:
chunks_added, _ = ingest_single_pdf(
file_path,
pinecone_index,
gemini_client,
force=force
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error ingesting document: {str(e)}")
return IngestResponse(
message="Document ingested successfully" if chunks_added > 0 else "Document already exists",
filename=safe_filename,
chunks_added=chunks_added
)
@app.get("/stats", response_model=StatsResponse)
async def get_stats(api_key: str = Depends(verify_api_key)):
if pinecone_index is None:
raise HTTPException(status_code=503, detail="Pinecone not initialized.")
try:
stats = pinecone_index.describe_index_stats()
return StatsResponse(
total_vectors=stats.total_vector_count,
dimension=stats.dimension,
index_name=PINECONE_INDEX
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting stats: {str(e)}")
@app.post("/yearly-wrap", response_model=YearlyWrapResponse)
async def create_yearly_wrap(
request: Request,
year: int = Form(..., ge=2000, le=2030),
api_key: str = Depends(verify_api_key)
):
"""
Create a yearly financial wrap video from account statements and financial documents.
Upload your bank statements, investment reports, tax documents, and financial images
to generate a personalized yearly financial summary video.
"""
# Parse multipart form data manually to handle optional files properly
form = await request.form()
# Get documents and images from form data
documents = []
images = []
for field_name, field_value in form.items():
if field_name.startswith("documents"):
if hasattr(field_value, 'filename') and field_value.filename:
documents.append(field_value)
elif field_name.startswith("images"):
if hasattr(field_value, 'filename') and field_value.filename:
images.append(field_value)
if not documents and not images:
raise HTTPException(
status_code=400,
detail="Please upload at least one financial document or image"
)
# Process documents
document_texts = []
image_data_list = []
# Handle documents
if documents:
for doc in documents:
if doc.filename:
allowed_exts = [".pdf", ".doc", ".docx", ".txt"]
if not any(doc.filename.lower().endswith(ext) for ext in allowed_exts):
raise HTTPException(
status_code=400,
detail=f"Document type not supported: {doc.filename}"
)
if doc.size and doc.size > 50 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Document too large. Max 50MB.")
doc_content = await doc.read()
try:
from rag.ingest import extract_text_from_file
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(doc.filename)[1]) as tmp_file:
tmp_file.write(doc_content)
tmp_file_path = tmp_file.name
try:
text = extract_text_from_file(Path(tmp_file_path))
if text:
document_texts.append(f"[Document: {doc.filename}]\n{text}")
finally:
os.unlink(tmp_file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing document {doc.filename}: {str(e)}")
# Handle images
if images:
for img in images:
if img.filename:
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
if img.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"Image type not supported: {img.filename}"
)
if img.size and img.size > 10 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Image too large. Max 10MB.")
img_data = await img.read()
image_data_list.append(img_data)
# Analyze financial data
try:
from rag.utils import analyze_financial_documents
analysis = analyze_financial_documents(document_texts, image_data_list)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error analyzing financial data: {str(e)}")
# Generate video script and video
video_script = None
video_url = None
try:
from rag.utils import create_video_script, generate_yearly_wrap_video
# Step 1: Create professional video script
video_script = create_video_script(analysis)
# Step 2: Generate video from script
video_url = generate_yearly_wrap_video(analysis)
if video_url:
status = "completed"
message = f"Yearly financial wrap for {year} created successfully with professional video script and animation!"
else:
status = "script_only"
message = f"Yearly financial wrap script created successfully! Video generation is temporarily unavailable due to service limitations, but you have a complete professional script ready."
except Exception as e:
error_msg = str(e)
if "Video generation" in error_msg:
# Script succeeded, video failed
status = "script_only"
message = f"Financial analysis and video script created, but video generation failed: {error_msg}"
else:
# Both failed
status = "analysis_only"
message = f"Financial analysis completed, but script/video generation failed: {error_msg}"
return YearlyWrapResponse(
analysis=analysis,
video_script=video_script,
video_url=video_url,
status=status,
message=message
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)