Vivek Kadamati commited on
Commit ·
ee444c0
0
Parent(s):
Initial commit
Browse files- .env.example +15 -0
- .gitignore +26 -0
- Dockerfile +28 -0
- ENHANCEMENTS.md +120 -0
- GIT_PUSH_GUIDE.md +156 -0
- Procfile +1 -0
- README.md +294 -0
- SETUP.md +69 -0
- UPDATE_REMOTE.md +178 -0
- __init__.py +15 -0
- api.py +374 -0
- chunking_strategies.py +207 -0
- cleanup_chroma.py +93 -0
- config.py +64 -0
- dataset_loader.py +178 -0
- docker-compose.yml +26 -0
- embedding_models.py +325 -0
- example.py +118 -0
- llm_client.py +351 -0
- requirements.txt +40 -0
- run.py +99 -0
- streamlit_app.py +721 -0
- trace_evaluator.py +352 -0
- vector_store.py +412 -0
.env.example
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Groq API Configuration
|
| 2 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 3 |
+
|
| 4 |
+
# Google Gemini API Configuration (for gemini-embedding-001)
|
| 5 |
+
GEMINI_API_KEY=your_gemini_api_key_here
|
| 6 |
+
|
| 7 |
+
# ChromaDB Configuration
|
| 8 |
+
CHROMA_PERSIST_DIRECTORY=./chroma_db
|
| 9 |
+
|
| 10 |
+
# Rate Limiting
|
| 11 |
+
GROQ_RPM_LIMIT=30
|
| 12 |
+
RATE_LIMIT_DELAY=2.0
|
| 13 |
+
|
| 14 |
+
# Application Configuration
|
| 15 |
+
LOG_LEVEL=INFO
|
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.so
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info/
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
.env
|
| 12 |
+
.venv
|
| 13 |
+
venv/
|
| 14 |
+
env/
|
| 15 |
+
data_cache/
|
| 16 |
+
*.log
|
| 17 |
+
.DS_Store
|
| 18 |
+
.vscode/
|
| 19 |
+
.idea/
|
| 20 |
+
*.swp
|
| 21 |
+
*.swo
|
| 22 |
+
*~
|
| 23 |
+
.pytest_cache/
|
| 24 |
+
.coverage
|
| 25 |
+
htmlcov/
|
| 26 |
+
chroma_db/
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
curl \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements and install Python dependencies
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy application files
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# Create directories for data
|
| 19 |
+
RUN mkdir -p chroma_db data_cache
|
| 20 |
+
|
| 21 |
+
# Expose ports
|
| 22 |
+
EXPOSE 8501 8000
|
| 23 |
+
|
| 24 |
+
# Set environment variables
|
| 25 |
+
ENV PYTHONUNBUFFERED=1
|
| 26 |
+
|
| 27 |
+
# Run Streamlit by default
|
| 28 |
+
CMD ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
ENHANCEMENTS.md
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG Application Enhancements
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
The application has been enhanced with collection management, LLM selection, and improved user experience.
|
| 5 |
+
|
| 6 |
+
## Key Enhancements
|
| 7 |
+
|
| 8 |
+
### 1. **Existing Collections Management** 🗂️
|
| 9 |
+
- **Auto-detection**: On application startup, the system automatically detects all existing collections in ChromaDB
|
| 10 |
+
- **Load Existing Collection**: Users can now choose from existing collections and load them directly without recreating
|
| 11 |
+
- **Collection Selection**: Dropdown menu shows all available collections for quick access
|
| 12 |
+
- **Seamless Loading**: Click "📖 Load Existing Collection" to use a previously created collection
|
| 13 |
+
|
| 14 |
+
### 2. **Smart Collection Recreation** 🔄
|
| 15 |
+
- **Selective Deletion**: When creating a new collection, only that specific collection is deleted and recreated
|
| 16 |
+
- **Other Collections Preserved**: All other existing collections remain untouched and unaffected
|
| 17 |
+
- **Conflict Resolution**: If a collection with the same name exists, it's deleted before creating the new one
|
| 18 |
+
- **User Feedback**: Clear warnings and progress messages when deleting and recreating collections
|
| 19 |
+
|
| 20 |
+
### 3. **LLM Selection Options** 🤖
|
| 21 |
+
|
| 22 |
+
#### Chat Interface
|
| 23 |
+
- **Dynamic LLM Selector**: Switch between different LLM models while chatting
|
| 24 |
+
- **Real-time Switching**: Change LLM without reloading the collection
|
| 25 |
+
- **Automatic Pipeline Update**: The RAG pipeline automatically updates when a new LLM is selected
|
| 26 |
+
- **Persistent Selection**: The selected LLM is remembered in the session state
|
| 27 |
+
|
| 28 |
+
#### Evaluation Interface
|
| 29 |
+
- **Evaluation-specific LLM**: Choose a different LLM for running TRACE evaluation
|
| 30 |
+
- **Independent Selection**: Evaluation LLM can be different from the chat LLM
|
| 31 |
+
- **Automatic Restoration**: After evaluation, the system restores the original LLM
|
| 32 |
+
- **Flexible Testing**: Test different LLM models on the same dataset and collection
|
| 33 |
+
|
| 34 |
+
### 4. **User Interface Improvements** 🎨
|
| 35 |
+
- **Two-step Process**:
|
| 36 |
+
1. Load existing collection OR
|
| 37 |
+
2. Create new collection (with all configuration options)
|
| 38 |
+
- **Clear Sections**: Separated sidebar sections for existing vs. new collections
|
| 39 |
+
- **Visual Indicators**: Icons and colors to distinguish different actions
|
| 40 |
+
- **Better Organization**: Configuration options logically grouped and hierarchical
|
| 41 |
+
|
| 42 |
+
## Technical Implementation
|
| 43 |
+
|
| 44 |
+
### New Functions
|
| 45 |
+
- `get_available_collections()`: Fetches list of collections from ChromaDB
|
| 46 |
+
- `load_existing_collection()`: Loads a pre-existing collection with LLM selection
|
| 47 |
+
- Updated `load_and_create_collection()`: Handles selective collection deletion
|
| 48 |
+
|
| 49 |
+
### Session State Variables
|
| 50 |
+
- `current_llm`: Tracks the currently selected LLM
|
| 51 |
+
- `selected_collection`: Tracks which collection is loaded
|
| 52 |
+
- `available_collections`: Stores list of available collections
|
| 53 |
+
|
| 54 |
+
### Collection Naming Convention
|
| 55 |
+
Collections are named as: `{dataset}_{chunking_strategy}_{embedding_model_short_name}`
|
| 56 |
+
|
| 57 |
+
Example: `covidqa_dense_all_mpnet`
|
| 58 |
+
|
| 59 |
+
## User Workflow
|
| 60 |
+
|
| 61 |
+
### Scenario 1: Using Existing Collection
|
| 62 |
+
1. Application starts and detects existing collections
|
| 63 |
+
2. User selects a collection from the dropdown
|
| 64 |
+
3. User clicks "📖 Load Existing Collection"
|
| 65 |
+
4. User selects an LLM for chatting
|
| 66 |
+
5. User can start chatting immediately
|
| 67 |
+
|
| 68 |
+
### Scenario 2: Creating New Collection
|
| 69 |
+
1. User selects dataset from sidebar
|
| 70 |
+
2. User clicks "🔍 Check Dataset Size" (optional)
|
| 71 |
+
3. User configures chunking strategy and chunk parameters
|
| 72 |
+
4. User selects embedding model
|
| 73 |
+
5. User selects LLM
|
| 74 |
+
6. User clicks "🚀 Load Data & Create Collection"
|
| 75 |
+
7. System deletes any existing collection with same name
|
| 76 |
+
8. System creates new collection with fresh data
|
| 77 |
+
|
| 78 |
+
### Scenario 3: Switching LLMs During Chat
|
| 79 |
+
1. Chat interface shows current collection and LLM selector
|
| 80 |
+
2. User selects different LLM from "Select LLM for chat"
|
| 81 |
+
3. RAG pipeline automatically updates with new LLM
|
| 82 |
+
4. Continue chatting with new LLM
|
| 83 |
+
|
| 84 |
+
### Scenario 4: Running Evaluation with Different LLM
|
| 85 |
+
1. In Evaluation tab, user can select a different LLM
|
| 86 |
+
2. Click "🔬 Run Evaluation"
|
| 87 |
+
3. System uses selected LLM for evaluation
|
| 88 |
+
4. Results are displayed with metrics
|
| 89 |
+
5. Original chat LLM is restored after evaluation
|
| 90 |
+
|
| 91 |
+
## Benefits
|
| 92 |
+
|
| 93 |
+
✅ **Efficiency**: No need to recreate collections when testing different configurations
|
| 94 |
+
✅ **Flexibility**: Easily compare different LLM models on the same data
|
| 95 |
+
✅ **Safety**: Other collections remain untouched when managing new ones
|
| 96 |
+
✅ **User Experience**: Clearer navigation and configuration options
|
| 97 |
+
✅ **Time Saving**: Reuse existing collections instead of recreating them
|
| 98 |
+
✅ **Testing**: Run evaluations with different LLMs for comprehensive analysis
|
| 99 |
+
|
| 100 |
+
## API Key Management
|
| 101 |
+
|
| 102 |
+
- API Key input is required at startup
|
| 103 |
+
- Store in sidebar for use across all operations
|
| 104 |
+
- Used for both chat and evaluation with different LLMs
|
| 105 |
+
|
| 106 |
+
## Error Handling
|
| 107 |
+
|
| 108 |
+
- Collection not found errors show helpful messages
|
| 109 |
+
- LLM loading failures fall back to default model
|
| 110 |
+
- Graceful error messages for all operations
|
| 111 |
+
- Automatic reconnection to ChromaDB if connection is lost
|
| 112 |
+
|
| 113 |
+
## Future Enhancement Ideas
|
| 114 |
+
|
| 115 |
+
- 💾 Save evaluation results with metadata
|
| 116 |
+
- 📊 Compare multiple LLM evaluation results
|
| 117 |
+
- 🔄 Batch collection operations (delete multiple)
|
| 118 |
+
- 📈 Analytics dashboard for collection usage
|
| 119 |
+
- 🏷️ Collection tagging/categorization system
|
| 120 |
+
- 💬 Multi-turn evaluation with conversation history
|
GIT_PUSH_GUIDE.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Push Code to GitHub - Steps
|
| 2 |
+
|
| 3 |
+
## Option 1: Push to Existing GitHub Repository
|
| 4 |
+
|
| 5 |
+
If you already have a repository on GitHub, follow these steps:
|
| 6 |
+
|
| 7 |
+
### 1. On GitHub, create a new repository:
|
| 8 |
+
- Go to https://github.com/new
|
| 9 |
+
- Fill in repository name: `RAG-Capstone-Project` (or your preferred name)
|
| 10 |
+
- Add description: "Retrieval-Augmented Generation (RAG) system with TRACE evaluation metrics"
|
| 11 |
+
- Choose: Public or Private
|
| 12 |
+
- **DO NOT** initialize with README, .gitignore, or license (we already have these)
|
| 13 |
+
- Click "Create repository"
|
| 14 |
+
|
| 15 |
+
### 2. In PowerShell, add remote and push:
|
| 16 |
+
|
| 17 |
+
```powershell
|
| 18 |
+
cd "D:\CapStoneProject\RAG Capstone Project"
|
| 19 |
+
|
| 20 |
+
# Add remote (replace YOUR_USERNAME and REPO_NAME)
|
| 21 |
+
git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 22 |
+
|
| 23 |
+
# Rename branch to main (optional but recommended)
|
| 24 |
+
git branch -M main
|
| 25 |
+
|
| 26 |
+
# Push to GitHub
|
| 27 |
+
git push -u origin main
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### 3. When prompted, enter your GitHub credentials:
|
| 31 |
+
- Username: Your GitHub username
|
| 32 |
+
- Password: Your GitHub personal access token (not your password)
|
| 33 |
+
- Generate token: https://github.com/settings/tokens
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Option 2: If You Don't Have GitHub Account Yet
|
| 38 |
+
|
| 39 |
+
1. Go to https://github.com/join
|
| 40 |
+
2. Create a free account
|
| 41 |
+
3. Follow Option 1 steps above
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Troubleshooting
|
| 46 |
+
|
| 47 |
+
### If you get "fatal: remote origin already exists":
|
| 48 |
+
```powershell
|
| 49 |
+
git remote remove origin
|
| 50 |
+
git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 51 |
+
git push -u origin main
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### If you get authentication error:
|
| 55 |
+
1. Generate Personal Access Token:
|
| 56 |
+
- Go to https://github.com/settings/tokens
|
| 57 |
+
- Click "Generate new token (classic)"
|
| 58 |
+
- Select: repo, write:packages
|
| 59 |
+
- Copy the token
|
| 60 |
+
2. Use token instead of password when prompted
|
| 61 |
+
|
| 62 |
+
### If you have SSH key setup:
|
| 63 |
+
```powershell
|
| 64 |
+
git remote add origin git@github.com:YOUR_USERNAME/RAG-Capstone-Project.git
|
| 65 |
+
git push -u origin main
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## What Will Be Pushed
|
| 71 |
+
|
| 72 |
+
✅ All Python source files:
|
| 73 |
+
- `streamlit_app.py` - Main Streamlit UI
|
| 74 |
+
- `llm_client.py` - Groq LLM integration
|
| 75 |
+
- `vector_store.py` - ChromaDB management
|
| 76 |
+
- `embedding_models.py` - 8 embedding models
|
| 77 |
+
- `trace_evaluator.py` - TRACE metrics
|
| 78 |
+
- `dataset_loader.py` - RAGBench dataset loading
|
| 79 |
+
- `chunking_strategies.py` - 4 chunking strategies
|
| 80 |
+
- And more...
|
| 81 |
+
|
| 82 |
+
✅ Documentation:
|
| 83 |
+
- `README.md` - Project overview
|
| 84 |
+
- `SETUP.md` - Installation guide
|
| 85 |
+
- `ENHANCEMENTS.md` - Recent enhancements
|
| 86 |
+
- `.env.example` - Environment template
|
| 87 |
+
|
| 88 |
+
✅ Configuration:
|
| 89 |
+
- `requirements.txt` - All dependencies
|
| 90 |
+
- `Dockerfile` - Docker containerization
|
| 91 |
+
- `docker-compose.yml` - Multi-container setup
|
| 92 |
+
- `Procfile` - Heroku deployment
|
| 93 |
+
|
| 94 |
+
❌ Excluded (in .gitignore):
|
| 95 |
+
- `venv/` - Virtual environment
|
| 96 |
+
- `chroma_db/` - ChromaDB data
|
| 97 |
+
- `.env` - API keys (keep local!)
|
| 98 |
+
- `__pycache__/` - Python cache
|
| 99 |
+
- `.streamlit/` - Streamlit config
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## Next Steps After Pushing
|
| 104 |
+
|
| 105 |
+
1. **Add a GitHub Actions workflow** (optional):
|
| 106 |
+
- Automated testing
|
| 107 |
+
- Code quality checks
|
| 108 |
+
- Deployment automation
|
| 109 |
+
|
| 110 |
+
2. **Set up branch protection**:
|
| 111 |
+
- Require pull request reviews
|
| 112 |
+
- Enforce status checks
|
| 113 |
+
|
| 114 |
+
3. **Add GitHub Pages documentation**:
|
| 115 |
+
- Host project documentation
|
| 116 |
+
- API documentation
|
| 117 |
+
- Evaluation results
|
| 118 |
+
|
| 119 |
+
4. **Setup CI/CD**:
|
| 120 |
+
- Test on every push
|
| 121 |
+
- Deploy to Heroku/Cloud Run
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## Commands Summary
|
| 126 |
+
|
| 127 |
+
```powershell
|
| 128 |
+
# Navigate to project
|
| 129 |
+
cd "D:\CapStoneProject\RAG Capstone Project"
|
| 130 |
+
|
| 131 |
+
# Configure git
|
| 132 |
+
git config user.email "your_email@example.com"
|
| 133 |
+
git config user.name "Your Name"
|
| 134 |
+
|
| 135 |
+
# Add remote (replace placeholders)
|
| 136 |
+
git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 137 |
+
|
| 138 |
+
# Rename branch to main
|
| 139 |
+
git branch -M main
|
| 140 |
+
|
| 141 |
+
# Push to GitHub
|
| 142 |
+
git push -u origin main
|
| 143 |
+
|
| 144 |
+
# Verify
|
| 145 |
+
git remote -v
|
| 146 |
+
git log --oneline
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
**Share the GitHub URL with your team:**
|
| 152 |
+
```
|
| 153 |
+
https://github.com/YOUR_USERNAME/RAG-Capstone-Project
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Let me know when you have your GitHub username ready, and I can help you complete the push!
|
Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: streamlit run streamlit_app.py --server.port=$PORT --server.address=0.0.0.0
|
README.md
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG Capstone Project
|
| 2 |
+
|
| 3 |
+
A comprehensive Retrieval-Augmented Generation (RAG) system with TRACE evaluation metrics for medical/clinical domains.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- 🔍 **Multiple RAG Bench Datasets**: HotpotQA, 2WikiMultihopQA, MuSiQue, Natural Questions, TriviaQA
|
| 8 |
+
- 🧩 **Chunking Strategies**: Dense, Sparse, Hybrid, Re-ranking
|
| 9 |
+
- 🤖 **Medical Embedding Models**:
|
| 10 |
+
- sentence-transformers/embeddinggemma-300m-medical
|
| 11 |
+
- emilyalsentzer/Bio_ClinicalBERT
|
| 12 |
+
- Simonlee711/Clinical_ModernBERT
|
| 13 |
+
- 💾 **ChromaDB Vector Storage**: Persistent vector storage with efficient retrieval
|
| 14 |
+
- 🦙 **Groq LLM Integration**: With rate limiting (30 RPM)
|
| 15 |
+
- meta-llama/llama-4-maverick-17b-128e-instruct
|
| 16 |
+
- llama-3.1-8b-instant
|
| 17 |
+
- openai/gpt-oss-120b
|
| 18 |
+
- 📊 **TRACE Evaluation Metrics**:
|
| 19 |
+
- **U**tilization: How well the system uses retrieved documents
|
| 20 |
+
- **R**elevance: Relevance of retrieved documents to the query
|
| 21 |
+
- **A**dherence: How well the response adheres to the retrieved context
|
| 22 |
+
- **C**ompleteness: How complete the response is
|
| 23 |
+
- 💬 **Chat Interface**: Streamlit-based interactive chat with history
|
| 24 |
+
- 🔌 **REST API**: FastAPI backend for integration
|
| 25 |
+
|
| 26 |
+
## Installation
|
| 27 |
+
|
| 28 |
+
### Prerequisites
|
| 29 |
+
|
| 30 |
+
- Python 3.8+
|
| 31 |
+
- pip
|
| 32 |
+
- Groq API key
|
| 33 |
+
|
| 34 |
+
### Setup
|
| 35 |
+
|
| 36 |
+
1. Clone the repository:
|
| 37 |
+
```bash
|
| 38 |
+
git clone <repository-url>
|
| 39 |
+
cd "RAG Capstone Project"
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
2. Create a virtual environment:
|
| 43 |
+
```bash
|
| 44 |
+
python -m venv venv
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
3. Activate the virtual environment:
|
| 48 |
+
|
| 49 |
+
**Windows:**
|
| 50 |
+
```bash
|
| 51 |
+
.\venv\Scripts\activate
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
**Linux/Mac:**
|
| 55 |
+
```bash
|
| 56 |
+
source venv/bin/activate
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
4. Install dependencies:
|
| 60 |
+
```bash
|
| 61 |
+
pip install -r requirements.txt
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
5. Create a `.env` file from the example:
|
| 65 |
+
```bash
|
| 66 |
+
copy .env.example .env
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
6. Edit `.env` and add your Groq API key:
|
| 70 |
+
```
|
| 71 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Usage
|
| 75 |
+
|
| 76 |
+
### Streamlit Application
|
| 77 |
+
|
| 78 |
+
Run the interactive Streamlit interface:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
streamlit run streamlit_app.py
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
Then open your browser to `http://localhost:8501`
|
| 85 |
+
|
| 86 |
+
**Workflow:**
|
| 87 |
+
1. Enter your Groq API key in the sidebar
|
| 88 |
+
2. Select a dataset from RAG Bench
|
| 89 |
+
3. Choose chunking strategy
|
| 90 |
+
4. Select embedding model
|
| 91 |
+
5. Choose LLM model
|
| 92 |
+
6. Click "Load Data & Create Collection"
|
| 93 |
+
7. Start chatting!
|
| 94 |
+
8. View retrieved documents
|
| 95 |
+
9. Run TRACE evaluation
|
| 96 |
+
10. Export chat history
|
| 97 |
+
|
| 98 |
+
### FastAPI Backend
|
| 99 |
+
|
| 100 |
+
Run the REST API server:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
python api.py
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
Or with uvicorn:
|
| 107 |
+
```bash
|
| 108 |
+
uvicorn api:app --reload --host 0.0.0.0 --port 8000
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
API documentation available at: `http://localhost:8000/docs`
|
| 112 |
+
|
| 113 |
+
#### API Endpoints
|
| 114 |
+
|
| 115 |
+
- `GET /` - Root endpoint
|
| 116 |
+
- `GET /health` - Health check
|
| 117 |
+
- `GET /datasets` - List available datasets
|
| 118 |
+
- `GET /models/embedding` - List embedding models
|
| 119 |
+
- `GET /models/llm` - List LLM models
|
| 120 |
+
- `GET /chunking-strategies` - List chunking strategies
|
| 121 |
+
- `GET /collections` - List all collections
|
| 122 |
+
- `GET /collections/{name}` - Get collection info
|
| 123 |
+
- `POST /load-dataset` - Load dataset and create collection
|
| 124 |
+
- `POST /query` - Query the RAG system
|
| 125 |
+
- `GET /chat-history` - Get chat history
|
| 126 |
+
- `DELETE /chat-history` - Clear chat history
|
| 127 |
+
- `POST /evaluate` - Run TRACE evaluation
|
| 128 |
+
- `DELETE /collections/{name}` - Delete collection
|
| 129 |
+
|
| 130 |
+
### Python API
|
| 131 |
+
|
| 132 |
+
Use the components programmatically:
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
from config import settings
|
| 136 |
+
from dataset_loader import RAGBenchLoader
|
| 137 |
+
from vector_store import ChromaDBManager
|
| 138 |
+
from llm_client import GroqLLMClient, RAGPipeline
|
| 139 |
+
from trace_evaluator import TRACEEvaluator
|
| 140 |
+
|
| 141 |
+
# Load dataset
|
| 142 |
+
loader = RAGBenchLoader()
|
| 143 |
+
dataset = loader.load_dataset("hotpotqa", max_samples=100)
|
| 144 |
+
|
| 145 |
+
# Create vector store
|
| 146 |
+
vector_store = ChromaDBManager()
|
| 147 |
+
vector_store.load_dataset_into_collection(
|
| 148 |
+
collection_name="my_collection",
|
| 149 |
+
embedding_model_name="emilyalsentzer/Bio_ClinicalBERT",
|
| 150 |
+
chunking_strategy="hybrid",
|
| 151 |
+
dataset_data=dataset
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Initialize LLM
|
| 155 |
+
llm = GroqLLMClient(
|
| 156 |
+
api_key="your_api_key",
|
| 157 |
+
model_name="llama-3.1-8b-instant"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Create RAG pipeline
|
| 161 |
+
rag = RAGPipeline(llm, vector_store)
|
| 162 |
+
|
| 163 |
+
# Query
|
| 164 |
+
result = rag.query("What is the capital of France?")
|
| 165 |
+
print(result["response"])
|
| 166 |
+
|
| 167 |
+
# Evaluate
|
| 168 |
+
evaluator = TRACEEvaluator()
|
| 169 |
+
test_cases = [...] # Your test cases
|
| 170 |
+
results = evaluator.evaluate_batch(test_cases)
|
| 171 |
+
print(results)
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## Project Structure
|
| 175 |
+
|
| 176 |
+
```
|
| 177 |
+
RAG Capstone Project/
|
| 178 |
+
├── __init__.py # Package initialization
|
| 179 |
+
├── config.py # Configuration management
|
| 180 |
+
├── dataset_loader.py # RAG Bench dataset loader
|
| 181 |
+
├── chunking_strategies.py # Document chunking strategies
|
| 182 |
+
├── embedding_models.py # Embedding model implementations
|
| 183 |
+
├── vector_store.py # ChromaDB integration
|
| 184 |
+
├── llm_client.py # Groq LLM client with rate limiting
|
| 185 |
+
├── trace_evaluator.py # TRACE evaluation metrics
|
| 186 |
+
├── streamlit_app.py # Streamlit chat interface
|
| 187 |
+
├── api.py # FastAPI REST API
|
| 188 |
+
├── requirements.txt # Python dependencies
|
| 189 |
+
├── .env.example # Environment variables template
|
| 190 |
+
├── .gitignore # Git ignore file
|
| 191 |
+
└── README.md # This file
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
## TRACE Metrics Explained
|
| 195 |
+
|
| 196 |
+
### Utilization (U)
|
| 197 |
+
Measures how well the system uses the retrieved documents in generating the response. Higher scores indicate that the system effectively incorporates information from multiple retrieved documents.
|
| 198 |
+
|
| 199 |
+
### Relevance (R)
|
| 200 |
+
Evaluates the relevance of retrieved documents to the user's query. Uses lexical overlap and keyword matching to determine if the right documents were retrieved.
|
| 201 |
+
|
| 202 |
+
### Adherence (A)
|
| 203 |
+
Assesses how well the generated response adheres to the retrieved context. Ensures the response is grounded in the provided documents rather than hallucinated.
|
| 204 |
+
|
| 205 |
+
### Completeness (C)
|
| 206 |
+
Evaluates how complete the response is in answering the query. Considers response length, question type, and comparison with ground truth if available.
|
| 207 |
+
|
| 208 |
+
## Deployment Options
|
| 209 |
+
|
| 210 |
+
### Heroku
|
| 211 |
+
|
| 212 |
+
1. Create `Procfile`:
|
| 213 |
+
```
|
| 214 |
+
web: streamlit run streamlit_app.py --server.port=$PORT --server.address=0.0.0.0
|
| 215 |
+
api: uvicorn api:app --host=0.0.0.0 --port=$PORT
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
2. Deploy:
|
| 219 |
+
```bash
|
| 220 |
+
heroku create your-app-name
|
| 221 |
+
git push heroku main
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### Docker
|
| 225 |
+
|
| 226 |
+
Create `Dockerfile`:
|
| 227 |
+
```dockerfile
|
| 228 |
+
FROM python:3.9-slim
|
| 229 |
+
|
| 230 |
+
WORKDIR /app
|
| 231 |
+
COPY requirements.txt .
|
| 232 |
+
RUN pip install -r requirements.txt
|
| 233 |
+
|
| 234 |
+
COPY . .
|
| 235 |
+
|
| 236 |
+
EXPOSE 8501 8000
|
| 237 |
+
|
| 238 |
+
CMD ["streamlit", "run", "streamlit_app.py"]
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
Build and run:
|
| 242 |
+
```bash
|
| 243 |
+
docker build -t rag-capstone .
|
| 244 |
+
docker run -p 8501:8501 -p 8000:8000 rag-capstone
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Cloud Run / AWS / Azure
|
| 248 |
+
|
| 249 |
+
The application can be deployed to any cloud platform that supports Python applications. See the respective platform documentation for deployment instructions.
|
| 250 |
+
|
| 251 |
+
## Configuration
|
| 252 |
+
|
| 253 |
+
Edit `config.py` or set environment variables in `.env`:
|
| 254 |
+
|
| 255 |
+
```env
|
| 256 |
+
GROQ_API_KEY=your_api_key
|
| 257 |
+
CHROMA_PERSIST_DIRECTORY=./chroma_db
|
| 258 |
+
GROQ_RPM_LIMIT=30
|
| 259 |
+
RATE_LIMIT_DELAY=2.0
|
| 260 |
+
LOG_LEVEL=INFO
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
## Rate Limiting
|
| 264 |
+
|
| 265 |
+
The application implements rate limiting for Groq API calls:
|
| 266 |
+
- Maximum 30 requests per minute (configurable)
|
| 267 |
+
- Automatic delay of 2 seconds between requests
|
| 268 |
+
- Smart waiting when rate limit is reached
|
| 269 |
+
|
| 270 |
+
## Troubleshooting
|
| 271 |
+
|
| 272 |
+
### ChromaDB Issues
|
| 273 |
+
If you encounter ChromaDB errors, try deleting the `chroma_db` directory and recreating collections.
|
| 274 |
+
|
| 275 |
+
### Embedding Model Loading
|
| 276 |
+
Medical embedding models may require significant memory. If you encounter out-of-memory errors, try:
|
| 277 |
+
- Using a smaller model
|
| 278 |
+
- Reducing batch size
|
| 279 |
+
- Using CPU instead of GPU
|
| 280 |
+
|
| 281 |
+
### API Key Errors
|
| 282 |
+
Ensure your Groq API key is correctly set in the `.env` file or passed to the application.
|
| 283 |
+
|
| 284 |
+
## License
|
| 285 |
+
|
| 286 |
+
MIT License
|
| 287 |
+
|
| 288 |
+
## Contributors
|
| 289 |
+
|
| 290 |
+
RAG Capstone Team
|
| 291 |
+
|
| 292 |
+
## Support
|
| 293 |
+
|
| 294 |
+
For issues and questions, please open an issue on the GitHub repository.
|
SETUP.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quick Setup Guide (Windows)
|
| 2 |
+
|
| 3 |
+
## Requirements
|
| 4 |
+
- Python 3.10+
|
| 5 |
+
- Groq API Key
|
| 6 |
+
|
| 7 |
+
## Installation Steps
|
| 8 |
+
|
| 9 |
+
### 1. Create Virtual Environment
|
| 10 |
+
```powershell
|
| 11 |
+
python -m venv venv
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
### 2. Activate Virtual Environment
|
| 15 |
+
```powershell
|
| 16 |
+
.\venv\Scripts\Activate.ps1
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
**If you get execution policy error:**
|
| 20 |
+
```powershell
|
| 21 |
+
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### 3. Upgrade pip
|
| 25 |
+
```powershell
|
| 26 |
+
python -m pip install --upgrade pip
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 4. Install Dependencies
|
| 30 |
+
```powershell
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 5. Configure API Key
|
| 35 |
+
```powershell
|
| 36 |
+
copy .env.example .env
|
| 37 |
+
notepad .env
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Add your Groq API key:
|
| 41 |
+
```
|
| 42 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### 6. Run Application
|
| 46 |
+
```powershell
|
| 47 |
+
streamlit run streamlit_app.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Open browser to: **http://localhost:8501**
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## Common Issues
|
| 55 |
+
|
| 56 |
+
**Execution Policy Error:**
|
| 57 |
+
```powershell
|
| 58 |
+
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
**Reset ChromaDB:**
|
| 62 |
+
```powershell
|
| 63 |
+
Remove-Item -Recurse -Force .\chroma_db
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
**Deactivate venv:**
|
| 67 |
+
```powershell
|
| 68 |
+
deactivate
|
| 69 |
+
```
|
UPDATE_REMOTE.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to Update Remote Origin
|
| 2 |
+
|
| 3 |
+
## Method 1: Change Existing Remote URL (Recommended)
|
| 4 |
+
|
| 5 |
+
If you already have a remote set and want to change it:
|
| 6 |
+
|
| 7 |
+
```powershell
|
| 8 |
+
# View current remote
|
| 9 |
+
git remote -v
|
| 10 |
+
|
| 11 |
+
# Option A: Update with HTTPS URL
|
| 12 |
+
git remote set-url origin https://github.com/YOUR_USERNAME/YOUR_REPO_NAME.git
|
| 13 |
+
|
| 14 |
+
# Option B: Update with SSH URL
|
| 15 |
+
git remote set-url origin git@github.com:YOUR_USERNAME/YOUR_REPO_NAME.git
|
| 16 |
+
|
| 17 |
+
# Verify the change
|
| 18 |
+
git remote -v
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Method 2: Remove and Re-add Remote
|
| 24 |
+
|
| 25 |
+
If you want to completely remove and add a new remote:
|
| 26 |
+
|
| 27 |
+
```powershell
|
| 28 |
+
# Remove existing remote
|
| 29 |
+
git remote remove origin
|
| 30 |
+
|
| 31 |
+
# Add new remote
|
| 32 |
+
git remote add origin https://github.com/YOUR_USERNAME/YOUR_REPO_NAME.git
|
| 33 |
+
|
| 34 |
+
# Verify
|
| 35 |
+
git remote -v
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Method 3: Using Your GitHub Repository
|
| 41 |
+
|
| 42 |
+
### Step 1: Create Repository on GitHub
|
| 43 |
+
1. Go to https://github.com/new
|
| 44 |
+
2. Repository name: `RAG-Capstone-Project`
|
| 45 |
+
3. Description: `Retrieval-Augmented Generation system with TRACE evaluation`
|
| 46 |
+
4. Select Public or Private
|
| 47 |
+
5. **IMPORTANT**: Don't initialize with README, .gitignore, or license
|
| 48 |
+
6. Click "Create repository"
|
| 49 |
+
|
| 50 |
+
### Step 2: Copy Your Repository URL
|
| 51 |
+
After creating, GitHub will show you the URL. Copy it (should look like):
|
| 52 |
+
```
|
| 53 |
+
https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Step 3: Update Remote in Your Local Repository
|
| 57 |
+
```powershell
|
| 58 |
+
cd "D:\CapStoneProject\RAG Capstone Project"
|
| 59 |
+
|
| 60 |
+
git remote set-url origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 61 |
+
|
| 62 |
+
# Verify
|
| 63 |
+
git remote -v
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Complete Step-by-Step Example
|
| 69 |
+
|
| 70 |
+
Let's say your GitHub username is `john-doe`:
|
| 71 |
+
|
| 72 |
+
```powershell
|
| 73 |
+
# 1. Navigate to project
|
| 74 |
+
cd "D:\CapStoneProject\RAG Capstone Project"
|
| 75 |
+
|
| 76 |
+
# 2. Update the remote URL
|
| 77 |
+
git remote set-url origin https://github.com/john-doe/RAG-Capstone-Project.git
|
| 78 |
+
|
| 79 |
+
# 3. Verify it was updated
|
| 80 |
+
git remote -v
|
| 81 |
+
|
| 82 |
+
# 4. Push to GitHub (first time)
|
| 83 |
+
git push -u origin main
|
| 84 |
+
|
| 85 |
+
# 5. For future pushes (just use)
|
| 86 |
+
git push
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Using SSH Instead of HTTPS
|
| 92 |
+
|
| 93 |
+
### Step 1: Generate SSH Key (if you don't have one)
|
| 94 |
+
```powershell
|
| 95 |
+
ssh-keygen -t ed25519 -C "your_email@example.com"
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Step 2: Add SSH Key to GitHub
|
| 99 |
+
1. Copy the public key: `cat ~/.ssh/id_ed25519.pub`
|
| 100 |
+
2. Go to https://github.com/settings/keys
|
| 101 |
+
3. Click "New SSH key"
|
| 102 |
+
4. Paste your public key
|
| 103 |
+
5. Save
|
| 104 |
+
|
| 105 |
+
### Step 3: Update Remote to SSH
|
| 106 |
+
```powershell
|
| 107 |
+
git remote set-url origin git@github.com:YOUR_USERNAME/RAG-Capstone-Project.git
|
| 108 |
+
|
| 109 |
+
# Verify
|
| 110 |
+
git remote -v
|
| 111 |
+
|
| 112 |
+
# Test connection
|
| 113 |
+
ssh -T git@github.com
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## Troubleshooting
|
| 119 |
+
|
| 120 |
+
### Problem: "fatal: remote origin already exists"
|
| 121 |
+
```powershell
|
| 122 |
+
# Remove the old remote first
|
| 123 |
+
git remote remove origin
|
| 124 |
+
|
| 125 |
+
# Then add the new one
|
| 126 |
+
git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Problem: "Permission denied (publickey)"
|
| 130 |
+
This means SSH authentication failed. Use HTTPS instead:
|
| 131 |
+
```powershell
|
| 132 |
+
git remote set-url origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### Problem: "fatal: Authentication failed"
|
| 136 |
+
This means your GitHub credentials are incorrect. Use a Personal Access Token:
|
| 137 |
+
1. Generate token: https://github.com/settings/tokens
|
| 138 |
+
2. When pushing, use the token as password
|
| 139 |
+
|
| 140 |
+
### Check Current Configuration
|
| 141 |
+
```powershell
|
| 142 |
+
# View remote URLs
|
| 143 |
+
git remote -v
|
| 144 |
+
|
| 145 |
+
# View detailed remote info
|
| 146 |
+
git remote show origin
|
| 147 |
+
|
| 148 |
+
# View git config
|
| 149 |
+
git config --local -l
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## Quick Reference
|
| 155 |
+
|
| 156 |
+
| Task | Command |
|
| 157 |
+
|------|---------|
|
| 158 |
+
| View remotes | `git remote -v` |
|
| 159 |
+
| Update URL (HTTPS) | `git remote set-url origin https://github.com/USER/REPO.git` |
|
| 160 |
+
| Update URL (SSH) | `git remote set-url origin git@github.com:USER/REPO.git` |
|
| 161 |
+
| Remove remote | `git remote remove origin` |
|
| 162 |
+
| Add remote | `git remote add origin <URL>` |
|
| 163 |
+
| Push to remote | `git push -u origin main` |
|
| 164 |
+
| Check remote details | `git remote show origin` |
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## What You Need to Push
|
| 169 |
+
|
| 170 |
+
Before pushing, make sure you have:
|
| 171 |
+
- ✅ GitHub account created
|
| 172 |
+
- ✅ Repository created on GitHub
|
| 173 |
+
- ✅ Remote URL updated correctly
|
| 174 |
+
- ✅ Local commits ready (already done ✓)
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
**What's your GitHub username?** I can help you with the exact commands once you provide it!
|
__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG Capstone Project - Retrieval-Augmented Generation with TRACE Evaluation
|
| 3 |
+
|
| 4 |
+
This application provides a complete RAG system with:
|
| 5 |
+
- Multiple embedding models (Bio-medical BERT models)
|
| 6 |
+
- Various chunking strategies (dense, sparse, hybrid, re-ranking)
|
| 7 |
+
- ChromaDB vector storage
|
| 8 |
+
- Groq LLM integration with rate limiting
|
| 9 |
+
- TRACE evaluation metrics
|
| 10 |
+
- Streamlit chat interface
|
| 11 |
+
- FastAPI REST API
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
__version__ = "1.0.0"
|
| 15 |
+
__author__ = "RAG Capstone Team"
|
api.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend service for RAG application."""
|
| 2 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from typing import List, Optional, Dict
|
| 6 |
+
import uvicorn
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from config import settings
|
| 11 |
+
from dataset_loader import RAGBenchLoader
|
| 12 |
+
from vector_store import ChromaDBManager
|
| 13 |
+
from llm_client import GroqLLMClient, RAGPipeline
|
| 14 |
+
from trace_evaluator import TRACEEvaluator
|
| 15 |
+
|
| 16 |
+
# Initialize FastAPI app
|
| 17 |
+
app = FastAPI(
|
| 18 |
+
title="RAG Capstone API",
|
| 19 |
+
description="API for RAG system with TRACE evaluation",
|
| 20 |
+
version="1.0.0"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Add CORS middleware
|
| 24 |
+
app.add_middleware(
|
| 25 |
+
CORSMiddleware,
|
| 26 |
+
allow_origins=["*"],
|
| 27 |
+
allow_credentials=True,
|
| 28 |
+
allow_methods=["*"],
|
| 29 |
+
allow_headers=["*"],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Global state
|
| 33 |
+
rag_pipeline: Optional[RAGPipeline] = None
|
| 34 |
+
vector_store: Optional[ChromaDBManager] = None
|
| 35 |
+
current_collection: Optional[str] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Request/Response models
|
| 39 |
+
class DatasetLoadRequest(BaseModel):
|
| 40 |
+
"""Request model for loading dataset."""
|
| 41 |
+
dataset_name: str = Field(..., description="Name of the dataset")
|
| 42 |
+
num_samples: int = Field(50, description="Number of samples to load")
|
| 43 |
+
chunking_strategy: str = Field("hybrid", description="Chunking strategy")
|
| 44 |
+
chunk_size: int = Field(512, description="Size of chunks")
|
| 45 |
+
overlap: int = Field(50, description="Overlap between chunks")
|
| 46 |
+
embedding_model: str = Field(..., description="Embedding model name")
|
| 47 |
+
llm_model: str = Field("llama-3.1-8b-instant", description="LLM model name")
|
| 48 |
+
groq_api_key: str = Field(..., description="Groq API key")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class QueryRequest(BaseModel):
|
| 52 |
+
"""Request model for querying."""
|
| 53 |
+
query: str = Field(..., description="User query")
|
| 54 |
+
n_results: int = Field(5, description="Number of documents to retrieve")
|
| 55 |
+
max_tokens: int = Field(1024, description="Maximum tokens to generate")
|
| 56 |
+
temperature: float = Field(0.7, description="Sampling temperature")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class QueryResponse(BaseModel):
|
| 60 |
+
"""Response model for query."""
|
| 61 |
+
query: str
|
| 62 |
+
response: str
|
| 63 |
+
retrieved_documents: List[Dict]
|
| 64 |
+
timestamp: str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class EvaluationRequest(BaseModel):
|
| 68 |
+
"""Request model for evaluation."""
|
| 69 |
+
num_samples: int = Field(10, description="Number of test samples")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CollectionInfo(BaseModel):
|
| 73 |
+
"""Collection information model."""
|
| 74 |
+
name: str
|
| 75 |
+
count: int
|
| 76 |
+
metadata: Dict
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# API endpoints
|
| 80 |
+
@app.get("/")
|
| 81 |
+
async def root():
|
| 82 |
+
"""Root endpoint."""
|
| 83 |
+
return {
|
| 84 |
+
"message": "RAG Capstone API",
|
| 85 |
+
"version": "1.0.0",
|
| 86 |
+
"docs": "/docs"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@app.get("/health")
|
| 91 |
+
async def health_check():
|
| 92 |
+
"""Health check endpoint."""
|
| 93 |
+
return {
|
| 94 |
+
"status": "healthy",
|
| 95 |
+
"timestamp": datetime.now().isoformat()
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@app.get("/datasets")
|
| 100 |
+
async def list_datasets():
|
| 101 |
+
"""List available datasets."""
|
| 102 |
+
return {
|
| 103 |
+
"datasets": settings.ragbench_datasets
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.get("/models/embedding")
|
| 108 |
+
async def list_embedding_models():
|
| 109 |
+
"""List available embedding models."""
|
| 110 |
+
return {
|
| 111 |
+
"embedding_models": settings.embedding_models
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@app.get("/models/llm")
|
| 116 |
+
async def list_llm_models():
|
| 117 |
+
"""List available LLM models."""
|
| 118 |
+
return {
|
| 119 |
+
"llm_models": settings.llm_models
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@app.get("/chunking-strategies")
|
| 124 |
+
async def list_chunking_strategies():
|
| 125 |
+
"""List available chunking strategies."""
|
| 126 |
+
return {
|
| 127 |
+
"chunking_strategies": settings.chunking_strategies
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@app.get("/collections")
|
| 132 |
+
async def list_collections():
|
| 133 |
+
"""List all vector store collections."""
|
| 134 |
+
global vector_store
|
| 135 |
+
|
| 136 |
+
if not vector_store:
|
| 137 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 138 |
+
|
| 139 |
+
collections = vector_store.list_collections()
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"collections": collections,
|
| 143 |
+
"count": len(collections)
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@app.get("/collections/{collection_name}")
|
| 148 |
+
async def get_collection_info(collection_name: str):
|
| 149 |
+
"""Get information about a specific collection."""
|
| 150 |
+
global vector_store
|
| 151 |
+
|
| 152 |
+
if not vector_store:
|
| 153 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
stats = vector_store.get_collection_stats(collection_name)
|
| 157 |
+
return stats
|
| 158 |
+
except Exception as e:
|
| 159 |
+
raise HTTPException(status_code=404, detail=f"Collection not found: {str(e)}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@app.post("/load-dataset")
|
| 163 |
+
async def load_dataset(request: DatasetLoadRequest, background_tasks: BackgroundTasks):
|
| 164 |
+
"""Load dataset and create vector collection."""
|
| 165 |
+
global rag_pipeline, vector_store, current_collection
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
# Initialize dataset loader
|
| 169 |
+
loader = RAGBenchLoader()
|
| 170 |
+
|
| 171 |
+
# Load dataset
|
| 172 |
+
dataset = loader.load_dataset(
|
| 173 |
+
request.dataset_name,
|
| 174 |
+
split="train",
|
| 175 |
+
max_samples=request.num_samples
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if not dataset:
|
| 179 |
+
raise HTTPException(status_code=400, detail="Failed to load dataset")
|
| 180 |
+
|
| 181 |
+
# Initialize vector store
|
| 182 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 183 |
+
|
| 184 |
+
# Create collection name
|
| 185 |
+
collection_name = f"{request.dataset_name}_{request.chunking_strategy}_{request.embedding_model.split('/')[-1]}"
|
| 186 |
+
collection_name = collection_name.replace("-", "_").replace(".", "_")
|
| 187 |
+
|
| 188 |
+
# Load data into collection
|
| 189 |
+
vector_store.load_dataset_into_collection(
|
| 190 |
+
collection_name=collection_name,
|
| 191 |
+
embedding_model_name=request.embedding_model,
|
| 192 |
+
chunking_strategy=request.chunking_strategy,
|
| 193 |
+
dataset_data=dataset,
|
| 194 |
+
chunk_size=request.chunk_size,
|
| 195 |
+
overlap=request.overlap
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Initialize LLM client
|
| 199 |
+
llm_client = GroqLLMClient(
|
| 200 |
+
api_key=request.groq_api_key,
|
| 201 |
+
model_name=request.llm_model,
|
| 202 |
+
max_rpm=settings.groq_rpm_limit,
|
| 203 |
+
rate_limit_delay=settings.rate_limit_delay
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Create RAG pipeline
|
| 207 |
+
rag_pipeline = RAGPipeline(llm_client, vector_store)
|
| 208 |
+
current_collection = collection_name
|
| 209 |
+
|
| 210 |
+
return {
|
| 211 |
+
"status": "success",
|
| 212 |
+
"collection_name": collection_name,
|
| 213 |
+
"num_documents": len(dataset),
|
| 214 |
+
"message": f"Collection '{collection_name}' created successfully"
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
raise HTTPException(status_code=500, detail=f"Error loading dataset: {str(e)}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@app.post("/query", response_model=QueryResponse)
|
| 222 |
+
async def query_rag(request: QueryRequest):
|
| 223 |
+
"""Query the RAG system."""
|
| 224 |
+
global rag_pipeline
|
| 225 |
+
|
| 226 |
+
if not rag_pipeline:
|
| 227 |
+
raise HTTPException(
|
| 228 |
+
status_code=400,
|
| 229 |
+
detail="RAG pipeline not initialized. Load a dataset first."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
result = rag_pipeline.query(
|
| 234 |
+
query=request.query,
|
| 235 |
+
n_results=request.n_results,
|
| 236 |
+
max_tokens=request.max_tokens,
|
| 237 |
+
temperature=request.temperature
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
result["timestamp"] = datetime.now().isoformat()
|
| 241 |
+
|
| 242 |
+
return result
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@app.get("/chat-history")
|
| 249 |
+
async def get_chat_history():
|
| 250 |
+
"""Get chat history."""
|
| 251 |
+
global rag_pipeline
|
| 252 |
+
|
| 253 |
+
if not rag_pipeline:
|
| 254 |
+
raise HTTPException(
|
| 255 |
+
status_code=400,
|
| 256 |
+
detail="RAG pipeline not initialized. Load a dataset first."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return {
|
| 260 |
+
"history": rag_pipeline.get_chat_history()
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@app.delete("/chat-history")
|
| 265 |
+
async def clear_chat_history():
|
| 266 |
+
"""Clear chat history."""
|
| 267 |
+
global rag_pipeline
|
| 268 |
+
|
| 269 |
+
if not rag_pipeline:
|
| 270 |
+
raise HTTPException(
|
| 271 |
+
status_code=400,
|
| 272 |
+
detail="RAG pipeline not initialized. Load a dataset first."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
rag_pipeline.clear_history()
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
"status": "success",
|
| 279 |
+
"message": "Chat history cleared"
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@app.post("/evaluate")
|
| 284 |
+
async def run_evaluation(request: EvaluationRequest):
|
| 285 |
+
"""Run TRACE evaluation."""
|
| 286 |
+
global rag_pipeline, current_collection
|
| 287 |
+
|
| 288 |
+
if not rag_pipeline:
|
| 289 |
+
raise HTTPException(
|
| 290 |
+
status_code=400,
|
| 291 |
+
detail="RAG pipeline not initialized. Load a dataset first."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
# Get dataset name from collection metadata
|
| 296 |
+
collection_metadata = vector_store.current_collection.metadata
|
| 297 |
+
dataset_name = current_collection.split("_")[0] if current_collection else "hotpotqa"
|
| 298 |
+
|
| 299 |
+
# Get test data
|
| 300 |
+
loader = RAGBenchLoader()
|
| 301 |
+
test_data = loader.get_test_data(dataset_name, request.num_samples)
|
| 302 |
+
|
| 303 |
+
# Prepare test cases
|
| 304 |
+
test_cases = []
|
| 305 |
+
|
| 306 |
+
for sample in test_data:
|
| 307 |
+
result = rag_pipeline.query(sample["question"], n_results=5)
|
| 308 |
+
|
| 309 |
+
test_cases.append({
|
| 310 |
+
"query": sample["question"],
|
| 311 |
+
"response": result["response"],
|
| 312 |
+
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
|
| 313 |
+
"ground_truth": sample.get("answer", "")
|
| 314 |
+
})
|
| 315 |
+
|
| 316 |
+
# Run evaluation
|
| 317 |
+
evaluator = TRACEEvaluator()
|
| 318 |
+
results = evaluator.evaluate_batch(test_cases)
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
"status": "success",
|
| 322 |
+
"results": results
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
raise HTTPException(status_code=500, detail=f"Error during evaluation: {str(e)}")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@app.delete("/collections/{collection_name}")
|
| 330 |
+
async def delete_collection(collection_name: str):
|
| 331 |
+
"""Delete a collection."""
|
| 332 |
+
global vector_store
|
| 333 |
+
|
| 334 |
+
if not vector_store:
|
| 335 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 336 |
+
|
| 337 |
+
try:
|
| 338 |
+
vector_store.delete_collection(collection_name)
|
| 339 |
+
return {
|
| 340 |
+
"status": "success",
|
| 341 |
+
"message": f"Collection '{collection_name}' deleted"
|
| 342 |
+
}
|
| 343 |
+
except Exception as e:
|
| 344 |
+
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@app.get("/current-collection")
|
| 348 |
+
async def get_current_collection():
|
| 349 |
+
"""Get current collection information."""
|
| 350 |
+
global current_collection, vector_store
|
| 351 |
+
|
| 352 |
+
if not current_collection:
|
| 353 |
+
return {
|
| 354 |
+
"collection": None,
|
| 355 |
+
"message": "No collection loaded"
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
try:
|
| 359 |
+
stats = vector_store.get_collection_stats(current_collection)
|
| 360 |
+
return {
|
| 361 |
+
"collection": current_collection,
|
| 362 |
+
"stats": stats
|
| 363 |
+
}
|
| 364 |
+
except Exception as e:
|
| 365 |
+
raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
uvicorn.run(
|
| 370 |
+
"api:app",
|
| 371 |
+
host="0.0.0.0",
|
| 372 |
+
port=8000,
|
| 373 |
+
reload=True
|
| 374 |
+
)
|
chunking_strategies.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chunking strategies for document processing."""
|
| 2 |
+
from typing import List, Dict, Tuple
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
import re
|
| 5 |
+
from rank_bm25 import BM25Okapi
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChunkingStrategy(ABC):
|
| 10 |
+
"""Abstract base class for chunking strategies."""
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def chunk_text(self, text: str, chunk_size: int = 512,
|
| 14 |
+
overlap: int = 50) -> List[str]:
|
| 15 |
+
"""Chunk text into smaller pieces.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
text: Input text to chunk
|
| 19 |
+
chunk_size: Maximum size of each chunk
|
| 20 |
+
overlap: Number of characters to overlap between chunks
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
List of text chunks
|
| 24 |
+
"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DenseChunking(ChunkingStrategy):
|
| 29 |
+
"""Dense chunking strategy - fixed-size chunks with overlap."""
|
| 30 |
+
|
| 31 |
+
def chunk_text(self, text: str, chunk_size: int = 512,
|
| 32 |
+
overlap: int = 50) -> List[str]:
|
| 33 |
+
"""Create dense chunks with fixed size and overlap."""
|
| 34 |
+
if not text:
|
| 35 |
+
return []
|
| 36 |
+
|
| 37 |
+
chunks = []
|
| 38 |
+
start = 0
|
| 39 |
+
text_length = len(text)
|
| 40 |
+
|
| 41 |
+
while start < text_length:
|
| 42 |
+
end = start + chunk_size
|
| 43 |
+
chunk = text[start:end]
|
| 44 |
+
|
| 45 |
+
# Try to break at sentence boundary
|
| 46 |
+
if end < text_length:
|
| 47 |
+
last_period = chunk.rfind('.')
|
| 48 |
+
last_newline = chunk.rfind('\n')
|
| 49 |
+
break_point = max(last_period, last_newline)
|
| 50 |
+
|
| 51 |
+
if break_point > chunk_size * 0.5: # At least 50% of chunk size
|
| 52 |
+
chunk = chunk[:break_point + 1]
|
| 53 |
+
end = start + break_point + 1
|
| 54 |
+
|
| 55 |
+
chunks.append(chunk.strip())
|
| 56 |
+
start = end - overlap
|
| 57 |
+
|
| 58 |
+
if start >= text_length:
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
return [c for c in chunks if c] # Remove empty chunks
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SparseChunking(ChunkingStrategy):
|
| 65 |
+
"""Sparse chunking strategy - semantic-based chunks (paragraphs/sections)."""
|
| 66 |
+
|
| 67 |
+
def chunk_text(self, text: str, chunk_size: int = 512,
|
| 68 |
+
overlap: int = 50) -> List[str]:
|
| 69 |
+
"""Create sparse chunks based on semantic boundaries."""
|
| 70 |
+
if not text:
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
# Split by double newlines (paragraphs)
|
| 74 |
+
paragraphs = re.split(r'\n\s*\n', text)
|
| 75 |
+
chunks = []
|
| 76 |
+
current_chunk = ""
|
| 77 |
+
|
| 78 |
+
for para in paragraphs:
|
| 79 |
+
para = para.strip()
|
| 80 |
+
if not para:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
# If adding this paragraph exceeds chunk_size, save current chunk
|
| 84 |
+
if len(current_chunk) + len(para) > chunk_size and current_chunk:
|
| 85 |
+
chunks.append(current_chunk.strip())
|
| 86 |
+
# Start new chunk with overlap
|
| 87 |
+
if overlap > 0:
|
| 88 |
+
words = current_chunk.split()
|
| 89 |
+
overlap_words = words[-min(overlap // 5, len(words)):]
|
| 90 |
+
current_chunk = " ".join(overlap_words) + " " + para
|
| 91 |
+
else:
|
| 92 |
+
current_chunk = para
|
| 93 |
+
else:
|
| 94 |
+
current_chunk += ("\n\n" if current_chunk else "") + para
|
| 95 |
+
|
| 96 |
+
# Add the last chunk
|
| 97 |
+
if current_chunk:
|
| 98 |
+
chunks.append(current_chunk.strip())
|
| 99 |
+
|
| 100 |
+
return chunks
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class HybridChunking(ChunkingStrategy):
|
| 104 |
+
"""Hybrid chunking strategy - combines dense and sparse approaches."""
|
| 105 |
+
|
| 106 |
+
def __init__(self):
|
| 107 |
+
self.dense_chunker = DenseChunking()
|
| 108 |
+
self.sparse_chunker = SparseChunking()
|
| 109 |
+
|
| 110 |
+
def chunk_text(self, text: str, chunk_size: int = 512,
|
| 111 |
+
overlap: int = 50) -> List[str]:
|
| 112 |
+
"""Create hybrid chunks combining both strategies."""
|
| 113 |
+
if not text:
|
| 114 |
+
return []
|
| 115 |
+
|
| 116 |
+
# First apply sparse chunking to get semantic boundaries
|
| 117 |
+
sparse_chunks = self.sparse_chunker.chunk_text(text, chunk_size * 2, 0)
|
| 118 |
+
|
| 119 |
+
# Then apply dense chunking to each sparse chunk
|
| 120 |
+
all_chunks = []
|
| 121 |
+
for sparse_chunk in sparse_chunks:
|
| 122 |
+
if len(sparse_chunk) > chunk_size:
|
| 123 |
+
dense_chunks = self.dense_chunker.chunk_text(
|
| 124 |
+
sparse_chunk, chunk_size, overlap
|
| 125 |
+
)
|
| 126 |
+
all_chunks.extend(dense_chunks)
|
| 127 |
+
else:
|
| 128 |
+
all_chunks.append(sparse_chunk)
|
| 129 |
+
|
| 130 |
+
return all_chunks
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class ReRankingChunking(ChunkingStrategy):
|
| 134 |
+
"""Re-ranking chunking strategy - creates chunks and provides relevance scoring."""
|
| 135 |
+
|
| 136 |
+
def __init__(self):
|
| 137 |
+
self.base_chunker = HybridChunking()
|
| 138 |
+
self.bm25 = None
|
| 139 |
+
self.chunks = []
|
| 140 |
+
|
| 141 |
+
def chunk_text(self, text: str, chunk_size: int = 512,
|
| 142 |
+
overlap: int = 50) -> List[str]:
|
| 143 |
+
"""Create chunks suitable for re-ranking."""
|
| 144 |
+
self.chunks = self.base_chunker.chunk_text(text, chunk_size, overlap)
|
| 145 |
+
|
| 146 |
+
# Initialize BM25 for re-ranking capability
|
| 147 |
+
tokenized_chunks = [chunk.lower().split() for chunk in self.chunks]
|
| 148 |
+
self.bm25 = BM25Okapi(tokenized_chunks)
|
| 149 |
+
|
| 150 |
+
return self.chunks
|
| 151 |
+
|
| 152 |
+
def rerank_chunks(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 153 |
+
"""Re-rank chunks based on query relevance.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
query: Query string
|
| 157 |
+
top_k: Number of top chunks to return
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
List of (chunk, score) tuples
|
| 161 |
+
"""
|
| 162 |
+
if not self.bm25 or not self.chunks:
|
| 163 |
+
return []
|
| 164 |
+
|
| 165 |
+
tokenized_query = query.lower().split()
|
| 166 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 167 |
+
|
| 168 |
+
# Get top-k chunks with scores
|
| 169 |
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 170 |
+
ranked_chunks = [
|
| 171 |
+
(self.chunks[idx], float(scores[idx]))
|
| 172 |
+
for idx in top_indices
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
return ranked_chunks
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class ChunkingFactory:
|
| 179 |
+
"""Factory for creating chunking strategy instances."""
|
| 180 |
+
|
| 181 |
+
STRATEGIES = {
|
| 182 |
+
"dense": DenseChunking,
|
| 183 |
+
"sparse": SparseChunking,
|
| 184 |
+
"hybrid": HybridChunking,
|
| 185 |
+
"re-ranking": ReRankingChunking
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def create_chunker(cls, strategy: str) -> ChunkingStrategy:
|
| 190 |
+
"""Create a chunking strategy instance.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
strategy: Name of the chunking strategy
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
ChunkingStrategy instance
|
| 197 |
+
"""
|
| 198 |
+
if strategy not in cls.STRATEGIES:
|
| 199 |
+
raise ValueError(f"Unknown chunking strategy: {strategy}. "
|
| 200 |
+
f"Available: {list(cls.STRATEGIES.keys())}")
|
| 201 |
+
|
| 202 |
+
return cls.STRATEGIES[strategy]()
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def get_available_strategies(cls) -> List[str]:
|
| 206 |
+
"""Get list of available chunking strategies."""
|
| 207 |
+
return list(cls.STRATEGIES.keys())
|
cleanup_chroma.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Script to clean up ChromaDB collections and cache."""
|
| 3 |
+
|
| 4 |
+
import shutil
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
def cleanup_chroma_db():
|
| 9 |
+
"""Clean up ChromaDB collections and cache."""
|
| 10 |
+
|
| 11 |
+
print("=" * 60)
|
| 12 |
+
print("ChromaDB Cleanup Utility")
|
| 13 |
+
print("=" * 60)
|
| 14 |
+
|
| 15 |
+
# First, forcefully delete the chroma_db directory
|
| 16 |
+
chroma_path = Path("./chroma_db")
|
| 17 |
+
if chroma_path.exists():
|
| 18 |
+
print(f"\n🗑️ Removing chroma_db directory: {chroma_path}")
|
| 19 |
+
try:
|
| 20 |
+
shutil.rmtree(chroma_path)
|
| 21 |
+
print(f"✅ Deleted directory: {chroma_path}")
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"❌ Error deleting directory: {e}")
|
| 24 |
+
else:
|
| 25 |
+
print(f"\n✅ chroma_db directory not found: {chroma_path}")
|
| 26 |
+
|
| 27 |
+
# Also check for ChromaDB in .chroma directory (alternative location)
|
| 28 |
+
chroma_alt_path = Path("./.chroma")
|
| 29 |
+
if chroma_alt_path.exists():
|
| 30 |
+
print(f"\n🗑️ Removing .chroma directory: {chroma_alt_path}")
|
| 31 |
+
try:
|
| 32 |
+
shutil.rmtree(chroma_alt_path)
|
| 33 |
+
print(f"✅ Deleted directory: {chroma_alt_path}")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"❌ Error deleting directory: {e}")
|
| 36 |
+
|
| 37 |
+
# Clear HuggingFace dataset cache (optional)
|
| 38 |
+
response = input("\n🗑️ Clear HuggingFace dataset cache? (y/n): ").lower()
|
| 39 |
+
if response == 'y':
|
| 40 |
+
cache_path = Path.home() / ".cache" / "huggingface" / "datasets"
|
| 41 |
+
if cache_path.exists():
|
| 42 |
+
print(f"🗑️ Removing HF cache: {cache_path}")
|
| 43 |
+
try:
|
| 44 |
+
shutil.rmtree(cache_path)
|
| 45 |
+
print(f"✅ Deleted HF cache: {cache_path}")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"❌ Error deleting HF cache: {e}")
|
| 48 |
+
else:
|
| 49 |
+
print("ℹ️ HuggingFace cache not found")
|
| 50 |
+
|
| 51 |
+
# Clear ChromaDB chroma cache directory
|
| 52 |
+
response = input("\n🗑️ Clear ChromaDB chroma cache? (y/n): ").lower()
|
| 53 |
+
if response == 'y':
|
| 54 |
+
chroma_cache = Path.home() / ".chroma"
|
| 55 |
+
if chroma_cache.exists():
|
| 56 |
+
print(f"🗑️ Removing ChromaDB cache: {chroma_cache}")
|
| 57 |
+
try:
|
| 58 |
+
shutil.rmtree(chroma_cache)
|
| 59 |
+
print(f"✅ Deleted ChromaDB cache: {chroma_cache}")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"❌ Error deleting ChromaDB cache: {e}")
|
| 62 |
+
|
| 63 |
+
# Try to use ChromaDBManager if possible
|
| 64 |
+
print("\n📋 Attempting to connect to ChromaDB...")
|
| 65 |
+
try:
|
| 66 |
+
from vector_store import ChromaDBManager
|
| 67 |
+
|
| 68 |
+
manager = ChromaDBManager(persist_directory="./chroma_db")
|
| 69 |
+
|
| 70 |
+
# List existing collections
|
| 71 |
+
collections = manager.list_collections()
|
| 72 |
+
print(f"📊 Found {len(collections)} collection(s):")
|
| 73 |
+
for col in collections:
|
| 74 |
+
print(f" - {col}")
|
| 75 |
+
|
| 76 |
+
# Clear all collections
|
| 77 |
+
if collections:
|
| 78 |
+
print("\n🗑️ Clearing all collections...")
|
| 79 |
+
deleted = manager.clear_all_collections()
|
| 80 |
+
print(f"✅ Deleted {deleted} collection(s)")
|
| 81 |
+
else:
|
| 82 |
+
print("\n✅ No collections to delete")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"⚠️ Could not connect to ChromaDB via manager: {e}")
|
| 85 |
+
print("ℹ️ This is OK - the directory has been deleted already.")
|
| 86 |
+
|
| 87 |
+
print("\n" + "=" * 60)
|
| 88 |
+
print("✅ Cleanup completed! You can now start fresh.")
|
| 89 |
+
print("=" * 60)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
cleanup_chroma_db()
|
config.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management for RAG Application."""
|
| 2 |
+
from pydantic_settings import BaseSettings
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Settings(BaseSettings):
|
| 8 |
+
"""Application settings."""
|
| 9 |
+
|
| 10 |
+
# API Keys
|
| 11 |
+
groq_api_key: str = ""
|
| 12 |
+
|
| 13 |
+
# ChromaDB
|
| 14 |
+
chroma_persist_directory: str = "./chroma_db"
|
| 15 |
+
|
| 16 |
+
# Rate Limiting
|
| 17 |
+
groq_rpm_limit: int = 30
|
| 18 |
+
rate_limit_delay: float = 2.0
|
| 19 |
+
|
| 20 |
+
# Embedding Models
|
| 21 |
+
embedding_models: list = [
|
| 22 |
+
"sentence-transformers/all-mpnet-base-v2", # Stable, high quality
|
| 23 |
+
"emilyalsentzer/Bio_ClinicalBERT", # Clinical domain
|
| 24 |
+
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", # Medical domain
|
| 25 |
+
"sentence-transformers/all-MiniLM-L6-v2", # Fast, lightweight
|
| 26 |
+
"sentence-transformers/multilingual-MiniLM-L12-v2", # Multilingual
|
| 27 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", # Paraphrase
|
| 28 |
+
"allenai/specter", # Academic papers
|
| 29 |
+
"gemini-embedding-001" # Gemini API
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
# LLM Models
|
| 33 |
+
llm_models: list = [
|
| 34 |
+
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 35 |
+
"llama-3.1-8b-instant",
|
| 36 |
+
"openai/gpt-oss-120b"
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# Chunking Strategies
|
| 40 |
+
chunking_strategies: list = ["dense", "sparse", "hybrid", "re-ranking"]
|
| 41 |
+
|
| 42 |
+
# RAG Bench Datasets (from rungalileo/ragbench)
|
| 43 |
+
ragbench_datasets: list = [
|
| 44 |
+
"covidqa",
|
| 45 |
+
"cuad",
|
| 46 |
+
"delucionqa",
|
| 47 |
+
"emanual",
|
| 48 |
+
"expertqa",
|
| 49 |
+
"finqa",
|
| 50 |
+
"hagrid",
|
| 51 |
+
"hotpotqa",
|
| 52 |
+
"msmarco",
|
| 53 |
+
"pubmedqa",
|
| 54 |
+
"tatqa",
|
| 55 |
+
"techqa"
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
class Config:
|
| 59 |
+
env_file = ".env"
|
| 60 |
+
case_sensitive = False
|
| 61 |
+
extra = "allow"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
settings = Settings()
|
dataset_loader.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loader for RAG Bench datasets."""
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Dict, Optional
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RAGBenchLoader:
|
| 10 |
+
"""Load and manage RAG Bench datasets."""
|
| 11 |
+
|
| 12 |
+
SUPPORTED_DATASETS = [
|
| 13 |
+
'covidqa',
|
| 14 |
+
'cuad',
|
| 15 |
+
'delucionqa',
|
| 16 |
+
'emanual',
|
| 17 |
+
'expertqa',
|
| 18 |
+
'finqa',
|
| 19 |
+
'hagrid',
|
| 20 |
+
'hotpotqa',
|
| 21 |
+
'msmarco',
|
| 22 |
+
'pubmedqa',
|
| 23 |
+
'tatqa',
|
| 24 |
+
'techqa'
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
def __init__(self, cache_dir: str = "./data_cache"):
|
| 28 |
+
"""Initialize the dataset loader.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
cache_dir: Directory to cache downloaded datasets
|
| 32 |
+
"""
|
| 33 |
+
self.cache_dir = cache_dir
|
| 34 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
def load_dataset(self, dataset_name: str, split: str = "test",
|
| 37 |
+
max_samples: Optional[int] = None) -> List[Dict]:
|
| 38 |
+
"""Load a RAG Bench dataset from rungalileo/ragbench.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
dataset_name: Name of the dataset to load
|
| 42 |
+
split: Dataset split (train/validation/test)
|
| 43 |
+
max_samples: Maximum number of samples to load
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
List of dictionaries containing dataset samples
|
| 47 |
+
"""
|
| 48 |
+
if dataset_name not in self.SUPPORTED_DATASETS:
|
| 49 |
+
raise ValueError(f"Unsupported dataset: {dataset_name}. "
|
| 50 |
+
f"Supported: {self.SUPPORTED_DATASETS}")
|
| 51 |
+
|
| 52 |
+
print(f"Loading {dataset_name} dataset ({split} split) from rungalileo/ragbench...")
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# Load from rungalileo/ragbench
|
| 56 |
+
dataset = load_dataset("rungalileo/ragbench", dataset_name, split=split,
|
| 57 |
+
cache_dir=self.cache_dir)
|
| 58 |
+
|
| 59 |
+
processed_data = []
|
| 60 |
+
samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset))))
|
| 61 |
+
|
| 62 |
+
# Process the dataset
|
| 63 |
+
for item in tqdm(samples, desc=f"Processing {dataset_name}"):
|
| 64 |
+
processed_data.append(self._process_ragbench_item(item, dataset_name))
|
| 65 |
+
|
| 66 |
+
print(f"Loaded {len(processed_data)} samples from {dataset_name}")
|
| 67 |
+
return processed_data
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error loading {dataset_name}: {str(e)}")
|
| 71 |
+
print("Falling back to sample data for testing...")
|
| 72 |
+
return self._create_sample_data(dataset_name, max_samples or 10)
|
| 73 |
+
|
| 74 |
+
def _process_ragbench_item(self, item: Dict, dataset_name: str) -> Dict:
|
| 75 |
+
"""Process a single RAGBench dataset item into standardized format.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
item: Raw dataset item
|
| 79 |
+
dataset_name: Name of the dataset
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Processed item dictionary
|
| 83 |
+
"""
|
| 84 |
+
# RAGBench datasets typically have: question, documents, answer, and retrieved_contexts
|
| 85 |
+
processed = {
|
| 86 |
+
"question": item.get("question", ""),
|
| 87 |
+
"answer": item.get("answer", ""),
|
| 88 |
+
"context": "", # For embedding and retrieval
|
| 89 |
+
"documents": [], # Store original documents list
|
| 90 |
+
"dataset": dataset_name
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Extract documents - RAGBench uses 'documents' as primary source for embeddings
|
| 94 |
+
# Priority: documents > retrieved_contexts > context
|
| 95 |
+
if "documents" in item:
|
| 96 |
+
if isinstance(item["documents"], list):
|
| 97 |
+
processed["documents"] = [str(doc) for doc in item["documents"]]
|
| 98 |
+
processed["context"] = " ".join(processed["documents"])
|
| 99 |
+
else:
|
| 100 |
+
processed["documents"] = [str(item["documents"])]
|
| 101 |
+
processed["context"] = str(item["documents"])
|
| 102 |
+
elif "retrieved_contexts" in item:
|
| 103 |
+
if isinstance(item["retrieved_contexts"], list):
|
| 104 |
+
processed["documents"] = [str(ctx) for ctx in item["retrieved_contexts"]]
|
| 105 |
+
processed["context"] = " ".join(processed["documents"])
|
| 106 |
+
else:
|
| 107 |
+
processed["documents"] = [str(item["retrieved_contexts"])]
|
| 108 |
+
processed["context"] = str(item["retrieved_contexts"])
|
| 109 |
+
elif "context" in item:
|
| 110 |
+
if isinstance(item["context"], list):
|
| 111 |
+
processed["documents"] = [str(ctx) for ctx in item["context"]]
|
| 112 |
+
processed["context"] = " ".join(processed["documents"])
|
| 113 |
+
else:
|
| 114 |
+
processed["documents"] = [str(item["context"])]
|
| 115 |
+
processed["context"] = str(item["context"])
|
| 116 |
+
|
| 117 |
+
# Store additional metadata if available
|
| 118 |
+
if "metadata" in item:
|
| 119 |
+
processed["metadata"] = item["metadata"]
|
| 120 |
+
|
| 121 |
+
return processed
|
| 122 |
+
|
| 123 |
+
def load_all_datasets(self, split: str = "test", max_samples: Optional[int] = None) -> Dict[str, List[Dict]]:
|
| 124 |
+
"""Load all RAGBench datasets.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
split: Dataset split to load
|
| 128 |
+
max_samples: Maximum samples per dataset
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Dictionary mapping dataset names to their data
|
| 132 |
+
"""
|
| 133 |
+
all_data = {}
|
| 134 |
+
for dataset_name in self.SUPPORTED_DATASETS:
|
| 135 |
+
print(f"\n{'='*50}")
|
| 136 |
+
print(f"Loading {dataset_name}...")
|
| 137 |
+
print(f"{'='*50}")
|
| 138 |
+
try:
|
| 139 |
+
all_data[dataset_name] = self.load_dataset(dataset_name, split, max_samples)
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"Failed to load {dataset_name}: {str(e)}")
|
| 142 |
+
all_data[dataset_name] = []
|
| 143 |
+
|
| 144 |
+
return all_data
|
| 145 |
+
|
| 146 |
+
def _create_sample_data(self, dataset_name: str, num_samples: int) -> List[Dict]:
|
| 147 |
+
"""Create sample data for testing when actual dataset is unavailable."""
|
| 148 |
+
sample_data = []
|
| 149 |
+
for i in range(num_samples):
|
| 150 |
+
# Create multiple sample documents per question
|
| 151 |
+
sample_docs = [
|
| 152 |
+
f"Document 1: This is the first sample document {i+1} for {dataset_name} dataset. "
|
| 153 |
+
f"It contains relevant information to answer the question.",
|
| 154 |
+
f"Document 2: This is the second sample document {i+1} providing additional context. "
|
| 155 |
+
f"It includes more details about the topic.",
|
| 156 |
+
f"Document 3: This is the third sample document {i+1} with supplementary information."
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
sample_data.append({
|
| 160 |
+
"question": f"Sample question {i+1} for {dataset_name}?",
|
| 161 |
+
"answer": f"Sample answer {i+1}",
|
| 162 |
+
"documents": sample_docs,
|
| 163 |
+
"context": " ".join(sample_docs), # Combined for backward compatibility
|
| 164 |
+
"dataset": dataset_name
|
| 165 |
+
})
|
| 166 |
+
return sample_data
|
| 167 |
+
|
| 168 |
+
def get_test_data(self, dataset_name: str, num_samples: int = 100) -> List[Dict]:
|
| 169 |
+
"""Get test data for TRACE evaluation.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
dataset_name: Name of the dataset
|
| 173 |
+
num_samples: Number of test samples
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List of test samples
|
| 177 |
+
"""
|
| 178 |
+
return self.load_dataset(dataset_name, split="test", max_samples=num_samples)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
streamlit:
|
| 5 |
+
build: .
|
| 6 |
+
ports:
|
| 7 |
+
- "8501:8501"
|
| 8 |
+
environment:
|
| 9 |
+
- GROQ_API_KEY=${GROQ_API_KEY}
|
| 10 |
+
- CHROMA_PERSIST_DIRECTORY=/app/chroma_db
|
| 11 |
+
volumes:
|
| 12 |
+
- ./chroma_db:/app/chroma_db
|
| 13 |
+
- ./data_cache:/app/data_cache
|
| 14 |
+
command: streamlit run streamlit_app.py --server.port=8501 --server.address=0.0.0.0
|
| 15 |
+
|
| 16 |
+
api:
|
| 17 |
+
build: .
|
| 18 |
+
ports:
|
| 19 |
+
- "8000:8000"
|
| 20 |
+
environment:
|
| 21 |
+
- GROQ_API_KEY=${GROQ_API_KEY}
|
| 22 |
+
- CHROMA_PERSIST_DIRECTORY=/app/chroma_db
|
| 23 |
+
volumes:
|
| 24 |
+
- ./chroma_db:/app/chroma_db
|
| 25 |
+
- ./data_cache:/app/data_cache
|
| 26 |
+
command: uvicorn api:app --host 0.0.0.0 --port 8000 --reload
|
embedding_models.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding models for document vectorization."""
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import torch
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from transformers import AutoTokenizer, AutoModel
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EmbeddingModel:
|
| 12 |
+
"""Base class for embedding models."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, model_name: str, device: Optional[str] = None):
|
| 15 |
+
"""Initialize embedding model.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model_name: Name/path of the model
|
| 19 |
+
device: Device to run model on (cuda/cpu)
|
| 20 |
+
"""
|
| 21 |
+
self.model_name = model_name
|
| 22 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
self.model = None
|
| 24 |
+
self.tokenizer = None
|
| 25 |
+
|
| 26 |
+
def load_model(self):
|
| 27 |
+
"""Load the embedding model."""
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
| 31 |
+
"""Embed a list of documents.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
texts: List of texts to embed
|
| 35 |
+
batch_size: Batch size for processing
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Numpy array of embeddings
|
| 39 |
+
"""
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
|
| 42 |
+
def embed_query(self, query: str) -> np.ndarray:
|
| 43 |
+
"""Embed a single query.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
query: Query text
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Numpy array of embedding
|
| 50 |
+
"""
|
| 51 |
+
return self.embed_documents([query])[0]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SentenceTransformerEmbedding(EmbeddingModel):
|
| 55 |
+
"""Sentence Transformer based embedding model."""
|
| 56 |
+
|
| 57 |
+
def load_model(self):
|
| 58 |
+
"""Load sentence transformer model."""
|
| 59 |
+
print(f"Loading SentenceTransformer model: {self.model_name}")
|
| 60 |
+
try:
|
| 61 |
+
self.model = SentenceTransformer(self.model_name, device=self.device)
|
| 62 |
+
print(f"Model loaded successfully on {self.device}")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error loading model {self.model_name}: {str(e)}")
|
| 65 |
+
print("Falling back to default model...")
|
| 66 |
+
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
|
| 67 |
+
|
| 68 |
+
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
| 69 |
+
"""Embed documents using sentence transformer."""
|
| 70 |
+
if self.model is None:
|
| 71 |
+
self.load_model()
|
| 72 |
+
|
| 73 |
+
embeddings = []
|
| 74 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
|
| 75 |
+
batch = texts[i:i + batch_size]
|
| 76 |
+
batch_embeddings = self.model.encode(
|
| 77 |
+
batch,
|
| 78 |
+
convert_to_numpy=True,
|
| 79 |
+
show_progress_bar=False,
|
| 80 |
+
batch_size=batch_size
|
| 81 |
+
)
|
| 82 |
+
embeddings.append(batch_embeddings)
|
| 83 |
+
|
| 84 |
+
return np.vstack(embeddings) if embeddings else np.array([])
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BioMedicalEmbedding(EmbeddingModel):
|
| 88 |
+
"""Bio-medical BERT based embedding model."""
|
| 89 |
+
|
| 90 |
+
def load_model(self):
|
| 91 |
+
"""Load bio-medical BERT model."""
|
| 92 |
+
print(f"Loading Bio-Medical model: {self.model_name}")
|
| 93 |
+
try:
|
| 94 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 95 |
+
self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
|
| 96 |
+
self.model.eval()
|
| 97 |
+
print(f"Model loaded successfully on {self.device}")
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Error loading model {self.model_name}: {str(e)}")
|
| 100 |
+
print("Falling back to default model...")
|
| 101 |
+
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
| 102 |
+
self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device)
|
| 103 |
+
self.model.eval()
|
| 104 |
+
|
| 105 |
+
def mean_pooling(self, model_output, attention_mask):
|
| 106 |
+
"""Apply mean pooling to get sentence embeddings."""
|
| 107 |
+
token_embeddings = model_output[0]
|
| 108 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 109 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 110 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
| 114 |
+
"""Embed documents using bio-medical BERT."""
|
| 115 |
+
if self.model is None:
|
| 116 |
+
self.load_model()
|
| 117 |
+
|
| 118 |
+
embeddings = []
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
|
| 122 |
+
batch = texts[i:i + batch_size]
|
| 123 |
+
|
| 124 |
+
# Tokenize
|
| 125 |
+
encoded_input = self.tokenizer(
|
| 126 |
+
batch,
|
| 127 |
+
padding=True,
|
| 128 |
+
truncation=True,
|
| 129 |
+
max_length=512,
|
| 130 |
+
return_tensors='pt'
|
| 131 |
+
).to(self.device)
|
| 132 |
+
|
| 133 |
+
# Get embeddings
|
| 134 |
+
model_output = self.model(**encoded_input)
|
| 135 |
+
|
| 136 |
+
# Apply mean pooling
|
| 137 |
+
batch_embeddings = self.mean_pooling(
|
| 138 |
+
model_output,
|
| 139 |
+
encoded_input['attention_mask']
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Normalize
|
| 143 |
+
batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
|
| 144 |
+
|
| 145 |
+
embeddings.append(batch_embeddings.cpu().numpy())
|
| 146 |
+
|
| 147 |
+
return np.vstack(embeddings) if embeddings else np.array([])
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class GeminiEmbedding(EmbeddingModel):
|
| 151 |
+
"""Gemini embedding model using Google AI API."""
|
| 152 |
+
|
| 153 |
+
def load_model(self):
|
| 154 |
+
"""Load Gemini embedding model."""
|
| 155 |
+
print(f"Initializing Gemini embedding model: {self.model_name}")
|
| 156 |
+
try:
|
| 157 |
+
import google.generativeai as genai
|
| 158 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
| 159 |
+
if not api_key:
|
| 160 |
+
raise ValueError("GEMINI_API_KEY environment variable not set")
|
| 161 |
+
genai.configure(api_key=api_key)
|
| 162 |
+
self.model = genai
|
| 163 |
+
print(f"Gemini model initialized successfully")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Error loading Gemini model: {str(e)}")
|
| 166 |
+
print("Falling back to default model...")
|
| 167 |
+
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
|
| 168 |
+
|
| 169 |
+
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
| 170 |
+
"""Embed documents using Gemini API."""
|
| 171 |
+
if self.model is None:
|
| 172 |
+
self.load_model()
|
| 173 |
+
|
| 174 |
+
embeddings = []
|
| 175 |
+
|
| 176 |
+
# Gemini API has rate limits, process with delays
|
| 177 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
|
| 178 |
+
batch = texts[i:i + batch_size]
|
| 179 |
+
|
| 180 |
+
for text in batch:
|
| 181 |
+
try:
|
| 182 |
+
if hasattr(self.model, 'embed_content'):
|
| 183 |
+
result = self.model.embed_content(
|
| 184 |
+
model="models/embedding-001",
|
| 185 |
+
content=text,
|
| 186 |
+
task_type="retrieval_document"
|
| 187 |
+
)
|
| 188 |
+
embeddings.append(result['embedding'])
|
| 189 |
+
else:
|
| 190 |
+
# Fallback if Gemini not available
|
| 191 |
+
from sentence_transformers import SentenceTransformer
|
| 192 |
+
fallback_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 193 |
+
emb = fallback_model.encode([text])[0]
|
| 194 |
+
embeddings.append(emb)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"Error embedding text: {str(e)}")
|
| 197 |
+
# Use zero vector as fallback
|
| 198 |
+
embeddings.append(np.zeros(768))
|
| 199 |
+
|
| 200 |
+
return np.array(embeddings)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class EmbeddingFactory:
|
| 204 |
+
"""Factory for creating embedding model instances."""
|
| 205 |
+
|
| 206 |
+
# Map model names to their types
|
| 207 |
+
MODEL_TYPES = {
|
| 208 |
+
"sentence-transformers/all-mpnet-base-v2": "sentence-transformer", # Stable, well-supported
|
| 209 |
+
"emilyalsentzer/Bio_ClinicalBERT": "biomedical", # Clinical domain
|
| 210 |
+
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": "biomedical", # Medical domain
|
| 211 |
+
"sentence-transformers/all-MiniLM-L6-v2": "sentence-transformer", # Fast, lightweight
|
| 212 |
+
"sentence-transformers/multilingual-MiniLM-L12-v2": "sentence-transformer", # Multilingual
|
| 213 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformer", # Paraphrase
|
| 214 |
+
"allenai/specter": "biomedical", # Academic paper embeddings
|
| 215 |
+
"gemini-embedding-001": "gemini" # Gemini API
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def create_embedding_model(cls, model_name: str, device: Optional[str] = None) -> EmbeddingModel:
|
| 220 |
+
"""Create an embedding model instance.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
model_name: Name of the embedding model
|
| 224 |
+
device: Device to run model on
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
EmbeddingModel instance
|
| 228 |
+
"""
|
| 229 |
+
model_type = cls.MODEL_TYPES.get(model_name, "sentence-transformer")
|
| 230 |
+
|
| 231 |
+
if model_type == "gemini":
|
| 232 |
+
return GeminiEmbedding(model_name, device)
|
| 233 |
+
elif model_type == "biomedical":
|
| 234 |
+
return BioMedicalEmbedding(model_name, device)
|
| 235 |
+
else:
|
| 236 |
+
return SentenceTransformerEmbedding(model_name, device)
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def get_available_models(cls) -> List[str]:
|
| 240 |
+
"""Get list of available embedding models."""
|
| 241 |
+
return list(cls.MODEL_TYPES.keys())
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def get_model_info(cls, model_name: str) -> dict:
|
| 245 |
+
"""Get information about a specific model.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
model_name: Name of the model
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Dictionary with model information
|
| 252 |
+
"""
|
| 253 |
+
info = {
|
| 254 |
+
"sentence-transformers/all-mpnet-base-v2": {
|
| 255 |
+
"description": "High-quality, general-purpose sentence embeddings (384d)",
|
| 256 |
+
"dimension": 768,
|
| 257 |
+
"type": "sentence-transformer",
|
| 258 |
+
"note": "Recommended for general use"
|
| 259 |
+
},
|
| 260 |
+
"emilyalsentzer/Bio_ClinicalBERT": {
|
| 261 |
+
"description": "Clinical BERT for biomedical and clinical text",
|
| 262 |
+
"dimension": 768,
|
| 263 |
+
"type": "biomedical"
|
| 264 |
+
},
|
| 265 |
+
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": {
|
| 266 |
+
"description": "PubMedBERT for biomedical and medical text",
|
| 267 |
+
"dimension": 768,
|
| 268 |
+
"type": "biomedical"
|
| 269 |
+
},
|
| 270 |
+
"sentence-transformers/all-MiniLM-L6-v2": {
|
| 271 |
+
"description": "Fast, lightweight sentence embeddings",
|
| 272 |
+
"dimension": 384,
|
| 273 |
+
"type": "sentence-transformer",
|
| 274 |
+
"note": "Good for speed-sensitive applications"
|
| 275 |
+
},
|
| 276 |
+
"sentence-transformers/multilingual-MiniLM-L12-v2": {
|
| 277 |
+
"description": "Fast multilingual sentence embeddings",
|
| 278 |
+
"dimension": 384,
|
| 279 |
+
"type": "sentence-transformer",
|
| 280 |
+
"note": "Supports 50+ languages"
|
| 281 |
+
},
|
| 282 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": {
|
| 283 |
+
"description": "Multilingual paraphrase embeddings",
|
| 284 |
+
"dimension": 384,
|
| 285 |
+
"type": "sentence-transformer",
|
| 286 |
+
"note": "Good for paraphrase detection"
|
| 287 |
+
},
|
| 288 |
+
"allenai/specter": {
|
| 289 |
+
"description": "Embeddings for academic papers and citations",
|
| 290 |
+
"dimension": 768,
|
| 291 |
+
"type": "biomedical",
|
| 292 |
+
"note": "Optimized for scientific literature"
|
| 293 |
+
},
|
| 294 |
+
"gemini-embedding-001": {
|
| 295 |
+
"description": "Google Gemini embedding model via API",
|
| 296 |
+
"dimension": 768,
|
| 297 |
+
"type": "gemini",
|
| 298 |
+
"url": "https://ai.google.dev/gemini-api/docs/embeddings",
|
| 299 |
+
"note": "Requires GEMINI_API_KEY environment variable"
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
return info.get(model_name, {"description": "Unknown model", "dimension": 768})
|
| 303 |
+
|
| 304 |
+
@classmethod
|
| 305 |
+
def get_embedding_dimension(cls, model_name: str) -> int:
|
| 306 |
+
"""Get embedding dimension for a model.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
model_name: Name of the model
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Embedding dimension
|
| 313 |
+
"""
|
| 314 |
+
# Default dimensions (adjust based on actual models)
|
| 315 |
+
dimensions = {
|
| 316 |
+
"sentence-transformers/all-mpnet-base-v2": 768,
|
| 317 |
+
"emilyalsentzer/Bio_ClinicalBERT": 768,
|
| 318 |
+
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": 768,
|
| 319 |
+
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
| 320 |
+
"sentence-transformers/multilingual-MiniLM-L12-v2": 384,
|
| 321 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384,
|
| 322 |
+
"allenai/specter": 768,
|
| 323 |
+
"gemini-embedding-001": 768
|
| 324 |
+
}
|
| 325 |
+
return dimensions.get(model_name, 768)
|
example.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example script demonstrating how to use the RAG system programmatically.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from config import settings
|
| 6 |
+
from dataset_loader import RAGBenchLoader
|
| 7 |
+
from vector_store import ChromaDBManager
|
| 8 |
+
from llm_client import GroqLLMClient, RAGPipeline
|
| 9 |
+
from trace_evaluator import TRACEEvaluator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
"""Example usage of RAG system."""
|
| 14 |
+
|
| 15 |
+
# Set your API key
|
| 16 |
+
api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here"
|
| 17 |
+
|
| 18 |
+
if api_key == "your_api_key_here":
|
| 19 |
+
print("Please set your GROQ_API_KEY in .env file or environment variable")
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
print("=" * 50)
|
| 23 |
+
print("RAG System Example")
|
| 24 |
+
print("=" * 50)
|
| 25 |
+
|
| 26 |
+
# 1. Load dataset
|
| 27 |
+
print("\n1. Loading dataset...")
|
| 28 |
+
loader = RAGBenchLoader()
|
| 29 |
+
dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20)
|
| 30 |
+
print(f"Loaded {len(dataset)} samples")
|
| 31 |
+
|
| 32 |
+
# 2. Create vector store and collection
|
| 33 |
+
print("\n2. Creating vector store...")
|
| 34 |
+
vector_store = ChromaDBManager()
|
| 35 |
+
|
| 36 |
+
collection_name = "example_collection"
|
| 37 |
+
embedding_model = "emilyalsentzer/Bio_ClinicalBERT"
|
| 38 |
+
chunking_strategy = "hybrid"
|
| 39 |
+
|
| 40 |
+
print(f"Loading data into collection with {chunking_strategy} chunking...")
|
| 41 |
+
vector_store.load_dataset_into_collection(
|
| 42 |
+
collection_name=collection_name,
|
| 43 |
+
embedding_model_name=embedding_model,
|
| 44 |
+
chunking_strategy=chunking_strategy,
|
| 45 |
+
dataset_data=dataset,
|
| 46 |
+
chunk_size=512,
|
| 47 |
+
overlap=50
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# 3. Initialize LLM client
|
| 51 |
+
print("\n3. Initializing LLM client...")
|
| 52 |
+
llm_client = GroqLLMClient(
|
| 53 |
+
api_key=api_key,
|
| 54 |
+
model_name="llama-3.1-8b-instant",
|
| 55 |
+
max_rpm=30,
|
| 56 |
+
rate_limit_delay=2.0
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# 4. Create RAG pipeline
|
| 60 |
+
print("\n4. Creating RAG pipeline...")
|
| 61 |
+
rag = RAGPipeline(llm_client, vector_store)
|
| 62 |
+
|
| 63 |
+
# 5. Query the system
|
| 64 |
+
print("\n5. Querying the system...")
|
| 65 |
+
queries = [
|
| 66 |
+
"What is machine learning?",
|
| 67 |
+
"How does neural network work?",
|
| 68 |
+
"What is deep learning?"
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
for i, query in enumerate(queries, 1):
|
| 72 |
+
print(f"\n--- Query {i}: {query} ---")
|
| 73 |
+
result = rag.query(query, n_results=3)
|
| 74 |
+
|
| 75 |
+
print(f"Response: {result['response']}")
|
| 76 |
+
print(f"\nRetrieved {len(result['retrieved_documents'])} documents:")
|
| 77 |
+
for j, doc in enumerate(result['retrieved_documents'], 1):
|
| 78 |
+
print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):")
|
| 79 |
+
print(f"{doc['document'][:200]}...")
|
| 80 |
+
|
| 81 |
+
# 6. Run evaluation
|
| 82 |
+
print("\n6. Running TRACE evaluation...")
|
| 83 |
+
evaluator = TRACEEvaluator(llm_client)
|
| 84 |
+
|
| 85 |
+
# Prepare test cases
|
| 86 |
+
test_cases = []
|
| 87 |
+
test_samples = loader.get_test_data("hotpotqa", num_samples=5)
|
| 88 |
+
|
| 89 |
+
for sample in test_samples:
|
| 90 |
+
result = rag.query(sample["question"], n_results=5)
|
| 91 |
+
test_cases.append({
|
| 92 |
+
"query": sample["question"],
|
| 93 |
+
"response": result["response"],
|
| 94 |
+
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
|
| 95 |
+
"ground_truth": sample.get("answer", "")
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
results = evaluator.evaluate_batch(test_cases)
|
| 99 |
+
|
| 100 |
+
print("\nTRACE Evaluation Results:")
|
| 101 |
+
print(f"Utilization: {results['utilization']:.3f}")
|
| 102 |
+
print(f"Relevance: {results['relevance']:.3f}")
|
| 103 |
+
print(f"Adherence: {results['adherence']:.3f}")
|
| 104 |
+
print(f"Completeness: {results['completeness']:.3f}")
|
| 105 |
+
print(f"Average: {results['average']:.3f}")
|
| 106 |
+
|
| 107 |
+
# 7. View chat history
|
| 108 |
+
print("\n7. Chat History:")
|
| 109 |
+
history = rag.get_chat_history()
|
| 110 |
+
print(f"Total conversations: {len(history)}")
|
| 111 |
+
|
| 112 |
+
print("\n" + "=" * 50)
|
| 113 |
+
print("Example completed successfully!")
|
| 114 |
+
print("=" * 50)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|
llm_client.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Groq LLM integration with rate limiting."""
|
| 2 |
+
from typing import List, Dict, Optional, AsyncIterator
|
| 3 |
+
import time
|
| 4 |
+
from groq import Groq
|
| 5 |
+
import asyncio
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from collections import deque
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RateLimiter:
|
| 12 |
+
"""Rate limiter for API calls."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, max_requests_per_minute: int = 30):
|
| 15 |
+
"""Initialize rate limiter.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
max_requests_per_minute: Maximum requests allowed per minute
|
| 19 |
+
"""
|
| 20 |
+
self.max_requests = max_requests_per_minute
|
| 21 |
+
self.request_times = deque()
|
| 22 |
+
self.lock = asyncio.Lock()
|
| 23 |
+
|
| 24 |
+
async def acquire(self):
|
| 25 |
+
"""Acquire permission to make a request."""
|
| 26 |
+
async with self.lock:
|
| 27 |
+
now = datetime.now()
|
| 28 |
+
|
| 29 |
+
# Remove requests older than 1 minute
|
| 30 |
+
while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1):
|
| 31 |
+
self.request_times.popleft()
|
| 32 |
+
|
| 33 |
+
# If at limit, wait
|
| 34 |
+
if len(self.request_times) >= self.max_requests:
|
| 35 |
+
# Calculate how long to wait
|
| 36 |
+
oldest_request = self.request_times[0]
|
| 37 |
+
wait_time = 60 - (now - oldest_request).total_seconds()
|
| 38 |
+
|
| 39 |
+
if wait_time > 0:
|
| 40 |
+
print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
|
| 41 |
+
await asyncio.sleep(wait_time)
|
| 42 |
+
# Recursive call after waiting
|
| 43 |
+
return await self.acquire()
|
| 44 |
+
|
| 45 |
+
# Record this request
|
| 46 |
+
self.request_times.append(now)
|
| 47 |
+
|
| 48 |
+
def acquire_sync(self):
|
| 49 |
+
"""Synchronous version of acquire."""
|
| 50 |
+
now = datetime.now()
|
| 51 |
+
|
| 52 |
+
# Remove requests older than 1 minute
|
| 53 |
+
while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1):
|
| 54 |
+
self.request_times.popleft()
|
| 55 |
+
|
| 56 |
+
# If at limit, wait
|
| 57 |
+
if len(self.request_times) >= self.max_requests:
|
| 58 |
+
oldest_request = self.request_times[0]
|
| 59 |
+
wait_time = 60 - (now - oldest_request).total_seconds()
|
| 60 |
+
|
| 61 |
+
if wait_time > 0:
|
| 62 |
+
print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
|
| 63 |
+
time.sleep(wait_time)
|
| 64 |
+
return self.acquire_sync()
|
| 65 |
+
|
| 66 |
+
# Record this request
|
| 67 |
+
self.request_times.append(now)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class GroqLLMClient:
|
| 71 |
+
"""Client for Groq LLM API with rate limiting."""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
api_key: str,
|
| 76 |
+
model_name: str = "llama-3.1-8b-instant",
|
| 77 |
+
max_rpm: int = 30,
|
| 78 |
+
rate_limit_delay: float = 2.0
|
| 79 |
+
):
|
| 80 |
+
"""Initialize Groq client.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
api_key: Groq API key
|
| 84 |
+
model_name: Name of the LLM model
|
| 85 |
+
max_rpm: Maximum requests per minute
|
| 86 |
+
rate_limit_delay: Additional delay between requests (seconds)
|
| 87 |
+
"""
|
| 88 |
+
self.client = Groq(api_key=api_key)
|
| 89 |
+
self.model_name = model_name
|
| 90 |
+
self.rate_limiter = RateLimiter(max_rpm)
|
| 91 |
+
self.rate_limit_delay = rate_limit_delay
|
| 92 |
+
|
| 93 |
+
# Available models
|
| 94 |
+
self.available_models = [
|
| 95 |
+
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 96 |
+
"llama-3.1-8b-instant",
|
| 97 |
+
"openai/gpt-oss-120b"
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
def set_model(self, model_name: str):
|
| 101 |
+
"""Set the LLM model.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
model_name: Name of the model
|
| 105 |
+
"""
|
| 106 |
+
if model_name not in self.available_models:
|
| 107 |
+
print(f"Warning: {model_name} not in available models. Using anyway...")
|
| 108 |
+
self.model_name = model_name
|
| 109 |
+
|
| 110 |
+
def generate(
|
| 111 |
+
self,
|
| 112 |
+
prompt: str,
|
| 113 |
+
max_tokens: int = 1024,
|
| 114 |
+
temperature: float = 0.7,
|
| 115 |
+
system_prompt: Optional[str] = None
|
| 116 |
+
) -> str:
|
| 117 |
+
"""Generate text using Groq LLM.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
prompt: Input prompt
|
| 121 |
+
max_tokens: Maximum tokens to generate
|
| 122 |
+
temperature: Sampling temperature
|
| 123 |
+
system_prompt: System prompt
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Generated text
|
| 127 |
+
"""
|
| 128 |
+
# Apply rate limiting
|
| 129 |
+
self.rate_limiter.acquire_sync()
|
| 130 |
+
|
| 131 |
+
# Prepare messages
|
| 132 |
+
messages = []
|
| 133 |
+
if system_prompt:
|
| 134 |
+
messages.append({
|
| 135 |
+
"role": "system",
|
| 136 |
+
"content": system_prompt
|
| 137 |
+
})
|
| 138 |
+
messages.append({
|
| 139 |
+
"role": "user",
|
| 140 |
+
"content": prompt
|
| 141 |
+
})
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
# Make API call
|
| 145 |
+
response = self.client.chat.completions.create(
|
| 146 |
+
model=self.model_name,
|
| 147 |
+
messages=messages,
|
| 148 |
+
max_tokens=max_tokens,
|
| 149 |
+
temperature=temperature
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Add delay
|
| 153 |
+
time.sleep(self.rate_limit_delay)
|
| 154 |
+
|
| 155 |
+
return response.choices[0].message.content
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error generating response: {str(e)}")
|
| 159 |
+
return f"Error: {str(e)}"
|
| 160 |
+
|
| 161 |
+
async def generate_async(
|
| 162 |
+
self,
|
| 163 |
+
prompt: str,
|
| 164 |
+
max_tokens: int = 1024,
|
| 165 |
+
temperature: float = 0.7,
|
| 166 |
+
system_prompt: Optional[str] = None
|
| 167 |
+
) -> str:
|
| 168 |
+
"""Asynchronous version of generate.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
prompt: Input prompt
|
| 172 |
+
max_tokens: Maximum tokens to generate
|
| 173 |
+
temperature: Sampling temperature
|
| 174 |
+
system_prompt: System prompt
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Generated text
|
| 178 |
+
"""
|
| 179 |
+
# Apply rate limiting
|
| 180 |
+
await self.rate_limiter.acquire()
|
| 181 |
+
|
| 182 |
+
# Prepare messages
|
| 183 |
+
messages = []
|
| 184 |
+
if system_prompt:
|
| 185 |
+
messages.append({
|
| 186 |
+
"role": "system",
|
| 187 |
+
"content": system_prompt
|
| 188 |
+
})
|
| 189 |
+
messages.append({
|
| 190 |
+
"role": "user",
|
| 191 |
+
"content": prompt
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
# Make API call (synchronous client used in async context)
|
| 196 |
+
response = self.client.chat.completions.create(
|
| 197 |
+
model=self.model_name,
|
| 198 |
+
messages=messages,
|
| 199 |
+
max_tokens=max_tokens,
|
| 200 |
+
temperature=temperature
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Add delay
|
| 204 |
+
await asyncio.sleep(self.rate_limit_delay)
|
| 205 |
+
|
| 206 |
+
return response.choices[0].message.content
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"Error generating response: {str(e)}")
|
| 210 |
+
return f"Error: {str(e)}"
|
| 211 |
+
|
| 212 |
+
def generate_with_context(
|
| 213 |
+
self,
|
| 214 |
+
query: str,
|
| 215 |
+
context_documents: List[str],
|
| 216 |
+
max_tokens: int = 1024,
|
| 217 |
+
temperature: float = 0.7
|
| 218 |
+
) -> str:
|
| 219 |
+
"""Generate response with retrieved context.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
query: User query
|
| 223 |
+
context_documents: List of retrieved documents
|
| 224 |
+
max_tokens: Maximum tokens to generate
|
| 225 |
+
temperature: Sampling temperature
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Generated response
|
| 229 |
+
"""
|
| 230 |
+
# Build context
|
| 231 |
+
context = "\n\n".join([
|
| 232 |
+
f"Document {i+1}: {doc}"
|
| 233 |
+
for i, doc in enumerate(context_documents)
|
| 234 |
+
])
|
| 235 |
+
|
| 236 |
+
# Build prompt
|
| 237 |
+
prompt = f"""Answer the following question based on the provided context.
|
| 238 |
+
|
| 239 |
+
Context:
|
| 240 |
+
{context}
|
| 241 |
+
|
| 242 |
+
Question: {query}
|
| 243 |
+
|
| 244 |
+
Answer:"""
|
| 245 |
+
|
| 246 |
+
system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so."
|
| 247 |
+
|
| 248 |
+
return self.generate(prompt, max_tokens, temperature, system_prompt)
|
| 249 |
+
|
| 250 |
+
def batch_generate(
|
| 251 |
+
self,
|
| 252 |
+
prompts: List[str],
|
| 253 |
+
max_tokens: int = 1024,
|
| 254 |
+
temperature: float = 0.7,
|
| 255 |
+
system_prompt: Optional[str] = None
|
| 256 |
+
) -> List[str]:
|
| 257 |
+
"""Generate responses for multiple prompts.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
prompts: List of prompts
|
| 261 |
+
max_tokens: Maximum tokens to generate
|
| 262 |
+
temperature: Sampling temperature
|
| 263 |
+
system_prompt: System prompt
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
List of generated responses
|
| 267 |
+
"""
|
| 268 |
+
responses = []
|
| 269 |
+
for i, prompt in enumerate(prompts):
|
| 270 |
+
print(f"Processing prompt {i+1}/{len(prompts)}")
|
| 271 |
+
response = self.generate(prompt, max_tokens, temperature, system_prompt)
|
| 272 |
+
responses.append(response)
|
| 273 |
+
|
| 274 |
+
return responses
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class RAGPipeline:
|
| 278 |
+
"""Complete RAG pipeline with LLM and vector store."""
|
| 279 |
+
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
llm_client: GroqLLMClient,
|
| 283 |
+
vector_store_manager
|
| 284 |
+
):
|
| 285 |
+
"""Initialize RAG pipeline.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
llm_client: Groq LLM client
|
| 289 |
+
vector_store_manager: ChromaDB manager
|
| 290 |
+
"""
|
| 291 |
+
self.llm = llm_client
|
| 292 |
+
self.vector_store = vector_store_manager
|
| 293 |
+
self.chat_history = []
|
| 294 |
+
|
| 295 |
+
def query(
|
| 296 |
+
self,
|
| 297 |
+
query: str,
|
| 298 |
+
n_results: int = 5,
|
| 299 |
+
max_tokens: int = 1024,
|
| 300 |
+
temperature: float = 0.7
|
| 301 |
+
) -> Dict:
|
| 302 |
+
"""Query the RAG system.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
query: User query
|
| 306 |
+
n_results: Number of documents to retrieve
|
| 307 |
+
max_tokens: Maximum tokens to generate
|
| 308 |
+
temperature: Sampling temperature
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
Dictionary with response and retrieved documents
|
| 312 |
+
"""
|
| 313 |
+
# Retrieve documents
|
| 314 |
+
retrieved_docs = self.vector_store.get_retrieved_documents(query, n_results)
|
| 315 |
+
|
| 316 |
+
# Extract document texts
|
| 317 |
+
doc_texts = [doc["document"] for doc in retrieved_docs]
|
| 318 |
+
|
| 319 |
+
# Generate response
|
| 320 |
+
response = self.llm.generate_with_context(
|
| 321 |
+
query,
|
| 322 |
+
doc_texts,
|
| 323 |
+
max_tokens,
|
| 324 |
+
temperature
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Store in chat history
|
| 328 |
+
self.chat_history.append({
|
| 329 |
+
"query": query,
|
| 330 |
+
"response": response,
|
| 331 |
+
"retrieved_docs": retrieved_docs,
|
| 332 |
+
"timestamp": datetime.now().isoformat()
|
| 333 |
+
})
|
| 334 |
+
|
| 335 |
+
return {
|
| 336 |
+
"query": query,
|
| 337 |
+
"response": response,
|
| 338 |
+
"retrieved_documents": retrieved_docs
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
def get_chat_history(self) -> List[Dict]:
|
| 342 |
+
"""Get chat history.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
List of chat history entries
|
| 346 |
+
"""
|
| 347 |
+
return self.chat_history
|
| 348 |
+
|
| 349 |
+
def clear_history(self):
|
| 350 |
+
"""Clear chat history."""
|
| 351 |
+
self.chat_history = []
|
requirements.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Dependencies
|
| 2 |
+
fastapi==0.109.0
|
| 3 |
+
uvicorn[standard]==0.27.0
|
| 4 |
+
streamlit==1.31.0
|
| 5 |
+
python-dotenv==1.0.0
|
| 6 |
+
|
| 7 |
+
# LLM and AI
|
| 8 |
+
groq>=0.11.0
|
| 9 |
+
openai==1.12.0
|
| 10 |
+
google-generativeai>=0.3.0
|
| 11 |
+
|
| 12 |
+
# Embeddings and Vector Store
|
| 13 |
+
sentence-transformers==2.7.0
|
| 14 |
+
transformers==4.40.2
|
| 15 |
+
torch>=2.0.0
|
| 16 |
+
chromadb==0.5.23
|
| 17 |
+
|
| 18 |
+
# Data Processing
|
| 19 |
+
pandas==2.2.0
|
| 20 |
+
numpy==1.26.3
|
| 21 |
+
datasets==2.16.1
|
| 22 |
+
|
| 23 |
+
# RAG and Retrieval
|
| 24 |
+
langchain==0.1.6
|
| 25 |
+
langchain-community==0.0.19
|
| 26 |
+
langchain-groq==0.0.1
|
| 27 |
+
|
| 28 |
+
# Evaluation
|
| 29 |
+
ragas==0.1.4
|
| 30 |
+
rank-bm25==0.2.2
|
| 31 |
+
|
| 32 |
+
# Utilities
|
| 33 |
+
pydantic==2.6.0
|
| 34 |
+
pydantic-settings==2.1.0
|
| 35 |
+
tenacity==8.2.3
|
| 36 |
+
aiohttp==3.9.3
|
| 37 |
+
tqdm==4.66.1
|
| 38 |
+
|
| 39 |
+
# Deployment
|
| 40 |
+
gunicorn==21.2.0
|
run.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick start script to run the RAG application.
|
| 3 |
+
"""
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def check_dependencies():
|
| 10 |
+
"""Check if required dependencies are installed."""
|
| 11 |
+
try:
|
| 12 |
+
import streamlit
|
| 13 |
+
import fastapi
|
| 14 |
+
import groq
|
| 15 |
+
return True
|
| 16 |
+
except ImportError:
|
| 17 |
+
return False
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def install_dependencies():
|
| 21 |
+
"""Install dependencies from requirements.txt."""
|
| 22 |
+
print("Installing dependencies...")
|
| 23 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
|
| 24 |
+
print("Dependencies installed successfully!")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def check_env_file():
|
| 28 |
+
"""Check if .env file exists."""
|
| 29 |
+
if not os.path.exists(".env"):
|
| 30 |
+
print("\n⚠️ Warning: .env file not found!")
|
| 31 |
+
print("Creating .env from .env.example...")
|
| 32 |
+
if os.path.exists(".env.example"):
|
| 33 |
+
with open(".env.example", "r") as src:
|
| 34 |
+
with open(".env", "w") as dst:
|
| 35 |
+
dst.write(src.read())
|
| 36 |
+
print("✅ .env file created. Please edit it and add your Groq API key.")
|
| 37 |
+
else:
|
| 38 |
+
print("❌ .env.example not found. Please create .env manually.")
|
| 39 |
+
return False
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def run_streamlit():
|
| 44 |
+
"""Run the Streamlit application."""
|
| 45 |
+
print("\n🚀 Starting Streamlit application...")
|
| 46 |
+
print("📱 Open your browser to: http://localhost:8501")
|
| 47 |
+
subprocess.run([sys.executable, "-m", "streamlit", "run", "streamlit_app.py"])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def run_api():
|
| 51 |
+
"""Run the FastAPI application."""
|
| 52 |
+
print("\n🚀 Starting FastAPI server...")
|
| 53 |
+
print("📱 API available at: http://localhost:8000")
|
| 54 |
+
print("📚 API docs at: http://localhost:8000/docs")
|
| 55 |
+
subprocess.run([sys.executable, "api.py"])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
"""Main function."""
|
| 60 |
+
print("=" * 50)
|
| 61 |
+
print("RAG Capstone Project - Quick Start")
|
| 62 |
+
print("=" * 50)
|
| 63 |
+
|
| 64 |
+
# Check dependencies
|
| 65 |
+
if not check_dependencies():
|
| 66 |
+
print("\n📦 Installing dependencies...")
|
| 67 |
+
install_dependencies()
|
| 68 |
+
|
| 69 |
+
# Check .env file
|
| 70 |
+
env_exists = check_env_file()
|
| 71 |
+
|
| 72 |
+
if not env_exists:
|
| 73 |
+
print("\n❌ Please configure your .env file before running the application.")
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# Ask user what to run
|
| 77 |
+
print("\nWhat would you like to run?")
|
| 78 |
+
print("1. Streamlit Chat Interface (recommended)")
|
| 79 |
+
print("2. FastAPI Backend")
|
| 80 |
+
print("3. Both (requires two terminals)")
|
| 81 |
+
|
| 82 |
+
choice = input("\nEnter your choice (1-3): ").strip()
|
| 83 |
+
|
| 84 |
+
if choice == "1":
|
| 85 |
+
run_streamlit()
|
| 86 |
+
elif choice == "2":
|
| 87 |
+
run_api()
|
| 88 |
+
elif choice == "3":
|
| 89 |
+
print("\n📌 To run both:")
|
| 90 |
+
print("Terminal 1: python api.py")
|
| 91 |
+
print("Terminal 2: streamlit run streamlit_app.py")
|
| 92 |
+
print("\nStarting Streamlit in this terminal...")
|
| 93 |
+
run_streamlit()
|
| 94 |
+
else:
|
| 95 |
+
print("Invalid choice. Exiting.")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
main()
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit chat interface for RAG application."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import json
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
# Suppress warnings
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 14 |
+
|
| 15 |
+
# Add parent directory to path
|
| 16 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
+
|
| 18 |
+
from config import settings
|
| 19 |
+
from dataset_loader import RAGBenchLoader
|
| 20 |
+
from vector_store import ChromaDBManager
|
| 21 |
+
from llm_client import GroqLLMClient, RAGPipeline
|
| 22 |
+
from trace_evaluator import TRACEEvaluator
|
| 23 |
+
from embedding_models import EmbeddingFactory
|
| 24 |
+
from chunking_strategies import ChunkingFactory
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Page configuration
|
| 28 |
+
st.set_page_config(
|
| 29 |
+
page_title="RAG Capstone Project",
|
| 30 |
+
page_icon="🤖",
|
| 31 |
+
layout="wide"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Initialize session state
|
| 35 |
+
if "chat_history" not in st.session_state:
|
| 36 |
+
st.session_state.chat_history = []
|
| 37 |
+
|
| 38 |
+
if "rag_pipeline" not in st.session_state:
|
| 39 |
+
st.session_state.rag_pipeline = None
|
| 40 |
+
|
| 41 |
+
if "vector_store" not in st.session_state:
|
| 42 |
+
st.session_state.vector_store = None
|
| 43 |
+
|
| 44 |
+
if "collection_loaded" not in st.session_state:
|
| 45 |
+
st.session_state.collection_loaded = False
|
| 46 |
+
|
| 47 |
+
if "evaluation_results" not in st.session_state:
|
| 48 |
+
st.session_state.evaluation_results = None
|
| 49 |
+
|
| 50 |
+
if "dataset_size" not in st.session_state:
|
| 51 |
+
st.session_state.dataset_size = 10000
|
| 52 |
+
|
| 53 |
+
if "current_dataset" not in st.session_state:
|
| 54 |
+
st.session_state.current_dataset = None
|
| 55 |
+
|
| 56 |
+
if "current_llm" not in st.session_state:
|
| 57 |
+
st.session_state.current_llm = settings.llm_models[1]
|
| 58 |
+
|
| 59 |
+
if "selected_collection" not in st.session_state:
|
| 60 |
+
st.session_state.selected_collection = None
|
| 61 |
+
|
| 62 |
+
if "available_collections" not in st.session_state:
|
| 63 |
+
st.session_state.available_collections = []
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_available_collections():
|
| 67 |
+
"""Get list of available collections from ChromaDB."""
|
| 68 |
+
try:
|
| 69 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 70 |
+
collections = vector_store.list_collections()
|
| 71 |
+
return collections
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error getting collections: {e}")
|
| 74 |
+
return []
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main():
|
| 78 |
+
"""Main Streamlit application."""
|
| 79 |
+
st.title("🤖 RAG Capstone Project")
|
| 80 |
+
st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation")
|
| 81 |
+
|
| 82 |
+
# Get available collections at startup
|
| 83 |
+
available_collections = get_available_collections()
|
| 84 |
+
st.session_state.available_collections = available_collections
|
| 85 |
+
|
| 86 |
+
# Sidebar for configuration
|
| 87 |
+
with st.sidebar:
|
| 88 |
+
st.header("Configuration")
|
| 89 |
+
|
| 90 |
+
# API Key input
|
| 91 |
+
groq_api_key = st.text_input(
|
| 92 |
+
"Groq API Key",
|
| 93 |
+
type="password",
|
| 94 |
+
value=settings.groq_api_key or "",
|
| 95 |
+
help="Enter your Groq API key"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
st.divider()
|
| 99 |
+
|
| 100 |
+
# Option 1: Use existing collection
|
| 101 |
+
if available_collections:
|
| 102 |
+
st.subheader("📚 Existing Collections")
|
| 103 |
+
st.write(f"Found {len(available_collections)} collection(s)")
|
| 104 |
+
|
| 105 |
+
selected_collection = st.selectbox(
|
| 106 |
+
"Or select existing collection:",
|
| 107 |
+
available_collections,
|
| 108 |
+
key="collection_selector"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if st.button("📖 Load Existing Collection", type="secondary"):
|
| 112 |
+
if not groq_api_key:
|
| 113 |
+
st.error("Please enter your Groq API key")
|
| 114 |
+
else:
|
| 115 |
+
load_existing_collection(groq_api_key, selected_collection)
|
| 116 |
+
|
| 117 |
+
st.divider()
|
| 118 |
+
|
| 119 |
+
# Option 2: Create new collection
|
| 120 |
+
st.subheader("🆕 Create New Collection")
|
| 121 |
+
|
| 122 |
+
# Dataset selection
|
| 123 |
+
st.subheader("1. Dataset Selection")
|
| 124 |
+
dataset_name = st.selectbox(
|
| 125 |
+
"Choose Dataset",
|
| 126 |
+
settings.ragbench_datasets,
|
| 127 |
+
index=0
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Get dataset size dynamically
|
| 131 |
+
if st.button("🔍 Check Dataset Size", key="check_size"):
|
| 132 |
+
with st.spinner("Checking dataset size..."):
|
| 133 |
+
try:
|
| 134 |
+
from datasets import load_dataset
|
| 135 |
+
import os
|
| 136 |
+
|
| 137 |
+
# Load dataset with download_mode to avoid cache issues
|
| 138 |
+
st.info(f"Fetching dataset info for '{dataset_name}'...")
|
| 139 |
+
ds = load_dataset(
|
| 140 |
+
"rungalileo/ragbench",
|
| 141 |
+
dataset_name,
|
| 142 |
+
split="train",
|
| 143 |
+
trust_remote_code=True,
|
| 144 |
+
download_mode="force_redownload" # Force fresh download to avoid cache corruption
|
| 145 |
+
)
|
| 146 |
+
dataset_size = len(ds)
|
| 147 |
+
|
| 148 |
+
st.session_state.dataset_size = dataset_size
|
| 149 |
+
st.session_state.current_dataset = dataset_name
|
| 150 |
+
st.success(f"✅ Dataset '{dataset_name}' has {dataset_size:,} samples available")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
st.error(f"❌ Error: {str(e)}")
|
| 153 |
+
st.exception(e)
|
| 154 |
+
st.warning(f"Could not determine dataset size. Using default of 10,000.")
|
| 155 |
+
st.session_state.dataset_size = 10000
|
| 156 |
+
st.session_state.current_dataset = dataset_name
|
| 157 |
+
|
| 158 |
+
# Use stored dataset size or default
|
| 159 |
+
max_samples_available = st.session_state.get('dataset_size', 10000)
|
| 160 |
+
|
| 161 |
+
st.caption(f"Max available samples: {max_samples_available:,}")
|
| 162 |
+
|
| 163 |
+
num_samples = st.slider(
|
| 164 |
+
"Number of samples",
|
| 165 |
+
min_value=10,
|
| 166 |
+
max_value=max_samples_available,
|
| 167 |
+
value=min(100, max_samples_available),
|
| 168 |
+
step=50 if max_samples_available > 1000 else 10,
|
| 169 |
+
help="Adjust slider to select number of samples"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
load_all_samples = st.checkbox(
|
| 173 |
+
"Load all available samples",
|
| 174 |
+
value=False,
|
| 175 |
+
help="Override slider and load entire dataset"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
st.divider()
|
| 179 |
+
|
| 180 |
+
# Chunking strategy
|
| 181 |
+
st.subheader("2. Chunking Strategy")
|
| 182 |
+
chunking_strategy = st.selectbox(
|
| 183 |
+
"Choose Chunking Strategy",
|
| 184 |
+
settings.chunking_strategies,
|
| 185 |
+
index=0
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
chunk_size = st.slider(
|
| 189 |
+
"Chunk Size",
|
| 190 |
+
min_value=256,
|
| 191 |
+
max_value=1024,
|
| 192 |
+
value=512,
|
| 193 |
+
step=128
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
overlap = st.slider(
|
| 197 |
+
"Overlap",
|
| 198 |
+
min_value=0,
|
| 199 |
+
max_value=200,
|
| 200 |
+
value=50,
|
| 201 |
+
step=10
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
st.divider()
|
| 205 |
+
|
| 206 |
+
# Embedding model
|
| 207 |
+
st.subheader("3. Embedding Model")
|
| 208 |
+
embedding_model = st.selectbox(
|
| 209 |
+
"Choose Embedding Model",
|
| 210 |
+
settings.embedding_models,
|
| 211 |
+
index=0
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
st.divider()
|
| 215 |
+
|
| 216 |
+
# LLM model selection for new collection
|
| 217 |
+
st.subheader("4. LLM Model")
|
| 218 |
+
llm_model = st.selectbox(
|
| 219 |
+
"Choose LLM",
|
| 220 |
+
settings.llm_models,
|
| 221 |
+
index=1
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
st.divider()
|
| 225 |
+
|
| 226 |
+
# Load data button
|
| 227 |
+
if st.button("🚀 Load Data & Create Collection", type="primary"):
|
| 228 |
+
if not groq_api_key:
|
| 229 |
+
st.error("Please enter your Groq API key")
|
| 230 |
+
else:
|
| 231 |
+
# Use None for num_samples if loading all data
|
| 232 |
+
samples_to_load = None if load_all_samples else num_samples
|
| 233 |
+
load_and_create_collection(
|
| 234 |
+
groq_api_key,
|
| 235 |
+
dataset_name,
|
| 236 |
+
samples_to_load,
|
| 237 |
+
chunking_strategy,
|
| 238 |
+
chunk_size,
|
| 239 |
+
overlap,
|
| 240 |
+
embedding_model,
|
| 241 |
+
llm_model
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Main content area
|
| 245 |
+
if not st.session_state.collection_loaded:
|
| 246 |
+
st.info("👈 Please configure and load a dataset from the sidebar to begin")
|
| 247 |
+
|
| 248 |
+
# Show instructions
|
| 249 |
+
with st.expander("📖 How to Use", expanded=True):
|
| 250 |
+
st.markdown("""
|
| 251 |
+
1. **Enter your Groq API Key** in the sidebar
|
| 252 |
+
2. **Select a dataset** from RAG Bench
|
| 253 |
+
3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking)
|
| 254 |
+
4. **Select an embedding model** for document vectorization
|
| 255 |
+
5. **Choose an LLM model** for response generation
|
| 256 |
+
6. **Click "Load Data & Create Collection"** to initialize
|
| 257 |
+
7. **Start chatting** in the chat interface
|
| 258 |
+
8. **View retrieved documents** and evaluation metrics
|
| 259 |
+
9. **Run TRACE evaluation** on test data
|
| 260 |
+
""")
|
| 261 |
+
|
| 262 |
+
# Show available options
|
| 263 |
+
col1, col2 = st.columns(2)
|
| 264 |
+
|
| 265 |
+
with col1:
|
| 266 |
+
st.subheader("📊 Available Datasets")
|
| 267 |
+
for ds in settings.ragbench_datasets:
|
| 268 |
+
st.markdown(f"- {ds}")
|
| 269 |
+
|
| 270 |
+
with col2:
|
| 271 |
+
st.subheader("🤖 Available Models")
|
| 272 |
+
st.markdown("**Embedding Models:**")
|
| 273 |
+
for em in settings.embedding_models:
|
| 274 |
+
st.markdown(f"- {em}")
|
| 275 |
+
|
| 276 |
+
st.markdown("**LLM Models:**")
|
| 277 |
+
for lm in settings.llm_models:
|
| 278 |
+
st.markdown(f"- {lm}")
|
| 279 |
+
|
| 280 |
+
else:
|
| 281 |
+
# Create tabs for different functionalities
|
| 282 |
+
tab1, tab2, tab3 = st.tabs(["💬 Chat", "📊 Evaluation", "📜 History"])
|
| 283 |
+
|
| 284 |
+
with tab1:
|
| 285 |
+
chat_interface()
|
| 286 |
+
|
| 287 |
+
with tab2:
|
| 288 |
+
evaluation_interface()
|
| 289 |
+
|
| 290 |
+
with tab3:
|
| 291 |
+
history_interface()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def load_existing_collection(api_key: str, collection_name: str):
|
| 295 |
+
"""Load an existing collection from ChromaDB."""
|
| 296 |
+
with st.spinner(f"Loading collection '{collection_name}'..."):
|
| 297 |
+
try:
|
| 298 |
+
# Initialize vector store and get collection
|
| 299 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 300 |
+
vector_store.get_collection(collection_name)
|
| 301 |
+
|
| 302 |
+
# Prompt for LLM selection
|
| 303 |
+
st.session_state.current_llm = st.selectbox(
|
| 304 |
+
"Select LLM for this collection:",
|
| 305 |
+
settings.llm_models,
|
| 306 |
+
key=f"llm_selector_{collection_name}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Initialize LLM client
|
| 310 |
+
st.info("Initializing LLM client...")
|
| 311 |
+
llm_client = GroqLLMClient(
|
| 312 |
+
api_key=api_key,
|
| 313 |
+
model_name=st.session_state.current_llm,
|
| 314 |
+
max_rpm=settings.groq_rpm_limit,
|
| 315 |
+
rate_limit_delay=settings.rate_limit_delay
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Create RAG pipeline with correct parameter names
|
| 319 |
+
st.info("Creating RAG pipeline...")
|
| 320 |
+
rag_pipeline = RAGPipeline(
|
| 321 |
+
llm_client=llm_client,
|
| 322 |
+
vector_store_manager=vector_store
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Store in session state
|
| 326 |
+
st.session_state.vector_store = vector_store
|
| 327 |
+
st.session_state.rag_pipeline = rag_pipeline
|
| 328 |
+
st.session_state.collection_loaded = True
|
| 329 |
+
st.session_state.current_collection = collection_name
|
| 330 |
+
st.session_state.selected_collection = collection_name
|
| 331 |
+
st.session_state.groq_api_key = api_key
|
| 332 |
+
|
| 333 |
+
st.success(f"✅ Collection '{collection_name}' loaded successfully!")
|
| 334 |
+
st.rerun()
|
| 335 |
+
|
| 336 |
+
except Exception as e:
|
| 337 |
+
st.error(f"Error loading collection: {str(e)}")
|
| 338 |
+
st.exception(e)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def load_and_create_collection(
|
| 342 |
+
api_key: str,
|
| 343 |
+
dataset_name: str,
|
| 344 |
+
num_samples: Optional[int],
|
| 345 |
+
chunking_strategy: str,
|
| 346 |
+
chunk_size: int,
|
| 347 |
+
overlap: int,
|
| 348 |
+
embedding_model: str,
|
| 349 |
+
llm_model: str
|
| 350 |
+
):
|
| 351 |
+
"""Load dataset and create vector collection."""
|
| 352 |
+
with st.spinner("Loading dataset and creating collection..."):
|
| 353 |
+
try:
|
| 354 |
+
# Initialize dataset loader
|
| 355 |
+
loader = RAGBenchLoader()
|
| 356 |
+
|
| 357 |
+
# Load dataset
|
| 358 |
+
if num_samples is None:
|
| 359 |
+
st.info(f"Loading {dataset_name} dataset (all available samples)...")
|
| 360 |
+
else:
|
| 361 |
+
st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...")
|
| 362 |
+
dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
|
| 363 |
+
st.info(f"Loading {dataset_name} dataset...")
|
| 364 |
+
dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
|
| 365 |
+
|
| 366 |
+
if not dataset:
|
| 367 |
+
st.error("Failed to load dataset")
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
# Initialize vector store
|
| 371 |
+
st.info("Initializing vector store...")
|
| 372 |
+
vector_store = ChromaDBManager(settings.chroma_persist_directory)
|
| 373 |
+
|
| 374 |
+
# Create collection name
|
| 375 |
+
collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}"
|
| 376 |
+
collection_name = collection_name.replace("-", "_").replace(".", "_")
|
| 377 |
+
|
| 378 |
+
# Delete existing collection with same name (if exists)
|
| 379 |
+
existing_collections = vector_store.list_collections()
|
| 380 |
+
if collection_name in existing_collections:
|
| 381 |
+
st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...")
|
| 382 |
+
vector_store.delete_collection(collection_name)
|
| 383 |
+
st.info("Old collection deleted. Creating new one...")
|
| 384 |
+
|
| 385 |
+
# Load data into collection
|
| 386 |
+
st.info(f"Creating collection with {chunking_strategy} chunking...")
|
| 387 |
+
vector_store.load_dataset_into_collection(
|
| 388 |
+
collection_name=collection_name,
|
| 389 |
+
embedding_model_name=embedding_model,
|
| 390 |
+
chunking_strategy=chunking_strategy,
|
| 391 |
+
dataset_data=dataset,
|
| 392 |
+
chunk_size=chunk_size,
|
| 393 |
+
overlap=overlap
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Initialize LLM client
|
| 397 |
+
st.info("Initializing LLM client...")
|
| 398 |
+
llm_client = GroqLLMClient(
|
| 399 |
+
api_key=api_key,
|
| 400 |
+
model_name=llm_model,
|
| 401 |
+
max_rpm=settings.groq_rpm_limit,
|
| 402 |
+
rate_limit_delay=settings.rate_limit_delay
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Create RAG pipeline with correct parameter names
|
| 406 |
+
rag_pipeline = RAGPipeline(
|
| 407 |
+
llm_client=llm_client,
|
| 408 |
+
vector_store_manager=vector_store
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Store in session state
|
| 412 |
+
st.session_state.vector_store = vector_store
|
| 413 |
+
st.session_state.rag_pipeline = rag_pipeline
|
| 414 |
+
st.session_state.collection_loaded = True
|
| 415 |
+
st.session_state.current_collection = collection_name
|
| 416 |
+
st.session_state.dataset_name = dataset_name
|
| 417 |
+
st.session_state.dataset = dataset
|
| 418 |
+
|
| 419 |
+
st.success(f"✅ Collection '{collection_name}' created successfully!")
|
| 420 |
+
st.rerun()
|
| 421 |
+
|
| 422 |
+
except Exception as e:
|
| 423 |
+
st.error(f"Error: {str(e)}")
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def chat_interface():
|
| 427 |
+
"""Chat interface tab."""
|
| 428 |
+
st.subheader("💬 Chat Interface")
|
| 429 |
+
|
| 430 |
+
# Check if collection is loaded
|
| 431 |
+
if not st.session_state.collection_loaded:
|
| 432 |
+
st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.")
|
| 433 |
+
st.info("""
|
| 434 |
+
Steps:
|
| 435 |
+
1. Select a dataset from the dropdown
|
| 436 |
+
2. Click "Load Data & Create Collection" button
|
| 437 |
+
3. Wait for the collection to be created
|
| 438 |
+
4. Then you can start chatting
|
| 439 |
+
""")
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
# Display collection info and LLM selector
|
| 443 |
+
col1, col2, col3 = st.columns([2, 2, 1])
|
| 444 |
+
with col1:
|
| 445 |
+
st.info(f"📚 Collection: {st.session_state.current_collection}")
|
| 446 |
+
|
| 447 |
+
with col2:
|
| 448 |
+
# LLM selector for chat
|
| 449 |
+
selected_llm = st.selectbox(
|
| 450 |
+
"Select LLM for chat:",
|
| 451 |
+
settings.llm_models,
|
| 452 |
+
index=settings.llm_models.index(st.session_state.current_llm),
|
| 453 |
+
key="chat_llm_selector"
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if selected_llm != st.session_state.current_llm:
|
| 457 |
+
st.session_state.current_llm = selected_llm
|
| 458 |
+
# Recreate RAG pipeline with new LLM
|
| 459 |
+
llm_client = GroqLLMClient(
|
| 460 |
+
api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "",
|
| 461 |
+
model_name=selected_llm,
|
| 462 |
+
max_rpm=settings.groq_rpm_limit,
|
| 463 |
+
rate_limit_delay=settings.rate_limit_delay
|
| 464 |
+
)
|
| 465 |
+
st.session_state.rag_pipeline.llm_client = llm_client
|
| 466 |
+
|
| 467 |
+
with col3:
|
| 468 |
+
if st.button("🗑️ Clear History"):
|
| 469 |
+
st.session_state.chat_history = []
|
| 470 |
+
st.session_state.rag_pipeline.clear_history()
|
| 471 |
+
st.rerun()
|
| 472 |
+
|
| 473 |
+
# Chat container
|
| 474 |
+
chat_container = st.container()
|
| 475 |
+
|
| 476 |
+
# Display chat history
|
| 477 |
+
with chat_container:
|
| 478 |
+
for chat_idx, entry in enumerate(st.session_state.chat_history):
|
| 479 |
+
# User message
|
| 480 |
+
with st.chat_message("user"):
|
| 481 |
+
st.write(entry["query"])
|
| 482 |
+
|
| 483 |
+
# Assistant message
|
| 484 |
+
with st.chat_message("assistant"):
|
| 485 |
+
st.write(entry["response"])
|
| 486 |
+
|
| 487 |
+
# Show retrieved documents in expander
|
| 488 |
+
with st.expander("📄 Retrieved Documents"):
|
| 489 |
+
for doc_idx, doc in enumerate(entry["retrieved_documents"]):
|
| 490 |
+
st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
|
| 491 |
+
st.text_area(
|
| 492 |
+
f"doc_{chat_idx}_{doc_idx}",
|
| 493 |
+
value=doc["document"],
|
| 494 |
+
height=100,
|
| 495 |
+
key=f"doc_area_{chat_idx}_{doc_idx}",
|
| 496 |
+
label_visibility="collapsed"
|
| 497 |
+
)
|
| 498 |
+
if doc.get("metadata"):
|
| 499 |
+
st.caption(f"Metadata: {doc['metadata']}")
|
| 500 |
+
|
| 501 |
+
# Chat input
|
| 502 |
+
query = st.chat_input("Ask a question...")
|
| 503 |
+
|
| 504 |
+
if query:
|
| 505 |
+
# Check if collection exists
|
| 506 |
+
if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection:
|
| 507 |
+
st.error("❌ No data loaded. Please load a dataset first using the configuration panel.")
|
| 508 |
+
st.stop()
|
| 509 |
+
|
| 510 |
+
# Add user message
|
| 511 |
+
with chat_container:
|
| 512 |
+
with st.chat_message("user"):
|
| 513 |
+
st.write(query)
|
| 514 |
+
|
| 515 |
+
# Generate response
|
| 516 |
+
with st.spinner("Generating response..."):
|
| 517 |
+
try:
|
| 518 |
+
result = st.session_state.rag_pipeline.query(query)
|
| 519 |
+
except Exception as e:
|
| 520 |
+
st.error(f"❌ Error querying: {str(e)}")
|
| 521 |
+
st.info("Please load a dataset and create a collection first.")
|
| 522 |
+
st.stop()
|
| 523 |
+
|
| 524 |
+
# Add assistant message
|
| 525 |
+
with chat_container:
|
| 526 |
+
with st.chat_message("assistant"):
|
| 527 |
+
st.write(result["response"])
|
| 528 |
+
|
| 529 |
+
# Show retrieved documents
|
| 530 |
+
with st.expander("📄 Retrieved Documents"):
|
| 531 |
+
for doc_idx, doc in enumerate(result["retrieved_documents"]):
|
| 532 |
+
st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
|
| 533 |
+
st.text_area(
|
| 534 |
+
f"doc_current_{doc_idx}",
|
| 535 |
+
value=doc["document"],
|
| 536 |
+
height=100,
|
| 537 |
+
key=f"doc_current_area_{doc_idx}",
|
| 538 |
+
label_visibility="collapsed"
|
| 539 |
+
)
|
| 540 |
+
if doc.get("metadata"):
|
| 541 |
+
st.caption(f"Metadata: {doc['metadata']}")
|
| 542 |
+
|
| 543 |
+
# Store in history
|
| 544 |
+
st.session_state.chat_history.append(result)
|
| 545 |
+
st.rerun()
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def evaluation_interface():
|
| 549 |
+
"""Evaluation interface tab."""
|
| 550 |
+
st.subheader("📊 TRACE Evaluation")
|
| 551 |
+
|
| 552 |
+
# Check if collection is loaded
|
| 553 |
+
if not st.session_state.collection_loaded:
|
| 554 |
+
st.warning("⚠️ No data loaded. Please load a collection first.")
|
| 555 |
+
return
|
| 556 |
+
|
| 557 |
+
# LLM selector for evaluation
|
| 558 |
+
col1, col2 = st.columns([3, 1])
|
| 559 |
+
with col1:
|
| 560 |
+
selected_llm = st.selectbox(
|
| 561 |
+
"Select LLM for evaluation:",
|
| 562 |
+
settings.llm_models,
|
| 563 |
+
index=settings.llm_models.index(st.session_state.current_llm),
|
| 564 |
+
key="eval_llm_selector"
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
st.markdown("""
|
| 568 |
+
Run TRACE evaluation metrics on test data:
|
| 569 |
+
- **Utilization**: How well the system uses retrieved documents
|
| 570 |
+
- **Relevance**: Relevance of retrieved documents to the query
|
| 571 |
+
- **Adherence**: How well the response adheres to the retrieved context
|
| 572 |
+
- **Completeness**: How complete the response is in answering the query
|
| 573 |
+
""")
|
| 574 |
+
|
| 575 |
+
num_test_samples = st.slider(
|
| 576 |
+
"Number of test samples",
|
| 577 |
+
min_value=5,
|
| 578 |
+
max_value=50,
|
| 579 |
+
value=10,
|
| 580 |
+
step=5
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
if st.button("🔬 Run Evaluation", type="primary"):
|
| 584 |
+
# Use selected LLM for evaluation
|
| 585 |
+
run_evaluation(num_test_samples, selected_llm)
|
| 586 |
+
|
| 587 |
+
# Display results
|
| 588 |
+
if st.session_state.evaluation_results:
|
| 589 |
+
results = st.session_state.evaluation_results
|
| 590 |
+
|
| 591 |
+
st.success("✅ Evaluation Complete!")
|
| 592 |
+
|
| 593 |
+
# Display aggregate scores
|
| 594 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
| 595 |
+
|
| 596 |
+
with col1:
|
| 597 |
+
st.metric("📊 Utilization", f"{results['utilization']:.3f}")
|
| 598 |
+
with col2:
|
| 599 |
+
st.metric("🎯 Relevance", f"{results['relevance']:.3f}")
|
| 600 |
+
with col3:
|
| 601 |
+
st.metric("✅ Adherence", f"{results['adherence']:.3f}")
|
| 602 |
+
with col4:
|
| 603 |
+
st.metric("📝 Completeness", f"{results['completeness']:.3f}")
|
| 604 |
+
with col5:
|
| 605 |
+
st.metric("⭐ Average", f"{results['average']:.3f}")
|
| 606 |
+
|
| 607 |
+
# Detailed results
|
| 608 |
+
with st.expander("📋 Detailed Results"):
|
| 609 |
+
df = pd.DataFrame(results["individual_scores"])
|
| 610 |
+
st.dataframe(df, use_container_width=True)
|
| 611 |
+
|
| 612 |
+
# Download results
|
| 613 |
+
results_json = json.dumps(results, indent=2)
|
| 614 |
+
st.download_button(
|
| 615 |
+
label="💾 Download Results (JSON)",
|
| 616 |
+
data=results_json,
|
| 617 |
+
file_name=f"trace_evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
| 618 |
+
mime="application/json"
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def run_evaluation(num_samples: int, selected_llm: str = None):
|
| 623 |
+
"""Run TRACE evaluation."""
|
| 624 |
+
with st.spinner(f"Running evaluation on {num_samples} samples..."):
|
| 625 |
+
try:
|
| 626 |
+
# Use selected LLM if provided
|
| 627 |
+
if selected_llm and selected_llm != st.session_state.current_llm:
|
| 628 |
+
st.info(f"Switching to {selected_llm} for evaluation...")
|
| 629 |
+
groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else ""
|
| 630 |
+
eval_llm_client = GroqLLMClient(
|
| 631 |
+
api_key=groq_api_key,
|
| 632 |
+
model_name=selected_llm,
|
| 633 |
+
max_rpm=settings.groq_rpm_limit,
|
| 634 |
+
rate_limit_delay=settings.rate_limit_delay
|
| 635 |
+
)
|
| 636 |
+
# Temporarily replace LLM client
|
| 637 |
+
original_llm = st.session_state.rag_pipeline.llm_client
|
| 638 |
+
st.session_state.rag_pipeline.llm_client = eval_llm_client
|
| 639 |
+
|
| 640 |
+
# Get test data
|
| 641 |
+
loader = RAGBenchLoader()
|
| 642 |
+
test_data = loader.get_test_data(
|
| 643 |
+
st.session_state.dataset_name,
|
| 644 |
+
num_samples
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
# Prepare test cases
|
| 648 |
+
test_cases = []
|
| 649 |
+
|
| 650 |
+
progress_bar = st.progress(0)
|
| 651 |
+
|
| 652 |
+
for i, sample in enumerate(test_data):
|
| 653 |
+
# Query the RAG system
|
| 654 |
+
result = st.session_state.rag_pipeline.query(
|
| 655 |
+
sample["question"],
|
| 656 |
+
n_results=5
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Prepare test case
|
| 660 |
+
test_cases.append({
|
| 661 |
+
"query": sample["question"],
|
| 662 |
+
"response": result["response"],
|
| 663 |
+
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
|
| 664 |
+
"ground_truth": sample.get("answer", "")
|
| 665 |
+
})
|
| 666 |
+
|
| 667 |
+
# Update progress
|
| 668 |
+
progress_bar.progress((i + 1) / num_samples)
|
| 669 |
+
|
| 670 |
+
# Run evaluation
|
| 671 |
+
evaluator = TRACEEvaluator()
|
| 672 |
+
results = evaluator.evaluate_batch(test_cases)
|
| 673 |
+
|
| 674 |
+
st.session_state.evaluation_results = results
|
| 675 |
+
|
| 676 |
+
# Restore original LLM if it was switched
|
| 677 |
+
if selected_llm and selected_llm != st.session_state.current_llm:
|
| 678 |
+
st.session_state.rag_pipeline.llm_client = original_llm
|
| 679 |
+
|
| 680 |
+
except Exception as e:
|
| 681 |
+
st.error(f"Error during evaluation: {str(e)}")
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def history_interface():
|
| 685 |
+
"""History interface tab."""
|
| 686 |
+
st.subheader("📜 Chat History")
|
| 687 |
+
|
| 688 |
+
if not st.session_state.chat_history:
|
| 689 |
+
st.info("No chat history yet. Start a conversation in the Chat tab!")
|
| 690 |
+
return
|
| 691 |
+
|
| 692 |
+
# Export history
|
| 693 |
+
col1, col2 = st.columns([3, 1])
|
| 694 |
+
with col2:
|
| 695 |
+
history_json = json.dumps(st.session_state.chat_history, indent=2)
|
| 696 |
+
st.download_button(
|
| 697 |
+
label="💾 Export History",
|
| 698 |
+
data=history_json,
|
| 699 |
+
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
| 700 |
+
mime="application/json"
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# Display history
|
| 704 |
+
for i, entry in enumerate(st.session_state.chat_history):
|
| 705 |
+
with st.expander(f"💬 Conversation {i+1}: {entry['query'][:50]}..."):
|
| 706 |
+
st.markdown(f"**Query:** {entry['query']}")
|
| 707 |
+
st.markdown(f"**Response:** {entry['response']}")
|
| 708 |
+
st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}")
|
| 709 |
+
|
| 710 |
+
st.markdown("**Retrieved Documents:**")
|
| 711 |
+
for j, doc in enumerate(entry["retrieved_documents"]):
|
| 712 |
+
st.text_area(
|
| 713 |
+
f"Document {j+1}",
|
| 714 |
+
value=doc["document"],
|
| 715 |
+
height=100,
|
| 716 |
+
key=f"history_doc_{i}_{j}"
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
if __name__ == "__main__":
|
| 721 |
+
main()
|
trace_evaluator.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TRACE evaluation metrics for RAG systems.
|
| 2 |
+
|
| 3 |
+
TRACE Metrics:
|
| 4 |
+
- uTilization: How well the system uses retrieved documents
|
| 5 |
+
- Relevance: Relevance of retrieved documents to the query
|
| 6 |
+
- Adherence: How well the response adheres to the retrieved context
|
| 7 |
+
- Completeness: How complete the response is in answering the query
|
| 8 |
+
"""
|
| 9 |
+
from typing import List, Dict, Optional
|
| 10 |
+
import numpy as np
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import re
|
| 13 |
+
from collections import Counter
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TRACEScores:
|
| 18 |
+
"""Container for TRACE evaluation scores."""
|
| 19 |
+
utilization: float
|
| 20 |
+
relevance: float
|
| 21 |
+
adherence: float
|
| 22 |
+
completeness: float
|
| 23 |
+
|
| 24 |
+
def to_dict(self) -> Dict:
|
| 25 |
+
"""Convert to dictionary."""
|
| 26 |
+
return {
|
| 27 |
+
"utilization": self.utilization,
|
| 28 |
+
"relevance": self.relevance,
|
| 29 |
+
"adherence": self.adherence,
|
| 30 |
+
"completeness": self.completeness,
|
| 31 |
+
"average": self.average()
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def average(self) -> float:
|
| 35 |
+
"""Calculate average score."""
|
| 36 |
+
return (self.utilization + self.relevance +
|
| 37 |
+
self.adherence + self.completeness) / 4
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TRACEEvaluator:
|
| 41 |
+
"""TRACE evaluation metrics for RAG systems."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, llm_client=None):
|
| 44 |
+
"""Initialize TRACE evaluator.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
llm_client: Optional LLM client for LLM-based evaluation
|
| 48 |
+
"""
|
| 49 |
+
self.llm_client = llm_client
|
| 50 |
+
|
| 51 |
+
def evaluate(
|
| 52 |
+
self,
|
| 53 |
+
query: str,
|
| 54 |
+
response: str,
|
| 55 |
+
retrieved_documents: List[str],
|
| 56 |
+
ground_truth: Optional[str] = None
|
| 57 |
+
) -> TRACEScores:
|
| 58 |
+
"""Evaluate a RAG response using TRACE metrics.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
query: User query
|
| 62 |
+
response: Generated response
|
| 63 |
+
retrieved_documents: List of retrieved documents
|
| 64 |
+
ground_truth: Optional ground truth answer
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
TRACEScores object
|
| 68 |
+
"""
|
| 69 |
+
utilization = self._compute_utilization(response, retrieved_documents)
|
| 70 |
+
relevance = self._compute_relevance(query, retrieved_documents)
|
| 71 |
+
adherence = self._compute_adherence(response, retrieved_documents)
|
| 72 |
+
completeness = self._compute_completeness(query, response, ground_truth)
|
| 73 |
+
|
| 74 |
+
return TRACEScores(
|
| 75 |
+
utilization=utilization,
|
| 76 |
+
relevance=relevance,
|
| 77 |
+
adherence=adherence,
|
| 78 |
+
completeness=completeness
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _compute_utilization(
|
| 82 |
+
self,
|
| 83 |
+
response: str,
|
| 84 |
+
retrieved_documents: List[str]
|
| 85 |
+
) -> float:
|
| 86 |
+
"""Compute utilization score.
|
| 87 |
+
|
| 88 |
+
Measures how well the system uses retrieved documents.
|
| 89 |
+
Score based on:
|
| 90 |
+
- Number of documents that contributed to the response
|
| 91 |
+
- Proportion of retrieved documents used
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
response: Generated response
|
| 95 |
+
retrieved_documents: List of retrieved documents
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Utilization score (0-1)
|
| 99 |
+
"""
|
| 100 |
+
if not retrieved_documents or not response:
|
| 101 |
+
return 0.0
|
| 102 |
+
|
| 103 |
+
response_lower = response.lower()
|
| 104 |
+
response_words = set(self._tokenize(response_lower))
|
| 105 |
+
|
| 106 |
+
# Count how many documents contributed
|
| 107 |
+
docs_used = 0
|
| 108 |
+
total_overlap = 0
|
| 109 |
+
|
| 110 |
+
for doc in retrieved_documents:
|
| 111 |
+
doc_lower = doc.lower()
|
| 112 |
+
doc_words = set(self._tokenize(doc_lower))
|
| 113 |
+
|
| 114 |
+
# Check for significant overlap
|
| 115 |
+
overlap = len(response_words & doc_words)
|
| 116 |
+
if overlap > 5: # Threshold for significant contribution
|
| 117 |
+
docs_used += 1
|
| 118 |
+
total_overlap += overlap
|
| 119 |
+
|
| 120 |
+
# Score based on proportion of documents used
|
| 121 |
+
proportion_used = docs_used / len(retrieved_documents)
|
| 122 |
+
|
| 123 |
+
# Also consider depth of utilization
|
| 124 |
+
avg_overlap = total_overlap / len(retrieved_documents) if retrieved_documents else 0
|
| 125 |
+
depth_score = min(avg_overlap / 20, 1.0) # Normalize
|
| 126 |
+
|
| 127 |
+
# Combined score
|
| 128 |
+
utilization_score = 0.6 * proportion_used + 0.4 * depth_score
|
| 129 |
+
|
| 130 |
+
return min(utilization_score, 1.0)
|
| 131 |
+
|
| 132 |
+
def _compute_relevance(
|
| 133 |
+
self,
|
| 134 |
+
query: str,
|
| 135 |
+
retrieved_documents: List[str]
|
| 136 |
+
) -> float:
|
| 137 |
+
"""Compute relevance score.
|
| 138 |
+
|
| 139 |
+
Measures relevance of retrieved documents to the query.
|
| 140 |
+
Uses lexical overlap and keyword matching.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
query: User query
|
| 144 |
+
retrieved_documents: List of retrieved documents
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Relevance score (0-1)
|
| 148 |
+
"""
|
| 149 |
+
if not retrieved_documents or not query:
|
| 150 |
+
return 0.0
|
| 151 |
+
|
| 152 |
+
query_lower = query.lower()
|
| 153 |
+
query_words = set(self._tokenize(query_lower))
|
| 154 |
+
query_keywords = self._extract_keywords(query_lower)
|
| 155 |
+
|
| 156 |
+
relevance_scores = []
|
| 157 |
+
|
| 158 |
+
for doc in retrieved_documents:
|
| 159 |
+
doc_lower = doc.lower()
|
| 160 |
+
doc_words = set(self._tokenize(doc_lower))
|
| 161 |
+
|
| 162 |
+
# Lexical overlap
|
| 163 |
+
overlap = len(query_words & doc_words)
|
| 164 |
+
overlap_score = overlap / len(query_words) if query_words else 0
|
| 165 |
+
|
| 166 |
+
# Keyword matching
|
| 167 |
+
keyword_matches = sum(1 for kw in query_keywords if kw in doc_lower)
|
| 168 |
+
keyword_score = keyword_matches / len(query_keywords) if query_keywords else 0
|
| 169 |
+
|
| 170 |
+
# Combined relevance for this document
|
| 171 |
+
doc_relevance = 0.5 * overlap_score + 0.5 * keyword_score
|
| 172 |
+
relevance_scores.append(doc_relevance)
|
| 173 |
+
|
| 174 |
+
# Average relevance across documents
|
| 175 |
+
return np.mean(relevance_scores)
|
| 176 |
+
|
| 177 |
+
def _compute_adherence(
|
| 178 |
+
self,
|
| 179 |
+
response: str,
|
| 180 |
+
retrieved_documents: List[str]
|
| 181 |
+
) -> float:
|
| 182 |
+
"""Compute adherence score.
|
| 183 |
+
|
| 184 |
+
Measures how well the response adheres to the retrieved context.
|
| 185 |
+
Higher score means response is grounded in the documents.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
response: Generated response
|
| 189 |
+
retrieved_documents: List of retrieved documents
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Adherence score (0-1)
|
| 193 |
+
"""
|
| 194 |
+
if not retrieved_documents or not response:
|
| 195 |
+
return 0.0
|
| 196 |
+
|
| 197 |
+
# Combine all documents
|
| 198 |
+
combined_docs = " ".join(retrieved_documents).lower()
|
| 199 |
+
doc_words = set(self._tokenize(combined_docs))
|
| 200 |
+
|
| 201 |
+
# Analyze response
|
| 202 |
+
response_lower = response.lower()
|
| 203 |
+
response_sentences = self._split_sentences(response_lower)
|
| 204 |
+
|
| 205 |
+
adherence_scores = []
|
| 206 |
+
|
| 207 |
+
for sentence in response_sentences:
|
| 208 |
+
sentence_words = set(self._tokenize(sentence))
|
| 209 |
+
|
| 210 |
+
# Check what proportion of sentence words appear in documents
|
| 211 |
+
if sentence_words:
|
| 212 |
+
grounded_words = len(sentence_words & doc_words)
|
| 213 |
+
sentence_adherence = grounded_words / len(sentence_words)
|
| 214 |
+
adherence_scores.append(sentence_adherence)
|
| 215 |
+
|
| 216 |
+
# Average adherence across sentences
|
| 217 |
+
return np.mean(adherence_scores) if adherence_scores else 0.0
|
| 218 |
+
|
| 219 |
+
def _compute_completeness(
|
| 220 |
+
self,
|
| 221 |
+
query: str,
|
| 222 |
+
response: str,
|
| 223 |
+
ground_truth: Optional[str] = None
|
| 224 |
+
) -> float:
|
| 225 |
+
"""Compute completeness score.
|
| 226 |
+
|
| 227 |
+
Measures how complete the response is in answering the query.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
query: User query
|
| 231 |
+
response: Generated response
|
| 232 |
+
ground_truth: Optional ground truth answer
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Completeness score (0-1)
|
| 236 |
+
"""
|
| 237 |
+
if not response or not query:
|
| 238 |
+
return 0.0
|
| 239 |
+
|
| 240 |
+
# Query analysis
|
| 241 |
+
query_lower = query.lower()
|
| 242 |
+
|
| 243 |
+
# Check for question types and expected components
|
| 244 |
+
is_what = any(w in query_lower for w in ["what", "which"])
|
| 245 |
+
is_when = "when" in query_lower
|
| 246 |
+
is_where = "where" in query_lower
|
| 247 |
+
is_who = "who" in query_lower
|
| 248 |
+
is_why = "why" in query_lower
|
| 249 |
+
is_how = "how" in query_lower
|
| 250 |
+
|
| 251 |
+
response_lower = response.lower()
|
| 252 |
+
|
| 253 |
+
# Basic completeness checks
|
| 254 |
+
completeness_factors = []
|
| 255 |
+
|
| 256 |
+
# Length check (not too short)
|
| 257 |
+
min_length = 50
|
| 258 |
+
length_score = min(len(response) / min_length, 1.0)
|
| 259 |
+
completeness_factors.append(length_score)
|
| 260 |
+
|
| 261 |
+
# Check for appropriate response type
|
| 262 |
+
if is_when and any(w in response_lower for w in ["year", "date", "time", "century"]):
|
| 263 |
+
completeness_factors.append(1.0)
|
| 264 |
+
elif is_where and any(w in response_lower for w in ["location", "place", "country", "city"]):
|
| 265 |
+
completeness_factors.append(1.0)
|
| 266 |
+
elif is_who and any(w in response_lower for w in ["person", "people", "name"]):
|
| 267 |
+
completeness_factors.append(1.0)
|
| 268 |
+
|
| 269 |
+
# If ground truth available, compare
|
| 270 |
+
if ground_truth:
|
| 271 |
+
gt_lower = ground_truth.lower()
|
| 272 |
+
gt_words = set(self._tokenize(gt_lower))
|
| 273 |
+
response_words = set(self._tokenize(response_lower))
|
| 274 |
+
|
| 275 |
+
# Check overlap with ground truth
|
| 276 |
+
overlap = len(gt_words & response_words)
|
| 277 |
+
gt_score = overlap / len(gt_words) if gt_words else 0
|
| 278 |
+
completeness_factors.append(gt_score)
|
| 279 |
+
|
| 280 |
+
# Average all factors
|
| 281 |
+
return np.mean(completeness_factors) if completeness_factors else 0.5
|
| 282 |
+
|
| 283 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 284 |
+
"""Tokenize text into words."""
|
| 285 |
+
# Remove punctuation and split
|
| 286 |
+
text = re.sub(r'[^\w\s]', ' ', text)
|
| 287 |
+
words = text.split()
|
| 288 |
+
# Filter out very short words and common stop words
|
| 289 |
+
stop_words = {"a", "an", "the", "is", "are", "was", "were", "in", "on", "at", "to", "for"}
|
| 290 |
+
return [w for w in words if len(w) > 2 and w not in stop_words]
|
| 291 |
+
|
| 292 |
+
def _extract_keywords(self, text: str) -> List[str]:
|
| 293 |
+
"""Extract keywords from text."""
|
| 294 |
+
words = self._tokenize(text)
|
| 295 |
+
# Simple keyword extraction - words that appear in query
|
| 296 |
+
# In production, use TF-IDF or similar
|
| 297 |
+
word_freq = Counter(words)
|
| 298 |
+
# Return words that appear at least once
|
| 299 |
+
return list(word_freq.keys())
|
| 300 |
+
|
| 301 |
+
def _split_sentences(self, text: str) -> List[str]:
|
| 302 |
+
"""Split text into sentences."""
|
| 303 |
+
# Simple sentence splitting
|
| 304 |
+
sentences = re.split(r'[.!?]+', text)
|
| 305 |
+
return [s.strip() for s in sentences if s.strip()]
|
| 306 |
+
|
| 307 |
+
def evaluate_batch(
|
| 308 |
+
self,
|
| 309 |
+
test_data: List[Dict]
|
| 310 |
+
) -> Dict:
|
| 311 |
+
"""Evaluate multiple test cases.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
test_data: List of test cases, each containing:
|
| 315 |
+
- query: User query
|
| 316 |
+
- response: Generated response
|
| 317 |
+
- retrieved_documents: Retrieved documents
|
| 318 |
+
- ground_truth: Ground truth answer (optional)
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Dictionary with aggregated scores
|
| 322 |
+
"""
|
| 323 |
+
all_scores = []
|
| 324 |
+
|
| 325 |
+
for i, test_case in enumerate(test_data):
|
| 326 |
+
print(f"Evaluating test case {i+1}/{len(test_data)}")
|
| 327 |
+
|
| 328 |
+
scores = self.evaluate(
|
| 329 |
+
query=test_case.get("query", ""),
|
| 330 |
+
response=test_case.get("response", ""),
|
| 331 |
+
retrieved_documents=test_case.get("retrieved_documents", []),
|
| 332 |
+
ground_truth=test_case.get("ground_truth")
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
all_scores.append(scores)
|
| 336 |
+
|
| 337 |
+
# Aggregate scores
|
| 338 |
+
avg_utilization = np.mean([s.utilization for s in all_scores])
|
| 339 |
+
avg_relevance = np.mean([s.relevance for s in all_scores])
|
| 340 |
+
avg_adherence = np.mean([s.adherence for s in all_scores])
|
| 341 |
+
avg_completeness = np.mean([s.completeness for s in all_scores])
|
| 342 |
+
|
| 343 |
+
return {
|
| 344 |
+
"utilization": float(avg_utilization),
|
| 345 |
+
"relevance": float(avg_relevance),
|
| 346 |
+
"adherence": float(avg_adherence),
|
| 347 |
+
"completeness": float(avg_completeness),
|
| 348 |
+
"average": float((avg_utilization + avg_relevance +
|
| 349 |
+
avg_adherence + avg_completeness) / 4),
|
| 350 |
+
"num_samples": len(test_data),
|
| 351 |
+
"individual_scores": [s.to_dict() for s in all_scores]
|
| 352 |
+
}
|
vector_store.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ChromaDB integration for vector storage and retrieval."""
|
| 2 |
+
from typing import List, Dict, Optional, Tuple
|
| 3 |
+
import chromadb
|
| 4 |
+
from chromadb.config import Settings
|
| 5 |
+
import uuid
|
| 6 |
+
import os
|
| 7 |
+
from embedding_models import EmbeddingFactory, EmbeddingModel
|
| 8 |
+
from chunking_strategies import ChunkingFactory
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ChromaDBManager:
|
| 13 |
+
"""Manager for ChromaDB operations."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, persist_directory: str = "./chroma_db"):
|
| 16 |
+
"""Initialize ChromaDB manager.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
persist_directory: Directory to persist ChromaDB data
|
| 20 |
+
"""
|
| 21 |
+
self.persist_directory = persist_directory
|
| 22 |
+
os.makedirs(persist_directory, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# Initialize ChromaDB client with is_persistent=True to use persistent storage
|
| 25 |
+
try:
|
| 26 |
+
self.client = chromadb.PersistentClient(
|
| 27 |
+
path=persist_directory,
|
| 28 |
+
settings=Settings(
|
| 29 |
+
anonymized_telemetry=False,
|
| 30 |
+
allow_reset=True # Allow reset if needed
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Warning: Could not create persistent client: {e}")
|
| 35 |
+
print("Falling back to regular client...")
|
| 36 |
+
self.client = chromadb.Client(Settings(
|
| 37 |
+
persist_directory=persist_directory,
|
| 38 |
+
anonymized_telemetry=False,
|
| 39 |
+
allow_reset=True
|
| 40 |
+
))
|
| 41 |
+
|
| 42 |
+
self.embedding_model = None
|
| 43 |
+
self.current_collection = None
|
| 44 |
+
|
| 45 |
+
def reconnect(self):
|
| 46 |
+
"""Reconnect to ChromaDB in case of connection loss."""
|
| 47 |
+
try:
|
| 48 |
+
self.client = chromadb.PersistentClient(
|
| 49 |
+
path=self.persist_directory,
|
| 50 |
+
settings=Settings(
|
| 51 |
+
anonymized_telemetry=False,
|
| 52 |
+
allow_reset=True
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
print("✅ Reconnected to ChromaDB")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error reconnecting: {e}")
|
| 58 |
+
|
| 59 |
+
def create_collection(
|
| 60 |
+
self,
|
| 61 |
+
collection_name: str,
|
| 62 |
+
embedding_model_name: str,
|
| 63 |
+
metadata: Optional[Dict] = None
|
| 64 |
+
) -> chromadb.Collection:
|
| 65 |
+
"""Create a new collection.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
collection_name: Name of the collection
|
| 69 |
+
embedding_model_name: Name of the embedding model
|
| 70 |
+
metadata: Additional metadata for the collection
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
ChromaDB collection
|
| 74 |
+
"""
|
| 75 |
+
# Delete if exists
|
| 76 |
+
try:
|
| 77 |
+
self.client.delete_collection(collection_name)
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
# Create embedding model
|
| 82 |
+
self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
|
| 83 |
+
self.embedding_model.load_model()
|
| 84 |
+
|
| 85 |
+
# Create collection with metadata
|
| 86 |
+
collection_metadata = {
|
| 87 |
+
"embedding_model": embedding_model_name,
|
| 88 |
+
"hnsw:space": "cosine"
|
| 89 |
+
}
|
| 90 |
+
if metadata:
|
| 91 |
+
collection_metadata.update(metadata)
|
| 92 |
+
|
| 93 |
+
self.current_collection = self.client.create_collection(
|
| 94 |
+
name=collection_name,
|
| 95 |
+
metadata=collection_metadata
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
print(f"Created collection: {collection_name}")
|
| 99 |
+
return self.current_collection
|
| 100 |
+
|
| 101 |
+
def get_collection(self, collection_name: str) -> chromadb.Collection:
|
| 102 |
+
"""Get an existing collection.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
collection_name: Name of the collection
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
ChromaDB collection
|
| 109 |
+
"""
|
| 110 |
+
self.current_collection = self.client.get_collection(collection_name)
|
| 111 |
+
|
| 112 |
+
# Load embedding model from metadata
|
| 113 |
+
metadata = self.current_collection.metadata
|
| 114 |
+
if "embedding_model" in metadata:
|
| 115 |
+
self.embedding_model = EmbeddingFactory.create_embedding_model(
|
| 116 |
+
metadata["embedding_model"]
|
| 117 |
+
)
|
| 118 |
+
self.embedding_model.load_model()
|
| 119 |
+
|
| 120 |
+
return self.current_collection
|
| 121 |
+
|
| 122 |
+
def list_collections(self) -> List[str]:
|
| 123 |
+
"""List all collections.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
List of collection names
|
| 127 |
+
"""
|
| 128 |
+
collections = self.client.list_collections()
|
| 129 |
+
return [col.name for col in collections]
|
| 130 |
+
|
| 131 |
+
def clear_all_collections(self) -> int:
|
| 132 |
+
"""Delete all collections from the database.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Number of collections deleted
|
| 136 |
+
"""
|
| 137 |
+
collections = self.list_collections()
|
| 138 |
+
count = 0
|
| 139 |
+
|
| 140 |
+
for collection_name in collections:
|
| 141 |
+
try:
|
| 142 |
+
self.client.delete_collection(collection_name)
|
| 143 |
+
print(f"Deleted collection: {collection_name}")
|
| 144 |
+
count += 1
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"Error deleting collection {collection_name}: {e}")
|
| 147 |
+
|
| 148 |
+
self.current_collection = None
|
| 149 |
+
self.embedding_model = None
|
| 150 |
+
print(f"✅ Cleared {count} collections")
|
| 151 |
+
return count
|
| 152 |
+
|
| 153 |
+
def delete_collection(self, collection_name: str) -> bool:
|
| 154 |
+
"""Delete a specific collection.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
collection_name: Name of the collection to delete
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
True if deleted successfully, False otherwise
|
| 161 |
+
"""
|
| 162 |
+
try:
|
| 163 |
+
self.client.delete_collection(collection_name)
|
| 164 |
+
if self.current_collection and self.current_collection.name == collection_name:
|
| 165 |
+
self.current_collection = None
|
| 166 |
+
self.embedding_model = None
|
| 167 |
+
print(f"✅ Deleted collection: {collection_name}")
|
| 168 |
+
return True
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"❌ Error deleting collection: {e}")
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
def add_documents(
|
| 174 |
+
self,
|
| 175 |
+
documents: List[str],
|
| 176 |
+
metadatas: Optional[List[Dict]] = None,
|
| 177 |
+
ids: Optional[List[str]] = None,
|
| 178 |
+
batch_size: int = 100
|
| 179 |
+
):
|
| 180 |
+
"""Add documents to the current collection.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
documents: List of document texts
|
| 184 |
+
metadatas: List of metadata dictionaries
|
| 185 |
+
ids: List of document IDs
|
| 186 |
+
batch_size: Batch size for processing
|
| 187 |
+
"""
|
| 188 |
+
if not self.current_collection:
|
| 189 |
+
raise ValueError("No collection selected. Create or get a collection first.")
|
| 190 |
+
|
| 191 |
+
if not self.embedding_model:
|
| 192 |
+
raise ValueError("No embedding model loaded.")
|
| 193 |
+
|
| 194 |
+
# Generate IDs if not provided
|
| 195 |
+
if ids is None:
|
| 196 |
+
ids = [str(uuid.uuid4()) for _ in documents]
|
| 197 |
+
|
| 198 |
+
# Generate default metadata if not provided
|
| 199 |
+
if metadatas is None:
|
| 200 |
+
metadatas = [{"index": i} for i in range(len(documents))]
|
| 201 |
+
|
| 202 |
+
# Process in batches
|
| 203 |
+
total_docs = len(documents)
|
| 204 |
+
print(f"Adding {total_docs} documents to collection...")
|
| 205 |
+
|
| 206 |
+
for i in range(0, total_docs, batch_size):
|
| 207 |
+
batch_docs = documents[i:i + batch_size]
|
| 208 |
+
batch_ids = ids[i:i + batch_size]
|
| 209 |
+
batch_metadatas = metadatas[i:i + batch_size]
|
| 210 |
+
|
| 211 |
+
# Generate embeddings
|
| 212 |
+
embeddings = self.embedding_model.embed_documents(batch_docs)
|
| 213 |
+
|
| 214 |
+
# Add to collection
|
| 215 |
+
self.current_collection.add(
|
| 216 |
+
documents=batch_docs,
|
| 217 |
+
embeddings=embeddings.tolist(),
|
| 218 |
+
metadatas=batch_metadatas,
|
| 219 |
+
ids=batch_ids
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
|
| 223 |
+
|
| 224 |
+
print(f"Successfully added {total_docs} documents")
|
| 225 |
+
|
| 226 |
+
def load_dataset_into_collection(
|
| 227 |
+
self,
|
| 228 |
+
collection_name: str,
|
| 229 |
+
embedding_model_name: str,
|
| 230 |
+
chunking_strategy: str,
|
| 231 |
+
dataset_data: List[Dict],
|
| 232 |
+
chunk_size: int = 512,
|
| 233 |
+
overlap: int = 50
|
| 234 |
+
):
|
| 235 |
+
"""Load a dataset into a new collection with chunking.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
collection_name: Name for the new collection
|
| 239 |
+
embedding_model_name: Embedding model to use
|
| 240 |
+
chunking_strategy: Chunking strategy to use
|
| 241 |
+
dataset_data: List of dataset samples
|
| 242 |
+
chunk_size: Size of chunks
|
| 243 |
+
overlap: Overlap between chunks
|
| 244 |
+
"""
|
| 245 |
+
# Create collection
|
| 246 |
+
self.create_collection(
|
| 247 |
+
collection_name,
|
| 248 |
+
embedding_model_name,
|
| 249 |
+
metadata={
|
| 250 |
+
"chunking_strategy": chunking_strategy,
|
| 251 |
+
"chunk_size": chunk_size,
|
| 252 |
+
"overlap": overlap
|
| 253 |
+
}
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Get chunker
|
| 257 |
+
chunker = ChunkingFactory.create_chunker(chunking_strategy)
|
| 258 |
+
|
| 259 |
+
# Process documents
|
| 260 |
+
all_chunks = []
|
| 261 |
+
all_metadatas = []
|
| 262 |
+
|
| 263 |
+
print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
|
| 264 |
+
|
| 265 |
+
for idx, sample in enumerate(dataset_data):
|
| 266 |
+
# Use 'documents' list if available, otherwise fall back to 'context'
|
| 267 |
+
documents = sample.get("documents", [])
|
| 268 |
+
|
| 269 |
+
# If documents is empty, use context as fallback
|
| 270 |
+
if not documents:
|
| 271 |
+
context = sample.get("context", "")
|
| 272 |
+
if context:
|
| 273 |
+
documents = [context]
|
| 274 |
+
|
| 275 |
+
if not documents:
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
# Process each document separately for better granularity
|
| 279 |
+
for doc_idx, document in enumerate(documents):
|
| 280 |
+
if not document or not str(document).strip():
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
# Chunk each document
|
| 284 |
+
chunks = chunker.chunk_text(str(document), chunk_size, overlap)
|
| 285 |
+
|
| 286 |
+
# Create metadata for each chunk
|
| 287 |
+
for chunk_idx, chunk in enumerate(chunks):
|
| 288 |
+
all_chunks.append(chunk)
|
| 289 |
+
all_metadatas.append({
|
| 290 |
+
"doc_id": idx,
|
| 291 |
+
"doc_idx": doc_idx, # Track which document within the sample
|
| 292 |
+
"chunk_id": chunk_idx,
|
| 293 |
+
"question": sample.get("question", ""),
|
| 294 |
+
"answer": sample.get("answer", ""),
|
| 295 |
+
"dataset": sample.get("dataset", ""),
|
| 296 |
+
"total_docs": len(documents)
|
| 297 |
+
})
|
| 298 |
+
|
| 299 |
+
# Add all chunks to collection
|
| 300 |
+
self.add_documents(all_chunks, all_metadatas)
|
| 301 |
+
|
| 302 |
+
print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
|
| 303 |
+
|
| 304 |
+
def query(
|
| 305 |
+
self,
|
| 306 |
+
query_text: str,
|
| 307 |
+
n_results: int = 5,
|
| 308 |
+
filter_metadata: Optional[Dict] = None
|
| 309 |
+
) -> Dict:
|
| 310 |
+
"""Query the collection.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
query_text: Query text
|
| 314 |
+
n_results: Number of results to return
|
| 315 |
+
filter_metadata: Metadata filter
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Query results
|
| 319 |
+
"""
|
| 320 |
+
if not self.current_collection:
|
| 321 |
+
raise ValueError("No collection selected.")
|
| 322 |
+
|
| 323 |
+
if not self.embedding_model:
|
| 324 |
+
raise ValueError("No embedding model loaded.")
|
| 325 |
+
|
| 326 |
+
# Generate query embedding
|
| 327 |
+
query_embedding = self.embedding_model.embed_query(query_text)
|
| 328 |
+
|
| 329 |
+
# Query collection with retry logic
|
| 330 |
+
try:
|
| 331 |
+
results = self.current_collection.query(
|
| 332 |
+
query_embeddings=[query_embedding.tolist()],
|
| 333 |
+
n_results=n_results,
|
| 334 |
+
where=filter_metadata
|
| 335 |
+
)
|
| 336 |
+
except Exception as e:
|
| 337 |
+
if "default_tenant" in str(e):
|
| 338 |
+
print("Warning: Lost connection to ChromaDB, reconnecting...")
|
| 339 |
+
self.reconnect()
|
| 340 |
+
# Try again after reconnecting
|
| 341 |
+
results = self.current_collection.query(
|
| 342 |
+
query_embeddings=[query_embedding.tolist()],
|
| 343 |
+
n_results=n_results,
|
| 344 |
+
where=filter_metadata
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
raise
|
| 348 |
+
|
| 349 |
+
return results
|
| 350 |
+
|
| 351 |
+
def get_retrieved_documents(
|
| 352 |
+
self,
|
| 353 |
+
query_text: str,
|
| 354 |
+
n_results: int = 5
|
| 355 |
+
) -> List[Dict]:
|
| 356 |
+
"""Get retrieved documents with metadata.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
query_text: Query text
|
| 360 |
+
n_results: Number of results
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
List of retrieved documents with metadata
|
| 364 |
+
"""
|
| 365 |
+
results = self.query(query_text, n_results)
|
| 366 |
+
|
| 367 |
+
retrieved_docs = []
|
| 368 |
+
for i in range(len(results['documents'][0])):
|
| 369 |
+
retrieved_docs.append({
|
| 370 |
+
"document": results['documents'][0][i],
|
| 371 |
+
"metadata": results['metadatas'][0][i],
|
| 372 |
+
"distance": results['distances'][0][i] if 'distances' in results else None
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
return retrieved_docs
|
| 376 |
+
|
| 377 |
+
def delete_collection(self, collection_name: str):
|
| 378 |
+
"""Delete a collection.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
collection_name: Name of collection to delete
|
| 382 |
+
"""
|
| 383 |
+
try:
|
| 384 |
+
self.client.delete_collection(collection_name)
|
| 385 |
+
print(f"Deleted collection: {collection_name}")
|
| 386 |
+
except Exception as e:
|
| 387 |
+
print(f"Error deleting collection: {str(e)}")
|
| 388 |
+
|
| 389 |
+
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
|
| 390 |
+
"""Get statistics for a collection.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
collection_name: Name of collection (uses current if None)
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
Dictionary with collection statistics
|
| 397 |
+
"""
|
| 398 |
+
if collection_name:
|
| 399 |
+
collection = self.client.get_collection(collection_name)
|
| 400 |
+
elif self.current_collection:
|
| 401 |
+
collection = self.current_collection
|
| 402 |
+
else:
|
| 403 |
+
raise ValueError("No collection specified or selected")
|
| 404 |
+
|
| 405 |
+
count = collection.count()
|
| 406 |
+
metadata = collection.metadata
|
| 407 |
+
|
| 408 |
+
return {
|
| 409 |
+
"name": collection.name,
|
| 410 |
+
"count": count,
|
| 411 |
+
"metadata": metadata
|
| 412 |
+
}
|