Spaces:
Sleeping
Sleeping
Auto deploy backend
Browse files- .github/workflows/ci.yml +54 -4
- README.md +196 -9
- hf_backend/.github/workflows/ci.yml +27 -0
- hf_backend/.gitignore +0 -0
- hf_backend/Dockerfile +18 -0
- hf_backend/Procfile +1 -0
- hf_backend/README.md +196 -0
- hf_backend/agent.py +141 -0
- hf_backend/config.py +26 -0
- hf_backend/ingestion.py +127 -0
- hf_backend/main.py +104 -0
- hf_backend/requirements.txt +17 -0
- hf_backend/retriever.py +81 -0
- hf_backend/runtime.txt +1 -0
- hf_backend/tests/__init__.py +0 -0
- hf_backend/tests/test_integration.py +51 -0
- hf_backend/tests/test_unit.py +119 -0
- pytest.ini +4 -0
- tests/test_api.py +12 -0
.github/workflows/ci.yml
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
name: RAG
|
| 2 |
|
| 3 |
on:
|
| 4 |
push:
|
|
@@ -21,7 +21,57 @@ jobs:
|
|
| 21 |
- name: Install dependencies
|
| 22 |
run: pip install -r requirements.txt
|
| 23 |
|
| 24 |
-
- name: Run unit tests only
|
| 25 |
env:
|
| 26 |
-
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
| 27 |
-
run: pytest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: RAG CI/CD
|
| 2 |
|
| 3 |
on:
|
| 4 |
push:
|
|
|
|
| 21 |
- name: Install dependencies
|
| 22 |
run: pip install -r requirements.txt
|
| 23 |
|
| 24 |
+
- name: Run unit tests only
|
| 25 |
env:
|
| 26 |
+
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
| 27 |
+
run: pytest -v -m "not integration"
|
| 28 |
+
|
| 29 |
+
# π DEPLOY BACKEND
|
| 30 |
+
- name: Deploy Backend to HF
|
| 31 |
+
env:
|
| 32 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 33 |
+
run: |
|
| 34 |
+
set -e
|
| 35 |
+
|
| 36 |
+
pip install huggingface_hub
|
| 37 |
+
sudo apt-get update
|
| 38 |
+
sudo apt-get install -y rsync
|
| 39 |
+
|
| 40 |
+
git config --global user.email "you@example.com"
|
| 41 |
+
git config --global user.name "github-actions"
|
| 42 |
+
|
| 43 |
+
# clone repo
|
| 44 |
+
git clone https://huggingface.co/spaces/Hitan2004/agentic-corrective-rag hf_backend
|
| 45 |
+
|
| 46 |
+
cd hf_backend
|
| 47 |
+
|
| 48 |
+
# π₯ FIXED AUTH (IMPORTANT)
|
| 49 |
+
git remote set-url origin https://user:${HF_TOKEN}@huggingface.co/spaces/Hitan2004/agentic-corrective-rag
|
| 50 |
+
|
| 51 |
+
# copy backend files (exclude UI + .git)
|
| 52 |
+
rsync -av --exclude='.git' --exclude='ui' ../ ./
|
| 53 |
+
|
| 54 |
+
git add .
|
| 55 |
+
git commit -m "Auto deploy backend" || echo "No changes to commit"
|
| 56 |
+
git push
|
| 57 |
+
|
| 58 |
+
# π¨ DEPLOY UI
|
| 59 |
+
- name: Deploy UI to HF
|
| 60 |
+
env:
|
| 61 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 62 |
+
run: |
|
| 63 |
+
set -e
|
| 64 |
+
|
| 65 |
+
git clone https://huggingface.co/spaces/Hitan2004/agentic-corrective-rag-ui hf_ui
|
| 66 |
+
|
| 67 |
+
cd hf_ui
|
| 68 |
+
|
| 69 |
+
# π₯ FIXED AUTH (IMPORTANT)
|
| 70 |
+
git remote set-url origin https://user:${HF_TOKEN}@huggingface.co/spaces/Hitan2004/agentic-corrective-rag-ui
|
| 71 |
+
|
| 72 |
+
# copy UI files only
|
| 73 |
+
rsync -av ../ui/ ./
|
| 74 |
+
|
| 75 |
+
git add .
|
| 76 |
+
git commit -m "Auto deploy UI" || echo "No changes to commit"
|
| 77 |
+
git push
|
README.md
CHANGED
|
@@ -1,9 +1,196 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Agentic Corrective RAG β Document Q&A
|
| 2 |
+
|
| 3 |
+
[](https://github.com/Hitan547/agentic-corrective-rag/actions)
|
| 4 |
+

|
| 5 |
+

|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
> A production-aware document Q&A system that answers questions **only from your uploaded documents** β not from the model's imagination. Built with hybrid retrieval, cross-encoder reranking, and a self-correcting LangGraph agent that automatically retries if the answer isn't grounded in the source material.
|
| 9 |
+
|
| 10 |
+
## π Live Demo
|
| 11 |
+
|
| 12 |
+
| Service | URL |
|
| 13 |
+
|---------|-----|
|
| 14 |
+
| π₯οΈ Frontend UI | [hitan2004-agentic-corrective-rag-ui.hf.space](https://hitan2004-agentic-corrective-rag-ui.hf.space) |
|
| 15 |
+
| βοΈ Backend API | [hitan2004-agentic-corrective-rag.hf.space](https://hitan2004-agentic-corrective-rag.hf.space) |
|
| 16 |
+
| π API Docs | [hitan2004-agentic-corrective-rag.hf.space/docs](https://hitan2004-agentic-corrective-rag.hf.space/docs) |
|
| 17 |
+
|
| 18 |
+
## What It Does
|
| 19 |
+
|
| 20 |
+
Upload any PDF or TXT file, ask a question, and get an answer backed by:
|
| 21 |
+
- The exact source chunks it used
|
| 22 |
+
- A validation verdict (PASS/FAIL)
|
| 23 |
+
- How many self-correction retries were needed
|
| 24 |
+
|
| 25 |
+
## Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
PDF/TXT Upload
|
| 29 |
+
β
|
| 30 |
+
βΌ
|
| 31 |
+
βββββββββββββββββββββββββββββββββββ
|
| 32 |
+
β Ingestion Pipeline β
|
| 33 |
+
β PyMuPDF β Chunking β Embeddingsβ
|
| 34 |
+
β FAISS Index + BM25 Index β
|
| 35 |
+
βββββββββββββββββββββββββββββββββββ
|
| 36 |
+
β
|
| 37 |
+
βΌ
|
| 38 |
+
βββββββββββββββββββββββββββββββββββ
|
| 39 |
+
β Hybrid Retrieval β
|
| 40 |
+
β FAISS (dense) + BM25 (sparse) β
|
| 41 |
+
β β RRF Fusion β
|
| 42 |
+
β β Cross-Encoder Reranking β
|
| 43 |
+
βββββββββββββββββββββββββββββββββββ
|
| 44 |
+
β
|
| 45 |
+
βΌ
|
| 46 |
+
βββββββββββββββββββββββββββββββββββ
|
| 47 |
+
β Corrective RAG Agent β
|
| 48 |
+
β LangGraph StateGraph β
|
| 49 |
+
β Generate β Validate β Retry β
|
| 50 |
+
β (up to 3 automatic retries) β
|
| 51 |
+
βββββββββββββββββββββββββββββββββββ
|
| 52 |
+
β
|
| 53 |
+
βΌ
|
| 54 |
+
Static HTML UI + FastAPI Backend
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Tech Stack
|
| 58 |
+
|
| 59 |
+
| Layer | Technology |
|
| 60 |
+
|-------|-----------|
|
| 61 |
+
| LLM | LLaMA 3.3 70B via Groq API |
|
| 62 |
+
| Agent Framework | LangGraph (StateGraph) |
|
| 63 |
+
| Dense Retrieval | FAISS + all-MiniLM-L6-v2 |
|
| 64 |
+
| Sparse Retrieval | BM25 (rank-bm25) |
|
| 65 |
+
| Reranker | cross-encoder/ms-marco-MiniLM-L-6-v2 |
|
| 66 |
+
| Fusion | Reciprocal Rank Fusion (RRF) |
|
| 67 |
+
| PDF Parsing | PyMuPDF (fitz) |
|
| 68 |
+
| Backend | FastAPI |
|
| 69 |
+
| Frontend | Static HTML/CSS/JS |
|
| 70 |
+
| Testing | pytest (unit + integration) |
|
| 71 |
+
| CI/CD | GitHub Actions |
|
| 72 |
+
| Deployment | Hugging Face Spaces (Docker) |
|
| 73 |
+
|
| 74 |
+
## Key Features
|
| 75 |
+
|
| 76 |
+
- **Hybrid Search** β combines FAISS semantic search and BM25 keyword search, fused with Reciprocal Rank Fusion (RRF)
|
| 77 |
+
- **Cross-Encoder Reranking** β re-scores top candidates by reading query + chunk together for higher precision
|
| 78 |
+
- **Self-Correcting Agent** β LangGraph pipeline automatically detects hallucinations and retries up to 3 times
|
| 79 |
+
- **Hallucination Validation** β a second LLM call checks every answer against the source context before returning it
|
| 80 |
+
- **Session Memory** β remembers last 5 turns of conversation per session
|
| 81 |
+
- **Synchronous Indexing** β reliable document ingestion that completes before returning a response
|
| 82 |
+
- **CI/CD** β unit tests run automatically on every push via GitHub Actions
|
| 83 |
+
|
| 84 |
+
## Project Structure
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
agentic-corrective-rag/
|
| 88 |
+
βββ agent.py # LangGraph corrective RAG agent
|
| 89 |
+
βββ retriever.py # Hybrid retrieval + RRF + reranking
|
| 90 |
+
βββ ingestion.py # PDF/TXT ingestion + FAISS/BM25 indexing
|
| 91 |
+
βββ main.py # FastAPI backend
|
| 92 |
+
βββ config.py # Configuration and constants
|
| 93 |
+
βββ requirements.txt
|
| 94 |
+
βββ Dockerfile # HF Spaces deployment
|
| 95 |
+
βββ ui/
|
| 96 |
+
β βββ index.html # Static HTML/JS frontend
|
| 97 |
+
βββ tests/
|
| 98 |
+
β βββ test_unit.py # Unit tests (CI)
|
| 99 |
+
β βββ test_integration.py # Integration tests (local only)
|
| 100 |
+
βββ .github/
|
| 101 |
+
βββ workflows/
|
| 102 |
+
βββ ci.yml # GitHub Actions CI pipeline
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Setup
|
| 106 |
+
|
| 107 |
+
### 1. Clone the repo
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
git clone https://github.com/Hitan547/agentic-corrective-rag.git
|
| 111 |
+
cd agentic-corrective-rag
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### 2. Install dependencies
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
pip install -r requirements.txt
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### 3. Set up environment
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
echo "GROQ_API_KEY=your_key_here" > .env
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
Get your free API key at [console.groq.com](https://console.groq.com)
|
| 127 |
+
|
| 128 |
+
### 4. Run the backend
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
uvicorn main:app --reload --port 8000
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### 5. Open the frontend
|
| 135 |
+
|
| 136 |
+
Open `ui/index.html` in your browser, or serve it locally:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python -m http.server 3000
|
| 140 |
+
# Visit http://localhost:3000/ui/index.html
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## Running Tests
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
# Unit tests (fast, no API needed)
|
| 147 |
+
python -m pytest tests/test_unit.py -v
|
| 148 |
+
|
| 149 |
+
# Integration tests (requires GROQ_API_KEY)
|
| 150 |
+
python -m pytest tests/test_integration.py -v -m integration
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## How the Agent Works
|
| 154 |
+
|
| 155 |
+
1. **Generate** β LLaMA 3.3 70B answers using only the retrieved chunks
|
| 156 |
+
2. **Validate** β a second LLM call checks if every claim is supported by the context
|
| 157 |
+
3. **Retry** β if validation fails, the agent retries with the failure reason as feedback
|
| 158 |
+
4. **Stop** β returns the answer after PASS or after 3 retries
|
| 159 |
+
|
| 160 |
+
## API Endpoints
|
| 161 |
+
|
| 162 |
+
| Method | Endpoint | Description |
|
| 163 |
+
|--------|----------|-------------|
|
| 164 |
+
| `GET` | `/` | Health check |
|
| 165 |
+
| `GET` | `/health` | Returns API status + index state |
|
| 166 |
+
| `POST` | `/upload` | Upload and index a PDF or TXT file |
|
| 167 |
+
| `POST` | `/query` | Ask a question, get a grounded answer |
|
| 168 |
+
| `DELETE` | `/session/{id}` | Clear conversation history |
|
| 169 |
+
| `GET` | `/docs` | Interactive Swagger UI |
|
| 170 |
+
|
| 171 |
+
## Environment Variables
|
| 172 |
+
|
| 173 |
+
| Variable | Required | Description |
|
| 174 |
+
|----------|----------|-------------|
|
| 175 |
+
| `GROQ_API_KEY` | β
Yes | Your Groq API key from console.groq.com |
|
| 176 |
+
|
| 177 |
+
## Known Limitations
|
| 178 |
+
|
| 179 |
+
- **No index persistence** β indexes are stored in-memory and reset on redeploy. Re-upload your document after each redeploy on free hosting.
|
| 180 |
+
- **Free tier cold starts** β HF Spaces free tier may take 30β60 seconds to wake up after inactivity.
|
| 181 |
+
- **Single document at a time** β uploading a new document replaces the previous index.
|
| 182 |
+
|
| 183 |
+
## Deployment
|
| 184 |
+
|
| 185 |
+
This project is deployed as two separate services on Hugging Face Spaces:
|
| 186 |
+
|
| 187 |
+
- **Backend** (`agentic-corrective-rag`) β FastAPI app running in a Docker container
|
| 188 |
+
- **Frontend** (`agentic-corrective-rag-ui`) β Static HTML/JS served via HF Static Space
|
| 189 |
+
|
| 190 |
+
## Author
|
| 191 |
+
|
| 192 |
+
**Hitan K** β Final-year CS undergraduate (AI specialization)
|
| 193 |
+
|
| 194 |
+
[](https://linkedin.com/in/hitan-k)
|
| 195 |
+
[](https://github.com/Hitan547)
|
| 196 |
+
[](https://huggingface.co/Hitan2004)
|
hf_backend/.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: RAG Unit Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [main]
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
test:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
|
| 13 |
+
steps:
|
| 14 |
+
- uses: actions/checkout@v4
|
| 15 |
+
|
| 16 |
+
- name: Set up Python
|
| 17 |
+
uses: actions/setup-python@v5
|
| 18 |
+
with:
|
| 19 |
+
python-version: "3.11"
|
| 20 |
+
|
| 21 |
+
- name: Install dependencies
|
| 22 |
+
run: pip install -r requirements.txt
|
| 23 |
+
|
| 24 |
+
- name: Run unit tests only # β integration tests are skipped here
|
| 25 |
+
env:
|
| 26 |
+
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} # add this in GitHub β Settings β Secrets
|
| 27 |
+
run: pytest tests/test_unit.py -v
|
hf_backend/.gitignore
ADDED
|
Binary file (116 Bytes). View file
|
|
|
hf_backend/Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
RUN mkdir -p docs indexes
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
hf_backend/Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: uvicorn main:app --host 0.0.0.0 --port $PORT
|
hf_backend/README.md
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Agentic Corrective RAG β Document Q&A
|
| 2 |
+
|
| 3 |
+
[](https://github.com/Hitan547/agentic-corrective-rag/actions)
|
| 4 |
+

|
| 5 |
+

|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
> A production-aware document Q&A system that answers questions **only from your uploaded documents** β not from the model's imagination. Built with hybrid retrieval, cross-encoder reranking, and a self-correcting LangGraph agent that automatically retries if the answer isn't grounded in the source material.
|
| 9 |
+
|
| 10 |
+
## π Live Demo
|
| 11 |
+
|
| 12 |
+
| Service | URL |
|
| 13 |
+
|---------|-----|
|
| 14 |
+
| π₯οΈ Frontend UI | [hitan2004-agentic-corrective-rag-ui.hf.space](https://hitan2004-agentic-corrective-rag-ui.hf.space) |
|
| 15 |
+
| βοΈ Backend API | [hitan2004-agentic-corrective-rag.hf.space](https://hitan2004-agentic-corrective-rag.hf.space) |
|
| 16 |
+
| π API Docs | [hitan2004-agentic-corrective-rag.hf.space/docs](https://hitan2004-agentic-corrective-rag.hf.space/docs) |
|
| 17 |
+
|
| 18 |
+
## What It Does
|
| 19 |
+
|
| 20 |
+
Upload any PDF or TXT file, ask a question, and get an answer backed by:
|
| 21 |
+
- The exact source chunks it used
|
| 22 |
+
- A validation verdict (PASS/FAIL)
|
| 23 |
+
- How many self-correction retries were needed
|
| 24 |
+
|
| 25 |
+
## Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
PDF/TXT Upload
|
| 29 |
+
β
|
| 30 |
+
βΌ
|
| 31 |
+
βββββββββββββββββββββββββββββββββββ
|
| 32 |
+
β Ingestion Pipeline β
|
| 33 |
+
β PyMuPDF β Chunking β Embeddingsβ
|
| 34 |
+
β FAISS Index + BM25 Index β
|
| 35 |
+
βββββββββββββββββββββββββββββββββββ
|
| 36 |
+
β
|
| 37 |
+
βΌ
|
| 38 |
+
βββββββββββββββββββββββββββββββββββ
|
| 39 |
+
β Hybrid Retrieval β
|
| 40 |
+
β FAISS (dense) + BM25 (sparse) β
|
| 41 |
+
β β RRF Fusion β
|
| 42 |
+
β β Cross-Encoder Reranking β
|
| 43 |
+
βββββββββββββββββββββββββββββββββββ
|
| 44 |
+
β
|
| 45 |
+
βΌ
|
| 46 |
+
βββββββββββββββββββββββββββββββββββ
|
| 47 |
+
β Corrective RAG Agent β
|
| 48 |
+
β LangGraph StateGraph β
|
| 49 |
+
β Generate β Validate β Retry β
|
| 50 |
+
β (up to 3 automatic retries) β
|
| 51 |
+
βββββββββββββββββββββββββββββββββββ
|
| 52 |
+
β
|
| 53 |
+
βΌ
|
| 54 |
+
Static HTML UI + FastAPI Backend
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Tech Stack
|
| 58 |
+
|
| 59 |
+
| Layer | Technology |
|
| 60 |
+
|-------|-----------|
|
| 61 |
+
| LLM | LLaMA 3.3 70B via Groq API |
|
| 62 |
+
| Agent Framework | LangGraph (StateGraph) |
|
| 63 |
+
| Dense Retrieval | FAISS + all-MiniLM-L6-v2 |
|
| 64 |
+
| Sparse Retrieval | BM25 (rank-bm25) |
|
| 65 |
+
| Reranker | cross-encoder/ms-marco-MiniLM-L-6-v2 |
|
| 66 |
+
| Fusion | Reciprocal Rank Fusion (RRF) |
|
| 67 |
+
| PDF Parsing | PyMuPDF (fitz) |
|
| 68 |
+
| Backend | FastAPI |
|
| 69 |
+
| Frontend | Static HTML/CSS/JS |
|
| 70 |
+
| Testing | pytest (unit + integration) |
|
| 71 |
+
| CI/CD | GitHub Actions |
|
| 72 |
+
| Deployment | Hugging Face Spaces (Docker) |
|
| 73 |
+
|
| 74 |
+
## Key Features
|
| 75 |
+
|
| 76 |
+
- **Hybrid Search** β combines FAISS semantic search and BM25 keyword search, fused with Reciprocal Rank Fusion (RRF)
|
| 77 |
+
- **Cross-Encoder Reranking** β re-scores top candidates by reading query + chunk together for higher precision
|
| 78 |
+
- **Self-Correcting Agent** β LangGraph pipeline automatically detects hallucinations and retries up to 3 times
|
| 79 |
+
- **Hallucination Validation** β a second LLM call checks every answer against the source context before returning it
|
| 80 |
+
- **Session Memory** β remembers last 5 turns of conversation per session
|
| 81 |
+
- **Synchronous Indexing** β reliable document ingestion that completes before returning a response
|
| 82 |
+
- **CI/CD** β unit tests run automatically on every push via GitHub Actions
|
| 83 |
+
|
| 84 |
+
## Project Structure
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
agentic-corrective-rag/
|
| 88 |
+
βββ agent.py # LangGraph corrective RAG agent
|
| 89 |
+
βββ retriever.py # Hybrid retrieval + RRF + reranking
|
| 90 |
+
βββ ingestion.py # PDF/TXT ingestion + FAISS/BM25 indexing
|
| 91 |
+
βββ main.py # FastAPI backend
|
| 92 |
+
βββ config.py # Configuration and constants
|
| 93 |
+
βββ requirements.txt
|
| 94 |
+
βββ Dockerfile # HF Spaces deployment
|
| 95 |
+
βββ ui/
|
| 96 |
+
β βββ index.html # Static HTML/JS frontend
|
| 97 |
+
βββ tests/
|
| 98 |
+
β βββ test_unit.py # Unit tests (CI)
|
| 99 |
+
β βββ test_integration.py # Integration tests (local only)
|
| 100 |
+
βββ .github/
|
| 101 |
+
βββ workflows/
|
| 102 |
+
βββ ci.yml # GitHub Actions CI pipeline
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Setup
|
| 106 |
+
|
| 107 |
+
### 1. Clone the repo
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
git clone https://github.com/Hitan547/agentic-corrective-rag.git
|
| 111 |
+
cd agentic-corrective-rag
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### 2. Install dependencies
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
pip install -r requirements.txt
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### 3. Set up environment
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
echo "GROQ_API_KEY=your_key_here" > .env
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
Get your free API key at [console.groq.com](https://console.groq.com)
|
| 127 |
+
|
| 128 |
+
### 4. Run the backend
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
uvicorn main:app --reload --port 8000
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### 5. Open the frontend
|
| 135 |
+
|
| 136 |
+
Open `ui/index.html` in your browser, or serve it locally:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python -m http.server 3000
|
| 140 |
+
# Visit http://localhost:3000/ui/index.html
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## Running Tests
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
# Unit tests (fast, no API needed)
|
| 147 |
+
python -m pytest tests/test_unit.py -v
|
| 148 |
+
|
| 149 |
+
# Integration tests (requires GROQ_API_KEY)
|
| 150 |
+
python -m pytest tests/test_integration.py -v -m integration
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## How the Agent Works
|
| 154 |
+
|
| 155 |
+
1. **Generate** β LLaMA 3.3 70B answers using only the retrieved chunks
|
| 156 |
+
2. **Validate** β a second LLM call checks if every claim is supported by the context
|
| 157 |
+
3. **Retry** β if validation fails, the agent retries with the failure reason as feedback
|
| 158 |
+
4. **Stop** β returns the answer after PASS or after 3 retries
|
| 159 |
+
|
| 160 |
+
## API Endpoints
|
| 161 |
+
|
| 162 |
+
| Method | Endpoint | Description |
|
| 163 |
+
|--------|----------|-------------|
|
| 164 |
+
| `GET` | `/` | Health check |
|
| 165 |
+
| `GET` | `/health` | Returns API status + index state |
|
| 166 |
+
| `POST` | `/upload` | Upload and index a PDF or TXT file |
|
| 167 |
+
| `POST` | `/query` | Ask a question, get a grounded answer |
|
| 168 |
+
| `DELETE` | `/session/{id}` | Clear conversation history |
|
| 169 |
+
| `GET` | `/docs` | Interactive Swagger UI |
|
| 170 |
+
|
| 171 |
+
## Environment Variables
|
| 172 |
+
|
| 173 |
+
| Variable | Required | Description |
|
| 174 |
+
|----------|----------|-------------|
|
| 175 |
+
| `GROQ_API_KEY` | β
Yes | Your Groq API key from console.groq.com |
|
| 176 |
+
|
| 177 |
+
## Known Limitations
|
| 178 |
+
|
| 179 |
+
- **No index persistence** β indexes are stored in-memory and reset on redeploy. Re-upload your document after each redeploy on free hosting.
|
| 180 |
+
- **Free tier cold starts** β HF Spaces free tier may take 30β60 seconds to wake up after inactivity.
|
| 181 |
+
- **Single document at a time** β uploading a new document replaces the previous index.
|
| 182 |
+
|
| 183 |
+
## Deployment
|
| 184 |
+
|
| 185 |
+
This project is deployed as two separate services on Hugging Face Spaces:
|
| 186 |
+
|
| 187 |
+
- **Backend** (`agentic-corrective-rag`) β FastAPI app running in a Docker container
|
| 188 |
+
- **Frontend** (`agentic-corrective-rag-ui`) β Static HTML/JS served via HF Static Space
|
| 189 |
+
|
| 190 |
+
## Author
|
| 191 |
+
|
| 192 |
+
**Hitan K** β Final-year CS undergraduate (AI specialization)
|
| 193 |
+
|
| 194 |
+
[](https://linkedin.com/in/hitan-k)
|
| 195 |
+
[](https://github.com/Hitan547)
|
| 196 |
+
[](https://huggingface.co/Hitan2004)
|
hf_backend/agent.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#agent.py
|
| 2 |
+
from typing import TypedDict
|
| 3 |
+
from langgraph.graph import StateGraph, END
|
| 4 |
+
from langchain_groq import ChatGroq
|
| 5 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 6 |
+
from config import GROQ_API_KEY, GROQ_MODEL, MAX_RETRIES
|
| 7 |
+
|
| 8 |
+
llm = ChatGroq(
|
| 9 |
+
model=GROQ_MODEL,
|
| 10 |
+
temperature=0,
|
| 11 |
+
api_key=GROQ_API_KEY,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RAGState(TypedDict):
|
| 16 |
+
question: str
|
| 17 |
+
context_chunks: list
|
| 18 |
+
answer: str
|
| 19 |
+
validation_result: str
|
| 20 |
+
fail_reason: str
|
| 21 |
+
retry_count: int
|
| 22 |
+
chat_history: list
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def generate_node(state: RAGState) -> dict:
|
| 26 |
+
context_text = "\n\n---\n\n".join(
|
| 27 |
+
f"[Source: {r['source']}]\n{r['chunk']}"
|
| 28 |
+
for r in state["context_chunks"]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
history_lines = []
|
| 32 |
+
for msg in state.get("chat_history", [])[-6:]:
|
| 33 |
+
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 34 |
+
history_lines.append(f"{role}: {msg.content}")
|
| 35 |
+
history_text = "\n".join(history_lines) or "None"
|
| 36 |
+
|
| 37 |
+
correction = ""
|
| 38 |
+
if state.get("retry_count", 0) > 0:
|
| 39 |
+
correction = (
|
| 40 |
+
f"\n\nIMPORTANT CORRECTION REQUIRED: Your previous answer was "
|
| 41 |
+
f"rejected because: {state.get('fail_reason', 'unverifiable claims')}. "
|
| 42 |
+
f"Re-answer using ONLY the context provided."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
prompt = (
|
| 46 |
+
"You are an AI assistant that answers questions AND generates content based on provided documents.\n"
|
| 47 |
+
"Answer ONLY using information from the CONTEXT below.\n"
|
| 48 |
+
"If the answer cannot be found, say exactly: "
|
| 49 |
+
'"I don\'t have enough information in the provided documents."\n'
|
| 50 |
+
"Do NOT invent facts or use outside knowledge."
|
| 51 |
+
+ correction
|
| 52 |
+
+ f"\n\nPREVIOUS CONVERSATION:\n{history_text}"
|
| 53 |
+
+ f"\n\nCONTEXT:\n{context_text}"
|
| 54 |
+
+ f"\n\nQUESTION: {state['question']}\n\nAnswer:"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
response = llm.invoke([HumanMessage(content=prompt)])
|
| 58 |
+
return {"answer": response.content}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def validate_node(state: RAGState) -> dict:
|
| 62 |
+
context_text = "\n\n".join(r["chunk"] for r in state["context_chunks"])
|
| 63 |
+
|
| 64 |
+
prompt = (
|
| 65 |
+
"You are a strict hallucination checker for a RAG system.\n\n"
|
| 66 |
+
"Given the CONTEXT and the ANSWER below, check:\n"
|
| 67 |
+
"1. Is every factual claim directly supported by the context?\n"
|
| 68 |
+
"2. Does the answer address the question?\n"
|
| 69 |
+
"3. Are there any invented facts not in the context?\n\n"
|
| 70 |
+
f"Context:\n{context_text}\n\n"
|
| 71 |
+
f"Question: {state['question']}\n"
|
| 72 |
+
f"Answer: {state['answer']}\n\n"
|
| 73 |
+
"Respond in EXACTLY this format:\n"
|
| 74 |
+
"VERDICT: PASS\n"
|
| 75 |
+
"REASON: <one sentence>\n\n"
|
| 76 |
+
"or\n\n"
|
| 77 |
+
"VERDICT: FAIL\n"
|
| 78 |
+
"REASON: <one sentence explaining what is wrong>"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
result = llm.invoke([HumanMessage(content=prompt)])
|
| 82 |
+
text = result.content.strip()
|
| 83 |
+
|
| 84 |
+
verdict = "PASS" if "VERDICT: PASS" in text.upper() else "FAIL"
|
| 85 |
+
reason = ""
|
| 86 |
+
for line in text.splitlines():
|
| 87 |
+
if line.upper().startswith("REASON:"):
|
| 88 |
+
reason = line.split(":", 1)[1].strip()
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
return {"validation_result": verdict, "fail_reason": reason}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def increment_retry_node(state: RAGState) -> dict:
|
| 95 |
+
return {"retry_count": state.get("retry_count", 0) + 1}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def route_after_validation(state: RAGState) -> str:
|
| 99 |
+
if (
|
| 100 |
+
state["validation_result"] == "FAIL"
|
| 101 |
+
and state.get("retry_count", 0) < MAX_RETRIES
|
| 102 |
+
):
|
| 103 |
+
return "retry"
|
| 104 |
+
return "done"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _build_graph():
|
| 108 |
+
g = StateGraph(RAGState)
|
| 109 |
+
g.add_node("generate", generate_node)
|
| 110 |
+
g.add_node("validate", validate_node)
|
| 111 |
+
g.add_node("increment_retry", increment_retry_node)
|
| 112 |
+
g.set_entry_point("generate")
|
| 113 |
+
g.add_edge("generate", "validate")
|
| 114 |
+
g.add_conditional_edges(
|
| 115 |
+
"validate",
|
| 116 |
+
route_after_validation,
|
| 117 |
+
{"retry": "increment_retry", "done": END},
|
| 118 |
+
)
|
| 119 |
+
g.add_edge("increment_retry", "generate")
|
| 120 |
+
return g.compile()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
_rag_graph = _build_graph()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def run_rag_agent(
|
| 127 |
+
question: str,
|
| 128 |
+
context_chunks: list,
|
| 129 |
+
chat_history: list = [],
|
| 130 |
+
) -> tuple:
|
| 131 |
+
init_state: RAGState = {
|
| 132 |
+
"question": question,
|
| 133 |
+
"context_chunks": context_chunks,
|
| 134 |
+
"answer": "",
|
| 135 |
+
"validation_result": "",
|
| 136 |
+
"fail_reason": "",
|
| 137 |
+
"retry_count": 0,
|
| 138 |
+
"chat_history": chat_history,
|
| 139 |
+
}
|
| 140 |
+
final = _rag_graph.invoke(init_state)
|
| 141 |
+
return final["answer"], final["retry_count"], final["validation_result"]
|
hf_backend/config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
|
| 8 |
+
if not GROQ_API_KEY:
|
| 9 |
+
warnings.warn("GROQ_API_KEY not set β LLM calls will fail")
|
| 10 |
+
|
| 11 |
+
# ββ Anchor all paths to the directory this file lives in ββ
|
| 12 |
+
_BASE = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
|
| 14 |
+
GROQ_MODEL = "llama-3.3-70b-versatile"
|
| 15 |
+
DOCS_DIR = os.path.join(_BASE, "docs")
|
| 16 |
+
FAISS_INDEX_PATH = os.path.join(_BASE, "faiss.index")
|
| 17 |
+
BM25_PATH = os.path.join(_BASE, "bm25.pkl")
|
| 18 |
+
CHUNKS_PATH = os.path.join(_BASE, "chunks.pkl")
|
| 19 |
+
SOURCES_PATH = os.path.join(_BASE, "sources.pkl")
|
| 20 |
+
EMBEDDER_NAME = "all-MiniLM-L6-v2"
|
| 21 |
+
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 22 |
+
CHUNK_SIZE = 500
|
| 23 |
+
CHUNK_OVERLAP = 50
|
| 24 |
+
TOP_K = 5
|
| 25 |
+
MAX_RETRIES = 3
|
| 26 |
+
MAX_HISTORY_TURNS = 5
|
hf_backend/ingestion.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ingestion.py
|
| 2 |
+
import os, pickle
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import numpy as np
|
| 5 |
+
import faiss
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from rank_bm25 import BM25Okapi
|
| 8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 9 |
+
from config import (
|
| 10 |
+
DOCS_DIR, FAISS_INDEX_PATH, BM25_PATH,
|
| 11 |
+
CHUNKS_PATH, SOURCES_PATH,
|
| 12 |
+
EMBEDDER_NAME, CHUNK_SIZE, CHUNK_OVERLAP
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def read_pdf_text(fpath):
|
| 17 |
+
import fitz # PyMuPDF
|
| 18 |
+
doc = fitz.open(fpath)
|
| 19 |
+
text = []
|
| 20 |
+
for page in doc:
|
| 21 |
+
text.append(page.get_text())
|
| 22 |
+
return "\n".join(text).strip()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def clean_text(text):
|
| 26 |
+
return " ".join(text.split())
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_documents():
|
| 30 |
+
docs, filenames = [], []
|
| 31 |
+
path = Path(DOCS_DIR)
|
| 32 |
+
path.mkdir(exist_ok=True)
|
| 33 |
+
|
| 34 |
+
for fpath in path.glob("*.txt"):
|
| 35 |
+
try:
|
| 36 |
+
text = clean_text(fpath.read_text(encoding="utf-8"))
|
| 37 |
+
docs.append(text)
|
| 38 |
+
filenames.append(fpath.name)
|
| 39 |
+
print(f" Loaded text: {fpath.name}")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f" Skipped {fpath.name}: {e}")
|
| 42 |
+
|
| 43 |
+
for fpath in path.glob("*.pdf"):
|
| 44 |
+
try:
|
| 45 |
+
text = clean_text(read_pdf_text(fpath))
|
| 46 |
+
if text:
|
| 47 |
+
docs.append(text)
|
| 48 |
+
filenames.append(fpath.name)
|
| 49 |
+
print(f" Loaded PDF: {fpath.name}")
|
| 50 |
+
else:
|
| 51 |
+
print(f" WARNING: {fpath.name} extracted empty text")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f" Skipped {fpath.name}: {e}")
|
| 54 |
+
|
| 55 |
+
if not docs:
|
| 56 |
+
raise FileNotFoundError(
|
| 57 |
+
f"No .txt or .pdf files found in '{DOCS_DIR}'. "
|
| 58 |
+
"Add at least one document and re-run."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
print(f"\nLoaded {len(docs)} document(s)")
|
| 62 |
+
return docs, filenames
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def semantic_chunk(docs, filenames):
|
| 66 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 67 |
+
chunk_size=CHUNK_SIZE,
|
| 68 |
+
chunk_overlap=CHUNK_OVERLAP,
|
| 69 |
+
separators=["\n\n", "\n", ". ", " "],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
all_chunks, all_sources = [], []
|
| 73 |
+
for doc, fname in zip(docs, filenames):
|
| 74 |
+
chunks = splitter.split_text(doc)
|
| 75 |
+
all_chunks.extend(chunks)
|
| 76 |
+
all_sources.extend([fname] * len(chunks))
|
| 77 |
+
|
| 78 |
+
print(f"Created {len(all_chunks)} chunks "
|
| 79 |
+
f"(avg {sum(len(c) for c in all_chunks)//len(all_chunks)} chars each)")
|
| 80 |
+
print("\n--- SAMPLE CHUNK ---")
|
| 81 |
+
print(all_chunks[0][:500])
|
| 82 |
+
print("--------------------\n")
|
| 83 |
+
|
| 84 |
+
return all_chunks, all_sources
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def build_indexes(chunks, model=None):
|
| 88 |
+
print("\nBuilding dense embeddings...")
|
| 89 |
+
if model is None:
|
| 90 |
+
model = SentenceTransformer(EMBEDDER_NAME)
|
| 91 |
+
embeddings = model.encode(chunks, show_progress_bar=True, batch_size=32)
|
| 92 |
+
embeddings = np.array(embeddings, dtype="float32")
|
| 93 |
+
faiss.normalize_L2(embeddings)
|
| 94 |
+
dim = embeddings.shape[1]
|
| 95 |
+
faiss_index = faiss.IndexFlatIP(dim)
|
| 96 |
+
faiss_index.add(embeddings)
|
| 97 |
+
print(f"FAISS index: {faiss_index.ntotal} vectors, dim={dim}")
|
| 98 |
+
tokenized = [c.lower().split() for c in chunks]
|
| 99 |
+
bm25_index = BM25Okapi(tokenized)
|
| 100 |
+
print("BM25 index: built")
|
| 101 |
+
return faiss_index, bm25_index
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def save_indexes(faiss_index, bm25_index, chunks, sources):
|
| 105 |
+
faiss.write_index(faiss_index, FAISS_INDEX_PATH)
|
| 106 |
+
|
| 107 |
+
with open(BM25_PATH, "wb") as f:
|
| 108 |
+
pickle.dump(bm25_index, f)
|
| 109 |
+
with open(CHUNKS_PATH, "wb") as f:
|
| 110 |
+
pickle.dump(chunks, f)
|
| 111 |
+
with open(SOURCES_PATH, "wb") as f:
|
| 112 |
+
pickle.dump(sources, f)
|
| 113 |
+
|
| 114 |
+
print("\nSaved indexes to disk.")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def run_ingestion(model=None):
|
| 118 |
+
print("=== Starting ingestion ===\n")
|
| 119 |
+
docs, filenames = load_documents()
|
| 120 |
+
chunks, sources = semantic_chunk(docs, filenames)
|
| 121 |
+
fi, bm25 = build_indexes(chunks, model=model)
|
| 122 |
+
save_indexes(fi, bm25, chunks, sources)
|
| 123 |
+
print("\n=== Ingestion complete ===")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
run_ingestion()
|
hf_backend/main.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 7 |
+
from retriever import load_indexes, reload_indexes, hybrid_retrieve, indexes_loaded as _indexes_loaded
|
| 8 |
+
from agent import run_rag_agent
|
| 9 |
+
from ingestion import run_ingestion
|
| 10 |
+
from config import DOCS_DIR, TOP_K, MAX_HISTORY_TURNS
|
| 11 |
+
|
| 12 |
+
sessions: dict = {}
|
| 13 |
+
|
| 14 |
+
@asynccontextmanager
|
| 15 |
+
async def lifespan(app: FastAPI):
|
| 16 |
+
try:
|
| 17 |
+
load_indexes()
|
| 18 |
+
except FileNotFoundError:
|
| 19 |
+
print("WARNING: No indexes found. Upload documents first.")
|
| 20 |
+
yield
|
| 21 |
+
|
| 22 |
+
app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan)
|
| 23 |
+
|
| 24 |
+
@app.get("/")
|
| 25 |
+
def home():
|
| 26 |
+
return {"message": "RAG API running π"}
|
| 27 |
+
|
| 28 |
+
class QueryRequest(BaseModel):
|
| 29 |
+
question: str
|
| 30 |
+
session_id: str = "default"
|
| 31 |
+
top_k: int = TOP_K
|
| 32 |
+
|
| 33 |
+
class QueryResponse(BaseModel):
|
| 34 |
+
answer: str
|
| 35 |
+
sources: list
|
| 36 |
+
retries_used: int
|
| 37 |
+
validation: str
|
| 38 |
+
session_id: str
|
| 39 |
+
|
| 40 |
+
@app.post("/query", response_model=QueryResponse)
|
| 41 |
+
async def query(req: QueryRequest):
|
| 42 |
+
if not _indexes_loaded():
|
| 43 |
+
try:
|
| 44 |
+
load_indexes()
|
| 45 |
+
except Exception:
|
| 46 |
+
pass
|
| 47 |
+
if not _indexes_loaded():
|
| 48 |
+
raise HTTPException(
|
| 49 |
+
status_code=503,
|
| 50 |
+
detail="Indexes not ready. Upload and index documents first."
|
| 51 |
+
)
|
| 52 |
+
results = hybrid_retrieve(req.question, top_k=req.top_k)
|
| 53 |
+
if not results:
|
| 54 |
+
raise HTTPException(status_code=404, detail="No relevant chunks found.")
|
| 55 |
+
history = sessions.get(req.session_id, [])
|
| 56 |
+
answer, retries, verdict = run_rag_agent(req.question, results, history)
|
| 57 |
+
history.append(HumanMessage(content=req.question))
|
| 58 |
+
history.append(AIMessage(content=answer))
|
| 59 |
+
sessions[req.session_id] = history[-(MAX_HISTORY_TURNS * 2):]
|
| 60 |
+
return QueryResponse(
|
| 61 |
+
answer=answer,
|
| 62 |
+
sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results],
|
| 63 |
+
retries_used=retries,
|
| 64 |
+
validation=verdict,
|
| 65 |
+
session_id=req.session_id,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@app.post("/upload")
|
| 69 |
+
async def upload(file: UploadFile = File(...)):
|
| 70 |
+
allowed = {".txt", ".pdf"}
|
| 71 |
+
ext = os.path.splitext(file.filename or "")[1].lower()
|
| 72 |
+
if ext not in allowed:
|
| 73 |
+
raise HTTPException(status_code=400, detail="Only .txt and .pdf files allowed.")
|
| 74 |
+
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 75 |
+
dest = os.path.join(DOCS_DIR, file.filename)
|
| 76 |
+
with open(dest, "wb") as f:
|
| 77 |
+
shutil.copyfileobj(file.file, f)
|
| 78 |
+
_reindex()
|
| 79 |
+
return {"status": "uploaded", "filename": file.filename,
|
| 80 |
+
"message": "Indexing complete."}
|
| 81 |
+
|
| 82 |
+
def _reindex():
|
| 83 |
+
try:
|
| 84 |
+
run_ingestion()
|
| 85 |
+
print("Ingestion done, reloading indexes...")
|
| 86 |
+
reload_indexes()
|
| 87 |
+
print(f"Re-indexing complete. Indexes loaded: {_indexes_loaded()}")
|
| 88 |
+
except Exception as e:
|
| 89 |
+
import traceback
|
| 90 |
+
print(f"Re-indexing failed: {e}")
|
| 91 |
+
traceback.print_exc()
|
| 92 |
+
|
| 93 |
+
@app.delete("/session/{session_id}")
|
| 94 |
+
def clear_session(session_id: str):
|
| 95 |
+
sessions.pop(session_id, None)
|
| 96 |
+
return {"status": "cleared", "session_id": session_id}
|
| 97 |
+
|
| 98 |
+
@app.get("/health")
|
| 99 |
+
def health():
|
| 100 |
+
return {"status": "ok", "indexes_loaded": _indexes_loaded()}
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
import uvicorn
|
| 104 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|
hf_backend/requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain==0.3.25
|
| 2 |
+
langchain-groq==0.3.2
|
| 3 |
+
langgraph==0.3.29
|
| 4 |
+
sentence-transformers==3.4.1
|
| 5 |
+
faiss-cpu==1.13.2
|
| 6 |
+
rank-bm25==0.2.2
|
| 7 |
+
fastapi==0.115.12
|
| 8 |
+
uvicorn==0.34.0
|
| 9 |
+
pymupdf==1.25.3
|
| 10 |
+
python-dotenv==1.1.0
|
| 11 |
+
numpy==1.26.4
|
| 12 |
+
requests==2.32.3
|
| 13 |
+
pydantic>=2.7
|
| 14 |
+
pydantic-core>=2.20.0
|
| 15 |
+
python-multipart==0.0.20
|
| 16 |
+
pytest==8.3.5
|
| 17 |
+
|
hf_backend/retriever.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import numpy as np
|
| 4 |
+
import faiss
|
| 5 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 6 |
+
from config import (
|
| 7 |
+
FAISS_INDEX_PATH, BM25_PATH, CHUNKS_PATH,
|
| 8 |
+
SOURCES_PATH, EMBEDDER_NAME, RERANKER_MODEL
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
_faiss_index = None
|
| 12 |
+
_bm25_index = None
|
| 13 |
+
_chunks = None
|
| 14 |
+
_sources = None
|
| 15 |
+
_model = None
|
| 16 |
+
_reranker = None
|
| 17 |
+
|
| 18 |
+
def indexes_loaded() -> bool:
|
| 19 |
+
return _faiss_index is not None
|
| 20 |
+
|
| 21 |
+
def load_indexes():
|
| 22 |
+
global _faiss_index, _bm25_index, _chunks, _sources, _model, _reranker
|
| 23 |
+
|
| 24 |
+
if not os.path.exists(FAISS_INDEX_PATH):
|
| 25 |
+
print("WARNING: No FAISS index found at startup. Upload documents to initialize.")
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
_faiss_index = faiss.read_index(FAISS_INDEX_PATH)
|
| 29 |
+
with open(BM25_PATH, "rb") as f: _bm25_index = pickle.load(f)
|
| 30 |
+
with open(CHUNKS_PATH, "rb") as f: _chunks = pickle.load(f)
|
| 31 |
+
with open(SOURCES_PATH, "rb") as f: _sources = pickle.load(f)
|
| 32 |
+
_model = SentenceTransformer(EMBEDDER_NAME)
|
| 33 |
+
_reranker = CrossEncoder(RERANKER_MODEL)
|
| 34 |
+
print(f"Indexes loaded: {_faiss_index.ntotal} vectors, {len(_chunks)} chunks")
|
| 35 |
+
|
| 36 |
+
def reload_indexes():
|
| 37 |
+
global _faiss_index, _bm25_index, _chunks, _sources, _model, _reranker
|
| 38 |
+
_faiss_index = _bm25_index = _chunks = _sources = _model = _reranker = None
|
| 39 |
+
load_indexes()
|
| 40 |
+
|
| 41 |
+
def _reciprocal_rank_fusion(lists: list, k: int = 60) -> dict:
|
| 42 |
+
scores: dict = {}
|
| 43 |
+
for ranked_list in lists:
|
| 44 |
+
for rank, doc_id in enumerate(ranked_list):
|
| 45 |
+
scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
|
| 46 |
+
return scores
|
| 47 |
+
|
| 48 |
+
def hybrid_retrieve(query: str, top_k: int = 5) -> list:
|
| 49 |
+
if not indexes_loaded():
|
| 50 |
+
raise RuntimeError("Indexes not loaded. Call load_indexes() first.")
|
| 51 |
+
|
| 52 |
+
q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
|
| 53 |
+
faiss.normalize_L2(q_emb)
|
| 54 |
+
_, dense_ids = _faiss_index.search(q_emb, top_k * 3)
|
| 55 |
+
dense_ranking = [int(i) for i in dense_ids[0] if i >= 0]
|
| 56 |
+
|
| 57 |
+
bm25_scores = _bm25_index.get_scores(query.lower().split())
|
| 58 |
+
sparse_ranking = np.argsort(bm25_scores)[::-1][: top_k * 3].tolist()
|
| 59 |
+
|
| 60 |
+
rrf_scores = _reciprocal_rank_fusion([dense_ranking, sparse_ranking])
|
| 61 |
+
fused_ids = sorted(rrf_scores, key=rrf_scores.get, reverse=True)[: top_k * 2]
|
| 62 |
+
|
| 63 |
+
candidates = [(query, _chunks[i]) for i in fused_ids]
|
| 64 |
+
ce_scores = _reranker.predict(candidates)
|
| 65 |
+
|
| 66 |
+
ranked = sorted(
|
| 67 |
+
zip(fused_ids, ce_scores),
|
| 68 |
+
key=lambda x: x[1],
|
| 69 |
+
reverse=True,
|
| 70 |
+
)[:top_k]
|
| 71 |
+
|
| 72 |
+
return [
|
| 73 |
+
{
|
| 74 |
+
"chunk": _chunks[i],
|
| 75 |
+
"source": _sources[i],
|
| 76 |
+
"chunk_id": i,
|
| 77 |
+
"rrf_score": round(float(rrf_scores[i]), 4),
|
| 78 |
+
"ce_score": round(float(score), 4),
|
| 79 |
+
}
|
| 80 |
+
for i, score in ranked
|
| 81 |
+
]
|
hf_backend/runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.11.9
|
hf_backend/tests/__init__.py
ADDED
|
File without changes
|
hf_backend/tests/test_integration.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_integration.py
|
| 2 |
+
# Run with: pytest tests/test_integration.py -v -m integration
|
| 3 |
+
# These call real APIs β don't run in CI automatically.
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
pytestmark = pytest.mark.integration # tag so CI can skip these
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_groq_connection_live():
|
| 11 |
+
from langchain_groq import ChatGroq
|
| 12 |
+
from langchain_core.messages import HumanMessage
|
| 13 |
+
from config import GROQ_API_KEY, GROQ_MODEL
|
| 14 |
+
llm = ChatGroq(model=GROQ_MODEL, temperature=0, api_key=GROQ_API_KEY)
|
| 15 |
+
r = llm.invoke([HumanMessage(content="Reply with just the word OK")])
|
| 16 |
+
assert len(r.content) > 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_full_pipeline_live():
|
| 20 |
+
"""Ingests a tiny doc, retrieves, runs agent β end to end."""
|
| 21 |
+
import os
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# Write test doc
|
| 25 |
+
Path("./docs").mkdir(exist_ok=True)
|
| 26 |
+
test_file = Path("./docs/_pytest_temp.txt")
|
| 27 |
+
test_file.write_text(
|
| 28 |
+
"The Eiffel Tower is in Paris, France. "
|
| 29 |
+
"It was built in 1889. It is 330 metres tall."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from ingestion import run_ingestion
|
| 34 |
+
from retriever import load_indexes, hybrid_retrieve
|
| 35 |
+
from agent import run_rag_agent
|
| 36 |
+
|
| 37 |
+
run_ingestion()
|
| 38 |
+
load_indexes()
|
| 39 |
+
|
| 40 |
+
results = hybrid_retrieve("How tall is the Eiffel Tower?", top_k=3)
|
| 41 |
+
assert len(results) > 0
|
| 42 |
+
assert "ce_score" in results[0] # reranker ran
|
| 43 |
+
|
| 44 |
+
answer, retries, verdict = run_rag_agent(
|
| 45 |
+
"How tall is the Eiffel Tower?", results
|
| 46 |
+
)
|
| 47 |
+
assert "330" in answer or "metres" in answer.lower()
|
| 48 |
+
assert verdict in {"PASS", "FAIL"}
|
| 49 |
+
|
| 50 |
+
finally:
|
| 51 |
+
test_file.unlink(missing_ok=True) # always clean up
|
hf_backend/tests/test_unit.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_unit.py
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
# ββ RRF logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
+
|
| 6 |
+
def test_rrf_prefers_doc_appearing_in_both_lists():
|
| 7 |
+
from retriever import _reciprocal_rank_fusion
|
| 8 |
+
scores = _reciprocal_rank_fusion([[0, 1, 2], [2, 0, 1]])
|
| 9 |
+
# doc 2 is rank-0 in sparse and rank-2 in dense β should beat doc 1
|
| 10 |
+
assert scores[2] > scores[1]
|
| 11 |
+
|
| 12 |
+
def test_rrf_returns_all_docs():
|
| 13 |
+
from retriever import _reciprocal_rank_fusion
|
| 14 |
+
scores = _reciprocal_rank_fusion([[0, 1], [1, 2]])
|
| 15 |
+
assert set(scores.keys()) == {0, 1, 2}
|
| 16 |
+
|
| 17 |
+
def test_rrf_scores_are_positive():
|
| 18 |
+
from retriever import _reciprocal_rank_fusion
|
| 19 |
+
scores = _reciprocal_rank_fusion([[0, 1, 2]])
|
| 20 |
+
assert all(v > 0 for v in scores.values())
|
| 21 |
+
|
| 22 |
+
# ββ Config sanity βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
|
| 24 |
+
def test_config_values_are_sane():
|
| 25 |
+
from config import CHUNK_SIZE, CHUNK_OVERLAP, TOP_K, MAX_RETRIES
|
| 26 |
+
assert CHUNK_SIZE > CHUNK_OVERLAP, "overlap must be smaller than chunk size"
|
| 27 |
+
assert TOP_K > 0, "TOP_K must be positive"
|
| 28 |
+
assert MAX_RETRIES >= 1, "need at least 1 retry"
|
| 29 |
+
|
| 30 |
+
def test_groq_api_key_present(monkeypatch):
|
| 31 |
+
# patch so we don't need a real key in CI
|
| 32 |
+
monkeypatch.setenv("GROQ_API_KEY", "gsk_fakekeyfortesting1234567890")
|
| 33 |
+
import importlib, config
|
| 34 |
+
importlib.reload(config) # re-reads env
|
| 35 |
+
assert len(config.GROQ_API_KEY) > 10
|
| 36 |
+
|
| 37 |
+
# ββ Agent routing logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
def test_route_returns_done_on_pass():
|
| 40 |
+
from agent import route_after_validation
|
| 41 |
+
state = {"validation_result": "PASS", "retry_count": 0}
|
| 42 |
+
assert route_after_validation(state) == "done"
|
| 43 |
+
|
| 44 |
+
def test_route_returns_retry_on_fail_within_limit():
|
| 45 |
+
from agent import route_after_validation
|
| 46 |
+
state = {"validation_result": "FAIL", "retry_count": 0}
|
| 47 |
+
assert route_after_validation(state) == "retry"
|
| 48 |
+
|
| 49 |
+
def test_route_returns_done_when_retries_exhausted():
|
| 50 |
+
from agent import route_after_validation
|
| 51 |
+
state = {"validation_result": "FAIL", "retry_count": 3}
|
| 52 |
+
assert route_after_validation(state) == "done"
|
| 53 |
+
|
| 54 |
+
def test_increment_retry_node():
|
| 55 |
+
from agent import increment_retry_node
|
| 56 |
+
result = increment_retry_node({"retry_count": 1})
|
| 57 |
+
assert result["retry_count"] == 2
|
| 58 |
+
|
| 59 |
+
# ββ Retriever output shape (mocked indexes) βββββββββββββββββββββββββββββββββββ
|
| 60 |
+
|
| 61 |
+
@pytest.fixture
|
| 62 |
+
def mock_indexes(monkeypatch):
|
| 63 |
+
"""Patches all globals in retriever so no files need to exist."""
|
| 64 |
+
import numpy as np
|
| 65 |
+
import retriever
|
| 66 |
+
|
| 67 |
+
# Fake chunks and sources
|
| 68 |
+
fake_chunks = ["Paris is in France.", "Tower is 330m tall.", "Built in 1889."]
|
| 69 |
+
fake_sources = ["doc1.txt", "doc1.txt", "doc1.txt"]
|
| 70 |
+
|
| 71 |
+
# Fake FAISS index that always returns ids [0, 1, 2]
|
| 72 |
+
class FakeFaiss:
|
| 73 |
+
ntotal = 3
|
| 74 |
+
def search(self, vec, k):
|
| 75 |
+
ids = np.array([[0, 1, 2]])
|
| 76 |
+
return None, ids
|
| 77 |
+
|
| 78 |
+
# Fake BM25 that returns uniform scores
|
| 79 |
+
class FakeBM25:
|
| 80 |
+
def get_scores(self, tokens):
|
| 81 |
+
return np.array([0.9, 0.5, 0.3])
|
| 82 |
+
|
| 83 |
+
# Fake embedder
|
| 84 |
+
class FakeModel:
|
| 85 |
+
def encode(self, texts, convert_to_numpy=True):
|
| 86 |
+
return np.random.rand(len(texts), 384).astype("float32")
|
| 87 |
+
|
| 88 |
+
# Fake cross-encoder
|
| 89 |
+
class FakeReranker:
|
| 90 |
+
def predict(self, pairs):
|
| 91 |
+
return np.array([0.9, 0.7, 0.5][: len(pairs)])
|
| 92 |
+
|
| 93 |
+
monkeypatch.setattr(retriever, "_faiss_index", FakeFaiss())
|
| 94 |
+
monkeypatch.setattr(retriever, "_bm25_index", FakeBM25())
|
| 95 |
+
monkeypatch.setattr(retriever, "_chunks", fake_chunks)
|
| 96 |
+
monkeypatch.setattr(retriever, "_sources", fake_sources)
|
| 97 |
+
monkeypatch.setattr(retriever, "_model", FakeModel())
|
| 98 |
+
monkeypatch.setattr(retriever, "_reranker", FakeReranker())
|
| 99 |
+
return fake_chunks
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def test_hybrid_retrieve_returns_top_k(mock_indexes):
|
| 103 |
+
from retriever import hybrid_retrieve
|
| 104 |
+
results = hybrid_retrieve("Where is Paris?", top_k=2)
|
| 105 |
+
assert len(results) == 2
|
| 106 |
+
|
| 107 |
+
def test_hybrid_retrieve_result_has_required_keys(mock_indexes):
|
| 108 |
+
from retriever import hybrid_retrieve
|
| 109 |
+
result = hybrid_retrieve("Where is Paris?", top_k=1)[0]
|
| 110 |
+
assert "chunk" in result
|
| 111 |
+
assert "source" in result
|
| 112 |
+
assert "rrf_score" in result
|
| 113 |
+
assert "ce_score" in result
|
| 114 |
+
|
| 115 |
+
def test_hybrid_retrieve_scores_are_floats(mock_indexes):
|
| 116 |
+
from retriever import hybrid_retrieve
|
| 117 |
+
result = hybrid_retrieve("test", top_k=1)[0]
|
| 118 |
+
assert isinstance(result["rrf_score"], float)
|
| 119 |
+
assert isinstance(result["ce_score"], float)
|
pytest.ini
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
markers =
|
| 3 |
+
integration: marks integration tests
|
| 4 |
+
addopts = -ra
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 4 |
+
|
| 5 |
+
from main import app
|
| 6 |
+
from fastapi.testclient import TestClient
|
| 7 |
+
|
| 8 |
+
client = TestClient(app)
|
| 9 |
+
|
| 10 |
+
def test_health():
|
| 11 |
+
response = client.get("/")
|
| 12 |
+
assert response.status_code == 200
|