tax / app.py
Mozzicstar
Deploy AI Tax Reform API
18ab7fd
"""
AI Tax Reform API - Backend Service
A comprehensive Flask API for Nigerian tax calculations and AI-powered Q&A
about Nigerian tax law based on the Nigeria Tax Act 2025.
"""
from flask import Flask, request, jsonify, g
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from dotenv import load_dotenv
from pathlib import Path
import os
import time
import logging
import re
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
load_dotenv()
# ============================================================================
# App Configuration
# ============================================================================
app = Flask(__name__)
app.config['JSON_SORT_KEYS'] = False
app.config['MAX_CONTENT_LENGTH'] = 1 * 1024 * 1024 # 1MB max request size
# CORS Configuration - Secure defaults
allowed_origins = os.getenv("CORS_ORIGINS", "").strip()
if allowed_origins:
origins = [o.strip() for o in allowed_origins.split(",") if o.strip()]
else:
origins = ["http://localhost:3000", "http://localhost:7860"]
CORS(app, origins=origins, supports_credentials=True)
# Rate Limiting
limiter = Limiter(
app=app,
key_func=get_remote_address,
default_limits=["200 per day", "50 per hour"],
storage_uri="memory://",
)
# ============================================================================
# Imports (after app initialization)
# ============================================================================
from src.tax_calculator import calculate_tax, get_tax_summary, TaxCalculationError
from scripts.query_qa import load_vectorstore, query
from scripts.qa_service import generate_answer, verify_answer
import threading
# ============================================================================
# Vectorstore Cache
# ============================================================================
_vectorstore_cache: Optional[Tuple[Any, Any]] = None
_vectorstore_loading = False
_vectorstore_lock = threading.Lock()
def preload_vectorstore():
"""Preload vectorstore in background thread (embeddings use HF API, no local model)."""
global _vectorstore_cache, _vectorstore_loading
with _vectorstore_lock:
if _vectorstore_cache is None and not _vectorstore_loading:
_vectorstore_loading = True
if _vectorstore_loading and _vectorstore_cache is None:
try:
logger.info("Background loading vectorstore...")
_vectorstore_cache = load_vectorstore()
logger.info("Vectorstore preloaded successfully (using HF API for embeddings)")
except Exception as e:
logger.error(f"Background preload failed: {e}")
finally:
_vectorstore_loading = False
def get_vectorstore() -> Tuple[Any, Any]:
"""Load and cache vectorstore with thread-safe initialization."""
global _vectorstore_cache
if _vectorstore_cache is None:
logger.info("Loading vectorstore...")
try:
_vectorstore_cache = load_vectorstore()
logger.info("Vectorstore loaded successfully")
except Exception as e:
logger.error(f"Failed to load vectorstore: {e}")
raise
return _vectorstore_cache
# ============================================================================
# Input Validation & Security
# ============================================================================
def sanitize_string(text: str, max_length: int = 2000) -> str:
"""Sanitize user input string."""
if not isinstance(text, str):
return ""
# Remove null bytes and control characters (except newlines/tabs)
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
# Limit length
return text[:max_length].strip()
def validate_numeric(value: Any, field_name: str, min_val: float = 0, max_val: float = 1e15) -> float:
"""Validate and convert numeric input."""
if value is None:
raise ValueError(f"'{field_name}' is required")
try:
num = float(value)
if num < min_val or num > max_val:
raise ValueError(f"'{field_name}' must be between {min_val:,.0f} and {max_val:,.0f}")
return num
except (TypeError, ValueError):
raise ValueError(f"'{field_name}' must be a valid number")
def validate_positive_int(value: Any, field_name: str, min_val: int = 1, max_val: int = 20) -> int:
"""Validate positive integer input."""
try:
num = int(value)
if num < min_val or num > max_val:
raise ValueError(f"'{field_name}' must be between {min_val} and {max_val}")
return num
except (TypeError, ValueError):
raise ValueError(f"'{field_name}' must be a valid integer")
# ============================================================================
# Error Handlers
# ============================================================================
class APIError(Exception):
"""Custom API exception with status code."""
def __init__(self, message: str, status_code: int = 400, details: Optional[str] = None):
self.message = message
self.status_code = status_code
self.details = details
super().__init__(message)
@app.errorhandler(APIError)
def handle_api_error(error: APIError):
"""Handle custom API errors."""
response = {"error": error.message}
if error.details and app.debug:
response["details"] = error.details
return jsonify(response), error.status_code
@app.errorhandler(429)
def handle_rate_limit(e):
"""Handle rate limit exceeded."""
return jsonify({
"error": "Rate limit exceeded",
"message": "Too many requests. Please try again later."
}), 429
@app.errorhandler(413)
def handle_large_request(e):
"""Handle request too large."""
return jsonify({
"error": "Request too large",
"message": "The request payload exceeds the maximum allowed size."
}), 413
@app.errorhandler(404)
def handle_not_found(e):
"""Handle page not found errors."""
return jsonify({
"error": "Endpoint not found",
"message": "The requested endpoint does not exist. Check the API documentation at the root endpoint.",
"available_endpoints": ["/health", "/calculate", "/retrieve", "/qa", "/aqa"]
}), 404
@app.errorhandler(500)
def handle_internal_error(e):
"""Handle internal server errors."""
logger.exception("Internal server error")
return jsonify({
"error": "Internal server error",
"message": "An unexpected error occurred. Please try again later."
}), 500
@app.errorhandler(Exception)
def handle_unexpected_error(e):
"""Catch-all handler for any unhandled exceptions."""
logger.exception(f"Unhandled exception: {e}")
return jsonify({
"error": "An unexpected error occurred",
"message": str(e) if app.debug else "Please try again later."
}), 500
# ============================================================================
# Request Logging Middleware
# ============================================================================
@app.before_request
def before_request():
"""Log request and set start time."""
g.start_time = time.time()
@app.after_request
def after_request(response):
"""Log response time."""
if hasattr(g, 'start_time'):
elapsed = (time.time() - g.start_time) * 1000
logger.info(f"{request.method} {request.path} - {response.status_code} - {elapsed:.2f}ms")
return response
# ============================================================================
# API Routes
# ============================================================================
@app.route("/health", methods=["GET"])
@limiter.exempt
def health():
"""Health check endpoint for monitoring."""
return jsonify({
"status": "healthy",
"service": "AI Tax Reform API",
"version": "2.0.0",
"timestamp": time.time()
}), 200
@app.route("/calculate", methods=["POST"])
@limiter.limit("30 per minute")
def calculate_endpoint():
"""
Calculate Nigerian personal income tax.
Request JSON:
- income (float, required): Gross annual income in NGN
- allowances (float, optional): Non-taxable allowances
- reliefs (float, optional): Tax reliefs
- pension (float, optional): Pension contribution
- include_cra (bool, optional): Include Consolidated Relief Allowance (default: true)
Returns:
JSON with tax calculation breakdown
"""
try:
data = request.get_json() or {}
# Validate inputs
income = validate_numeric(data.get("income"), "income")
allowances = validate_numeric(data.get("allowances", 0), "allowances", min_val=0)
reliefs = validate_numeric(data.get("reliefs", 0), "reliefs", min_val=0)
pension = validate_numeric(data.get("pension", 0), "pension", min_val=0)
include_cra = bool(data.get("include_cra", True))
# Calculate tax using the improved calculator
result = calculate_tax(
annual_income=income,
allowances=allowances,
reliefs=reliefs,
pension_contribution=pension,
include_cra=include_cra
)
return jsonify(get_tax_summary(result)), 200
except ValueError as e:
raise APIError(str(e), 400)
except TaxCalculationError as e:
raise APIError(str(e), 400)
except Exception as e:
logger.exception("Tax calculation failed")
raise APIError("Tax calculation failed", 500)
@app.route("/retrieve", methods=["POST"])
@limiter.limit("20 per minute")
def retrieve():
"""
Retrieve relevant document chunks from the tax law knowledge base.
Request JSON:
- query (string, required): Search query
- top_k (int, optional): Number of results to return (1-20, default: 5)
Returns:
JSON with matching document chunks
"""
try:
payload = request.get_json() or {}
query_text = sanitize_string(payload.get("query", ""))
if not query_text or len(query_text) < 2:
raise APIError("Query must be at least 2 characters", 400)
top_k = validate_positive_int(payload.get("top_k", 5), "top_k", min_val=1, max_val=20)
index, docs = get_vectorstore()
results = query(index, docs, query_text, top_k=top_k)
return jsonify({
"query": query_text,
"count": len(results),
"results": results
}), 200
except APIError:
raise
except Exception as e:
logger.exception("Retrieval failed")
raise APIError("Document retrieval failed", 500)
@app.route("/qa", methods=["POST"])
@limiter.limit("15 per minute")
def qa():
"""
Answer questions about Nigerian tax law using RAG (Retrieval-Augmented Generation).
Request JSON:
- query (string, required): Question to answer
- top_k (int, optional): Number of context documents (1-8, default: 3)
- prefer_grok (bool, optional): Prefer Groq/Grok model (default: true)
- fast_mode (bool, optional): Return sources without LLM generation (default: false)
Returns:
JSON with AI-generated answer and source documents
"""
try:
payload = request.get_json() or {}
query_text = sanitize_string(payload.get("query", ""))
if not query_text or len(query_text) < 2:
raise APIError("Query must be at least 2 characters", 400)
# Default to 3 docs instead of 5 for faster response
top_k = validate_positive_int(payload.get("top_k", 3), "top_k", min_val=1, max_val=8)
prefer_grok = bool(payload.get("prefer_grok", True))
fast_mode = bool(payload.get("fast_mode", False))
# Retrieve relevant context with timeout handling
try:
index, docs = get_vectorstore()
results = query(index, docs, query_text, top_k=top_k)
except Exception as ve:
logger.error(f"Vectorstore query failed: {ve}")
raise APIError("Search service temporarily unavailable", 503)
if not results:
return jsonify({
"answer": "I couldn't find relevant information about this topic in the tax documentation. Please try rephrasing your question.",
"model": "none",
"sources": []
}), 200
# Fast mode: return sources with excerpt instead of calling slow LLM
if fast_mode:
top_text = results[0].get("text", "")[:800]
return jsonify({
"query": query_text,
"answer": f"**From the Nigeria Tax Act 2025:**\n\n{top_text}\n\n---\n*[Fast mode - showing direct excerpt from source documents]*",
"model": "fast",
"sources": results
}), 200
# Generate answer with shorter timeout
try:
answer, model_used, _ = generate_answer(query_text, results, prefer_grok=prefer_grok, timeout=20)
except Exception as ge:
logger.error(f"Answer generation failed: {ge}")
# Return sources even if generation fails
return jsonify({
"answer": "I found relevant documents but couldn't generate a complete answer. Here are the key sections from the tax law that may help:",
"model": "fallback",
"sources": results
}), 200
return jsonify({
"query": query_text,
"answer": answer,
"model": model_used,
"sources": results
}), 200
except APIError:
raise
except Exception as e:
logger.exception("QA processing failed")
raise APIError("Question answering failed. Please try again.", 500)
@app.route("/aqa", methods=["POST"])
@limiter.limit("10 per minute")
def aqa():
"""
Answer questions with verification (Assured QA).
Same as /qa but includes answer verification step for higher accuracy.
Request JSON:
- query (string, required): Question to answer
- top_k (int, optional): Number of context documents (1-10, default: 5)
- prefer_grok (bool, optional): Prefer Groq/Grok model (default: true)
Returns:
JSON with AI-generated answer, verification result, and source documents
"""
try:
payload = request.get_json() or {}
query_text = sanitize_string(payload.get("query", ""))
if not query_text or len(query_text) < 2:
raise APIError("Query must be at least 2 characters", 400)
top_k = validate_positive_int(payload.get("top_k", 5), "top_k", min_val=1, max_val=10)
prefer_grok = bool(payload.get("prefer_grok", True))
# Retrieve relevant context
index, docs = get_vectorstore()
results = query(index, docs, query_text, top_k=top_k)
if not results:
return jsonify({
"answer": "I couldn't find relevant information about this topic in the tax documentation.",
"model": "none",
"verification": {"score": 0, "reason": "No relevant documents found"},
"verified": False,
"sources": []
}), 200
# Generate answer
answer, model_used, _ = generate_answer(query_text, results, prefer_grok=prefer_grok)
# Verify answer
try:
verification = verify_answer(answer, query_text, results, prefer_grok=prefer_grok)
# Parse verification result
verified = False
if isinstance(verification, dict):
score = verification.get("score", 0)
verified = score >= 0.7 if isinstance(score, (int, float)) else False
elif isinstance(verification, str):
# Try to extract score from string response
import json
try:
verification = json.loads(verification)
score = verification.get("score", 0)
verified = score >= 0.7 if isinstance(score, (int, float)) else False
except json.JSONDecodeError:
verification = {"raw": verification, "score": 0}
verified = "accurate" in verification.get("raw", "").lower()
except Exception as ve:
logger.warning(f"Verification failed: {ve}")
verification = {"error": "Verification unavailable"}
verified = False
return jsonify({
"query": query_text,
"answer": answer,
"model": model_used,
"verification": verification,
"verified": verified,
"sources": results
}), 200
except APIError:
raise
except Exception as e:
logger.exception("AQA processing failed")
raise APIError("Verified question answering failed. Please try again.", 500)
# ============================================================================
# API Documentation Endpoint
# ============================================================================
@app.route("/", methods=["GET"])
@limiter.exempt
def api_docs():
"""Return API documentation."""
return jsonify({
"name": "AI Tax Reform API",
"version": "2.0.0",
"description": "AI-powered Nigerian tax calculator and Q&A service",
"endpoints": {
"GET /health": "Health check",
"POST /calculate": "Calculate personal income tax",
"POST /retrieve": "Retrieve relevant tax documents",
"POST /qa": "Ask questions about tax law",
"POST /aqa": "Ask questions with answer verification"
},
"documentation": "https://github.com/your-repo/AI-TAX-REFORM#readme"
}), 200
# ============================================================================
# Application Startup
# ============================================================================
def start_background_tasks():
"""Start background tasks after app is ready."""
# Preload vectorstore in background after 2 seconds
def delayed_preload():
import time
time.sleep(2)
preload_vectorstore()
thread = threading.Thread(target=delayed_preload, daemon=True)
thread.start()
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
debug = os.getenv("FLASK_ENV") == "development"
logger.info(f"Starting AI Tax Reform API v2.0.0 on port {port}")
logger.info(f"Allowed origins: {origins}")
# Start background preloading
start_background_tasks()
app.run(host="0.0.0.0", port=port, debug=debug)