Spaces:
Runtime error
Runtime error
Commit Β·
3951d64
1
Parent(s): 62c0d9a
dd
Browse files- .env.example +8 -0
- .gitignore +132 -0
- Dockerfile +46 -0
- Procfile +1 -0
- README.md +1 -1
- deploy.bat +41 -0
- deploy.sh +35 -0
- models/all-MiniLM-L6-v2/1_Pooling/config.json +10 -0
- models/all-MiniLM-L6-v2/README.md +173 -0
- models/all-MiniLM-L6-v2/config.json +25 -0
- models/all-MiniLM-L6-v2/config_sentence_transformers.json +10 -0
- models/all-MiniLM-L6-v2/model.safetensors +3 -0
- models/all-MiniLM-L6-v2/modules.json +20 -0
- models/all-MiniLM-L6-v2/sentence_bert_config.json +4 -0
- models/all-MiniLM-L6-v2/special_tokens_map.json +37 -0
- models/all-MiniLM-L6-v2/tokenizer.json +0 -0
- models/all-MiniLM-L6-v2/tokenizer_config.json +65 -0
- models/all-MiniLM-L6-v2/vocab.txt +0 -0
- rag.py +816 -0
- railway.json +12 -0
- req.txt +18 -0
- requirements.txt +14 -0
- test.py +231 -0
- test_api.py +124 -0
.env.example
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment Variables Template
|
| 2 |
+
# Copy this file to .env and fill in your actual API keys
|
| 3 |
+
|
| 4 |
+
PINECONE_API_KEY=your_pinecone_api_key_here
|
| 5 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 6 |
+
LANGSMITH_API_KEY=your_langsmith_api_key_here
|
| 7 |
+
LANGSMITH_TRACING=true
|
| 8 |
+
LANGSMITH_PROJECT=BajaRX
|
.gitignore
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.installed.cfg
|
| 25 |
+
*.egg
|
| 26 |
+
MANIFEST
|
| 27 |
+
|
| 28 |
+
# PyInstaller
|
| 29 |
+
*.manifest
|
| 30 |
+
*.spec
|
| 31 |
+
|
| 32 |
+
# Installer logs
|
| 33 |
+
pip-log.txt
|
| 34 |
+
pip-delete-this-directory.txt
|
| 35 |
+
|
| 36 |
+
# Unit test / coverage reports
|
| 37 |
+
htmlcov/
|
| 38 |
+
.tox/
|
| 39 |
+
.nox/
|
| 40 |
+
.coverage
|
| 41 |
+
.coverage.*
|
| 42 |
+
.cache
|
| 43 |
+
nosetests.xml
|
| 44 |
+
coverage.xml
|
| 45 |
+
*.cover
|
| 46 |
+
.hypothesis/
|
| 47 |
+
.pytest_cache/
|
| 48 |
+
|
| 49 |
+
# Translations
|
| 50 |
+
*.mo
|
| 51 |
+
*.pot
|
| 52 |
+
|
| 53 |
+
# Django stuff:
|
| 54 |
+
*.log
|
| 55 |
+
local_settings.py
|
| 56 |
+
db.sqlite3
|
| 57 |
+
|
| 58 |
+
# Flask stuff:
|
| 59 |
+
instance/
|
| 60 |
+
.webassets-cache
|
| 61 |
+
|
| 62 |
+
# Scrapy stuff:
|
| 63 |
+
.scrapy
|
| 64 |
+
|
| 65 |
+
# Sphinx documentation
|
| 66 |
+
docs/_build/
|
| 67 |
+
|
| 68 |
+
# PyBuilder
|
| 69 |
+
target/
|
| 70 |
+
|
| 71 |
+
# Jupyter Notebook
|
| 72 |
+
.ipynb_checkpoints
|
| 73 |
+
|
| 74 |
+
# IPython
|
| 75 |
+
profile_default/
|
| 76 |
+
ipython_config.py
|
| 77 |
+
|
| 78 |
+
# pyenv
|
| 79 |
+
.python-version
|
| 80 |
+
|
| 81 |
+
# celery beat schedule file
|
| 82 |
+
celerybeat-schedule
|
| 83 |
+
|
| 84 |
+
# SageMath parsed files
|
| 85 |
+
*.sage.py
|
| 86 |
+
|
| 87 |
+
# Environments
|
| 88 |
+
.env
|
| 89 |
+
.venv
|
| 90 |
+
env/
|
| 91 |
+
venv/
|
| 92 |
+
ENV/
|
| 93 |
+
env.bak/
|
| 94 |
+
venv.bak/
|
| 95 |
+
|
| 96 |
+
# Spyder project settings
|
| 97 |
+
.spyderproject
|
| 98 |
+
.spyproject
|
| 99 |
+
|
| 100 |
+
# Rope project settings
|
| 101 |
+
.ropeproject
|
| 102 |
+
|
| 103 |
+
# mkdocs documentation
|
| 104 |
+
/site
|
| 105 |
+
|
| 106 |
+
# mypy
|
| 107 |
+
.mypy_cache/
|
| 108 |
+
.dmypy.json
|
| 109 |
+
dmypy.json
|
| 110 |
+
|
| 111 |
+
# Pyre type checker
|
| 112 |
+
.pyre/
|
| 113 |
+
|
| 114 |
+
# Project specific
|
| 115 |
+
rag_system.db
|
| 116 |
+
backup/
|
| 117 |
+
*.pdf
|
| 118 |
+
*.docx
|
| 119 |
+
*.doc
|
| 120 |
+
*.eml
|
| 121 |
+
*.msg
|
| 122 |
+
|
| 123 |
+
# IDE
|
| 124 |
+
.vscode/
|
| 125 |
+
.idea/
|
| 126 |
+
*.swp
|
| 127 |
+
*.swo
|
| 128 |
+
*~
|
| 129 |
+
|
| 130 |
+
# OS
|
| 131 |
+
.DS_Store
|
| 132 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Set environment variables
|
| 7 |
+
ENV PYTHONPATH=/app
|
| 8 |
+
ENV PYTHONUNBUFFERED=1
|
| 9 |
+
ENV PORT=8000
|
| 10 |
+
# Set Hugging Face and Sentence Transformers cache directories
|
| 11 |
+
ENV HF_HOME=/app/.cache/huggingface
|
| 12 |
+
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
|
| 13 |
+
ENV SENTENCE_TRANSFORMERS_HOME=/app/.cache/sentence_transformers
|
| 14 |
+
|
| 15 |
+
# Install system dependencies
|
| 16 |
+
RUN apt-get update && apt-get install -y \
|
| 17 |
+
gcc \
|
| 18 |
+
g++ \
|
| 19 |
+
libmagic1 \
|
| 20 |
+
curl \
|
| 21 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Create cache directories with proper permissions
|
| 24 |
+
RUN mkdir -p /app/.cache/huggingface /app/.cache/sentence_transformers && chmod -R 777 /app/.cache
|
| 25 |
+
|
| 26 |
+
# Copy requirements first for better caching
|
| 27 |
+
COPY requirements.txt .
|
| 28 |
+
|
| 29 |
+
# Install Python dependencies
|
| 30 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 31 |
+
|
| 32 |
+
# Copy application code and model files
|
| 33 |
+
COPY . .
|
| 34 |
+
|
| 35 |
+
# Ensure model directory has proper permissions
|
| 36 |
+
RUN chmod -R 777 /app/models
|
| 37 |
+
|
| 38 |
+
# Expose port
|
| 39 |
+
EXPOSE $PORT
|
| 40 |
+
|
| 41 |
+
# Health check
|
| 42 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 43 |
+
CMD curl -f http://localhost:$PORT/health || exit 1
|
| 44 |
+
|
| 45 |
+
# Run the application
|
| 46 |
+
CMD ["sh", "-c", "uvicorn rag:app --host 0.0.0.0 --port $PORT"]
|
Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: uvicorn rag:app --host 0.0.0.0 --port $PORT
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: HackRx
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
|
|
|
| 1 |
---
|
| 2 |
title: HackRx
|
| 3 |
+
emoji: π
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
deploy.bat
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo π Deploying Ultra-Fast RAG System to Railway...
|
| 3 |
+
|
| 4 |
+
REM Check if Railway CLI is installed
|
| 5 |
+
railway --version >nul 2>&1
|
| 6 |
+
if %errorlevel% neq 0 (
|
| 7 |
+
echo β Railway CLI not found. Please install from: https://railway.app/cli
|
| 8 |
+
echo Or run: npm install -g @railway/cli
|
| 9 |
+
pause
|
| 10 |
+
exit /b 1
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
REM Login to Railway (if not already logged in)
|
| 14 |
+
echo π Checking Railway authentication...
|
| 15 |
+
railway whoami
|
| 16 |
+
if %errorlevel% neq 0 (
|
| 17 |
+
railway login
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
REM Initialize project (if not already initialized)
|
| 21 |
+
if not exist "railway.toml" (
|
| 22 |
+
echo π¦ Initializing Railway project...
|
| 23 |
+
railway init
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
REM Set environment variables reminder
|
| 27 |
+
echo π§ Environment Variables Setup Required:
|
| 28 |
+
echo Please set these in Railway dashboard after deployment:
|
| 29 |
+
echo - PINECONE_API_KEY=your_pinecone_api_key
|
| 30 |
+
echo - GROQ_API_KEY=your_groq_api_key
|
| 31 |
+
echo - LANGSMITH_API_KEY=your_langsmith_api_key
|
| 32 |
+
echo.
|
| 33 |
+
|
| 34 |
+
REM Deploy
|
| 35 |
+
echo π Deploying to Railway...
|
| 36 |
+
railway up
|
| 37 |
+
|
| 38 |
+
echo β
Deployment complete!
|
| 39 |
+
echo π Your API will be available at the Railway-provided URL
|
| 40 |
+
echo π Test with: GET https://your-app.railway.app/health
|
| 41 |
+
pause
|
deploy.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Deploy to Railway Script
|
| 4 |
+
echo "π Deploying Ultra-Fast RAG System to Railway..."
|
| 5 |
+
|
| 6 |
+
# Check if Railway CLI is installed
|
| 7 |
+
if ! command -v railway &> /dev/null; then
|
| 8 |
+
echo "β Railway CLI not found. Installing..."
|
| 9 |
+
curl -fsSL https://railway.app/install.sh | sh
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
# Login to Railway (if not already logged in)
|
| 13 |
+
echo "π Checking Railway authentication..."
|
| 14 |
+
railway whoami || railway login
|
| 15 |
+
|
| 16 |
+
# Initialize project (if not already initialized)
|
| 17 |
+
if [ ! -f "railway.toml" ]; then
|
| 18 |
+
echo "π¦ Initializing Railway project..."
|
| 19 |
+
railway init
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
# Set environment variables
|
| 23 |
+
echo "π§ Setting environment variables..."
|
| 24 |
+
echo "Please set these environment variables in Railway dashboard:"
|
| 25 |
+
echo "PINECONE_API_KEY=your_pinecone_api_key"
|
| 26 |
+
echo "GROQ_API_KEY=your_groq_api_key"
|
| 27 |
+
echo "LANGSMITH_API_KEY=your_langsmith_api_key"
|
| 28 |
+
|
| 29 |
+
# Deploy
|
| 30 |
+
echo "π Deploying to Railway..."
|
| 31 |
+
railway up
|
| 32 |
+
|
| 33 |
+
echo "β
Deployment complete!"
|
| 34 |
+
echo "π Your API will be available at the Railway-provided URL"
|
| 35 |
+
echo "π Test with: GET https://your-app.railway.app/health"
|
models/all-MiniLM-L6-v2/1_Pooling/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"word_embedding_dimension": 384,
|
| 3 |
+
"pooling_mode_cls_token": false,
|
| 4 |
+
"pooling_mode_mean_tokens": true,
|
| 5 |
+
"pooling_mode_max_tokens": false,
|
| 6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
| 7 |
+
"pooling_mode_weightedmean_tokens": false,
|
| 8 |
+
"pooling_mode_lasttoken": false,
|
| 9 |
+
"include_prompt": true
|
| 10 |
+
}
|
models/all-MiniLM-L6-v2/README.md
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
library_name: sentence-transformers
|
| 5 |
+
tags:
|
| 6 |
+
- sentence-transformers
|
| 7 |
+
- feature-extraction
|
| 8 |
+
- sentence-similarity
|
| 9 |
+
- transformers
|
| 10 |
+
datasets:
|
| 11 |
+
- s2orc
|
| 12 |
+
- flax-sentence-embeddings/stackexchange_xml
|
| 13 |
+
- ms_marco
|
| 14 |
+
- gooaq
|
| 15 |
+
- yahoo_answers_topics
|
| 16 |
+
- code_search_net
|
| 17 |
+
- search_qa
|
| 18 |
+
- eli5
|
| 19 |
+
- snli
|
| 20 |
+
- multi_nli
|
| 21 |
+
- wikihow
|
| 22 |
+
- natural_questions
|
| 23 |
+
- trivia_qa
|
| 24 |
+
- embedding-data/sentence-compression
|
| 25 |
+
- embedding-data/flickr30k-captions
|
| 26 |
+
- embedding-data/altlex
|
| 27 |
+
- embedding-data/simple-wiki
|
| 28 |
+
- embedding-data/QQP
|
| 29 |
+
- embedding-data/SPECTER
|
| 30 |
+
- embedding-data/PAQ_pairs
|
| 31 |
+
- embedding-data/WikiAnswers
|
| 32 |
+
pipeline_tag: sentence-similarity
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# all-MiniLM-L6-v2
|
| 37 |
+
This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
|
| 38 |
+
|
| 39 |
+
## Usage (Sentence-Transformers)
|
| 40 |
+
Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
pip install -U sentence-transformers
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Then you can use the model like this:
|
| 47 |
+
```python
|
| 48 |
+
from sentence_transformers import SentenceTransformer
|
| 49 |
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
| 50 |
+
|
| 51 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 52 |
+
embeddings = model.encode(sentences)
|
| 53 |
+
print(embeddings)
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Usage (HuggingFace Transformers)
|
| 57 |
+
Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from transformers import AutoTokenizer, AutoModel
|
| 61 |
+
import torch
|
| 62 |
+
import torch.nn.functional as F
|
| 63 |
+
|
| 64 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
| 65 |
+
def mean_pooling(model_output, attention_mask):
|
| 66 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
| 67 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 68 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Sentences we want sentence embeddings for
|
| 72 |
+
sentences = ['This is an example sentence', 'Each sentence is converted']
|
| 73 |
+
|
| 74 |
+
# Load model from HuggingFace Hub
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
| 76 |
+
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
| 77 |
+
|
| 78 |
+
# Tokenize sentences
|
| 79 |
+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
| 80 |
+
|
| 81 |
+
# Compute token embeddings
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
model_output = model(**encoded_input)
|
| 84 |
+
|
| 85 |
+
# Perform pooling
|
| 86 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 87 |
+
|
| 88 |
+
# Normalize embeddings
|
| 89 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
| 90 |
+
|
| 91 |
+
print("Sentence embeddings:")
|
| 92 |
+
print(sentence_embeddings)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
------
|
| 96 |
+
|
| 97 |
+
## Background
|
| 98 |
+
|
| 99 |
+
The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
|
| 100 |
+
contrastive learning objective. We used the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model and fine-tuned in on a
|
| 101 |
+
1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
|
| 102 |
+
|
| 103 |
+
We developed this model during the
|
| 104 |
+
[Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
|
| 105 |
+
organized by Hugging Face. We developed this model as part of the project:
|
| 106 |
+
[Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
|
| 107 |
+
|
| 108 |
+
## Intended uses
|
| 109 |
+
|
| 110 |
+
Our model is intended to be used as a sentence and short paragraph encoder. Given an input text, it outputs a vector which captures
|
| 111 |
+
the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
|
| 112 |
+
|
| 113 |
+
By default, input text longer than 256 word pieces is truncated.
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
## Training procedure
|
| 117 |
+
|
| 118 |
+
### Pre-training
|
| 119 |
+
|
| 120 |
+
We use the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model. Please refer to the model card for more detailed information about the pre-training procedure.
|
| 121 |
+
|
| 122 |
+
### Fine-tuning
|
| 123 |
+
|
| 124 |
+
We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
|
| 125 |
+
We then apply the cross entropy loss by comparing with true pairs.
|
| 126 |
+
|
| 127 |
+
#### Hyper parameters
|
| 128 |
+
|
| 129 |
+
We trained our model on a TPU v3-8. We train the model during 100k steps using a batch size of 1024 (128 per TPU core).
|
| 130 |
+
We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
|
| 131 |
+
a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
|
| 132 |
+
|
| 133 |
+
#### Training data
|
| 134 |
+
|
| 135 |
+
We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
|
| 136 |
+
We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
| Dataset | Paper | Number of training tuples |
|
| 140 |
+
|--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
|
| 141 |
+
| [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
|
| 142 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
|
| 143 |
+
| [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
|
| 144 |
+
| [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
|
| 145 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
|
| 146 |
+
| [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
|
| 147 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
|
| 148 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title+Body, Answer) pairs | - | 21,396,559 |
|
| 149 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs | - | 21,396,559 |
|
| 150 |
+
| [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
|
| 151 |
+
| [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
|
| 152 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
|
| 153 |
+
| [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
|
| 154 |
+
| [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
|
| 155 |
+
| [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
|
| 156 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
|
| 157 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
|
| 158 |
+
| [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
|
| 159 |
+
| [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
|
| 160 |
+
| [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
|
| 161 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
|
| 162 |
+
| AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
|
| 163 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
|
| 164 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
|
| 165 |
+
| [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
|
| 166 |
+
| [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
|
| 167 |
+
| [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
|
| 168 |
+
| [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
|
| 169 |
+
| [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
|
| 170 |
+
| [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
|
| 171 |
+
| [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
|
| 172 |
+
| [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
|
| 173 |
+
| **Total** | | **1,170,060,424** |
|
models/all-MiniLM-L6-v2/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"gradient_checkpointing": false,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_dropout_prob": 0.1,
|
| 10 |
+
"hidden_size": 384,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": 1536,
|
| 13 |
+
"layer_norm_eps": 1e-12,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "bert",
|
| 16 |
+
"num_attention_heads": 12,
|
| 17 |
+
"num_hidden_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"position_embedding_type": "absolute",
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.50.1",
|
| 22 |
+
"type_vocab_size": 2,
|
| 23 |
+
"use_cache": true,
|
| 24 |
+
"vocab_size": 30522
|
| 25 |
+
}
|
models/all-MiniLM-L6-v2/config_sentence_transformers.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"__version__": {
|
| 3 |
+
"sentence_transformers": "4.0.1",
|
| 4 |
+
"transformers": "4.50.1",
|
| 5 |
+
"pytorch": "2.6.0+cu118"
|
| 6 |
+
},
|
| 7 |
+
"prompts": {},
|
| 8 |
+
"default_prompt_name": null,
|
| 9 |
+
"similarity_fn_name": "cosine"
|
| 10 |
+
}
|
models/all-MiniLM-L6-v2/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1377e9af0ca0b016a9f2aa584d6fc71ab3ea6804fae21ef9fb1416e2944057ac
|
| 3 |
+
size 90864192
|
models/all-MiniLM-L6-v2/modules.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "sentence_transformers.models.Transformer"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"idx": 1,
|
| 10 |
+
"name": "1",
|
| 11 |
+
"path": "1_Pooling",
|
| 12 |
+
"type": "sentence_transformers.models.Pooling"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"idx": 2,
|
| 16 |
+
"name": "2",
|
| 17 |
+
"path": "2_Normalize",
|
| 18 |
+
"type": "sentence_transformers.models.Normalize"
|
| 19 |
+
}
|
| 20 |
+
]
|
models/all-MiniLM-L6-v2/sentence_bert_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_seq_length": 256,
|
| 3 |
+
"do_lower_case": false
|
| 4 |
+
}
|
models/all-MiniLM-L6-v2/special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"mask_token": {
|
| 10 |
+
"content": "[MASK]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"sep_token": {
|
| 24 |
+
"content": "[SEP]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "[UNK]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
models/all-MiniLM-L6-v2/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/all-MiniLM-L6-v2/tokenizer_config.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": true,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "[MASK]",
|
| 50 |
+
"max_length": 128,
|
| 51 |
+
"model_max_length": 256,
|
| 52 |
+
"never_split": null,
|
| 53 |
+
"pad_to_multiple_of": null,
|
| 54 |
+
"pad_token": "[PAD]",
|
| 55 |
+
"pad_token_type_id": 0,
|
| 56 |
+
"padding_side": "right",
|
| 57 |
+
"sep_token": "[SEP]",
|
| 58 |
+
"stride": 0,
|
| 59 |
+
"strip_accents": null,
|
| 60 |
+
"tokenize_chinese_chars": true,
|
| 61 |
+
"tokenizer_class": "BertTokenizer",
|
| 62 |
+
"truncation_side": "right",
|
| 63 |
+
"truncation_strategy": "longest_first",
|
| 64 |
+
"unk_token": "[UNK]"
|
| 65 |
+
}
|
models/all-MiniLM-L6-v2/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
rag.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
import io
|
| 6 |
+
import traceback
|
| 7 |
+
from datetime import datetime, timezone
|
| 8 |
+
from typing import List, Dict, Tuple, Optional
|
| 9 |
+
from fastapi import FastAPI, HTTPException, Depends, Security, BackgroundTasks
|
| 10 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from pydantic import BaseModel, Field, validator
|
| 13 |
+
import requests
|
| 14 |
+
import pdfplumber
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 17 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 18 |
+
from langchain.schema import Document
|
| 19 |
+
from langchain_groq import ChatGroq
|
| 20 |
+
|
| 21 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 22 |
+
import time
|
| 23 |
+
import hashlib
|
| 24 |
+
from urllib.parse import urlparse
|
| 25 |
+
import magic
|
| 26 |
+
import docx2txt
|
| 27 |
+
|
| 28 |
+
model_path = "/app/models/all-MiniLM-L6-v2"
|
| 29 |
+
# Set cache directories to writable path inside the container
|
| 30 |
+
os.environ["HF_HOME"] = "/app/.cache/huggingface"
|
| 31 |
+
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface/transformers"
|
| 32 |
+
os.environ["TOKENIZERS_CACHE"] = "/app/.cache/huggingface/tokenizers"
|
| 33 |
+
|
| 34 |
+
# Optional: create the folders if not exist (may help avoid errors)
|
| 35 |
+
os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
|
| 36 |
+
os.makedirs(os.environ["TOKENIZERS_CACHE"], exist_ok=True)
|
| 37 |
+
# Configure logging
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO,
|
| 40 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 41 |
+
)
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
# Load environment variables
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
env_file = Path(".env")
|
| 47 |
+
if env_file.exists():
|
| 48 |
+
with open(env_file, 'r') as f:
|
| 49 |
+
for line in f:
|
| 50 |
+
line = line.strip()
|
| 51 |
+
if '=' in line and not line.startswith('#'):
|
| 52 |
+
key, value = line.split('=', 1)
|
| 53 |
+
os.environ[key] = value
|
| 54 |
+
|
| 55 |
+
# Configuration
|
| 56 |
+
class Config:
|
| 57 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
|
| 58 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "")
|
| 59 |
+
BEARER_TOKEN = "dbbdb701cfc45d4041e22a03edbfc65753fe9d7b4b9ba1df4884e864f3bb934d"
|
| 60 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 61 |
+
MAX_CHUNK_SIZE = 1536 # works well with models like MiniLM and e5
|
| 62 |
+
CHUNK_OVERLAP = 200
|
| 63 |
+
SIMILARITY_THRESHOLD = 0.2
|
| 64 |
+
TOP_K = 11
|
| 65 |
+
PINECONE_INDEX_NAME = "insurance-documents"
|
| 66 |
+
PINECONE_REGION = "us-east-1"
|
| 67 |
+
MAX_DOCUMENT_SIZE = 50 * 1024 * 1024 # 50MB
|
| 68 |
+
REQUEST_TIMEOUT = 60
|
| 69 |
+
MAX_RETRIES = 3
|
| 70 |
+
|
| 71 |
+
config = Config()
|
| 72 |
+
|
| 73 |
+
# Validate configuration
|
| 74 |
+
if not config.GROQ_API_KEY:
|
| 75 |
+
logger.error("GROQ_API_KEY not found in environment variables")
|
| 76 |
+
if not config.PINECONE_API_KEY:
|
| 77 |
+
logger.error("PINECONE_API_KEY not found in environment variables")
|
| 78 |
+
|
| 79 |
+
# Initialize LLM and embeddings with error handling
|
| 80 |
+
try:
|
| 81 |
+
llm = ChatGroq(
|
| 82 |
+
api_key=config.GROQ_API_KEY,
|
| 83 |
+
model="llama3-70b-8192",
|
| 84 |
+
temperature=0.3, # Slightly higher temperature for more complete responses
|
| 85 |
+
max_tokens=2048, # Explicitly set max tokens
|
| 86 |
+
max_retries=config.MAX_RETRIES
|
| 87 |
+
)
|
| 88 |
+
embedding_model = SentenceTransformer(model_path)
|
| 89 |
+
logger.info("LLM and embedding model initialized successfully")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Failed to initialize models: {str(e)}")
|
| 92 |
+
raise
|
| 93 |
+
|
| 94 |
+
security = HTTPBearer()
|
| 95 |
+
|
| 96 |
+
# Pydantic Models with validation
|
| 97 |
+
class QueryRequest(BaseModel):
|
| 98 |
+
documents: str = Field(..., description="Comma-separated URLs to document blobs", min_length=1)
|
| 99 |
+
questions: List[str] = Field(..., description="List of questions to answer", min_items=1, max_items=50)
|
| 100 |
+
|
| 101 |
+
@validator('questions')
|
| 102 |
+
def validate_questions(cls, v):
|
| 103 |
+
if not all(question.strip() for question in v):
|
| 104 |
+
raise ValueError("All questions must be non-empty strings")
|
| 105 |
+
return [question.strip() for question in v]
|
| 106 |
+
|
| 107 |
+
@validator('documents')
|
| 108 |
+
def validate_documents(cls, v):
|
| 109 |
+
urls = [url.strip() for url in v.split(',') if url.strip()]
|
| 110 |
+
if not urls:
|
| 111 |
+
raise ValueError("At least one valid document URL must be provided")
|
| 112 |
+
for url in urls:
|
| 113 |
+
parsed = urlparse(url)
|
| 114 |
+
if not parsed.scheme or not parsed.netloc:
|
| 115 |
+
raise ValueError(f"Invalid URL format: {url}")
|
| 116 |
+
return v
|
| 117 |
+
|
| 118 |
+
class QueryResponse(BaseModel):
|
| 119 |
+
answers: List[str] = Field(..., description="List of answers")
|
| 120 |
+
processing_time: float = Field(..., description="Total processing time in seconds")
|
| 121 |
+
documents_processed: int = Field(..., description="Number of documents processed")
|
| 122 |
+
chunks_retrieved: int = Field(..., description="Total chunks retrieved for all questions")
|
| 123 |
+
|
| 124 |
+
# Enhanced Document Processor
|
| 125 |
+
class DocumentProcessor:
|
| 126 |
+
def __init__(self):
|
| 127 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 128 |
+
chunk_size=config.MAX_CHUNK_SIZE,
|
| 129 |
+
chunk_overlap=config.CHUNK_OVERLAP,
|
| 130 |
+
separators=["\n\n", "\n", ". ", " ", ""]
|
| 131 |
+
)
|
| 132 |
+
self.document_cache = {}
|
| 133 |
+
self.supported_types = {
|
| 134 |
+
'application/pdf': self._extract_pdf_text,
|
| 135 |
+
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': self._extract_docx_text,
|
| 136 |
+
'text/plain': self._extract_text_content,
|
| 137 |
+
'text/html': self._extract_text_content
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def _get_document_hash(self, url: str) -> str:
|
| 141 |
+
"""Generate a hash for the document URL for caching"""
|
| 142 |
+
return hashlib.md5(url.encode()).hexdigest()
|
| 143 |
+
|
| 144 |
+
def download_document(self, url: str) -> Tuple[bytes, str]:
|
| 145 |
+
"""Download document and return content with MIME type"""
|
| 146 |
+
try:
|
| 147 |
+
# Validate URL format more strictly
|
| 148 |
+
parsed = urlparse(url)
|
| 149 |
+
if not parsed.scheme or not parsed.netloc:
|
| 150 |
+
raise HTTPException(
|
| 151 |
+
status_code=400,
|
| 152 |
+
detail=f"Invalid URL format: {url}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Check if domain is reachable (basic validation)
|
| 156 |
+
import socket
|
| 157 |
+
try:
|
| 158 |
+
socket.gethostbyname(parsed.netloc.split(':')[0])
|
| 159 |
+
except socket.gaierror:
|
| 160 |
+
raise HTTPException(
|
| 161 |
+
status_code=400,
|
| 162 |
+
detail=f"Domain not reachable: {parsed.netloc}"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
headers = {
|
| 166 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
with requests.get(
|
| 170 |
+
url,
|
| 171 |
+
timeout=config.REQUEST_TIMEOUT,
|
| 172 |
+
headers=headers,
|
| 173 |
+
stream=True
|
| 174 |
+
) as response:
|
| 175 |
+
response.raise_for_status()
|
| 176 |
+
|
| 177 |
+
# Check content length
|
| 178 |
+
content_length = response.headers.get('content-length')
|
| 179 |
+
if content_length and int(content_length) > config.MAX_DOCUMENT_SIZE:
|
| 180 |
+
raise HTTPException(
|
| 181 |
+
status_code=413,
|
| 182 |
+
detail=f"Document too large. Max size: {config.MAX_DOCUMENT_SIZE} bytes"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
content = response.content
|
| 186 |
+
|
| 187 |
+
# Detect MIME type
|
| 188 |
+
try:
|
| 189 |
+
mime_type = magic.from_buffer(content[:1024], mime=True)
|
| 190 |
+
except:
|
| 191 |
+
# Fallback to content-type header or URL extension
|
| 192 |
+
mime_type = response.headers.get('content-type', '').split(';')[0]
|
| 193 |
+
if not mime_type:
|
| 194 |
+
if url.lower().endswith('.pdf'):
|
| 195 |
+
mime_type = 'application/pdf'
|
| 196 |
+
elif url.lower().endswith('.docx'):
|
| 197 |
+
mime_type = 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
| 198 |
+
else:
|
| 199 |
+
mime_type = 'text/plain'
|
| 200 |
+
|
| 201 |
+
return content, mime_type
|
| 202 |
+
|
| 203 |
+
except requests.RequestException as e:
|
| 204 |
+
logger.error(f"Failed to download {url}: {str(e)}")
|
| 205 |
+
raise HTTPException(
|
| 206 |
+
status_code=400,
|
| 207 |
+
detail=f"Failed to download document: {str(e)}"
|
| 208 |
+
)
|
| 209 |
+
except HTTPException:
|
| 210 |
+
raise # Re-raise HTTP exceptions
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logger.error(f"Unexpected error downloading {url}: {str(e)}")
|
| 213 |
+
raise HTTPException(
|
| 214 |
+
status_code=500,
|
| 215 |
+
detail=f"Unexpected error downloading document: {str(e)}"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def _extract_pdf_text(self, content: bytes) -> str:
|
| 219 |
+
"""Extract text from PDF content"""
|
| 220 |
+
try:
|
| 221 |
+
# Convert bytes to file-like object
|
| 222 |
+
pdf_file = io.BytesIO(content)
|
| 223 |
+
|
| 224 |
+
with pdfplumber.open(pdf_file) as pdf:
|
| 225 |
+
text_parts = []
|
| 226 |
+
|
| 227 |
+
for page_num, page in enumerate(pdf.pages):
|
| 228 |
+
try:
|
| 229 |
+
page_text = page.extract_text()
|
| 230 |
+
if page_text:
|
| 231 |
+
text_parts.append(f"\n--- Page {page_num + 1} ---\n{page_text.strip()}")
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.warning(f"Failed to extract text from page {page_num + 1}: {str(e)}")
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
full_text = "\n".join(text_parts)
|
| 237 |
+
|
| 238 |
+
if not full_text.strip():
|
| 239 |
+
# Try alternative extraction methods
|
| 240 |
+
logger.info("Standard extraction failed, trying alternative methods")
|
| 241 |
+
# You could add OCR here if needed (like pytesseract)
|
| 242 |
+
return "No readable text content found in PDF"
|
| 243 |
+
|
| 244 |
+
return full_text.strip()
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.error(f"PDF extraction failed: {str(e)}")
|
| 248 |
+
logger.error(traceback.format_exc())
|
| 249 |
+
raise HTTPException(
|
| 250 |
+
status_code=400,
|
| 251 |
+
detail=f"Failed to extract PDF text: {str(e)}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def _extract_docx_text(self, content: bytes) -> str:
|
| 255 |
+
"""Extract text from DOCX content"""
|
| 256 |
+
try:
|
| 257 |
+
docx_file = io.BytesIO(content)
|
| 258 |
+
text = docx2txt.process(docx_file)
|
| 259 |
+
return text.strip() if text else "No text content found in document"
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"DOCX extraction failed: {str(e)}")
|
| 262 |
+
raise HTTPException(
|
| 263 |
+
status_code=400,
|
| 264 |
+
detail=f"Failed to extract DOCX text: {str(e)}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def _extract_text_content(self, content: bytes) -> str:
|
| 268 |
+
"""Extract text from plain text or HTML content"""
|
| 269 |
+
try:
|
| 270 |
+
# Try different encodings
|
| 271 |
+
encodings = ['utf-8', 'utf-16', 'latin-1', 'cp1252']
|
| 272 |
+
|
| 273 |
+
for encoding in encodings:
|
| 274 |
+
try:
|
| 275 |
+
text = content.decode(encoding, errors='ignore')
|
| 276 |
+
if text.strip():
|
| 277 |
+
return text.strip()
|
| 278 |
+
except:
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
return "Unable to decode text content"
|
| 282 |
+
except Exception as e:
|
| 283 |
+
logger.error(f"Text extraction failed: {str(e)}")
|
| 284 |
+
return "Failed to extract text content"
|
| 285 |
+
|
| 286 |
+
def process_document(self, url: str) -> List[Document]:
|
| 287 |
+
"""Process a document and return chunks"""
|
| 288 |
+
doc_hash = self._get_document_hash(url)
|
| 289 |
+
|
| 290 |
+
if doc_hash in self.document_cache:
|
| 291 |
+
logger.info(f"Using cached document for {url}")
|
| 292 |
+
return self.document_cache[doc_hash]
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
content, mime_type = self.download_document(url)
|
| 296 |
+
logger.info(f"Downloaded document {url} with MIME type: {mime_type}")
|
| 297 |
+
|
| 298 |
+
# Extract text based on MIME type
|
| 299 |
+
if mime_type in self.supported_types:
|
| 300 |
+
text = self.supported_types[mime_type](content)
|
| 301 |
+
else:
|
| 302 |
+
logger.warning(f"Unsupported MIME type {mime_type}, treating as plain text")
|
| 303 |
+
text = self._extract_text_content(content)
|
| 304 |
+
|
| 305 |
+
if not text or len(text.strip()) < 10:
|
| 306 |
+
raise HTTPException(
|
| 307 |
+
status_code=400,
|
| 308 |
+
detail="Document appears to be empty or contains insufficient text content"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Split text into chunks
|
| 312 |
+
chunks = self.text_splitter.split_text(text)
|
| 313 |
+
|
| 314 |
+
# Filter out very short chunks
|
| 315 |
+
meaningful_chunks = [chunk for chunk in chunks if len(chunk.strip()) > 20]
|
| 316 |
+
|
| 317 |
+
if not meaningful_chunks:
|
| 318 |
+
raise HTTPException(
|
| 319 |
+
status_code=400,
|
| 320 |
+
detail="No meaningful text chunks could be extracted from the document"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
documents = [
|
| 324 |
+
Document(
|
| 325 |
+
page_content=chunk,
|
| 326 |
+
metadata={
|
| 327 |
+
"source": url,
|
| 328 |
+
"chunk_id": i,
|
| 329 |
+
"mime_type": mime_type,
|
| 330 |
+
"doc_hash": doc_hash
|
| 331 |
+
}
|
| 332 |
+
)
|
| 333 |
+
for i, chunk in enumerate(meaningful_chunks)
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
self.document_cache[doc_hash] = documents
|
| 337 |
+
logger.info(f"Processed {len(documents)} chunks for {url}")
|
| 338 |
+
return documents
|
| 339 |
+
|
| 340 |
+
except HTTPException:
|
| 341 |
+
raise
|
| 342 |
+
except Exception as e:
|
| 343 |
+
logger.error(f"Unexpected error processing document {url}: {str(e)}")
|
| 344 |
+
logger.error(traceback.format_exc())
|
| 345 |
+
raise HTTPException(
|
| 346 |
+
status_code=500,
|
| 347 |
+
detail=f"Unexpected error processing document: {str(e)}"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Enhanced Pinecone Vector Store
|
| 351 |
+
class PineconeVectorStore:
|
| 352 |
+
def __init__(self, api_key: str, index_name: str):
|
| 353 |
+
try:
|
| 354 |
+
self.pc = Pinecone(api_key=api_key)
|
| 355 |
+
self.index_name = index_name
|
| 356 |
+
self.dimension = 384
|
| 357 |
+
|
| 358 |
+
# Check if index exists, create if not
|
| 359 |
+
existing_indexes = [index.name for index in self.pc.list_indexes()]
|
| 360 |
+
|
| 361 |
+
if index_name not in existing_indexes:
|
| 362 |
+
logger.info(f"Creating new Pinecone index: {index_name}")
|
| 363 |
+
self.pc.create_index(
|
| 364 |
+
name=index_name,
|
| 365 |
+
dimension=self.dimension,
|
| 366 |
+
metric="cosine",
|
| 367 |
+
spec=ServerlessSpec(cloud="aws", region=config.PINECONE_REGION)
|
| 368 |
+
)
|
| 369 |
+
# Wait for index to be ready
|
| 370 |
+
time.sleep(10)
|
| 371 |
+
|
| 372 |
+
self.index = self.pc.Index(index_name)
|
| 373 |
+
self.processed_docs = set()
|
| 374 |
+
logger.info(f"Pinecone vector store initialized successfully")
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(f"Failed to initialize Pinecone: {str(e)}")
|
| 378 |
+
raise
|
| 379 |
+
|
| 380 |
+
def document_exists(self, doc_hash: str) -> bool:
|
| 381 |
+
"""Check if document is already indexed"""
|
| 382 |
+
return doc_hash in self.processed_docs
|
| 383 |
+
|
| 384 |
+
async def add_documents(self, documents: List[Document], batch_size: int = 100):
|
| 385 |
+
"""Add documents to the vector store in batches"""
|
| 386 |
+
try:
|
| 387 |
+
doc_hash = documents[0].metadata.get('doc_hash')
|
| 388 |
+
|
| 389 |
+
if self.document_exists(doc_hash):
|
| 390 |
+
logger.info(f"Document {doc_hash} already indexed")
|
| 391 |
+
return
|
| 392 |
+
|
| 393 |
+
# Process in batches to avoid memory issues
|
| 394 |
+
for i in range(0, len(documents), batch_size):
|
| 395 |
+
batch = documents[i:i + batch_size]
|
| 396 |
+
vectors = []
|
| 397 |
+
|
| 398 |
+
for doc in batch:
|
| 399 |
+
try:
|
| 400 |
+
embedding = embedding_model.encode(doc.page_content).tolist()
|
| 401 |
+
vector = {
|
| 402 |
+
"id": f"{doc_hash}_{doc.metadata['chunk_id']}",
|
| 403 |
+
"values": embedding,
|
| 404 |
+
"metadata": {
|
| 405 |
+
"text": doc.page_content[:1000], # Limit metadata size
|
| 406 |
+
"source": doc.metadata['source'],
|
| 407 |
+
"chunk_id": doc.metadata['chunk_id'],
|
| 408 |
+
"doc_hash": doc_hash
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
vectors.append(vector)
|
| 412 |
+
except Exception as e:
|
| 413 |
+
logger.error(f"Failed to create embedding for chunk {doc.metadata['chunk_id']}: {str(e)}")
|
| 414 |
+
continue
|
| 415 |
+
|
| 416 |
+
if vectors:
|
| 417 |
+
self.index.upsert(vectors=vectors)
|
| 418 |
+
logger.info(f"Upserted batch of {len(vectors)} vectors")
|
| 419 |
+
|
| 420 |
+
self.processed_docs.add(doc_hash)
|
| 421 |
+
logger.info(f"Successfully indexed {len(documents)} chunks")
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.error(f"Failed to add documents to vector store: {str(e)}")
|
| 425 |
+
raise
|
| 426 |
+
|
| 427 |
+
async def similarity_search(self, query: str, top_k: int = config.TOP_K) -> List[Tuple[Document, float]]:
|
| 428 |
+
"""Perform similarity search"""
|
| 429 |
+
try:
|
| 430 |
+
query_embedding = embedding_model.encode(query).tolist()
|
| 431 |
+
|
| 432 |
+
results = self.index.query(
|
| 433 |
+
vector=query_embedding,
|
| 434 |
+
top_k=top_k,
|
| 435 |
+
include_metadata=True
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
documents_with_scores = []
|
| 439 |
+
for match in results.matches:
|
| 440 |
+
if match.score >= config.SIMILARITY_THRESHOLD:
|
| 441 |
+
doc = Document(
|
| 442 |
+
page_content=match.metadata.get("text", ""),
|
| 443 |
+
metadata=match.metadata
|
| 444 |
+
)
|
| 445 |
+
documents_with_scores.append((doc, float(match.score)))
|
| 446 |
+
|
| 447 |
+
logger.info(f"Retrieved {len(documents_with_scores)} relevant chunks for query")
|
| 448 |
+
return documents_with_scores
|
| 449 |
+
|
| 450 |
+
except Exception as e:
|
| 451 |
+
logger.error(f"Similarity search failed: {str(e)}")
|
| 452 |
+
return []
|
| 453 |
+
|
| 454 |
+
async def delete_documents(self, doc_hashes: List[str]):
|
| 455 |
+
"""Delete documents from the vector store"""
|
| 456 |
+
try:
|
| 457 |
+
for doc_hash in doc_hashes:
|
| 458 |
+
# Delete all vectors for this document
|
| 459 |
+
delete_response = self.index.delete(filter={"doc_hash": {"$eq": doc_hash}})
|
| 460 |
+
logger.info(f"Deleted vectors for document {doc_hash}")
|
| 461 |
+
self.processed_docs.discard(doc_hash)
|
| 462 |
+
except Exception as e:
|
| 463 |
+
logger.error(f"Failed to delete documents: {str(e)}")
|
| 464 |
+
|
| 465 |
+
# Enhanced Insurance Query Processor
|
| 466 |
+
class InsuranceQueryEnhancer:
|
| 467 |
+
def __init__(self):
|
| 468 |
+
self.insurance_terms = {
|
| 469 |
+
'premium': ['payment', 'installment', 'fee', 'cost'],
|
| 470 |
+
'coverage': ['benefit', 'protection', 'indemnity', 'compensation'],
|
| 471 |
+
'waiting period': ['qualification period', 'cooling period'],
|
| 472 |
+
'grace period': ['extension period', 'buffer period'],
|
| 473 |
+
'maternity': ['pregnancy', 'childbirth', 'delivery'],
|
| 474 |
+
'pre-existing': ['prior condition', 'existing condition'],
|
| 475 |
+
'deductible': ['excess', 'co-payment'],
|
| 476 |
+
'exclusion': ['limitation', 'restriction'],
|
| 477 |
+
'claim': ['settlement', 'reimbursement'],
|
| 478 |
+
'policy': ['contract', 'agreement', 'plan']
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
def expand_query(self, query: str) -> str:
|
| 482 |
+
"""Expand query with insurance-specific synonyms"""
|
| 483 |
+
query_lower = query.lower()
|
| 484 |
+
expanded_terms = [query]
|
| 485 |
+
|
| 486 |
+
for main_term, synonyms in self.insurance_terms.items():
|
| 487 |
+
if main_term in query_lower:
|
| 488 |
+
for synonym in synonyms:
|
| 489 |
+
expanded_terms.append(query.lower().replace(main_term, synonym))
|
| 490 |
+
|
| 491 |
+
return ' '.join(expanded_terms)
|
| 492 |
+
|
| 493 |
+
# FastAPI App
|
| 494 |
+
app = FastAPI(
|
| 495 |
+
title="Robust RAG System for Insurance Documents",
|
| 496 |
+
description="Advanced RAG system with comprehensive error handling and document processing",
|
| 497 |
+
version="3.0.0"
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
app.add_middleware(
|
| 501 |
+
CORSMiddleware,
|
| 502 |
+
allow_origins=["*"],
|
| 503 |
+
allow_credentials=True,
|
| 504 |
+
allow_methods=["*"],
|
| 505 |
+
allow_headers=["*"]
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Initialize components
|
| 509 |
+
processor = DocumentProcessor()
|
| 510 |
+
vector_store = PineconeVectorStore(config.PINECONE_API_KEY, config.PINECONE_INDEX_NAME)
|
| 511 |
+
query_enhancer = InsuranceQueryEnhancer()
|
| 512 |
+
|
| 513 |
+
# Authentication
|
| 514 |
+
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 515 |
+
if credentials.credentials != config.BEARER_TOKEN:
|
| 516 |
+
raise HTTPException(status_code=401, detail="Invalid authentication token")
|
| 517 |
+
return credentials.credentials
|
| 518 |
+
|
| 519 |
+
# API Endpoints
|
| 520 |
+
@app.post("/hackrx/run", response_model=QueryResponse)
|
| 521 |
+
async def query_retrieval(
|
| 522 |
+
request: QueryRequest,
|
| 523 |
+
background_tasks: BackgroundTasks,
|
| 524 |
+
token: str = Depends(verify_token)
|
| 525 |
+
):
|
| 526 |
+
start_time = time.time()
|
| 527 |
+
total_chunks_retrieved = 0
|
| 528 |
+
processed_docs = 0
|
| 529 |
+
|
| 530 |
+
try:
|
| 531 |
+
doc_urls = [url.strip() for url in request.documents.split(',') if url.strip()]
|
| 532 |
+
logger.info(f"Processing {len(doc_urls)} documents and {len(request.questions)} questions")
|
| 533 |
+
|
| 534 |
+
# Process documents with better error handling
|
| 535 |
+
doc_hashes = []
|
| 536 |
+
failed_docs = []
|
| 537 |
+
|
| 538 |
+
for url in doc_urls:
|
| 539 |
+
try:
|
| 540 |
+
doc_hash = processor._get_document_hash(url)
|
| 541 |
+
|
| 542 |
+
if not vector_store.document_exists(doc_hash):
|
| 543 |
+
logger.info(f"Processing new document: {url}")
|
| 544 |
+
documents = processor.process_document(url)
|
| 545 |
+
await vector_store.add_documents(documents)
|
| 546 |
+
processed_docs += 1
|
| 547 |
+
else:
|
| 548 |
+
logger.info(f"Document already processed: {url}")
|
| 549 |
+
processed_docs += 1
|
| 550 |
+
|
| 551 |
+
doc_hashes.append(doc_hash)
|
| 552 |
+
|
| 553 |
+
except HTTPException as e:
|
| 554 |
+
logger.error(f"HTTP error processing document {url}: {e.detail}")
|
| 555 |
+
failed_docs.append(f"{url}: {e.detail}")
|
| 556 |
+
continue
|
| 557 |
+
except Exception as e:
|
| 558 |
+
logger.error(f"Unexpected error processing document {url}: {str(e)}")
|
| 559 |
+
failed_docs.append(f"{url}: {str(e)}")
|
| 560 |
+
continue
|
| 561 |
+
|
| 562 |
+
# If no documents were successfully processed, return error
|
| 563 |
+
if processed_docs == 0:
|
| 564 |
+
error_msg = "No documents could be processed successfully."
|
| 565 |
+
if failed_docs:
|
| 566 |
+
error_msg += f" Errors: {'; '.join(failed_docs[:3])}" # Show first 3 errors
|
| 567 |
+
raise HTTPException(
|
| 568 |
+
status_code=400,
|
| 569 |
+
detail=error_msg
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# Process questions
|
| 573 |
+
async def process_question(question: str) -> str:
|
| 574 |
+
nonlocal total_chunks_retrieved
|
| 575 |
+
try:
|
| 576 |
+
expanded_query = query_enhancer.expand_query(question)
|
| 577 |
+
retrieved_docs = await vector_store.similarity_search(expanded_query)
|
| 578 |
+
|
| 579 |
+
if not retrieved_docs:
|
| 580 |
+
logger.warning(f"No relevant information found for question: {question}")
|
| 581 |
+
return "No relevant information found in the documents for this question."
|
| 582 |
+
|
| 583 |
+
total_chunks_retrieved += len(retrieved_docs)
|
| 584 |
+
|
| 585 |
+
# Build context from retrieved documents
|
| 586 |
+
context_parts = []
|
| 587 |
+
for i, (doc, score) in enumerate(retrieved_docs):
|
| 588 |
+
context_parts.append(
|
| 589 |
+
f"[Chunk {i+1} - Relevance: {score:.3f}]\n{doc.page_content}"
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
context = "\n\n".join(context_parts)
|
| 593 |
+
|
| 594 |
+
# Enhanced system prompt
|
| 595 |
+
system_prompt = """You are an expert insurance policy analyst with comprehensive knowledge of insurance regulations, particularly Indian insurance policies.
|
| 596 |
+
|
| 597 |
+
Your expertise includes:
|
| 598 |
+
- Policy terms, conditions, and exclusions
|
| 599 |
+
- Premium calculations and payment structures
|
| 600 |
+
- Claim procedures and settlement processes
|
| 601 |
+
- Waiting periods, grace periods, coverage limits, and deductibles
|
| 602 |
+
- Pre-existing disease clauses, maternity benefits, and specialized treatments
|
| 603 |
+
- Regulatory compliance and policy terminology
|
| 604 |
+
|
| 605 |
+
Instructions for answering:
|
| 606 |
+
1. Provide precise, factual answers based exclusively on the document context provided
|
| 607 |
+
2. Include specific amounts, percentages, time periods, and conditions when mentioned
|
| 608 |
+
3. Clearly state any conditions, limitations, or exclusions that apply
|
| 609 |
+
4. Use proper insurance terminology and maintain professional language
|
| 610 |
+
5. If information is not available in the context, explicitly state this
|
| 611 |
+
6. When referencing policy sections or clauses, mention them if available
|
| 612 |
+
7. Provide comprehensive answers that address all aspects of the question
|
| 613 |
+
|
| 614 |
+
Format your response clearly and professionally."""
|
| 615 |
+
|
| 616 |
+
messages = [
|
| 617 |
+
SystemMessage(content=system_prompt),
|
| 618 |
+
HumanMessage(content=f"""Based on the following document context, please answer the question comprehensively:
|
| 619 |
+
|
| 620 |
+
CONTEXT:
|
| 621 |
+
{context}
|
| 622 |
+
|
| 623 |
+
QUESTION: {question}
|
| 624 |
+
|
| 625 |
+
Please provide a detailed, accurate answer based solely on the information in the context above.""")
|
| 626 |
+
]
|
| 627 |
+
|
| 628 |
+
response = await llm.ainvoke(messages)
|
| 629 |
+
return response.content
|
| 630 |
+
|
| 631 |
+
except Exception as e:
|
| 632 |
+
logger.error(f"Error processing question '{question}': {str(e)}")
|
| 633 |
+
return f"An error occurred while processing this question: {str(e)}"
|
| 634 |
+
|
| 635 |
+
# Process all questions concurrently
|
| 636 |
+
logger.info("Processing questions concurrently...")
|
| 637 |
+
answers = await asyncio.gather(*[process_question(q) for q in request.questions])
|
| 638 |
+
|
| 639 |
+
processing_time = time.time() - start_time
|
| 640 |
+
logger.info(f"Completed processing in {processing_time:.2f} seconds")
|
| 641 |
+
|
| 642 |
+
# Schedule cleanup in background
|
| 643 |
+
background_tasks.add_task(vector_store.delete_documents, doc_hashes)
|
| 644 |
+
|
| 645 |
+
response_data = QueryResponse(
|
| 646 |
+
answers=answers,
|
| 647 |
+
processing_time=processing_time,
|
| 648 |
+
documents_processed=processed_docs,
|
| 649 |
+
chunks_retrieved=total_chunks_retrieved
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# Add warning if some documents failed
|
| 653 |
+
if failed_docs:
|
| 654 |
+
logger.warning(f"Some documents failed to process: {failed_docs}")
|
| 655 |
+
|
| 656 |
+
return response_data
|
| 657 |
+
|
| 658 |
+
except HTTPException:
|
| 659 |
+
raise # Re-raise HTTP exceptions
|
| 660 |
+
except Exception as e:
|
| 661 |
+
logger.error(f"Error in query retrieval: {str(e)}")
|
| 662 |
+
logger.error(traceback.format_exc())
|
| 663 |
+
raise HTTPException(
|
| 664 |
+
status_code=500,
|
| 665 |
+
detail=f"Internal server error: {str(e)}"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# 3. Add a dedicated document validation endpoint
|
| 669 |
+
@app.post("/validate-documents")
|
| 670 |
+
async def validate_documents(
|
| 671 |
+
documents: str,
|
| 672 |
+
token: str = Depends(verify_token)
|
| 673 |
+
):
|
| 674 |
+
"""Validate document URLs without processing them"""
|
| 675 |
+
try:
|
| 676 |
+
doc_urls = [url.strip() for url in documents.split(',') if url.strip()]
|
| 677 |
+
results = []
|
| 678 |
+
|
| 679 |
+
for url in doc_urls:
|
| 680 |
+
try:
|
| 681 |
+
# Basic URL validation
|
| 682 |
+
parsed = urlparse(url)
|
| 683 |
+
if not parsed.scheme or not parsed.netloc:
|
| 684 |
+
results.append({
|
| 685 |
+
"url": url,
|
| 686 |
+
"valid": False,
|
| 687 |
+
"error": "Invalid URL format"
|
| 688 |
+
})
|
| 689 |
+
continue
|
| 690 |
+
|
| 691 |
+
# Test connectivity
|
| 692 |
+
response = requests.head(url, timeout=10, allow_redirects=True)
|
| 693 |
+
|
| 694 |
+
results.append({
|
| 695 |
+
"url": url,
|
| 696 |
+
"valid": response.status_code < 400,
|
| 697 |
+
"status_code": response.status_code,
|
| 698 |
+
"content_type": response.headers.get('content-type', 'unknown'),
|
| 699 |
+
"content_length": response.headers.get('content-length', 'unknown')
|
| 700 |
+
})
|
| 701 |
+
|
| 702 |
+
except Exception as e:
|
| 703 |
+
results.append({
|
| 704 |
+
"url": url,
|
| 705 |
+
"valid": False,
|
| 706 |
+
"error": str(e)
|
| 707 |
+
})
|
| 708 |
+
|
| 709 |
+
return {
|
| 710 |
+
"validation_results": results,
|
| 711 |
+
"valid_count": sum(1 for r in results if r.get('valid', False)),
|
| 712 |
+
"total_count": len(results)
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
except Exception as e:
|
| 716 |
+
raise HTTPException(
|
| 717 |
+
status_code=500,
|
| 718 |
+
detail=f"Validation error: {str(e)}"
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
@app.get("/health")
|
| 722 |
+
async def health_check():
|
| 723 |
+
"""Health check endpoint"""
|
| 724 |
+
try:
|
| 725 |
+
# Test basic functionality
|
| 726 |
+
test_embedding = embedding_model.encode("test")
|
| 727 |
+
|
| 728 |
+
return {
|
| 729 |
+
"status": "healthy",
|
| 730 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 731 |
+
"version": "3.0.0",
|
| 732 |
+
"components": {
|
| 733 |
+
"embedding_model": "operational",
|
| 734 |
+
"vector_store": "operational",
|
| 735 |
+
"llm": "operational"
|
| 736 |
+
}
|
| 737 |
+
}
|
| 738 |
+
except Exception as e:
|
| 739 |
+
logger.error(f"Health check failed: {str(e)}")
|
| 740 |
+
raise HTTPException(status_code=503, detail="Service unhealthy")
|
| 741 |
+
|
| 742 |
+
@app.get("/metrics")
|
| 743 |
+
async def get_metrics(token: str = Depends(verify_token)):
|
| 744 |
+
"""Get system metrics"""
|
| 745 |
+
return {
|
| 746 |
+
"status": "operational",
|
| 747 |
+
"configuration": {
|
| 748 |
+
"pinecone_index": config.PINECONE_INDEX_NAME,
|
| 749 |
+
"embedding_model": config.EMBEDDING_MODEL,
|
| 750 |
+
"max_chunk_size": config.MAX_CHUNK_SIZE,
|
| 751 |
+
"similarity_threshold": config.SIMILARITY_THRESHOLD,
|
| 752 |
+
"top_k": config.TOP_K
|
| 753 |
+
},
|
| 754 |
+
"version": "3.0.0",
|
| 755 |
+
"features": [
|
| 756 |
+
"multi_format_document_processing",
|
| 757 |
+
"pinecone_vector_database",
|
| 758 |
+
"parallel_question_processing",
|
| 759 |
+
"insurance_domain_optimization",
|
| 760 |
+
"robust_error_handling",
|
| 761 |
+
"document_caching",
|
| 762 |
+
"batch_processing"
|
| 763 |
+
]
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
@app.post("/webhook")
|
| 767 |
+
async def hackathon_webhook(request: dict):
|
| 768 |
+
"""Webhook endpoint for hackathon"""
|
| 769 |
+
logger.info(f"Webhook received: {request}")
|
| 770 |
+
return {
|
| 771 |
+
"status": "success",
|
| 772 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 773 |
+
"system_health": await health_check(),
|
| 774 |
+
"api_endpoints": {
|
| 775 |
+
"main_submission": "/hackrx/run",
|
| 776 |
+
"health_check": "/health",
|
| 777 |
+
"metrics": "/metrics"
|
| 778 |
+
}
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
@app.get("/")
|
| 782 |
+
def read_root():
|
| 783 |
+
return {"message": "RAG backend is up and running!"}
|
| 784 |
+
|
| 785 |
+
# Error handlers
|
| 786 |
+
@app.exception_handler(Exception)
|
| 787 |
+
async def global_exception_handler(request, exc):
|
| 788 |
+
logger.error(f"Global exception handler: {str(exc)}")
|
| 789 |
+
logger.error(traceback.format_exc())
|
| 790 |
+
return HTTPException(
|
| 791 |
+
status_code=500,
|
| 792 |
+
detail="An unexpected error occurred. Please try again later."
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
# Startup event
|
| 796 |
+
@app.on_event("startup")
|
| 797 |
+
async def startup_event():
|
| 798 |
+
logger.info("RAG System starting up...")
|
| 799 |
+
logger.info(f"Configuration loaded: Index={config.PINECONE_INDEX_NAME}, Model={config.EMBEDDING_MODEL}")
|
| 800 |
+
|
| 801 |
+
# Shutdown event
|
| 802 |
+
@app.on_event("shutdown")
|
| 803 |
+
async def shutdown_event():
|
| 804 |
+
logger.info("RAG System shutting down...")
|
| 805 |
+
|
| 806 |
+
# Run the app
|
| 807 |
+
if __name__ == "__main__":
|
| 808 |
+
import uvicorn
|
| 809 |
+
port = int(os.environ.get("PORT", 8000))
|
| 810 |
+
uvicorn.run(
|
| 811 |
+
app,
|
| 812 |
+
host="0.0.0.0",
|
| 813 |
+
port=port,
|
| 814 |
+
log_level="info",
|
| 815 |
+
access_log=True
|
| 816 |
+
)
|
railway.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"build": {
|
| 3 |
+
"builder": "DOCKERFILE"
|
| 4 |
+
},
|
| 5 |
+
"deploy": {
|
| 6 |
+
"startCommand": "uvicorn rag:app --host 0.0.0.0 --port $PORT",
|
| 7 |
+
"healthcheckPath": "/health",
|
| 8 |
+
"healthcheckTimeout": 300,
|
| 9 |
+
"restartPolicyType": "ON_FAILURE",
|
| 10 |
+
"restartPolicyMaxRetries": 3
|
| 11 |
+
}
|
| 12 |
+
}
|
req.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
asyncio
|
| 2 |
+
datetime
|
| 3 |
+
fastapi
|
| 4 |
+
python-magic
|
| 5 |
+
langchain
|
| 6 |
+
langchain-groq
|
| 7 |
+
logging
|
| 8 |
+
pydantic
|
| 9 |
+
pinecone
|
| 10 |
+
python-docx
|
| 11 |
+
python-dotenv
|
| 12 |
+
python-magic
|
| 13 |
+
requests
|
| 14 |
+
sentence-transformers
|
| 15 |
+
pdfplumber
|
| 16 |
+
typing
|
| 17 |
+
uvicorn
|
| 18 |
+
docx2txt
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
fastapi
|
| 3 |
+
langchain
|
| 4 |
+
langchain-groq
|
| 5 |
+
pydantic
|
| 6 |
+
pinecone # Correct name for Pinecone
|
| 7 |
+
python-docx
|
| 8 |
+
python-dotenv
|
| 9 |
+
python-magic # Use this if you're on Windows; otherwise use python-magic + libmagic on Linux
|
| 10 |
+
requests
|
| 11 |
+
sentence-transformers
|
| 12 |
+
pdfplumber
|
| 13 |
+
uvicorn
|
| 14 |
+
docx2txt
|
test.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
import logging
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class RAGSystemTester:
|
| 14 |
+
def __init__(self, base_url: str = "http://localhost:8000", bearer_token: str = None):
|
| 15 |
+
self.base_url = base_url
|
| 16 |
+
self.bearer_token = bearer_token or "dbbdb701cfc45d4041e22a03edbfc65753fe9d7b4b9ba1df4884e864f3bb934d"
|
| 17 |
+
self.headers = {
|
| 18 |
+
"Authorization": f"Bearer {self.bearer_token}",
|
| 19 |
+
"Content-Type": "application/json"
|
| 20 |
+
}
|
| 21 |
+
self.executor = ThreadPoolExecutor(max_workers=3)
|
| 22 |
+
|
| 23 |
+
def test_health_check(self) -> bool:
|
| 24 |
+
"""Test health check endpoint"""
|
| 25 |
+
try:
|
| 26 |
+
response = requests.get(f"{self.base_url}/health", timeout=10)
|
| 27 |
+
if response.status_code == 200:
|
| 28 |
+
data = response.json()
|
| 29 |
+
print(f"β
Health check passed: {data}")
|
| 30 |
+
return True
|
| 31 |
+
else:
|
| 32 |
+
print(f"β Health check failed: Status {response.status_code}")
|
| 33 |
+
logger.error(f"Health check failed with status: {response.status_code}")
|
| 34 |
+
return False
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"β Health check error: {str(e)}")
|
| 37 |
+
logger.error(f"Health check error: {str(e)}")
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
def test_metrics_endpoint(self) -> bool:
|
| 41 |
+
"""Test metrics endpoint"""
|
| 42 |
+
try:
|
| 43 |
+
response = requests.get(f"{self.base_url}/metrics", headers=self.headers, timeout=10)
|
| 44 |
+
if response.status_code == 200:
|
| 45 |
+
data = response.json()
|
| 46 |
+
print(f"β
Metrics endpoint passed: {json.dumps(data, indent=2)}")
|
| 47 |
+
return True
|
| 48 |
+
else:
|
| 49 |
+
print(f"β Metrics endpoint failed: Status {response.status_code}")
|
| 50 |
+
logger.error(f"Metrics endpoint failed with status: {response.status_code}")
|
| 51 |
+
return False
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"β Metrics endpoint error: {str(e)}")
|
| 54 |
+
logger.error(f"Metrics endpoint error: {str(e)}")
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
def test_sample_query(self) -> bool:
|
| 58 |
+
"""Test with the provided sample data"""
|
| 59 |
+
sample_data = {
|
| 60 |
+
"documents": "https://hackrx.blob.core.windows.net/assets/Arogya%20Sanjeevani%20Policy%20-%20CIN%20-%20U10200WB1906GOI001713%201.pdf?sv=2023-01-03&st=2025-07-21T08%3A29%3A02Z&se=2025-09-22T08%3A29%3A00Z&sr=b&sp=r&sig=nzrz1K9Iurt%2BBXom%2FB%2BMPTFMFP3PRnIvEsipAX10Ig4%3D",
|
| 61 |
+
"questions": [
|
| 62 |
+
"What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?",
|
| 63 |
+
"What is the waiting period for pre-existing diseases (PED) to be covered?",
|
| 64 |
+
"Does this policy cover maternity expenses, and what are the conditions?",
|
| 65 |
+
"What is the waiting period for cataract surgery?",
|
| 66 |
+
"Are the medical expenses for an organ donor covered under this policy?",
|
| 67 |
+
"What is the No Claim Discount (NCD) offered in this policy?",
|
| 68 |
+
"Is there a benefit for preventive health check-ups?",
|
| 69 |
+
"How does the policy define a 'Hospital'?",
|
| 70 |
+
"What is the extent of coverage for AYUSH treatments?",
|
| 71 |
+
"Are there any sub-limits on room rent and ICU charges for Plan A?"
|
| 72 |
+
]
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
print("π Testing sample query...")
|
| 77 |
+
start_time = time.time()
|
| 78 |
+
|
| 79 |
+
response = requests.post(
|
| 80 |
+
f"{self.base_url}/hackrx/run",
|
| 81 |
+
headers=self.headers,
|
| 82 |
+
json=sample_data,
|
| 83 |
+
timeout=120
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
end_time = time.time()
|
| 87 |
+
latency = end_time - start_time
|
| 88 |
+
|
| 89 |
+
if response.status_code == 200:
|
| 90 |
+
data = response.json()
|
| 91 |
+
answers = data.get("answers", [])
|
| 92 |
+
|
| 93 |
+
print(f"β
Sample query successful (Latency: {latency:.2f}s)")
|
| 94 |
+
print(f"π Received {len(answers)} answers")
|
| 95 |
+
|
| 96 |
+
# Print all answers for validation
|
| 97 |
+
for i, (question, answer) in enumerate(zip(sample_data['questions'], answers)):
|
| 98 |
+
print(f"Q{i+1}: {question}")
|
| 99 |
+
print(f"A{i+1}: {answer[:200]}..." if len(answer) > 200 else f"A{i+1}: {answer}")
|
| 100 |
+
print("-" * 50)
|
| 101 |
+
|
| 102 |
+
# Validate that we received answers for all questions
|
| 103 |
+
if len(answers) == len(sample_data['questions']):
|
| 104 |
+
print("β
All questions answered")
|
| 105 |
+
return True
|
| 106 |
+
else:
|
| 107 |
+
print(f"β Incomplete response: Expected {len(sample_data['questions'])} answers, got {len(answers)}")
|
| 108 |
+
logger.warning(f"Incomplete response: Expected {len(sample_data['questions'])} answers, got {len(answers)}")
|
| 109 |
+
return False
|
| 110 |
+
else:
|
| 111 |
+
print(f"β Sample query failed: Status {response.status_code}")
|
| 112 |
+
print(f"Response: {response.text}")
|
| 113 |
+
logger.error(f"Sample query failed: Status {response.status_code}, Response: {response.text}")
|
| 114 |
+
return False
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"β Sample query error: {str(e)}")
|
| 117 |
+
logger.error(f"Sample query error: {str(e)}")
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
async def test_concurrent_queries(self, num_requests: int = 3) -> bool:
|
| 121 |
+
"""Test system under concurrent load"""
|
| 122 |
+
async def make_request():
|
| 123 |
+
try:
|
| 124 |
+
response = requests.post(
|
| 125 |
+
f"{self.base_url}/hackrx/run",
|
| 126 |
+
headers=self.headers,
|
| 127 |
+
json={
|
| 128 |
+
"documents": "https://hackrx.blob.core.windows.net/assets/Arogya%20Sanjeevani%20Policy%20-%20CIN%20-%20U10200WB1906GOI001713%201.pdf?sv=2023-01-03&st=2025-07-21T08%3A29%3A02Z&se=2025-09-22T08%3A29%3A00Z&sr=b&sp=r&sig=nzrz1K9Iurt%2BBXom%2FB%2BMPTFMFP3PRnIvEsipAX10Ig4%3D",
|
| 129 |
+
"questions": ["What is the grace period for premium payment?"]
|
| 130 |
+
},
|
| 131 |
+
timeout=60
|
| 132 |
+
)
|
| 133 |
+
return response.status_code == 200
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Concurrent query error: {str(e)}")
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
print(f"π Testing {num_requests} concurrent queries...")
|
| 139 |
+
tasks = [make_request() for _ in range(num_requests)]
|
| 140 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 141 |
+
|
| 142 |
+
success_count = sum(1 for result in results if result is True)
|
| 143 |
+
print(f"β
Concurrent test completed: {success_count}/{num_requests} successful")
|
| 144 |
+
|
| 145 |
+
return success_count == num_requests
|
| 146 |
+
|
| 147 |
+
def test_invalid_token(self) -> bool:
|
| 148 |
+
"""Test authentication with invalid token"""
|
| 149 |
+
try:
|
| 150 |
+
invalid_headers = {
|
| 151 |
+
"Authorization": "Bearer invalid_token",
|
| 152 |
+
"Content-Type": "application/json"
|
| 153 |
+
}
|
| 154 |
+
response = requests.post(
|
| 155 |
+
f"{self.base_url}/hackrx/run",
|
| 156 |
+
headers=invalid_headers,
|
| 157 |
+
json={
|
| 158 |
+
"documents": "https://hackrx.blob.core.windows.net/assets/Arogya%20Sanjeevani%20Policy%20-%20CIN%20-%20U10200WB1906GOI001713%201.pdf?sv=2023-01-03&st=2025-07-21T08%3A29%3A02Z&se=2025-09-22T08%3A29%3A00Z&sr=b&sp=r&sig=nzrz1K9Iurt%2BBXom%2FB%2BMPTFMFP3PRnIvEsipAX10Ig4%3D",
|
| 159 |
+
"questions": ["Test question"]
|
| 160 |
+
},
|
| 161 |
+
timeout=10
|
| 162 |
+
)
|
| 163 |
+
if response.status_code == 401:
|
| 164 |
+
print("β
Invalid token test passed: Correctly rejected")
|
| 165 |
+
return True
|
| 166 |
+
else:
|
| 167 |
+
print(f"β Invalid token test failed: Expected 401, got {response.status_code}")
|
| 168 |
+
logger.warning(f"Invalid token test failed: Expected 401, got {response.status_code}")
|
| 169 |
+
return False
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"β Invalid token test error: {str(e)}")
|
| 172 |
+
logger.error(f"Invalid token test error: {str(e)}")
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
def test_invalid_url(self) -> bool:
|
| 176 |
+
"""Test with invalid document URL"""
|
| 177 |
+
try:
|
| 178 |
+
response = requests.post(
|
| 179 |
+
f"{self.base_url}/hackrx/run",
|
| 180 |
+
headers=self.headers,
|
| 181 |
+
json={
|
| 182 |
+
"documents": "https://invalid-url-that-does-not-exist.com/fake.pdf", # Actually invalid URL
|
| 183 |
+
"questions": ["Test question"]
|
| 184 |
+
},
|
| 185 |
+
timeout=30
|
| 186 |
+
)
|
| 187 |
+
# Accept either 400 or 500 as valid error responses for invalid URLs
|
| 188 |
+
if response.status_code in [400, 500]:
|
| 189 |
+
print("β
Invalid URL test passed: Correctly handled")
|
| 190 |
+
return True
|
| 191 |
+
else:
|
| 192 |
+
print(f"β Invalid URL test failed: Expected 400/500, got {response.status_code}")
|
| 193 |
+
logger.warning(f"Invalid URL test failed: Expected 400/500, got {response.status_code}")
|
| 194 |
+
return False
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"β Invalid URL test error: {str(e)}")
|
| 197 |
+
logger.error(f"Invalid URL test error: {str(e)}")
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
async def run_all_tests(self):
|
| 201 |
+
"""Run all test cases"""
|
| 202 |
+
print("π Starting RAG System Tests")
|
| 203 |
+
print("=" * 50)
|
| 204 |
+
|
| 205 |
+
results = {
|
| 206 |
+
"health_check": self.test_health_check(),
|
| 207 |
+
"metrics_endpoint": self.test_metrics_endpoint(),
|
| 208 |
+
"sample_query": self.test_sample_query(),
|
| 209 |
+
"concurrent_queries": await self.test_concurrent_queries(),
|
| 210 |
+
"invalid_token": self.test_invalid_token(),
|
| 211 |
+
"invalid_url": self.test_invalid_url()
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
print("\nπ Test Summary")
|
| 215 |
+
print("=" * 50)
|
| 216 |
+
passed = sum(1 for result in results.values() if result)
|
| 217 |
+
total = len(results)
|
| 218 |
+
|
| 219 |
+
for test_name, passed in results.items():
|
| 220 |
+
status = "β
PASSED" if passed else "β FAILED"
|
| 221 |
+
print(f"{test_name}: {status}")
|
| 222 |
+
|
| 223 |
+
print(f"\nπ― Overall: {passed}/{total} tests passed")
|
| 224 |
+
return passed == total
|
| 225 |
+
|
| 226 |
+
def main():
|
| 227 |
+
tester = RAGSystemTester()
|
| 228 |
+
asyncio.run(tester.run_all_tests())
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
main()
|
test_api.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for the Ultra-Fast RAG System API
|
| 4 |
+
Usage: python test_api.py <base_url>
|
| 5 |
+
Example: python test_api.py https://your-app.railway.app
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
def test_api(base_url):
|
| 14 |
+
"""Test the deployed API endpoints"""
|
| 15 |
+
|
| 16 |
+
print(f"π§ͺ Testing API at: {base_url}")
|
| 17 |
+
|
| 18 |
+
# Test 1: Health Check
|
| 19 |
+
print("\n1οΈβ£ Testing Health Check...")
|
| 20 |
+
try:
|
| 21 |
+
response = requests.get(f"{base_url}/health", timeout=10)
|
| 22 |
+
if response.status_code == 200:
|
| 23 |
+
print("β
Health check passed!")
|
| 24 |
+
print(f" Response: {response.json()}")
|
| 25 |
+
else:
|
| 26 |
+
print(f"β Health check failed: {response.status_code}")
|
| 27 |
+
return False
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"β Health check error: {e}")
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
# Test 2: Sample RAG Query
|
| 33 |
+
print("\n2οΈβ£ Testing RAG Endpoint...")
|
| 34 |
+
|
| 35 |
+
headers = {
|
| 36 |
+
"Content-Type": "application/json",
|
| 37 |
+
"Accept": "application/json",
|
| 38 |
+
"Authorization": "Bearer dbbdb701cfc45d4041e22a03edbfc65753fe9d7b4b9ba1df4884e864f3bb934d"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
test_payload = {
|
| 42 |
+
"documents": "https://hackrx.blob.core.windows.net/assets/policy.pdf?sv=2023-01-03&st=2025-07-04T09%3A11%3A24Z&se=2027-07-05T09%3A11%3A00Z&sr=b&sp=r&sig=N4a9OU0w0QXO6AOIBiu4bpl7AXvEZogeT%2FjUHNO7HzQ%3D",
|
| 43 |
+
"questions": [
|
| 44 |
+
"What is the grace period for premium payment?",
|
| 45 |
+
"What is the waiting period for pre-existing diseases?"
|
| 46 |
+
]
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
print(" Sending request... (this may take 15-30 seconds)")
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
|
| 53 |
+
response = requests.post(
|
| 54 |
+
f"{base_url}/hackrx/run",
|
| 55 |
+
json=test_payload,
|
| 56 |
+
headers=headers,
|
| 57 |
+
timeout=120
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
end_time = time.time()
|
| 61 |
+
response_time = end_time - start_time
|
| 62 |
+
|
| 63 |
+
if response.status_code == 200:
|
| 64 |
+
result = response.json()
|
| 65 |
+
print(f"β
RAG query successful! ({response_time:.2f} seconds)")
|
| 66 |
+
print(f" Questions: {len(test_payload['questions'])}")
|
| 67 |
+
print(f" Answers: {len(result['answers'])}")
|
| 68 |
+
print("\n Sample answers:")
|
| 69 |
+
for i, answer in enumerate(result['answers'][:2]):
|
| 70 |
+
print(f" Q{i+1}: {answer[:100]}...")
|
| 71 |
+
else:
|
| 72 |
+
print(f"β RAG query failed: {response.status_code}")
|
| 73 |
+
print(f" Response: {response.text}")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"β RAG query error: {e}")
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
# Test 3: Metrics (optional)
|
| 81 |
+
print("\n3οΈβ£ Testing Metrics Endpoint...")
|
| 82 |
+
try:
|
| 83 |
+
response = requests.get(
|
| 84 |
+
f"{base_url}/metrics",
|
| 85 |
+
headers=headers,
|
| 86 |
+
timeout=10
|
| 87 |
+
)
|
| 88 |
+
if response.status_code == 200:
|
| 89 |
+
print("β
Metrics endpoint working!")
|
| 90 |
+
metrics = response.json()
|
| 91 |
+
print(f" Total queries: {metrics.get('total_queries_24h', 0)}")
|
| 92 |
+
else:
|
| 93 |
+
print(f"β οΈ Metrics endpoint issue: {response.status_code}")
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"β οΈ Metrics endpoint error: {e}")
|
| 96 |
+
|
| 97 |
+
print(f"\nπ API testing complete! System is ready for hackathon use.")
|
| 98 |
+
return True
|
| 99 |
+
|
| 100 |
+
def main():
|
| 101 |
+
if len(sys.argv) != 2:
|
| 102 |
+
print("Usage: python test_api.py <base_url>")
|
| 103 |
+
print("Example: python test_api.py https://your-app.railway.app")
|
| 104 |
+
sys.exit(1)
|
| 105 |
+
|
| 106 |
+
base_url = sys.argv[1].rstrip('/')
|
| 107 |
+
|
| 108 |
+
print("π Ultra-Fast RAG System API Tester")
|
| 109 |
+
print("=" * 50)
|
| 110 |
+
|
| 111 |
+
success = test_api(base_url)
|
| 112 |
+
|
| 113 |
+
if success:
|
| 114 |
+
print("\nβ
All tests passed! Your API is ready for the hackathon! π")
|
| 115 |
+
print(f"\nπ API Usage Summary:")
|
| 116 |
+
print(f" Endpoint: POST {base_url}/hackrx/run")
|
| 117 |
+
print(f" Auth: Bearer dbbdb701cfc45d4041e22a03edbfc65753fe9d7b4b9ba1df4884e864f3bb934d")
|
| 118 |
+
print(f" Health: GET {base_url}/health")
|
| 119 |
+
else:
|
| 120 |
+
print("\nβ Some tests failed. Please check your deployment.")
|
| 121 |
+
sys.exit(1)
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
main()
|