| | """ |
| | AI Tax Reform API - Backend Service |
| | |
| | A comprehensive Flask API for Nigerian tax calculations and AI-powered Q&A |
| | about Nigerian tax law based on the Nigeria Tax Act 2025. |
| | """ |
| |
|
| | from flask import Flask, request, jsonify, g |
| | from flask_cors import CORS |
| | from flask_limiter import Limiter |
| | from flask_limiter.util import get_remote_address |
| | from dotenv import load_dotenv |
| | from pathlib import Path |
| | import os |
| | import time |
| | import logging |
| | import re |
| | from functools import wraps |
| | from typing import Any, Callable, Dict, Optional, Tuple |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | load_dotenv() |
| |
|
| | |
| | |
| | |
| |
|
| | app = Flask(__name__) |
| | app.config['JSON_SORT_KEYS'] = False |
| | app.config['MAX_CONTENT_LENGTH'] = 1 * 1024 * 1024 |
| |
|
| | |
| | allowed_origins = os.getenv("CORS_ORIGINS", "").strip() |
| | if allowed_origins: |
| | origins = [o.strip() for o in allowed_origins.split(",") if o.strip()] |
| | else: |
| | origins = ["http://localhost:3000", "http://localhost:7860"] |
| |
|
| | CORS(app, origins=origins, supports_credentials=True) |
| |
|
| | |
| | limiter = Limiter( |
| | app=app, |
| | key_func=get_remote_address, |
| | default_limits=["200 per day", "50 per hour"], |
| | storage_uri="memory://", |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | from src.tax_calculator import calculate_tax, get_tax_summary, TaxCalculationError |
| | from scripts.query_qa import load_vectorstore, query |
| | from scripts.qa_service import generate_answer, verify_answer |
| | import threading |
| |
|
| | |
| | |
| | |
| |
|
| | _vectorstore_cache: Optional[Tuple[Any, Any]] = None |
| | _vectorstore_loading = False |
| | _vectorstore_lock = threading.Lock() |
| |
|
| |
|
| | def preload_vectorstore(): |
| | """Preload vectorstore in background thread (embeddings use HF API, no local model).""" |
| | global _vectorstore_cache, _vectorstore_loading |
| | with _vectorstore_lock: |
| | if _vectorstore_cache is None and not _vectorstore_loading: |
| | _vectorstore_loading = True |
| | |
| | if _vectorstore_loading and _vectorstore_cache is None: |
| | try: |
| | logger.info("Background loading vectorstore...") |
| | _vectorstore_cache = load_vectorstore() |
| | logger.info("Vectorstore preloaded successfully (using HF API for embeddings)") |
| | except Exception as e: |
| | logger.error(f"Background preload failed: {e}") |
| | finally: |
| | _vectorstore_loading = False |
| |
|
| |
|
| | def get_vectorstore() -> Tuple[Any, Any]: |
| | """Load and cache vectorstore with thread-safe initialization.""" |
| | global _vectorstore_cache |
| | if _vectorstore_cache is None: |
| | logger.info("Loading vectorstore...") |
| | try: |
| | _vectorstore_cache = load_vectorstore() |
| | logger.info("Vectorstore loaded successfully") |
| | except Exception as e: |
| | logger.error(f"Failed to load vectorstore: {e}") |
| | raise |
| | return _vectorstore_cache |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def sanitize_string(text: str, max_length: int = 2000) -> str: |
| | """Sanitize user input string.""" |
| | if not isinstance(text, str): |
| | return "" |
| | |
| | text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) |
| | |
| | return text[:max_length].strip() |
| |
|
| |
|
| | def validate_numeric(value: Any, field_name: str, min_val: float = 0, max_val: float = 1e15) -> float: |
| | """Validate and convert numeric input.""" |
| | if value is None: |
| | raise ValueError(f"'{field_name}' is required") |
| | try: |
| | num = float(value) |
| | if num < min_val or num > max_val: |
| | raise ValueError(f"'{field_name}' must be between {min_val:,.0f} and {max_val:,.0f}") |
| | return num |
| | except (TypeError, ValueError): |
| | raise ValueError(f"'{field_name}' must be a valid number") |
| |
|
| |
|
| | def validate_positive_int(value: Any, field_name: str, min_val: int = 1, max_val: int = 20) -> int: |
| | """Validate positive integer input.""" |
| | try: |
| | num = int(value) |
| | if num < min_val or num > max_val: |
| | raise ValueError(f"'{field_name}' must be between {min_val} and {max_val}") |
| | return num |
| | except (TypeError, ValueError): |
| | raise ValueError(f"'{field_name}' must be a valid integer") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class APIError(Exception): |
| | """Custom API exception with status code.""" |
| | def __init__(self, message: str, status_code: int = 400, details: Optional[str] = None): |
| | self.message = message |
| | self.status_code = status_code |
| | self.details = details |
| | super().__init__(message) |
| |
|
| |
|
| | @app.errorhandler(APIError) |
| | def handle_api_error(error: APIError): |
| | """Handle custom API errors.""" |
| | response = {"error": error.message} |
| | if error.details and app.debug: |
| | response["details"] = error.details |
| | return jsonify(response), error.status_code |
| |
|
| |
|
| | @app.errorhandler(429) |
| | def handle_rate_limit(e): |
| | """Handle rate limit exceeded.""" |
| | return jsonify({ |
| | "error": "Rate limit exceeded", |
| | "message": "Too many requests. Please try again later." |
| | }), 429 |
| |
|
| |
|
| | @app.errorhandler(413) |
| | def handle_large_request(e): |
| | """Handle request too large.""" |
| | return jsonify({ |
| | "error": "Request too large", |
| | "message": "The request payload exceeds the maximum allowed size." |
| | }), 413 |
| |
|
| |
|
| | @app.errorhandler(404) |
| | def handle_not_found(e): |
| | """Handle page not found errors.""" |
| | return jsonify({ |
| | "error": "Endpoint not found", |
| | "message": "The requested endpoint does not exist. Check the API documentation at the root endpoint.", |
| | "available_endpoints": ["/health", "/calculate", "/retrieve", "/qa", "/aqa"] |
| | }), 404 |
| |
|
| |
|
| | @app.errorhandler(500) |
| | def handle_internal_error(e): |
| | """Handle internal server errors.""" |
| | logger.exception("Internal server error") |
| | return jsonify({ |
| | "error": "Internal server error", |
| | "message": "An unexpected error occurred. Please try again later." |
| | }), 500 |
| |
|
| |
|
| | @app.errorhandler(Exception) |
| | def handle_unexpected_error(e): |
| | """Catch-all handler for any unhandled exceptions.""" |
| | logger.exception(f"Unhandled exception: {e}") |
| | return jsonify({ |
| | "error": "An unexpected error occurred", |
| | "message": str(e) if app.debug else "Please try again later." |
| | }), 500 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.before_request |
| | def before_request(): |
| | """Log request and set start time.""" |
| | g.start_time = time.time() |
| |
|
| |
|
| | @app.after_request |
| | def after_request(response): |
| | """Log response time.""" |
| | if hasattr(g, 'start_time'): |
| | elapsed = (time.time() - g.start_time) * 1000 |
| | logger.info(f"{request.method} {request.path} - {response.status_code} - {elapsed:.2f}ms") |
| | return response |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.route("/health", methods=["GET"]) |
| | @limiter.exempt |
| | def health(): |
| | """Health check endpoint for monitoring.""" |
| | return jsonify({ |
| | "status": "healthy", |
| | "service": "AI Tax Reform API", |
| | "version": "2.0.0", |
| | "timestamp": time.time() |
| | }), 200 |
| |
|
| |
|
| | @app.route("/calculate", methods=["POST"]) |
| | @limiter.limit("30 per minute") |
| | def calculate_endpoint(): |
| | """ |
| | Calculate Nigerian personal income tax. |
| | |
| | Request JSON: |
| | - income (float, required): Gross annual income in NGN |
| | - allowances (float, optional): Non-taxable allowances |
| | - reliefs (float, optional): Tax reliefs |
| | - pension (float, optional): Pension contribution |
| | - include_cra (bool, optional): Include Consolidated Relief Allowance (default: true) |
| | |
| | Returns: |
| | JSON with tax calculation breakdown |
| | """ |
| | try: |
| | data = request.get_json() or {} |
| | |
| | |
| | income = validate_numeric(data.get("income"), "income") |
| | allowances = validate_numeric(data.get("allowances", 0), "allowances", min_val=0) |
| | reliefs = validate_numeric(data.get("reliefs", 0), "reliefs", min_val=0) |
| | pension = validate_numeric(data.get("pension", 0), "pension", min_val=0) |
| | include_cra = bool(data.get("include_cra", True)) |
| | |
| | |
| | result = calculate_tax( |
| | annual_income=income, |
| | allowances=allowances, |
| | reliefs=reliefs, |
| | pension_contribution=pension, |
| | include_cra=include_cra |
| | ) |
| | |
| | return jsonify(get_tax_summary(result)), 200 |
| | |
| | except ValueError as e: |
| | raise APIError(str(e), 400) |
| | except TaxCalculationError as e: |
| | raise APIError(str(e), 400) |
| | except Exception as e: |
| | logger.exception("Tax calculation failed") |
| | raise APIError("Tax calculation failed", 500) |
| |
|
| |
|
| | @app.route("/retrieve", methods=["POST"]) |
| | @limiter.limit("20 per minute") |
| | def retrieve(): |
| | """ |
| | Retrieve relevant document chunks from the tax law knowledge base. |
| | |
| | Request JSON: |
| | - query (string, required): Search query |
| | - top_k (int, optional): Number of results to return (1-20, default: 5) |
| | |
| | Returns: |
| | JSON with matching document chunks |
| | """ |
| | try: |
| | payload = request.get_json() or {} |
| | |
| | query_text = sanitize_string(payload.get("query", "")) |
| | if not query_text or len(query_text) < 2: |
| | raise APIError("Query must be at least 2 characters", 400) |
| | |
| | top_k = validate_positive_int(payload.get("top_k", 5), "top_k", min_val=1, max_val=20) |
| | |
| | index, docs = get_vectorstore() |
| | results = query(index, docs, query_text, top_k=top_k) |
| | |
| | return jsonify({ |
| | "query": query_text, |
| | "count": len(results), |
| | "results": results |
| | }), 200 |
| | |
| | except APIError: |
| | raise |
| | except Exception as e: |
| | logger.exception("Retrieval failed") |
| | raise APIError("Document retrieval failed", 500) |
| |
|
| |
|
| | @app.route("/qa", methods=["POST"]) |
| | @limiter.limit("15 per minute") |
| | def qa(): |
| | """ |
| | Answer questions about Nigerian tax law using RAG (Retrieval-Augmented Generation). |
| | |
| | Request JSON: |
| | - query (string, required): Question to answer |
| | - top_k (int, optional): Number of context documents (1-8, default: 3) |
| | - prefer_grok (bool, optional): Prefer Groq/Grok model (default: true) |
| | - fast_mode (bool, optional): Return sources without LLM generation (default: false) |
| | |
| | Returns: |
| | JSON with AI-generated answer and source documents |
| | """ |
| | try: |
| | payload = request.get_json() or {} |
| | |
| | query_text = sanitize_string(payload.get("query", "")) |
| | if not query_text or len(query_text) < 2: |
| | raise APIError("Query must be at least 2 characters", 400) |
| | |
| | |
| | top_k = validate_positive_int(payload.get("top_k", 3), "top_k", min_val=1, max_val=8) |
| | prefer_grok = bool(payload.get("prefer_grok", True)) |
| | fast_mode = bool(payload.get("fast_mode", False)) |
| | |
| | |
| | try: |
| | index, docs = get_vectorstore() |
| | results = query(index, docs, query_text, top_k=top_k) |
| | except Exception as ve: |
| | logger.error(f"Vectorstore query failed: {ve}") |
| | raise APIError("Search service temporarily unavailable", 503) |
| | |
| | if not results: |
| | return jsonify({ |
| | "answer": "I couldn't find relevant information about this topic in the tax documentation. Please try rephrasing your question.", |
| | "model": "none", |
| | "sources": [] |
| | }), 200 |
| | |
| | |
| | if fast_mode: |
| | top_text = results[0].get("text", "")[:800] |
| | return jsonify({ |
| | "query": query_text, |
| | "answer": f"**From the Nigeria Tax Act 2025:**\n\n{top_text}\n\n---\n*[Fast mode - showing direct excerpt from source documents]*", |
| | "model": "fast", |
| | "sources": results |
| | }), 200 |
| | |
| | |
| | try: |
| | answer, model_used, _ = generate_answer(query_text, results, prefer_grok=prefer_grok, timeout=20) |
| | except Exception as ge: |
| | logger.error(f"Answer generation failed: {ge}") |
| | |
| | return jsonify({ |
| | "answer": "I found relevant documents but couldn't generate a complete answer. Here are the key sections from the tax law that may help:", |
| | "model": "fallback", |
| | "sources": results |
| | }), 200 |
| | |
| | return jsonify({ |
| | "query": query_text, |
| | "answer": answer, |
| | "model": model_used, |
| | "sources": results |
| | }), 200 |
| | |
| | except APIError: |
| | raise |
| | except Exception as e: |
| | logger.exception("QA processing failed") |
| | raise APIError("Question answering failed. Please try again.", 500) |
| |
|
| |
|
| | @app.route("/aqa", methods=["POST"]) |
| | @limiter.limit("10 per minute") |
| | def aqa(): |
| | """ |
| | Answer questions with verification (Assured QA). |
| | |
| | Same as /qa but includes answer verification step for higher accuracy. |
| | |
| | Request JSON: |
| | - query (string, required): Question to answer |
| | - top_k (int, optional): Number of context documents (1-10, default: 5) |
| | - prefer_grok (bool, optional): Prefer Groq/Grok model (default: true) |
| | |
| | Returns: |
| | JSON with AI-generated answer, verification result, and source documents |
| | """ |
| | try: |
| | payload = request.get_json() or {} |
| | |
| | query_text = sanitize_string(payload.get("query", "")) |
| | if not query_text or len(query_text) < 2: |
| | raise APIError("Query must be at least 2 characters", 400) |
| | |
| | top_k = validate_positive_int(payload.get("top_k", 5), "top_k", min_val=1, max_val=10) |
| | prefer_grok = bool(payload.get("prefer_grok", True)) |
| | |
| | |
| | index, docs = get_vectorstore() |
| | results = query(index, docs, query_text, top_k=top_k) |
| | |
| | if not results: |
| | return jsonify({ |
| | "answer": "I couldn't find relevant information about this topic in the tax documentation.", |
| | "model": "none", |
| | "verification": {"score": 0, "reason": "No relevant documents found"}, |
| | "verified": False, |
| | "sources": [] |
| | }), 200 |
| | |
| | |
| | answer, model_used, _ = generate_answer(query_text, results, prefer_grok=prefer_grok) |
| | |
| | |
| | try: |
| | verification = verify_answer(answer, query_text, results, prefer_grok=prefer_grok) |
| | |
| | |
| | verified = False |
| | if isinstance(verification, dict): |
| | score = verification.get("score", 0) |
| | verified = score >= 0.7 if isinstance(score, (int, float)) else False |
| | elif isinstance(verification, str): |
| | |
| | import json |
| | try: |
| | verification = json.loads(verification) |
| | score = verification.get("score", 0) |
| | verified = score >= 0.7 if isinstance(score, (int, float)) else False |
| | except json.JSONDecodeError: |
| | verification = {"raw": verification, "score": 0} |
| | verified = "accurate" in verification.get("raw", "").lower() |
| | except Exception as ve: |
| | logger.warning(f"Verification failed: {ve}") |
| | verification = {"error": "Verification unavailable"} |
| | verified = False |
| | |
| | return jsonify({ |
| | "query": query_text, |
| | "answer": answer, |
| | "model": model_used, |
| | "verification": verification, |
| | "verified": verified, |
| | "sources": results |
| | }), 200 |
| | |
| | except APIError: |
| | raise |
| | except Exception as e: |
| | logger.exception("AQA processing failed") |
| | raise APIError("Verified question answering failed. Please try again.", 500) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.route("/", methods=["GET"]) |
| | @limiter.exempt |
| | def api_docs(): |
| | """Return API documentation.""" |
| | return jsonify({ |
| | "name": "AI Tax Reform API", |
| | "version": "2.0.0", |
| | "description": "AI-powered Nigerian tax calculator and Q&A service", |
| | "endpoints": { |
| | "GET /health": "Health check", |
| | "POST /calculate": "Calculate personal income tax", |
| | "POST /retrieve": "Retrieve relevant tax documents", |
| | "POST /qa": "Ask questions about tax law", |
| | "POST /aqa": "Ask questions with answer verification" |
| | }, |
| | "documentation": "https://github.com/your-repo/AI-TAX-REFORM#readme" |
| | }), 200 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def start_background_tasks(): |
| | """Start background tasks after app is ready.""" |
| | |
| | def delayed_preload(): |
| | import time |
| | time.sleep(2) |
| | preload_vectorstore() |
| | |
| | thread = threading.Thread(target=delayed_preload, daemon=True) |
| | thread.start() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | port = int(os.getenv("PORT", 7860)) |
| | debug = os.getenv("FLASK_ENV") == "development" |
| | |
| | logger.info(f"Starting AI Tax Reform API v2.0.0 on port {port}") |
| | logger.info(f"Allowed origins: {origins}") |
| | |
| | |
| | start_background_tasks() |
| | |
| | app.run(host="0.0.0.0", port=port, debug=debug) |
| |
|