Upload 14 files
Browse files- RAG_SYSTEM_PLAN.md +243 -0
- app.py +370 -0
- dockerfile +42 -0
- rag/.env +9 -0
- rag/__init__.py +0 -0
- rag/__pycache__/__init__.cpython-312.pyc +0 -0
- rag/__pycache__/ingest.cpython-312.pyc +0 -0
- rag/__pycache__/utils.cpython-312.pyc +0 -0
- rag/ingest.py +211 -0
- rag/main.py +293 -0
- rag/requirements.txt +10 -0
- rag/utils.py +236 -0
- requirements.txt +10 -0
RAG_SYSTEM_PLAN.md
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Nigerian Tax Law RAG System
|
| 2 |
+
|
| 3 |
+
A lightweight, scalable Retrieval-Augmented Generation (RAG) system for querying Nigerian tax and legal documents.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This system uses:
|
| 8 |
+
- **FastAPI** - Backend API server
|
| 9 |
+
- **Gemini API** - Embeddings + answer generation
|
| 10 |
+
- **ChromaDB** - Vector database for semantic search
|
| 11 |
+
- **pdfplumber** - PDF text extraction
|
| 12 |
+
- **tiktoken** - Text chunking with token counting
|
| 13 |
+
|
| 14 |
+
## Architecture
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
┌─────────────────────────────┐
|
| 18 |
+
│ Client │
|
| 19 |
+
└───────────────┬─────────────┘
|
| 20 |
+
│ /ask
|
| 21 |
+
┌───────▼────────┐
|
| 22 |
+
│ FastAPI API │
|
| 23 |
+
└───────┬────────┘
|
| 24 |
+
│
|
| 25 |
+
│ Query → Gemini Embedding
|
| 26 |
+
┌───────▼──────────┐
|
| 27 |
+
│ Vector DB │
|
| 28 |
+
│ (Chroma) │
|
| 29 |
+
└───────┬──────────┘
|
| 30 |
+
│
|
| 31 |
+
│ Retrieved Chunks
|
| 32 |
+
┌───────▼──────────┐
|
| 33 |
+
│ Gemini Model │
|
| 34 |
+
│ (RAG Completion) │
|
| 35 |
+
└───────┬──────────┘
|
| 36 |
+
│
|
| 37 |
+
┌───────▼──────────┐
|
| 38 |
+
│ Final Answer │
|
| 39 |
+
└───────────────────┘
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## File Structure
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
tax/
|
| 46 |
+
├── docs/ # Your PDF documents
|
| 47 |
+
│ ├── Nigeria-Tax-Act-2025.pdf
|
| 48 |
+
│ └── ... (other tax/legal PDFs)
|
| 49 |
+
└── rag/
|
| 50 |
+
├── RAG_SYSTEM_PLAN.md # This file
|
| 51 |
+
├── main.py # FastAPI server
|
| 52 |
+
├── ingest.py # PDF → ChromaDB pipeline
|
| 53 |
+
├── utils.py # Chunking + embedding functions
|
| 54 |
+
├── requirements.txt # Python dependencies
|
| 55 |
+
└── db/ # ChromaDB vector database (auto-created)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Installation
|
| 59 |
+
|
| 60 |
+
1. **Create a virtual environment** (recommended):
|
| 61 |
+
```bash
|
| 62 |
+
cd rag
|
| 63 |
+
python -m venv venv
|
| 64 |
+
source venv/bin/activate # Linux/Mac
|
| 65 |
+
# or: venv\Scripts\activate # Windows
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
2. **Install dependencies**:
|
| 69 |
+
```bash
|
| 70 |
+
pip install -r requirements.txt
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
3. **Set your Gemini API key**:
|
| 74 |
+
```bash
|
| 75 |
+
export GEMINI_API_KEY='your-api-key-here'
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Get a free API key at: https://aistudio.google.com/apikey
|
| 79 |
+
|
| 80 |
+
## Usage
|
| 81 |
+
|
| 82 |
+
### Step 1: Ingest Documents
|
| 83 |
+
|
| 84 |
+
Index your PDF documents into the vector database:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
cd rag
|
| 88 |
+
python ingest.py
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Options:
|
| 92 |
+
- `--force` or `-f`: Re-ingest all documents (update embeddings)
|
| 93 |
+
- `--clear`: Clear the database before ingesting
|
| 94 |
+
- `--stats`: Show database statistics only
|
| 95 |
+
- `--data-dir PATH`: Use a different PDF directory
|
| 96 |
+
|
| 97 |
+
### Step 2: Start the API Server
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
uvicorn main:app --reload
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
The API will be available at `http://localhost:8000`
|
| 104 |
+
|
| 105 |
+
### Step 3: Query Documents
|
| 106 |
+
|
| 107 |
+
**Ask a question:**
|
| 108 |
+
```bash
|
| 109 |
+
curl -X POST "http://localhost:8000/ask" \
|
| 110 |
+
-H "Content-Type: application/json" \
|
| 111 |
+
-d '{"question": "What are the tax rates for personal income in Nigeria?"}'
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
**Check API health:**
|
| 115 |
+
```bash
|
| 116 |
+
curl http://localhost:8000/health
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
**Get statistics:**
|
| 120 |
+
```bash
|
| 121 |
+
curl http://localhost:8000/stats
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
## API Endpoints
|
| 125 |
+
|
| 126 |
+
| Method | Endpoint | Description |
|
| 127 |
+
|--------|----------|-------------|
|
| 128 |
+
| `GET` | `/` | API information |
|
| 129 |
+
| `GET` | `/health` | Health check |
|
| 130 |
+
| `POST` | `/ask` | Ask a question |
|
| 131 |
+
| `POST` | `/ingest` | Upload a new PDF |
|
| 132 |
+
| `GET` | `/stats` | Database statistics |
|
| 133 |
+
| `DELETE` | `/documents/{name}` | Remove a document |
|
| 134 |
+
|
| 135 |
+
### POST /ask
|
| 136 |
+
|
| 137 |
+
Request body:
|
| 138 |
+
```json
|
| 139 |
+
{
|
| 140 |
+
"question": "What is the penalty for late tax filing?",
|
| 141 |
+
"top_k": 5,
|
| 142 |
+
"model": "gemini-2.0-flash"
|
| 143 |
+
}
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
Response:
|
| 147 |
+
```json
|
| 148 |
+
{
|
| 149 |
+
"answer": "According to the Nigeria Tax Act 2025...",
|
| 150 |
+
"sources": [
|
| 151 |
+
{
|
| 152 |
+
"document": "Nigeria-Tax-Act-2025.pdf",
|
| 153 |
+
"chunk_index": 42,
|
| 154 |
+
"relevance_score": 0.8532
|
| 155 |
+
}
|
| 156 |
+
],
|
| 157 |
+
"chunks_used": 5
|
| 158 |
+
}
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### POST /ingest
|
| 162 |
+
|
| 163 |
+
Upload a PDF file to add to the index:
|
| 164 |
+
```bash
|
| 165 |
+
curl -X POST "http://localhost:8000/ingest" \
|
| 166 |
+
-F "file=@new-document.pdf"
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
## Configuration
|
| 170 |
+
|
| 171 |
+
Key settings in `ingest.py`:
|
| 172 |
+
- `CHUNK_SIZE = 500` - Tokens per chunk
|
| 173 |
+
- `CHUNK_OVERLAP = 50` - Overlap between chunks
|
| 174 |
+
- `DATA_DIR` - PDF source directory (`../docs`)
|
| 175 |
+
- `DB_DIR` - ChromaDB storage (`./db`)
|
| 176 |
+
|
| 177 |
+
## Components
|
| 178 |
+
|
| 179 |
+
### Data Ingestion (`ingest.py`)
|
| 180 |
+
|
| 181 |
+
1. Extracts text from PDFs using pdfplumber
|
| 182 |
+
2. Chunks into ~500 tokens using tiktoken
|
| 183 |
+
3. Generates embeddings with Gemini (`text-embedding-004`)
|
| 184 |
+
4. Stores in ChromaDB with metadata
|
| 185 |
+
|
| 186 |
+
### Retrieval & Answer Generation (`main.py`)
|
| 187 |
+
|
| 188 |
+
1. Converts query to embedding
|
| 189 |
+
2. Searches ChromaDB for top-K similar chunks
|
| 190 |
+
3. Sends context + question to Gemini
|
| 191 |
+
4. Returns grounded answer with sources
|
| 192 |
+
|
| 193 |
+
### Utilities (`utils.py`)
|
| 194 |
+
|
| 195 |
+
- `chunk_text()` - Split text into token-based chunks
|
| 196 |
+
- `generate_embedding()` - Create document embeddings
|
| 197 |
+
- `generate_query_embedding()` - Create query embeddings
|
| 198 |
+
- `generate_answer()` - RAG completion with Gemini
|
| 199 |
+
- `clean_text()` - Clean extracted PDF text
|
| 200 |
+
|
| 201 |
+
## Models Used
|
| 202 |
+
|
| 203 |
+
- **Embeddings**: `text-embedding-004` (768 dimensions)
|
| 204 |
+
- **Generation**: `gemini-2.0-flash` (default, fast)
|
| 205 |
+
- Can also use `gemini-2.0-pro` for complex reasoning
|
| 206 |
+
|
| 207 |
+
## Security Considerations
|
| 208 |
+
|
| 209 |
+
- API keys loaded via environment variables
|
| 210 |
+
- Input validation on all endpoints
|
| 211 |
+
- CORS middleware configured (restrict in production)
|
| 212 |
+
- Consider adding JWT authentication for production
|
| 213 |
+
|
| 214 |
+
## Troubleshooting
|
| 215 |
+
|
| 216 |
+
**"GEMINI_API_KEY not set"**
|
| 217 |
+
```bash
|
| 218 |
+
export GEMINI_API_KEY='your-key'
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
**"No documents indexed"**
|
| 222 |
+
```bash
|
| 223 |
+
python ingest.py
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
**"Error extracting text"**
|
| 227 |
+
- Check if PDF is not corrupted
|
| 228 |
+
- Some PDFs may be image-based (need OCR)
|
| 229 |
+
|
| 230 |
+
**Slow ingestion**
|
| 231 |
+
- Embedding generation is batched (100 texts at a time)
|
| 232 |
+
- Large PDFs with many pages take longer
|
| 233 |
+
|
| 234 |
+
## Future Improvements
|
| 235 |
+
|
| 236 |
+
- [ ] Admin dashboard for document management
|
| 237 |
+
- [ ] Streaming responses
|
| 238 |
+
- [ ] Multi-collection support
|
| 239 |
+
- [ ] Document summaries
|
| 240 |
+
- [ ] Caching layer for frequent queries
|
| 241 |
+
- [ ] OCR support for scanned PDFs
|
| 242 |
+
- [ ] JWT authentication
|
| 243 |
+
|
app.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import hashlib
|
| 4 |
+
import uuid
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Request, Depends, Form
|
| 10 |
+
from typing import Optional
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from fastapi.security import APIKeyHeader
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
from pinecone import Pinecone
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
load_dotenv("rag/.env")
|
| 18 |
+
|
| 19 |
+
from rag.utils import (
|
| 20 |
+
get_gemini_client,
|
| 21 |
+
generate_query_embedding,
|
| 22 |
+
generate_answer
|
| 23 |
+
)
|
| 24 |
+
from rag.ingest import (
|
| 25 |
+
get_pinecone_client,
|
| 26 |
+
get_pinecone_index,
|
| 27 |
+
ingest_single_pdf,
|
| 28 |
+
PINECONE_INDEX,
|
| 29 |
+
DATA_DIR
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
API_KEY = os.environ.get("API_KEY")
|
| 33 |
+
RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "30"))
|
| 34 |
+
RATE_LIMIT_WINDOW = int(os.environ.get("RATE_LIMIT_WINDOW", "60"))
|
| 35 |
+
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
|
| 36 |
+
|
| 37 |
+
gemini_client = None
|
| 38 |
+
pinecone_index = None
|
| 39 |
+
rate_limit_store = defaultdict(list)
|
| 40 |
+
conversation_sessions = defaultdict(list)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_client_ip(request: Request) -> str:
|
| 44 |
+
forwarded = request.headers.get("X-Forwarded-For")
|
| 45 |
+
if forwarded:
|
| 46 |
+
return forwarded.split(",")[0].strip()
|
| 47 |
+
return request.client.host if request.client else "unknown"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def check_rate_limit(request: Request):
|
| 51 |
+
client_ip = get_client_ip(request)
|
| 52 |
+
now = time.time()
|
| 53 |
+
|
| 54 |
+
rate_limit_store[client_ip] = [
|
| 55 |
+
t for t in rate_limit_store[client_ip]
|
| 56 |
+
if now - t < RATE_LIMIT_WINDOW
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
if len(rate_limit_store[client_ip]) >= RATE_LIMIT_REQUESTS:
|
| 60 |
+
raise HTTPException(
|
| 61 |
+
status_code=429,
|
| 62 |
+
detail=f"Rate limit exceeded. Max {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
rate_limit_store[client_ip].append(now)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def verify_api_key(api_key: str = Depends(api_key_header)):
|
| 72 |
+
if API_KEY and api_key != API_KEY:
|
| 73 |
+
raise HTTPException(status_code=403, detail="Invalid API key")
|
| 74 |
+
return api_key
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@asynccontextmanager
|
| 78 |
+
async def lifespan(app: FastAPI):
|
| 79 |
+
global gemini_client, pinecone_index
|
| 80 |
+
|
| 81 |
+
print("Starting Nigerian Tax Law RAG API...")
|
| 82 |
+
|
| 83 |
+
if API_KEY:
|
| 84 |
+
print("API Key authentication enabled")
|
| 85 |
+
else:
|
| 86 |
+
print("Warning: No API_KEY set - API is unprotected")
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
gemini_client = get_gemini_client()
|
| 90 |
+
print("Gemini client initialized")
|
| 91 |
+
except ValueError as e:
|
| 92 |
+
print(f"Warning: {e}")
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
pinecone_index = get_pinecone_index()
|
| 96 |
+
stats = pinecone_index.describe_index_stats()
|
| 97 |
+
print(f"Pinecone initialized ({stats.total_vector_count} vectors)")
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Warning: Pinecone error: {e}")
|
| 100 |
+
|
| 101 |
+
yield
|
| 102 |
+
|
| 103 |
+
print("Shutting down RAG API...")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
app = FastAPI(
|
| 107 |
+
title="Nigerian Tax Law RAG API",
|
| 108 |
+
description="Query Nigerian tax laws and legal documents using AI-powered retrieval",
|
| 109 |
+
version="1.0.0",
|
| 110 |
+
lifespan=lifespan
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
app.add_middleware(
|
| 114 |
+
CORSMiddleware,
|
| 115 |
+
allow_origins=ALLOWED_ORIGINS,
|
| 116 |
+
allow_credentials=True,
|
| 117 |
+
allow_methods=["GET", "POST"],
|
| 118 |
+
allow_headers=["*"],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class AskResponse(BaseModel):
|
| 125 |
+
answer: str
|
| 126 |
+
sources: list[dict]
|
| 127 |
+
chunks_used: int
|
| 128 |
+
session_id: str
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class IngestResponse(BaseModel):
|
| 132 |
+
message: str
|
| 133 |
+
filename: str
|
| 134 |
+
chunks_added: int
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class StatsResponse(BaseModel):
|
| 138 |
+
total_vectors: int
|
| 139 |
+
dimension: int
|
| 140 |
+
index_name: str
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class HealthResponse(BaseModel):
|
| 144 |
+
status: str
|
| 145 |
+
gemini_connected: bool
|
| 146 |
+
pinecone_connected: bool
|
| 147 |
+
vectors_indexed: int
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@app.get("/", response_model=dict)
|
| 151 |
+
async def root():
|
| 152 |
+
return {
|
| 153 |
+
"name": "Nigerian Tax Law RAG API",
|
| 154 |
+
"version": "1.0.0",
|
| 155 |
+
"endpoints": {
|
| 156 |
+
"POST /ask": "Ask a question about Nigerian tax law",
|
| 157 |
+
"POST /ingest": "Upload and index a new PDF document",
|
| 158 |
+
"GET /stats": "Get database statistics",
|
| 159 |
+
"GET /health": "Health check"
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@app.get("/health", response_model=HealthResponse)
|
| 165 |
+
async def health_check():
|
| 166 |
+
gemini_ok = gemini_client is not None
|
| 167 |
+
pinecone_ok = pinecone_index is not None
|
| 168 |
+
vectors = 0
|
| 169 |
+
|
| 170 |
+
if pinecone_ok:
|
| 171 |
+
try:
|
| 172 |
+
stats = pinecone_index.describe_index_stats()
|
| 173 |
+
vectors = stats.total_vector_count
|
| 174 |
+
except:
|
| 175 |
+
pinecone_ok = False
|
| 176 |
+
|
| 177 |
+
return HealthResponse(
|
| 178 |
+
status="healthy" if (gemini_ok and pinecone_ok) else "degraded",
|
| 179 |
+
gemini_connected=gemini_ok,
|
| 180 |
+
pinecone_connected=pinecone_ok,
|
| 181 |
+
vectors_indexed=vectors
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@app.post("/ask", response_model=AskResponse)
|
| 186 |
+
async def ask_question(
|
| 187 |
+
req: Request,
|
| 188 |
+
question: str = Form(..., min_length=3, max_length=2000),
|
| 189 |
+
top_k: int = Form(default=5, ge=1, le=20),
|
| 190 |
+
model: str = Form(default="gemini-2.5-flash"),
|
| 191 |
+
session_id: Optional[str] = Form(default=None),
|
| 192 |
+
image: Optional[UploadFile] = File(default=None),
|
| 193 |
+
api_key: str = Depends(verify_api_key)
|
| 194 |
+
):
|
| 195 |
+
check_rate_limit(req)
|
| 196 |
+
|
| 197 |
+
if gemini_client is None:
|
| 198 |
+
raise HTTPException(
|
| 199 |
+
status_code=503,
|
| 200 |
+
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if pinecone_index is None:
|
| 204 |
+
raise HTTPException(status_code=503, detail="Pinecone not initialized.")
|
| 205 |
+
|
| 206 |
+
if not session_id:
|
| 207 |
+
session_id = str(uuid.uuid4())
|
| 208 |
+
|
| 209 |
+
image_data = None
|
| 210 |
+
image_mime_type = None
|
| 211 |
+
|
| 212 |
+
if image and image.filename:
|
| 213 |
+
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
| 214 |
+
if image.content_type not in allowed_types:
|
| 215 |
+
raise HTTPException(
|
| 216 |
+
status_code=400,
|
| 217 |
+
detail=f"Invalid image type. Allowed: {', '.join(allowed_types)}"
|
| 218 |
+
)
|
| 219 |
+
if image.size and image.size > 10 * 1024 * 1024:
|
| 220 |
+
raise HTTPException(status_code=400, detail="Image too large. Max 10MB.")
|
| 221 |
+
|
| 222 |
+
image_data = await image.read()
|
| 223 |
+
image_mime_type = image.content_type
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
query_embedding = generate_query_embedding(gemini_client, question)
|
| 227 |
+
except Exception as e:
|
| 228 |
+
raise HTTPException(status_code=500, detail=f"Error generating query embedding: {str(e)}")
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
results = pinecone_index.query(
|
| 232 |
+
vector=query_embedding,
|
| 233 |
+
top_k=top_k,
|
| 234 |
+
include_metadata=True
|
| 235 |
+
)
|
| 236 |
+
except Exception as e:
|
| 237 |
+
raise HTTPException(status_code=500, detail=f"Error querying Pinecone: {str(e)}")
|
| 238 |
+
|
| 239 |
+
if not results.matches:
|
| 240 |
+
conversation_sessions[session_id].append({"role": "user", "content": question})
|
| 241 |
+
conversation_sessions[session_id].append({"role": "assistant", "content": "I couldn't find any relevant information in the indexed documents."})
|
| 242 |
+
|
| 243 |
+
return AskResponse(
|
| 244 |
+
answer="I couldn't find any relevant information in the indexed documents.",
|
| 245 |
+
sources=[],
|
| 246 |
+
chunks_used=0,
|
| 247 |
+
session_id=session_id
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
context_parts = []
|
| 251 |
+
sources = []
|
| 252 |
+
|
| 253 |
+
for match in results.matches:
|
| 254 |
+
meta = match.metadata
|
| 255 |
+
source_name = meta.get("source", "Unknown")
|
| 256 |
+
chunk_idx = meta.get("chunk_index", 0)
|
| 257 |
+
text = meta.get("text", "")
|
| 258 |
+
|
| 259 |
+
context_parts.append(f"[Source: {source_name}, Chunk {chunk_idx + 1}]\n{text}")
|
| 260 |
+
sources.append({
|
| 261 |
+
"document": source_name,
|
| 262 |
+
"chunk_index": chunk_idx,
|
| 263 |
+
"relevance_score": round(match.score, 4)
|
| 264 |
+
})
|
| 265 |
+
|
| 266 |
+
context = "\n\n---\n\n".join(context_parts)
|
| 267 |
+
|
| 268 |
+
conversation_history = conversation_sessions.get(session_id, [])
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
answer = generate_answer(
|
| 272 |
+
gemini_client,
|
| 273 |
+
question,
|
| 274 |
+
context,
|
| 275 |
+
model=model,
|
| 276 |
+
image_data=image_data,
|
| 277 |
+
image_mime_type=image_mime_type,
|
| 278 |
+
conversation_history=conversation_history
|
| 279 |
+
)
|
| 280 |
+
except Exception as e:
|
| 281 |
+
error_msg = str(e)
|
| 282 |
+
if "overloaded" in error_msg.lower() or "503" in error_msg:
|
| 283 |
+
raise HTTPException(status_code=503, detail=error_msg)
|
| 284 |
+
raise HTTPException(status_code=500, detail=f"Error generating answer: {error_msg}")
|
| 285 |
+
|
| 286 |
+
conversation_sessions[session_id].append({"role": "user", "content": question})
|
| 287 |
+
conversation_sessions[session_id].append({"role": "assistant", "content": answer})
|
| 288 |
+
|
| 289 |
+
if len(conversation_sessions[session_id]) > 20:
|
| 290 |
+
conversation_sessions[session_id] = conversation_sessions[session_id][-20:]
|
| 291 |
+
|
| 292 |
+
return AskResponse(
|
| 293 |
+
answer=answer,
|
| 294 |
+
sources=sources,
|
| 295 |
+
chunks_used=len(results.matches),
|
| 296 |
+
session_id=session_id
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@app.post("/ingest", response_model=IngestResponse)
|
| 301 |
+
async def ingest_document(
|
| 302 |
+
req: Request,
|
| 303 |
+
file: UploadFile = File(...),
|
| 304 |
+
force: bool = False,
|
| 305 |
+
api_key: str = Depends(verify_api_key)
|
| 306 |
+
):
|
| 307 |
+
check_rate_limit(req)
|
| 308 |
+
|
| 309 |
+
if gemini_client is None:
|
| 310 |
+
raise HTTPException(
|
| 311 |
+
status_code=503,
|
| 312 |
+
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if pinecone_index is None:
|
| 316 |
+
raise HTTPException(status_code=503, detail="Pinecone not initialized.")
|
| 317 |
+
|
| 318 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 319 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 320 |
+
|
| 321 |
+
if file.size and file.size > 50 * 1024 * 1024:
|
| 322 |
+
raise HTTPException(status_code=400, detail="File too large. Max 50MB.")
|
| 323 |
+
|
| 324 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 325 |
+
safe_filename = "".join(c for c in file.filename if c.isalnum() or c in "._- ")
|
| 326 |
+
file_path = DATA_DIR / safe_filename
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
contents = await file.read()
|
| 330 |
+
with open(file_path, "wb") as f:
|
| 331 |
+
f.write(contents)
|
| 332 |
+
except Exception as e:
|
| 333 |
+
raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}")
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
chunks_added, _ = ingest_single_pdf(
|
| 337 |
+
file_path,
|
| 338 |
+
pinecone_index,
|
| 339 |
+
gemini_client,
|
| 340 |
+
force=force
|
| 341 |
+
)
|
| 342 |
+
except Exception as e:
|
| 343 |
+
raise HTTPException(status_code=500, detail=f"Error ingesting document: {str(e)}")
|
| 344 |
+
|
| 345 |
+
return IngestResponse(
|
| 346 |
+
message="Document ingested successfully" if chunks_added > 0 else "Document already exists",
|
| 347 |
+
filename=safe_filename,
|
| 348 |
+
chunks_added=chunks_added
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
@app.get("/stats", response_model=StatsResponse)
|
| 353 |
+
async def get_stats(api_key: str = Depends(verify_api_key)):
|
| 354 |
+
if pinecone_index is None:
|
| 355 |
+
raise HTTPException(status_code=503, detail="Pinecone not initialized.")
|
| 356 |
+
|
| 357 |
+
try:
|
| 358 |
+
stats = pinecone_index.describe_index_stats()
|
| 359 |
+
return StatsResponse(
|
| 360 |
+
total_vectors=stats.total_vector_count,
|
| 361 |
+
dimension=stats.dimension,
|
| 362 |
+
index_name=PINECONE_INDEX
|
| 363 |
+
)
|
| 364 |
+
except Exception as e:
|
| 365 |
+
raise HTTPException(status_code=500, detail=f"Error getting stats: {str(e)}")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
import uvicorn
|
| 370 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
dockerfile
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ARG HF_TOKEN
|
| 4 |
+
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 6 |
+
PYTHONUNBUFFERED=1 \
|
| 7 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 8 |
+
HF_TOKEN=${HF_TOKEN}
|
| 9 |
+
|
| 10 |
+
WORKDIR /code
|
| 11 |
+
|
| 12 |
+
# System Dependencies
|
| 13 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
build-essential \
|
| 15 |
+
git \
|
| 16 |
+
curl \
|
| 17 |
+
libopenblas-dev \
|
| 18 |
+
libomp-dev \
|
| 19 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
# Copy requirements and install Python dependencies
|
| 22 |
+
COPY requirements.txt .
|
| 23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 24 |
+
|
| 25 |
+
# Hugging Face dependencies
|
| 26 |
+
RUN pip install --no-cache-dir huggingface-hub sentencepiece
|
| 27 |
+
|
| 28 |
+
# Hugging Face cache environment
|
| 29 |
+
ENV HF_HOME=/data/huggingface \
|
| 30 |
+
HUGGINGFACE_HUB_CACHE=/data/huggingface \
|
| 31 |
+
HF_HUB_CACHE=/data/huggingface \
|
| 32 |
+
API_PORT=7860
|
| 33 |
+
|
| 34 |
+
# Create cache dir and set permissions
|
| 35 |
+
RUN mkdir -p /data/huggingface && chmod -R 777 /data
|
| 36 |
+
|
| 37 |
+
# Copy project files
|
| 38 |
+
COPY . .
|
| 39 |
+
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
rag/.env
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GEMINI_API_KEY=AIzaSyBEEUyeYjDewcGffhHSbtsHjuhngyi3Coo
|
| 2 |
+
PINECONE_API_KEY=pcsk_2BGUcf_CzBnGUF9jP7UTgL6Ned77DVj6zV75RBGyKfFMxVqzw36bAQAc6HiH1nwdMLBoYA
|
| 3 |
+
PINECONE_INDEX=sabitax
|
| 4 |
+
|
| 5 |
+
# Security
|
| 6 |
+
API_KEY=11e10c46685090a8a464f7c8a8f09cd519b69836935a2c8897b71472e2b74138
|
| 7 |
+
RATE_LIMIT_REQUESTS=30
|
| 8 |
+
RATE_LIMIT_WINDOW=60
|
| 9 |
+
ALLOWED_ORIGINS=*
|
rag/__init__.py
ADDED
|
File without changes
|
rag/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (135 Bytes). View file
|
|
|
rag/__pycache__/ingest.cpython-312.pyc
ADDED
|
Binary file (8.75 kB). View file
|
|
|
rag/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
rag/ingest.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from hashlib import md5
|
| 4 |
+
|
| 5 |
+
import pdfplumber
|
| 6 |
+
from pinecone import Pinecone
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from .utils import (
|
| 10 |
+
get_gemini_client,
|
| 11 |
+
chunk_text,
|
| 12 |
+
clean_text,
|
| 13 |
+
generate_batch_embeddings,
|
| 14 |
+
count_tokens
|
| 15 |
+
)
|
| 16 |
+
except ImportError:
|
| 17 |
+
from utils import (
|
| 18 |
+
get_gemini_client,
|
| 19 |
+
chunk_text,
|
| 20 |
+
clean_text,
|
| 21 |
+
generate_batch_embeddings,
|
| 22 |
+
count_tokens
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
DATA_DIR = Path(__file__).parent.parent / "docs"
|
| 27 |
+
PINECONE_INDEX = os.environ.get("PINECONE_INDEX", "sabitax")
|
| 28 |
+
CHUNK_SIZE = 500
|
| 29 |
+
CHUNK_OVERLAP = 50
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_pinecone_client():
|
| 33 |
+
api_key = os.environ.get("PINECONE_API_KEY")
|
| 34 |
+
if not api_key:
|
| 35 |
+
raise ValueError("PINECONE_API_KEY environment variable is not set.")
|
| 36 |
+
return Pinecone(api_key=api_key)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_pinecone_index(pc=None):
|
| 40 |
+
if pc is None:
|
| 41 |
+
pc = get_pinecone_client()
|
| 42 |
+
return pc.Index(PINECONE_INDEX)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def extract_text_from_pdf(pdf_path: Path) -> str:
|
| 46 |
+
text_parts = []
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 50 |
+
for page_num, page in enumerate(pdf.pages, 1):
|
| 51 |
+
page_text = page.extract_text()
|
| 52 |
+
if page_text:
|
| 53 |
+
text_parts.append(f"[Page {page_num}]\n{page_text}")
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f" Error extracting text from {pdf_path.name}: {e}")
|
| 56 |
+
return ""
|
| 57 |
+
|
| 58 |
+
full_text = "\n\n".join(text_parts)
|
| 59 |
+
return clean_text(full_text)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def generate_chunk_id(doc_name: str, chunk_index: int) -> str:
|
| 63 |
+
content = f"{doc_name}_{chunk_index}"
|
| 64 |
+
return md5(content.encode()).hexdigest()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def ingest_single_pdf(
|
| 68 |
+
pdf_path: Path,
|
| 69 |
+
index,
|
| 70 |
+
gemini_client,
|
| 71 |
+
force: bool = False
|
| 72 |
+
) -> tuple[int, int]:
|
| 73 |
+
doc_name = pdf_path.name
|
| 74 |
+
|
| 75 |
+
if not force:
|
| 76 |
+
test_id = generate_chunk_id(doc_name, 0)
|
| 77 |
+
result = index.fetch(ids=[test_id])
|
| 78 |
+
if result.vectors:
|
| 79 |
+
print(f" Skipping {doc_name} (already ingested)")
|
| 80 |
+
return 0, 1
|
| 81 |
+
|
| 82 |
+
print(f" Processing: {doc_name}")
|
| 83 |
+
|
| 84 |
+
text = extract_text_from_pdf(pdf_path)
|
| 85 |
+
if not text:
|
| 86 |
+
print(f" No text extracted from {doc_name}")
|
| 87 |
+
return 0, 0
|
| 88 |
+
|
| 89 |
+
total_tokens = count_tokens(text)
|
| 90 |
+
print(f" Extracted {total_tokens:,} tokens")
|
| 91 |
+
|
| 92 |
+
chunks = chunk_text(text, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
|
| 93 |
+
print(f" Created {len(chunks)} chunks")
|
| 94 |
+
|
| 95 |
+
if not chunks:
|
| 96 |
+
return 0, 0
|
| 97 |
+
|
| 98 |
+
print(f" Generating embeddings...")
|
| 99 |
+
embeddings = generate_batch_embeddings(gemini_client, chunks)
|
| 100 |
+
|
| 101 |
+
vectors = []
|
| 102 |
+
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
| 103 |
+
vectors.append({
|
| 104 |
+
"id": generate_chunk_id(doc_name, i),
|
| 105 |
+
"values": embedding,
|
| 106 |
+
"metadata": {
|
| 107 |
+
"source": doc_name,
|
| 108 |
+
"chunk_index": i,
|
| 109 |
+
"total_chunks": len(chunks),
|
| 110 |
+
"text": chunk[:1000]
|
| 111 |
+
}
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
batch_size = 100
|
| 115 |
+
for i in range(0, len(vectors), batch_size):
|
| 116 |
+
batch = vectors[i:i + batch_size]
|
| 117 |
+
index.upsert(vectors=batch)
|
| 118 |
+
|
| 119 |
+
print(f" Added {len(chunks)} chunks to Pinecone")
|
| 120 |
+
return len(chunks), 0
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def ingest_all_documents(data_dir: Path = DATA_DIR, force: bool = False):
|
| 124 |
+
print("\nStarting document ingestion pipeline\n")
|
| 125 |
+
print(f"Data directory: {data_dir}")
|
| 126 |
+
print(f"Pinecone index: {PINECONE_INDEX}\n")
|
| 127 |
+
|
| 128 |
+
pdf_files = list(data_dir.glob("*.pdf"))
|
| 129 |
+
|
| 130 |
+
if not pdf_files:
|
| 131 |
+
print(f"No PDF files found in {data_dir}")
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
print(f"Found {len(pdf_files)} PDF files\n")
|
| 135 |
+
|
| 136 |
+
print("Connecting to Gemini API...")
|
| 137 |
+
gemini_client = get_gemini_client()
|
| 138 |
+
|
| 139 |
+
print("Connecting to Pinecone...")
|
| 140 |
+
index = get_pinecone_index()
|
| 141 |
+
stats = index.describe_index_stats()
|
| 142 |
+
print(f"Current index size: {stats.total_vector_count} vectors\n")
|
| 143 |
+
print("-" * 60)
|
| 144 |
+
|
| 145 |
+
total_added = 0
|
| 146 |
+
total_skipped = 0
|
| 147 |
+
|
| 148 |
+
for pdf_path in sorted(pdf_files):
|
| 149 |
+
added, skipped = ingest_single_pdf(
|
| 150 |
+
pdf_path,
|
| 151 |
+
index,
|
| 152 |
+
gemini_client,
|
| 153 |
+
force=force
|
| 154 |
+
)
|
| 155 |
+
total_added += added
|
| 156 |
+
total_skipped += skipped
|
| 157 |
+
|
| 158 |
+
print("-" * 60)
|
| 159 |
+
stats = index.describe_index_stats()
|
| 160 |
+
print(f"\nIngestion complete!")
|
| 161 |
+
print(f" Chunks added: {total_added}")
|
| 162 |
+
print(f" Documents skipped: {total_skipped}")
|
| 163 |
+
print(f" Total index size: {stats.total_vector_count} vectors\n")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def clear_index():
|
| 167 |
+
print("Clearing Pinecone index...")
|
| 168 |
+
try:
|
| 169 |
+
index = get_pinecone_index()
|
| 170 |
+
index.delete(delete_all=True)
|
| 171 |
+
print("Index cleared successfully")
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"Error clearing index: {e}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def show_stats():
|
| 177 |
+
print("\nPinecone Index Statistics\n")
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
index = get_pinecone_index()
|
| 181 |
+
stats = index.describe_index_stats()
|
| 182 |
+
print(f" Index: {PINECONE_INDEX}")
|
| 183 |
+
print(f" Total vectors: {stats.total_vector_count}")
|
| 184 |
+
print(f" Dimensions: {stats.dimension}")
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f" Error: {e}")
|
| 187 |
+
|
| 188 |
+
print()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
import argparse
|
| 193 |
+
from dotenv import load_dotenv
|
| 194 |
+
load_dotenv()
|
| 195 |
+
|
| 196 |
+
parser = argparse.ArgumentParser(description="Ingest PDF documents into Pinecone for RAG")
|
| 197 |
+
parser.add_argument("--force", "-f", action="store_true")
|
| 198 |
+
parser.add_argument("--clear", action="store_true")
|
| 199 |
+
parser.add_argument("--stats", action="store_true")
|
| 200 |
+
parser.add_argument("--data-dir", type=Path, default=DATA_DIR)
|
| 201 |
+
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
if args.stats:
|
| 205 |
+
show_stats()
|
| 206 |
+
elif args.clear:
|
| 207 |
+
clear_index()
|
| 208 |
+
if not args.stats:
|
| 209 |
+
ingest_all_documents(data_dir=args.data_dir, force=True)
|
| 210 |
+
else:
|
| 211 |
+
ingest_all_documents(data_dir=args.data_dir, force=args.force)
|
rag/main.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, BackgroundTasks
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
import chromadb
|
| 11 |
+
|
| 12 |
+
from utils import (
|
| 13 |
+
get_gemini_client,
|
| 14 |
+
generate_query_embedding,
|
| 15 |
+
generate_answer
|
| 16 |
+
)
|
| 17 |
+
from ingest import (
|
| 18 |
+
get_chroma_client,
|
| 19 |
+
get_or_create_collection,
|
| 20 |
+
ingest_single_pdf,
|
| 21 |
+
COLLECTION_NAME,
|
| 22 |
+
DATA_DIR
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
gemini_client = None
|
| 27 |
+
chroma_collection = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@asynccontextmanager
|
| 31 |
+
async def lifespan(app: FastAPI):
|
| 32 |
+
global gemini_client, chroma_collection
|
| 33 |
+
|
| 34 |
+
print("Starting Nigerian Tax Law RAG API...")
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
gemini_client = get_gemini_client()
|
| 38 |
+
print("Gemini client initialized")
|
| 39 |
+
except ValueError as e:
|
| 40 |
+
print(f"Warning: {e}")
|
| 41 |
+
print("The API will not work until GEMINI_API_KEY is set.")
|
| 42 |
+
|
| 43 |
+
chroma_client = get_chroma_client()
|
| 44 |
+
chroma_collection = get_or_create_collection(chroma_client)
|
| 45 |
+
print(f"ChromaDB initialized ({chroma_collection.count()} chunks indexed)")
|
| 46 |
+
|
| 47 |
+
yield
|
| 48 |
+
|
| 49 |
+
print("Shutting down RAG API...")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
app = FastAPI(
|
| 53 |
+
title="Nigerian Tax Law RAG API",
|
| 54 |
+
description="Query Nigerian tax laws and legal documents using AI-powered retrieval",
|
| 55 |
+
version="1.0.0",
|
| 56 |
+
lifespan=lifespan
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
app.add_middleware(
|
| 60 |
+
CORSMiddleware,
|
| 61 |
+
allow_origins=["*"],
|
| 62 |
+
allow_credentials=True,
|
| 63 |
+
allow_methods=["*"],
|
| 64 |
+
allow_headers=["*"],
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AskRequest(BaseModel):
|
| 69 |
+
question: str = Field(..., min_length=3, max_length=2000)
|
| 70 |
+
top_k: int = Field(default=5, ge=1, le=20)
|
| 71 |
+
model: str = Field(default="gemini-2.0-flash")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class AskResponse(BaseModel):
|
| 75 |
+
answer: str
|
| 76 |
+
sources: list[dict]
|
| 77 |
+
chunks_used: int
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class IngestResponse(BaseModel):
|
| 81 |
+
message: str
|
| 82 |
+
filename: str
|
| 83 |
+
chunks_added: int
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class StatsResponse(BaseModel):
|
| 87 |
+
total_chunks: int
|
| 88 |
+
total_documents: int
|
| 89 |
+
documents: list[dict]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class HealthResponse(BaseModel):
|
| 93 |
+
status: str
|
| 94 |
+
gemini_connected: bool
|
| 95 |
+
chroma_connected: bool
|
| 96 |
+
chunks_indexed: int
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@app.get("/", response_model=dict)
|
| 100 |
+
async def root():
|
| 101 |
+
return {
|
| 102 |
+
"name": "Nigerian Tax Law RAG API",
|
| 103 |
+
"version": "1.0.0",
|
| 104 |
+
"endpoints": {
|
| 105 |
+
"POST /ask": "Ask a question about Nigerian tax law",
|
| 106 |
+
"POST /ingest": "Upload and index a new PDF document",
|
| 107 |
+
"GET /stats": "Get database statistics",
|
| 108 |
+
"GET /health": "Health check"
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@app.get("/health", response_model=HealthResponse)
|
| 114 |
+
async def health_check():
|
| 115 |
+
gemini_ok = gemini_client is not None
|
| 116 |
+
chroma_ok = chroma_collection is not None
|
| 117 |
+
chunks = chroma_collection.count() if chroma_ok else 0
|
| 118 |
+
|
| 119 |
+
return HealthResponse(
|
| 120 |
+
status="healthy" if (gemini_ok and chroma_ok) else "degraded",
|
| 121 |
+
gemini_connected=gemini_ok,
|
| 122 |
+
chroma_connected=chroma_ok,
|
| 123 |
+
chunks_indexed=chunks
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@app.post("/ask", response_model=AskResponse)
|
| 128 |
+
async def ask_question(request: AskRequest):
|
| 129 |
+
if gemini_client is None:
|
| 130 |
+
raise HTTPException(
|
| 131 |
+
status_code=503,
|
| 132 |
+
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if chroma_collection is None:
|
| 136 |
+
raise HTTPException(status_code=503, detail="Vector database not initialized.")
|
| 137 |
+
|
| 138 |
+
if chroma_collection.count() == 0:
|
| 139 |
+
raise HTTPException(
|
| 140 |
+
status_code=404,
|
| 141 |
+
detail="No documents indexed. Please ingest documents first using: python ingest.py"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
query_embedding = generate_query_embedding(gemini_client, request.question)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
raise HTTPException(status_code=500, detail=f"Error generating query embedding: {str(e)}")
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
results = chroma_collection.query(
|
| 151 |
+
query_embeddings=[query_embedding],
|
| 152 |
+
n_results=request.top_k,
|
| 153 |
+
include=["documents", "metadatas", "distances"]
|
| 154 |
+
)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
raise HTTPException(status_code=500, detail=f"Error querying vector database: {str(e)}")
|
| 157 |
+
|
| 158 |
+
documents = results["documents"][0] if results["documents"] else []
|
| 159 |
+
metadatas = results["metadatas"][0] if results["metadatas"] else []
|
| 160 |
+
distances = results["distances"][0] if results["distances"] else []
|
| 161 |
+
|
| 162 |
+
if not documents:
|
| 163 |
+
return AskResponse(
|
| 164 |
+
answer="I couldn't find any relevant information in the indexed documents.",
|
| 165 |
+
sources=[],
|
| 166 |
+
chunks_used=0
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
context_parts = []
|
| 170 |
+
sources = []
|
| 171 |
+
|
| 172 |
+
for i, (doc, meta, dist) in enumerate(zip(documents, metadatas, distances)):
|
| 173 |
+
source_name = meta.get("source", "Unknown")
|
| 174 |
+
chunk_idx = meta.get("chunk_index", 0)
|
| 175 |
+
|
| 176 |
+
context_parts.append(f"[Source: {source_name}, Chunk {chunk_idx + 1}]\n{doc}")
|
| 177 |
+
sources.append({
|
| 178 |
+
"document": source_name,
|
| 179 |
+
"chunk_index": chunk_idx,
|
| 180 |
+
"relevance_score": round(1 - dist, 4)
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
context = "\n\n---\n\n".join(context_parts)
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
answer = generate_answer(
|
| 187 |
+
gemini_client,
|
| 188 |
+
request.question,
|
| 189 |
+
context,
|
| 190 |
+
model=request.model
|
| 191 |
+
)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
|
| 194 |
+
|
| 195 |
+
return AskResponse(
|
| 196 |
+
answer=answer,
|
| 197 |
+
sources=sources,
|
| 198 |
+
chunks_used=len(documents)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@app.post("/ingest", response_model=IngestResponse)
|
| 203 |
+
async def ingest_document(file: UploadFile = File(...), force: bool = False):
|
| 204 |
+
if gemini_client is None:
|
| 205 |
+
raise HTTPException(
|
| 206 |
+
status_code=503,
|
| 207 |
+
detail="Gemini API not configured. Set GEMINI_API_KEY environment variable."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 211 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 212 |
+
|
| 213 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 214 |
+
file_path = DATA_DIR / file.filename
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
contents = await file.read()
|
| 218 |
+
with open(file_path, "wb") as f:
|
| 219 |
+
f.write(contents)
|
| 220 |
+
except Exception as e:
|
| 221 |
+
raise HTTPException(status_code=500, detail=f"Error saving file: {str(e)}")
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
chunks_added, _ = ingest_single_pdf(
|
| 225 |
+
file_path,
|
| 226 |
+
chroma_collection,
|
| 227 |
+
gemini_client,
|
| 228 |
+
force=force
|
| 229 |
+
)
|
| 230 |
+
except Exception as e:
|
| 231 |
+
raise HTTPException(status_code=500, detail=f"Error ingesting document: {str(e)}")
|
| 232 |
+
|
| 233 |
+
return IngestResponse(
|
| 234 |
+
message="Document ingested successfully" if chunks_added > 0 else "Document already exists",
|
| 235 |
+
filename=file.filename,
|
| 236 |
+
chunks_added=chunks_added
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@app.get("/stats", response_model=StatsResponse)
|
| 241 |
+
async def get_stats():
|
| 242 |
+
if chroma_collection is None:
|
| 243 |
+
raise HTTPException(status_code=503, detail="Vector database not initialized.")
|
| 244 |
+
|
| 245 |
+
count = chroma_collection.count()
|
| 246 |
+
|
| 247 |
+
if count == 0:
|
| 248 |
+
return StatsResponse(total_chunks=0, total_documents=0, documents=[])
|
| 249 |
+
|
| 250 |
+
results = chroma_collection.get(limit=count, include=["metadatas"])
|
| 251 |
+
|
| 252 |
+
doc_chunks = {}
|
| 253 |
+
for meta in results["metadatas"]:
|
| 254 |
+
if meta:
|
| 255 |
+
source = meta.get("source", "Unknown")
|
| 256 |
+
doc_chunks[source] = doc_chunks.get(source, 0) + 1
|
| 257 |
+
|
| 258 |
+
documents = [
|
| 259 |
+
{"name": name, "chunks": chunks}
|
| 260 |
+
for name, chunks in sorted(doc_chunks.items())
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
return StatsResponse(
|
| 264 |
+
total_chunks=count,
|
| 265 |
+
total_documents=len(doc_chunks),
|
| 266 |
+
documents=documents
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@app.delete("/documents/{document_name}")
|
| 271 |
+
async def delete_document(document_name: str):
|
| 272 |
+
if chroma_collection is None:
|
| 273 |
+
raise HTTPException(status_code=503, detail="Vector database not initialized.")
|
| 274 |
+
|
| 275 |
+
results = chroma_collection.get(
|
| 276 |
+
where={"source": document_name},
|
| 277 |
+
include=["metadatas"]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if not results["ids"]:
|
| 281 |
+
raise HTTPException(status_code=404, detail=f"Document '{document_name}' not found in index.")
|
| 282 |
+
|
| 283 |
+
chroma_collection.delete(ids=results["ids"])
|
| 284 |
+
|
| 285 |
+
return {
|
| 286 |
+
"message": f"Document '{document_name}' deleted successfully",
|
| 287 |
+
"chunks_deleted": len(results["ids"])
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
import uvicorn
|
| 293 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
rag/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
python-multipart
|
| 4 |
+
pdfplumber
|
| 5 |
+
pinecone
|
| 6 |
+
tiktoken
|
| 7 |
+
google-genai
|
| 8 |
+
pydantic
|
| 9 |
+
python-dotenv
|
| 10 |
+
Pillow
|
rag/utils.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import io
|
| 4 |
+
import time
|
| 5 |
+
import tiktoken
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from google import genai
|
| 8 |
+
from google.genai import types
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_gemini_client():
|
| 15 |
+
api_key = os.environ.get("GEMINI_API_KEY")
|
| 16 |
+
if not api_key:
|
| 17 |
+
raise ValueError(
|
| 18 |
+
"GEMINI_API_KEY environment variable is not set. "
|
| 19 |
+
"Please set it with: export GEMINI_API_KEY='your-api-key'"
|
| 20 |
+
)
|
| 21 |
+
return genai.Client(api_key=api_key)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def count_tokens(text: str, model: str = "cl100k_base") -> int:
|
| 25 |
+
encoding = tiktoken.get_encoding(model)
|
| 26 |
+
return len(encoding.encode(text))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def chunk_text(
|
| 30 |
+
text: str,
|
| 31 |
+
chunk_size: int = 500,
|
| 32 |
+
chunk_overlap: int = 50,
|
| 33 |
+
encoding_name: str = "cl100k_base"
|
| 34 |
+
) -> list[str]:
|
| 35 |
+
encoding = tiktoken.get_encoding(encoding_name)
|
| 36 |
+
tokens = encoding.encode(text)
|
| 37 |
+
|
| 38 |
+
chunks = []
|
| 39 |
+
start = 0
|
| 40 |
+
|
| 41 |
+
while start < len(tokens):
|
| 42 |
+
end = start + chunk_size
|
| 43 |
+
chunk_tokens = tokens[start:end]
|
| 44 |
+
chunk_text = encoding.decode(chunk_tokens)
|
| 45 |
+
chunks.append(chunk_text)
|
| 46 |
+
start = end - chunk_overlap
|
| 47 |
+
|
| 48 |
+
if start <= 0 and len(chunks) > 0:
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
return chunks
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def generate_embedding(client: genai.Client, text: str) -> list[float]:
|
| 55 |
+
result = client.models.embed_content(
|
| 56 |
+
model="models/text-embedding-004",
|
| 57 |
+
contents=text,
|
| 58 |
+
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
| 59 |
+
)
|
| 60 |
+
return result.embeddings[0].values
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def generate_query_embedding(client: genai.Client, query: str) -> list[float]:
|
| 64 |
+
result = client.models.embed_content(
|
| 65 |
+
model="models/text-embedding-004",
|
| 66 |
+
contents=query,
|
| 67 |
+
config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY")
|
| 68 |
+
)
|
| 69 |
+
return result.embeddings[0].values
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def generate_batch_embeddings(
|
| 73 |
+
client: genai.Client,
|
| 74 |
+
texts: list[str],
|
| 75 |
+
batch_size: int = 100
|
| 76 |
+
) -> list[list[float]]:
|
| 77 |
+
all_embeddings = []
|
| 78 |
+
|
| 79 |
+
for i in range(0, len(texts), batch_size):
|
| 80 |
+
batch = texts[i:i + batch_size]
|
| 81 |
+
result = client.models.embed_content(
|
| 82 |
+
model="models/text-embedding-004",
|
| 83 |
+
contents=batch,
|
| 84 |
+
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
| 85 |
+
)
|
| 86 |
+
batch_embeddings = [emb.values for emb in result.embeddings]
|
| 87 |
+
all_embeddings.extend(batch_embeddings)
|
| 88 |
+
|
| 89 |
+
return all_embeddings
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def generate_answer(
|
| 93 |
+
client: genai.Client,
|
| 94 |
+
question: str,
|
| 95 |
+
context: str,
|
| 96 |
+
model: str = "gemini-2.5-flash",
|
| 97 |
+
image_data: bytes = None,
|
| 98 |
+
image_mime_type: str = None,
|
| 99 |
+
conversation_history: list = None
|
| 100 |
+
) -> str:
|
| 101 |
+
question_lower = question.lower().strip()
|
| 102 |
+
|
| 103 |
+
greetings = ["hello", "hi", "hey", "good morning", "good afternoon", "good evening", "greetings"]
|
| 104 |
+
is_greeting = any(question_lower.startswith(g) or question_lower == g for g in greetings)
|
| 105 |
+
|
| 106 |
+
if is_greeting:
|
| 107 |
+
prompt = f"""You are SabiTax, a friendly and conversational legal and tax expert assistant specializing in Nigerian law.
|
| 108 |
+
The user has greeted you. Respond naturally and warmly, like you're chatting with a friend. Introduce yourself as SabiTax in a casual, friendly way, and let them know you're here to help with any questions about Nigerian tax laws.
|
| 109 |
+
|
| 110 |
+
User: {question}
|
| 111 |
+
|
| 112 |
+
Respond conversationally - be warm, natural, and brief (2-3 sentences). Use a friendly, approachable tone."""
|
| 113 |
+
else:
|
| 114 |
+
name_questions = ["what is your name", "who are you", "what are you called", "what's your name", "tell me your name", "introduce yourself"]
|
| 115 |
+
is_name_question = any(q in question_lower for q in name_questions)
|
| 116 |
+
|
| 117 |
+
if is_name_question:
|
| 118 |
+
prompt = f"""You are SabiTax, a friendly and conversational legal and tax expert assistant specializing in Nigerian law and taxation.
|
| 119 |
+
|
| 120 |
+
User: {question}
|
| 121 |
+
|
| 122 |
+
Respond naturally and conversationally. Introduce yourself as SabiTax in a friendly, casual way. Explain that you help people understand Nigerian tax laws in simple terms, like you're explaining to a friend. Keep it brief, warm, and conversational."""
|
| 123 |
+
else:
|
| 124 |
+
history_text = ""
|
| 125 |
+
if conversation_history and len(conversation_history) > 0:
|
| 126 |
+
history_text = "\n\nPrevious conversation:\n"
|
| 127 |
+
for msg in conversation_history[-6:]:
|
| 128 |
+
role = "User" if msg["role"] == "user" else "You (SabiTax)"
|
| 129 |
+
history_text += f"{role}: {msg['content']}\n"
|
| 130 |
+
history_text += "\n"
|
| 131 |
+
|
| 132 |
+
prompt = f"""You are SabiTax, a friendly and conversational legal and tax expert assistant specializing in Nigerian law and taxation. You talk to users like you're having a natural conversation with a friend - warm, approachable, and easy to understand.
|
| 133 |
+
|
| 134 |
+
Your style:
|
| 135 |
+
- Talk naturally, like you're chatting over coffee
|
| 136 |
+
- Use "you" and "I" - make it personal and engaging
|
| 137 |
+
- Be warm and friendly, not robotic or formal
|
| 138 |
+
- Use everyday language and simple explanations
|
| 139 |
+
- Reference previous parts of the conversation when relevant: "As I mentioned earlier..." or "Building on what we discussed..."
|
| 140 |
+
- Ask follow-up questions if helpful: "Does that make sense?" or "Want me to explain that differently?"
|
| 141 |
+
- Show enthusiasm about helping: "Great question!" or "I'm happy to help with that!"
|
| 142 |
+
|
| 143 |
+
Your approach:
|
| 144 |
+
1. **Reason through the information**: Think about what the user really needs to know
|
| 145 |
+
2. **Break it down simply**: Translate complex legal stuff into everyday language
|
| 146 |
+
3. **Make it practical**: Focus on "what this means for you" and "what you need to do"
|
| 147 |
+
4. **Prioritize current info**: Always mention the most recent laws first (2025 over 2020, etc.) and note if something's been updated
|
| 148 |
+
5. **Continue the conversation**: If this is part of an ongoing discussion, naturally reference what was said before
|
| 149 |
+
|
| 150 |
+
Important rules:
|
| 151 |
+
- Answer based ONLY on the provided context from the documents
|
| 152 |
+
- Always prioritize the most recent/current legislation (e.g., 2025 acts over 2020 acts)
|
| 153 |
+
- If there's old info, mention it's been updated: "The old 2020 law has been replaced by the 2025 act..."
|
| 154 |
+
- Explain everything in simple terms - no legal jargon without explanation
|
| 155 |
+
- Use examples and analogies to make things clearer
|
| 156 |
+
- If you don't have enough info, say so honestly: "I don't have enough details on that, but here's what I know..."
|
| 157 |
+
- Keep it conversational - use short paragraphs, bullet points when helpful, but write like you're talking
|
| 158 |
+
- If the user is continuing a topic from earlier, acknowledge it and build on the previous conversation
|
| 159 |
+
|
| 160 |
+
{history_text}Context from documents:
|
| 161 |
+
{context}
|
| 162 |
+
|
| 163 |
+
Question: {question}
|
| 164 |
+
|
| 165 |
+
Respond naturally and conversationally. Explain things like you're helping a friend understand their taxes. Be clear, friendly, and focus on what they actually need to know. If this continues a previous topic, reference it naturally."""
|
| 166 |
+
|
| 167 |
+
if image_data:
|
| 168 |
+
img = Image.open(io.BytesIO(image_data))
|
| 169 |
+
contents = [prompt, img]
|
| 170 |
+
else:
|
| 171 |
+
contents = prompt
|
| 172 |
+
|
| 173 |
+
max_retries = 3
|
| 174 |
+
retry_delay = 2
|
| 175 |
+
|
| 176 |
+
for attempt in range(max_retries):
|
| 177 |
+
try:
|
| 178 |
+
response = client.models.generate_content(
|
| 179 |
+
model=model,
|
| 180 |
+
contents=contents
|
| 181 |
+
)
|
| 182 |
+
return response.text
|
| 183 |
+
except Exception as e:
|
| 184 |
+
error_str = str(e)
|
| 185 |
+
if "503" in error_str or "UNAVAILABLE" in error_str or "overloaded" in error_str.lower():
|
| 186 |
+
if attempt < max_retries - 1:
|
| 187 |
+
wait_time = retry_delay * (2 ** attempt)
|
| 188 |
+
time.sleep(wait_time)
|
| 189 |
+
continue
|
| 190 |
+
else:
|
| 191 |
+
raise Exception("Gemini service is temporarily overloaded. Please try again in a few moments.")
|
| 192 |
+
else:
|
| 193 |
+
raise e
|
| 194 |
+
|
| 195 |
+
raise Exception("Failed to generate answer after multiple attempts")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def clean_text(text: str) -> str:
|
| 199 |
+
text = text.encode('utf-8', errors='ignore').decode('utf-8')
|
| 200 |
+
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]', '', text)
|
| 201 |
+
|
| 202 |
+
text = re.sub(r'Page \d+ of \d+', '', text, flags=re.IGNORECASE)
|
| 203 |
+
text = re.sub(r'^\d+\s*$', '', text, flags=re.MULTILINE)
|
| 204 |
+
text = re.sub(r'^[-_=]{3,}$', '', text, flags=re.MULTILINE)
|
| 205 |
+
|
| 206 |
+
text = re.sub(r'\.{3,}', '...', text)
|
| 207 |
+
text = re.sub(r'_{2,}', ' ', text)
|
| 208 |
+
text = re.sub(r'-{3,}', ' - ', text)
|
| 209 |
+
|
| 210 |
+
text = re.sub(r'\t+', ' ', text)
|
| 211 |
+
text = re.sub(r' +', ' ', text)
|
| 212 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 213 |
+
|
| 214 |
+
text = re.sub(r'(\d+)\s*\.\s*(\d+)', r'\1.\2', text)
|
| 215 |
+
text = re.sub(r'([a-z])\s*-\s*([a-z])', r'\1\2', text)
|
| 216 |
+
|
| 217 |
+
lines = []
|
| 218 |
+
for line in text.split('\n'):
|
| 219 |
+
line = line.strip()
|
| 220 |
+
if len(line) > 2:
|
| 221 |
+
lines.append(line)
|
| 222 |
+
elif line == '':
|
| 223 |
+
lines.append(line)
|
| 224 |
+
text = '\n'.join(lines)
|
| 225 |
+
|
| 226 |
+
seen = set()
|
| 227 |
+
final_lines = []
|
| 228 |
+
for line in text.split('\n'):
|
| 229 |
+
line_lower = line.lower().strip()
|
| 230 |
+
if len(line_lower) < 50 and line_lower in seen:
|
| 231 |
+
continue
|
| 232 |
+
if len(line_lower) > 5:
|
| 233 |
+
seen.add(line_lower)
|
| 234 |
+
final_lines.append(line)
|
| 235 |
+
|
| 236 |
+
return '\n'.join(final_lines).strip()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
python-multipart
|
| 4 |
+
pdfplumber
|
| 5 |
+
pinecone
|
| 6 |
+
tiktoken
|
| 7 |
+
google-genai
|
| 8 |
+
pydantic
|
| 9 |
+
python-dotenv
|
| 10 |
+
Pillow
|