Vivek Kadamati commited on
Commit
ee444c0
·
0 Parent(s):

Initial commit

Browse files
Files changed (24) hide show
  1. .env.example +15 -0
  2. .gitignore +26 -0
  3. Dockerfile +28 -0
  4. ENHANCEMENTS.md +120 -0
  5. GIT_PUSH_GUIDE.md +156 -0
  6. Procfile +1 -0
  7. README.md +294 -0
  8. SETUP.md +69 -0
  9. UPDATE_REMOTE.md +178 -0
  10. __init__.py +15 -0
  11. api.py +374 -0
  12. chunking_strategies.py +207 -0
  13. cleanup_chroma.py +93 -0
  14. config.py +64 -0
  15. dataset_loader.py +178 -0
  16. docker-compose.yml +26 -0
  17. embedding_models.py +325 -0
  18. example.py +118 -0
  19. llm_client.py +351 -0
  20. requirements.txt +40 -0
  21. run.py +99 -0
  22. streamlit_app.py +721 -0
  23. trace_evaluator.py +352 -0
  24. vector_store.py +412 -0
.env.example ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Groq API Configuration
2
+ GROQ_API_KEY=your_groq_api_key_here
3
+
4
+ # Google Gemini API Configuration (for gemini-embedding-001)
5
+ GEMINI_API_KEY=your_gemini_api_key_here
6
+
7
+ # ChromaDB Configuration
8
+ CHROMA_PERSIST_DIRECTORY=./chroma_db
9
+
10
+ # Rate Limiting
11
+ GROQ_RPM_LIMIT=30
12
+ RATE_LIMIT_DELAY=2.0
13
+
14
+ # Application Configuration
15
+ LOG_LEVEL=INFO
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+ .env
12
+ .venv
13
+ venv/
14
+ env/
15
+ data_cache/
16
+ *.log
17
+ .DS_Store
18
+ .vscode/
19
+ .idea/
20
+ *.swp
21
+ *.swo
22
+ *~
23
+ .pytest_cache/
24
+ .coverage
25
+ htmlcov/
26
+ chroma_db/
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ curl \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application files
16
+ COPY . .
17
+
18
+ # Create directories for data
19
+ RUN mkdir -p chroma_db data_cache
20
+
21
+ # Expose ports
22
+ EXPOSE 8501 8000
23
+
24
+ # Set environment variables
25
+ ENV PYTHONUNBUFFERED=1
26
+
27
+ # Run Streamlit by default
28
+ CMD ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
ENHANCEMENTS.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Application Enhancements
2
+
3
+ ## Overview
4
+ The application has been enhanced with collection management, LLM selection, and improved user experience.
5
+
6
+ ## Key Enhancements
7
+
8
+ ### 1. **Existing Collections Management** 🗂️
9
+ - **Auto-detection**: On application startup, the system automatically detects all existing collections in ChromaDB
10
+ - **Load Existing Collection**: Users can now choose from existing collections and load them directly without recreating
11
+ - **Collection Selection**: Dropdown menu shows all available collections for quick access
12
+ - **Seamless Loading**: Click "📖 Load Existing Collection" to use a previously created collection
13
+
14
+ ### 2. **Smart Collection Recreation** 🔄
15
+ - **Selective Deletion**: When creating a new collection, only that specific collection is deleted and recreated
16
+ - **Other Collections Preserved**: All other existing collections remain untouched and unaffected
17
+ - **Conflict Resolution**: If a collection with the same name exists, it's deleted before creating the new one
18
+ - **User Feedback**: Clear warnings and progress messages when deleting and recreating collections
19
+
20
+ ### 3. **LLM Selection Options** 🤖
21
+
22
+ #### Chat Interface
23
+ - **Dynamic LLM Selector**: Switch between different LLM models while chatting
24
+ - **Real-time Switching**: Change LLM without reloading the collection
25
+ - **Automatic Pipeline Update**: The RAG pipeline automatically updates when a new LLM is selected
26
+ - **Persistent Selection**: The selected LLM is remembered in the session state
27
+
28
+ #### Evaluation Interface
29
+ - **Evaluation-specific LLM**: Choose a different LLM for running TRACE evaluation
30
+ - **Independent Selection**: Evaluation LLM can be different from the chat LLM
31
+ - **Automatic Restoration**: After evaluation, the system restores the original LLM
32
+ - **Flexible Testing**: Test different LLM models on the same dataset and collection
33
+
34
+ ### 4. **User Interface Improvements** 🎨
35
+ - **Two-step Process**:
36
+ 1. Load existing collection OR
37
+ 2. Create new collection (with all configuration options)
38
+ - **Clear Sections**: Separated sidebar sections for existing vs. new collections
39
+ - **Visual Indicators**: Icons and colors to distinguish different actions
40
+ - **Better Organization**: Configuration options logically grouped and hierarchical
41
+
42
+ ## Technical Implementation
43
+
44
+ ### New Functions
45
+ - `get_available_collections()`: Fetches list of collections from ChromaDB
46
+ - `load_existing_collection()`: Loads a pre-existing collection with LLM selection
47
+ - Updated `load_and_create_collection()`: Handles selective collection deletion
48
+
49
+ ### Session State Variables
50
+ - `current_llm`: Tracks the currently selected LLM
51
+ - `selected_collection`: Tracks which collection is loaded
52
+ - `available_collections`: Stores list of available collections
53
+
54
+ ### Collection Naming Convention
55
+ Collections are named as: `{dataset}_{chunking_strategy}_{embedding_model_short_name}`
56
+
57
+ Example: `covidqa_dense_all_mpnet`
58
+
59
+ ## User Workflow
60
+
61
+ ### Scenario 1: Using Existing Collection
62
+ 1. Application starts and detects existing collections
63
+ 2. User selects a collection from the dropdown
64
+ 3. User clicks "📖 Load Existing Collection"
65
+ 4. User selects an LLM for chatting
66
+ 5. User can start chatting immediately
67
+
68
+ ### Scenario 2: Creating New Collection
69
+ 1. User selects dataset from sidebar
70
+ 2. User clicks "🔍 Check Dataset Size" (optional)
71
+ 3. User configures chunking strategy and chunk parameters
72
+ 4. User selects embedding model
73
+ 5. User selects LLM
74
+ 6. User clicks "🚀 Load Data & Create Collection"
75
+ 7. System deletes any existing collection with same name
76
+ 8. System creates new collection with fresh data
77
+
78
+ ### Scenario 3: Switching LLMs During Chat
79
+ 1. Chat interface shows current collection and LLM selector
80
+ 2. User selects different LLM from "Select LLM for chat"
81
+ 3. RAG pipeline automatically updates with new LLM
82
+ 4. Continue chatting with new LLM
83
+
84
+ ### Scenario 4: Running Evaluation with Different LLM
85
+ 1. In Evaluation tab, user can select a different LLM
86
+ 2. Click "🔬 Run Evaluation"
87
+ 3. System uses selected LLM for evaluation
88
+ 4. Results are displayed with metrics
89
+ 5. Original chat LLM is restored after evaluation
90
+
91
+ ## Benefits
92
+
93
+ ✅ **Efficiency**: No need to recreate collections when testing different configurations
94
+ ✅ **Flexibility**: Easily compare different LLM models on the same data
95
+ ✅ **Safety**: Other collections remain untouched when managing new ones
96
+ ✅ **User Experience**: Clearer navigation and configuration options
97
+ ✅ **Time Saving**: Reuse existing collections instead of recreating them
98
+ ✅ **Testing**: Run evaluations with different LLMs for comprehensive analysis
99
+
100
+ ## API Key Management
101
+
102
+ - API Key input is required at startup
103
+ - Store in sidebar for use across all operations
104
+ - Used for both chat and evaluation with different LLMs
105
+
106
+ ## Error Handling
107
+
108
+ - Collection not found errors show helpful messages
109
+ - LLM loading failures fall back to default model
110
+ - Graceful error messages for all operations
111
+ - Automatic reconnection to ChromaDB if connection is lost
112
+
113
+ ## Future Enhancement Ideas
114
+
115
+ - 💾 Save evaluation results with metadata
116
+ - 📊 Compare multiple LLM evaluation results
117
+ - 🔄 Batch collection operations (delete multiple)
118
+ - 📈 Analytics dashboard for collection usage
119
+ - 🏷️ Collection tagging/categorization system
120
+ - 💬 Multi-turn evaluation with conversation history
GIT_PUSH_GUIDE.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Push Code to GitHub - Steps
2
+
3
+ ## Option 1: Push to Existing GitHub Repository
4
+
5
+ If you already have a repository on GitHub, follow these steps:
6
+
7
+ ### 1. On GitHub, create a new repository:
8
+ - Go to https://github.com/new
9
+ - Fill in repository name: `RAG-Capstone-Project` (or your preferred name)
10
+ - Add description: "Retrieval-Augmented Generation (RAG) system with TRACE evaluation metrics"
11
+ - Choose: Public or Private
12
+ - **DO NOT** initialize with README, .gitignore, or license (we already have these)
13
+ - Click "Create repository"
14
+
15
+ ### 2. In PowerShell, add remote and push:
16
+
17
+ ```powershell
18
+ cd "D:\CapStoneProject\RAG Capstone Project"
19
+
20
+ # Add remote (replace YOUR_USERNAME and REPO_NAME)
21
+ git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
22
+
23
+ # Rename branch to main (optional but recommended)
24
+ git branch -M main
25
+
26
+ # Push to GitHub
27
+ git push -u origin main
28
+ ```
29
+
30
+ ### 3. When prompted, enter your GitHub credentials:
31
+ - Username: Your GitHub username
32
+ - Password: Your GitHub personal access token (not your password)
33
+ - Generate token: https://github.com/settings/tokens
34
+
35
+ ---
36
+
37
+ ## Option 2: If You Don't Have GitHub Account Yet
38
+
39
+ 1. Go to https://github.com/join
40
+ 2. Create a free account
41
+ 3. Follow Option 1 steps above
42
+
43
+ ---
44
+
45
+ ## Troubleshooting
46
+
47
+ ### If you get "fatal: remote origin already exists":
48
+ ```powershell
49
+ git remote remove origin
50
+ git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
51
+ git push -u origin main
52
+ ```
53
+
54
+ ### If you get authentication error:
55
+ 1. Generate Personal Access Token:
56
+ - Go to https://github.com/settings/tokens
57
+ - Click "Generate new token (classic)"
58
+ - Select: repo, write:packages
59
+ - Copy the token
60
+ 2. Use token instead of password when prompted
61
+
62
+ ### If you have SSH key setup:
63
+ ```powershell
64
+ git remote add origin git@github.com:YOUR_USERNAME/RAG-Capstone-Project.git
65
+ git push -u origin main
66
+ ```
67
+
68
+ ---
69
+
70
+ ## What Will Be Pushed
71
+
72
+ ✅ All Python source files:
73
+ - `streamlit_app.py` - Main Streamlit UI
74
+ - `llm_client.py` - Groq LLM integration
75
+ - `vector_store.py` - ChromaDB management
76
+ - `embedding_models.py` - 8 embedding models
77
+ - `trace_evaluator.py` - TRACE metrics
78
+ - `dataset_loader.py` - RAGBench dataset loading
79
+ - `chunking_strategies.py` - 4 chunking strategies
80
+ - And more...
81
+
82
+ ✅ Documentation:
83
+ - `README.md` - Project overview
84
+ - `SETUP.md` - Installation guide
85
+ - `ENHANCEMENTS.md` - Recent enhancements
86
+ - `.env.example` - Environment template
87
+
88
+ ✅ Configuration:
89
+ - `requirements.txt` - All dependencies
90
+ - `Dockerfile` - Docker containerization
91
+ - `docker-compose.yml` - Multi-container setup
92
+ - `Procfile` - Heroku deployment
93
+
94
+ ❌ Excluded (in .gitignore):
95
+ - `venv/` - Virtual environment
96
+ - `chroma_db/` - ChromaDB data
97
+ - `.env` - API keys (keep local!)
98
+ - `__pycache__/` - Python cache
99
+ - `.streamlit/` - Streamlit config
100
+
101
+ ---
102
+
103
+ ## Next Steps After Pushing
104
+
105
+ 1. **Add a GitHub Actions workflow** (optional):
106
+ - Automated testing
107
+ - Code quality checks
108
+ - Deployment automation
109
+
110
+ 2. **Set up branch protection**:
111
+ - Require pull request reviews
112
+ - Enforce status checks
113
+
114
+ 3. **Add GitHub Pages documentation**:
115
+ - Host project documentation
116
+ - API documentation
117
+ - Evaluation results
118
+
119
+ 4. **Setup CI/CD**:
120
+ - Test on every push
121
+ - Deploy to Heroku/Cloud Run
122
+
123
+ ---
124
+
125
+ ## Commands Summary
126
+
127
+ ```powershell
128
+ # Navigate to project
129
+ cd "D:\CapStoneProject\RAG Capstone Project"
130
+
131
+ # Configure git
132
+ git config user.email "your_email@example.com"
133
+ git config user.name "Your Name"
134
+
135
+ # Add remote (replace placeholders)
136
+ git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
137
+
138
+ # Rename branch to main
139
+ git branch -M main
140
+
141
+ # Push to GitHub
142
+ git push -u origin main
143
+
144
+ # Verify
145
+ git remote -v
146
+ git log --oneline
147
+ ```
148
+
149
+ ---
150
+
151
+ **Share the GitHub URL with your team:**
152
+ ```
153
+ https://github.com/YOUR_USERNAME/RAG-Capstone-Project
154
+ ```
155
+
156
+ Let me know when you have your GitHub username ready, and I can help you complete the push!
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: streamlit run streamlit_app.py --server.port=$PORT --server.address=0.0.0.0
README.md ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Capstone Project
2
+
3
+ A comprehensive Retrieval-Augmented Generation (RAG) system with TRACE evaluation metrics for medical/clinical domains.
4
+
5
+ ## Features
6
+
7
+ - 🔍 **Multiple RAG Bench Datasets**: HotpotQA, 2WikiMultihopQA, MuSiQue, Natural Questions, TriviaQA
8
+ - 🧩 **Chunking Strategies**: Dense, Sparse, Hybrid, Re-ranking
9
+ - 🤖 **Medical Embedding Models**:
10
+ - sentence-transformers/embeddinggemma-300m-medical
11
+ - emilyalsentzer/Bio_ClinicalBERT
12
+ - Simonlee711/Clinical_ModernBERT
13
+ - 💾 **ChromaDB Vector Storage**: Persistent vector storage with efficient retrieval
14
+ - 🦙 **Groq LLM Integration**: With rate limiting (30 RPM)
15
+ - meta-llama/llama-4-maverick-17b-128e-instruct
16
+ - llama-3.1-8b-instant
17
+ - openai/gpt-oss-120b
18
+ - 📊 **TRACE Evaluation Metrics**:
19
+ - **U**tilization: How well the system uses retrieved documents
20
+ - **R**elevance: Relevance of retrieved documents to the query
21
+ - **A**dherence: How well the response adheres to the retrieved context
22
+ - **C**ompleteness: How complete the response is
23
+ - 💬 **Chat Interface**: Streamlit-based interactive chat with history
24
+ - 🔌 **REST API**: FastAPI backend for integration
25
+
26
+ ## Installation
27
+
28
+ ### Prerequisites
29
+
30
+ - Python 3.8+
31
+ - pip
32
+ - Groq API key
33
+
34
+ ### Setup
35
+
36
+ 1. Clone the repository:
37
+ ```bash
38
+ git clone <repository-url>
39
+ cd "RAG Capstone Project"
40
+ ```
41
+
42
+ 2. Create a virtual environment:
43
+ ```bash
44
+ python -m venv venv
45
+ ```
46
+
47
+ 3. Activate the virtual environment:
48
+
49
+ **Windows:**
50
+ ```bash
51
+ .\venv\Scripts\activate
52
+ ```
53
+
54
+ **Linux/Mac:**
55
+ ```bash
56
+ source venv/bin/activate
57
+ ```
58
+
59
+ 4. Install dependencies:
60
+ ```bash
61
+ pip install -r requirements.txt
62
+ ```
63
+
64
+ 5. Create a `.env` file from the example:
65
+ ```bash
66
+ copy .env.example .env
67
+ ```
68
+
69
+ 6. Edit `.env` and add your Groq API key:
70
+ ```
71
+ GROQ_API_KEY=your_groq_api_key_here
72
+ ```
73
+
74
+ ## Usage
75
+
76
+ ### Streamlit Application
77
+
78
+ Run the interactive Streamlit interface:
79
+
80
+ ```bash
81
+ streamlit run streamlit_app.py
82
+ ```
83
+
84
+ Then open your browser to `http://localhost:8501`
85
+
86
+ **Workflow:**
87
+ 1. Enter your Groq API key in the sidebar
88
+ 2. Select a dataset from RAG Bench
89
+ 3. Choose chunking strategy
90
+ 4. Select embedding model
91
+ 5. Choose LLM model
92
+ 6. Click "Load Data & Create Collection"
93
+ 7. Start chatting!
94
+ 8. View retrieved documents
95
+ 9. Run TRACE evaluation
96
+ 10. Export chat history
97
+
98
+ ### FastAPI Backend
99
+
100
+ Run the REST API server:
101
+
102
+ ```bash
103
+ python api.py
104
+ ```
105
+
106
+ Or with uvicorn:
107
+ ```bash
108
+ uvicorn api:app --reload --host 0.0.0.0 --port 8000
109
+ ```
110
+
111
+ API documentation available at: `http://localhost:8000/docs`
112
+
113
+ #### API Endpoints
114
+
115
+ - `GET /` - Root endpoint
116
+ - `GET /health` - Health check
117
+ - `GET /datasets` - List available datasets
118
+ - `GET /models/embedding` - List embedding models
119
+ - `GET /models/llm` - List LLM models
120
+ - `GET /chunking-strategies` - List chunking strategies
121
+ - `GET /collections` - List all collections
122
+ - `GET /collections/{name}` - Get collection info
123
+ - `POST /load-dataset` - Load dataset and create collection
124
+ - `POST /query` - Query the RAG system
125
+ - `GET /chat-history` - Get chat history
126
+ - `DELETE /chat-history` - Clear chat history
127
+ - `POST /evaluate` - Run TRACE evaluation
128
+ - `DELETE /collections/{name}` - Delete collection
129
+
130
+ ### Python API
131
+
132
+ Use the components programmatically:
133
+
134
+ ```python
135
+ from config import settings
136
+ from dataset_loader import RAGBenchLoader
137
+ from vector_store import ChromaDBManager
138
+ from llm_client import GroqLLMClient, RAGPipeline
139
+ from trace_evaluator import TRACEEvaluator
140
+
141
+ # Load dataset
142
+ loader = RAGBenchLoader()
143
+ dataset = loader.load_dataset("hotpotqa", max_samples=100)
144
+
145
+ # Create vector store
146
+ vector_store = ChromaDBManager()
147
+ vector_store.load_dataset_into_collection(
148
+ collection_name="my_collection",
149
+ embedding_model_name="emilyalsentzer/Bio_ClinicalBERT",
150
+ chunking_strategy="hybrid",
151
+ dataset_data=dataset
152
+ )
153
+
154
+ # Initialize LLM
155
+ llm = GroqLLMClient(
156
+ api_key="your_api_key",
157
+ model_name="llama-3.1-8b-instant"
158
+ )
159
+
160
+ # Create RAG pipeline
161
+ rag = RAGPipeline(llm, vector_store)
162
+
163
+ # Query
164
+ result = rag.query("What is the capital of France?")
165
+ print(result["response"])
166
+
167
+ # Evaluate
168
+ evaluator = TRACEEvaluator()
169
+ test_cases = [...] # Your test cases
170
+ results = evaluator.evaluate_batch(test_cases)
171
+ print(results)
172
+ ```
173
+
174
+ ## Project Structure
175
+
176
+ ```
177
+ RAG Capstone Project/
178
+ ├── __init__.py # Package initialization
179
+ ├── config.py # Configuration management
180
+ ├── dataset_loader.py # RAG Bench dataset loader
181
+ ├── chunking_strategies.py # Document chunking strategies
182
+ ├── embedding_models.py # Embedding model implementations
183
+ ├── vector_store.py # ChromaDB integration
184
+ ├── llm_client.py # Groq LLM client with rate limiting
185
+ ├── trace_evaluator.py # TRACE evaluation metrics
186
+ ├── streamlit_app.py # Streamlit chat interface
187
+ ├── api.py # FastAPI REST API
188
+ ├── requirements.txt # Python dependencies
189
+ ├── .env.example # Environment variables template
190
+ ├── .gitignore # Git ignore file
191
+ └── README.md # This file
192
+ ```
193
+
194
+ ## TRACE Metrics Explained
195
+
196
+ ### Utilization (U)
197
+ Measures how well the system uses the retrieved documents in generating the response. Higher scores indicate that the system effectively incorporates information from multiple retrieved documents.
198
+
199
+ ### Relevance (R)
200
+ Evaluates the relevance of retrieved documents to the user's query. Uses lexical overlap and keyword matching to determine if the right documents were retrieved.
201
+
202
+ ### Adherence (A)
203
+ Assesses how well the generated response adheres to the retrieved context. Ensures the response is grounded in the provided documents rather than hallucinated.
204
+
205
+ ### Completeness (C)
206
+ Evaluates how complete the response is in answering the query. Considers response length, question type, and comparison with ground truth if available.
207
+
208
+ ## Deployment Options
209
+
210
+ ### Heroku
211
+
212
+ 1. Create `Procfile`:
213
+ ```
214
+ web: streamlit run streamlit_app.py --server.port=$PORT --server.address=0.0.0.0
215
+ api: uvicorn api:app --host=0.0.0.0 --port=$PORT
216
+ ```
217
+
218
+ 2. Deploy:
219
+ ```bash
220
+ heroku create your-app-name
221
+ git push heroku main
222
+ ```
223
+
224
+ ### Docker
225
+
226
+ Create `Dockerfile`:
227
+ ```dockerfile
228
+ FROM python:3.9-slim
229
+
230
+ WORKDIR /app
231
+ COPY requirements.txt .
232
+ RUN pip install -r requirements.txt
233
+
234
+ COPY . .
235
+
236
+ EXPOSE 8501 8000
237
+
238
+ CMD ["streamlit", "run", "streamlit_app.py"]
239
+ ```
240
+
241
+ Build and run:
242
+ ```bash
243
+ docker build -t rag-capstone .
244
+ docker run -p 8501:8501 -p 8000:8000 rag-capstone
245
+ ```
246
+
247
+ ### Cloud Run / AWS / Azure
248
+
249
+ The application can be deployed to any cloud platform that supports Python applications. See the respective platform documentation for deployment instructions.
250
+
251
+ ## Configuration
252
+
253
+ Edit `config.py` or set environment variables in `.env`:
254
+
255
+ ```env
256
+ GROQ_API_KEY=your_api_key
257
+ CHROMA_PERSIST_DIRECTORY=./chroma_db
258
+ GROQ_RPM_LIMIT=30
259
+ RATE_LIMIT_DELAY=2.0
260
+ LOG_LEVEL=INFO
261
+ ```
262
+
263
+ ## Rate Limiting
264
+
265
+ The application implements rate limiting for Groq API calls:
266
+ - Maximum 30 requests per minute (configurable)
267
+ - Automatic delay of 2 seconds between requests
268
+ - Smart waiting when rate limit is reached
269
+
270
+ ## Troubleshooting
271
+
272
+ ### ChromaDB Issues
273
+ If you encounter ChromaDB errors, try deleting the `chroma_db` directory and recreating collections.
274
+
275
+ ### Embedding Model Loading
276
+ Medical embedding models may require significant memory. If you encounter out-of-memory errors, try:
277
+ - Using a smaller model
278
+ - Reducing batch size
279
+ - Using CPU instead of GPU
280
+
281
+ ### API Key Errors
282
+ Ensure your Groq API key is correctly set in the `.env` file or passed to the application.
283
+
284
+ ## License
285
+
286
+ MIT License
287
+
288
+ ## Contributors
289
+
290
+ RAG Capstone Team
291
+
292
+ ## Support
293
+
294
+ For issues and questions, please open an issue on the GitHub repository.
SETUP.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick Setup Guide (Windows)
2
+
3
+ ## Requirements
4
+ - Python 3.10+
5
+ - Groq API Key
6
+
7
+ ## Installation Steps
8
+
9
+ ### 1. Create Virtual Environment
10
+ ```powershell
11
+ python -m venv venv
12
+ ```
13
+
14
+ ### 2. Activate Virtual Environment
15
+ ```powershell
16
+ .\venv\Scripts\Activate.ps1
17
+ ```
18
+
19
+ **If you get execution policy error:**
20
+ ```powershell
21
+ Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
22
+ ```
23
+
24
+ ### 3. Upgrade pip
25
+ ```powershell
26
+ python -m pip install --upgrade pip
27
+ ```
28
+
29
+ ### 4. Install Dependencies
30
+ ```powershell
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ### 5. Configure API Key
35
+ ```powershell
36
+ copy .env.example .env
37
+ notepad .env
38
+ ```
39
+
40
+ Add your Groq API key:
41
+ ```
42
+ GROQ_API_KEY=your_groq_api_key_here
43
+ ```
44
+
45
+ ### 6. Run Application
46
+ ```powershell
47
+ streamlit run streamlit_app.py
48
+ ```
49
+
50
+ Open browser to: **http://localhost:8501**
51
+
52
+ ---
53
+
54
+ ## Common Issues
55
+
56
+ **Execution Policy Error:**
57
+ ```powershell
58
+ Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
59
+ ```
60
+
61
+ **Reset ChromaDB:**
62
+ ```powershell
63
+ Remove-Item -Recurse -Force .\chroma_db
64
+ ```
65
+
66
+ **Deactivate venv:**
67
+ ```powershell
68
+ deactivate
69
+ ```
UPDATE_REMOTE.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Update Remote Origin
2
+
3
+ ## Method 1: Change Existing Remote URL (Recommended)
4
+
5
+ If you already have a remote set and want to change it:
6
+
7
+ ```powershell
8
+ # View current remote
9
+ git remote -v
10
+
11
+ # Option A: Update with HTTPS URL
12
+ git remote set-url origin https://github.com/YOUR_USERNAME/YOUR_REPO_NAME.git
13
+
14
+ # Option B: Update with SSH URL
15
+ git remote set-url origin git@github.com:YOUR_USERNAME/YOUR_REPO_NAME.git
16
+
17
+ # Verify the change
18
+ git remote -v
19
+ ```
20
+
21
+ ---
22
+
23
+ ## Method 2: Remove and Re-add Remote
24
+
25
+ If you want to completely remove and add a new remote:
26
+
27
+ ```powershell
28
+ # Remove existing remote
29
+ git remote remove origin
30
+
31
+ # Add new remote
32
+ git remote add origin https://github.com/YOUR_USERNAME/YOUR_REPO_NAME.git
33
+
34
+ # Verify
35
+ git remote -v
36
+ ```
37
+
38
+ ---
39
+
40
+ ## Method 3: Using Your GitHub Repository
41
+
42
+ ### Step 1: Create Repository on GitHub
43
+ 1. Go to https://github.com/new
44
+ 2. Repository name: `RAG-Capstone-Project`
45
+ 3. Description: `Retrieval-Augmented Generation system with TRACE evaluation`
46
+ 4. Select Public or Private
47
+ 5. **IMPORTANT**: Don't initialize with README, .gitignore, or license
48
+ 6. Click "Create repository"
49
+
50
+ ### Step 2: Copy Your Repository URL
51
+ After creating, GitHub will show you the URL. Copy it (should look like):
52
+ ```
53
+ https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
54
+ ```
55
+
56
+ ### Step 3: Update Remote in Your Local Repository
57
+ ```powershell
58
+ cd "D:\CapStoneProject\RAG Capstone Project"
59
+
60
+ git remote set-url origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
61
+
62
+ # Verify
63
+ git remote -v
64
+ ```
65
+
66
+ ---
67
+
68
+ ## Complete Step-by-Step Example
69
+
70
+ Let's say your GitHub username is `john-doe`:
71
+
72
+ ```powershell
73
+ # 1. Navigate to project
74
+ cd "D:\CapStoneProject\RAG Capstone Project"
75
+
76
+ # 2. Update the remote URL
77
+ git remote set-url origin https://github.com/john-doe/RAG-Capstone-Project.git
78
+
79
+ # 3. Verify it was updated
80
+ git remote -v
81
+
82
+ # 4. Push to GitHub (first time)
83
+ git push -u origin main
84
+
85
+ # 5. For future pushes (just use)
86
+ git push
87
+ ```
88
+
89
+ ---
90
+
91
+ ## Using SSH Instead of HTTPS
92
+
93
+ ### Step 1: Generate SSH Key (if you don't have one)
94
+ ```powershell
95
+ ssh-keygen -t ed25519 -C "your_email@example.com"
96
+ ```
97
+
98
+ ### Step 2: Add SSH Key to GitHub
99
+ 1. Copy the public key: `cat ~/.ssh/id_ed25519.pub`
100
+ 2. Go to https://github.com/settings/keys
101
+ 3. Click "New SSH key"
102
+ 4. Paste your public key
103
+ 5. Save
104
+
105
+ ### Step 3: Update Remote to SSH
106
+ ```powershell
107
+ git remote set-url origin git@github.com:YOUR_USERNAME/RAG-Capstone-Project.git
108
+
109
+ # Verify
110
+ git remote -v
111
+
112
+ # Test connection
113
+ ssh -T git@github.com
114
+ ```
115
+
116
+ ---
117
+
118
+ ## Troubleshooting
119
+
120
+ ### Problem: "fatal: remote origin already exists"
121
+ ```powershell
122
+ # Remove the old remote first
123
+ git remote remove origin
124
+
125
+ # Then add the new one
126
+ git remote add origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
127
+ ```
128
+
129
+ ### Problem: "Permission denied (publickey)"
130
+ This means SSH authentication failed. Use HTTPS instead:
131
+ ```powershell
132
+ git remote set-url origin https://github.com/YOUR_USERNAME/RAG-Capstone-Project.git
133
+ ```
134
+
135
+ ### Problem: "fatal: Authentication failed"
136
+ This means your GitHub credentials are incorrect. Use a Personal Access Token:
137
+ 1. Generate token: https://github.com/settings/tokens
138
+ 2. When pushing, use the token as password
139
+
140
+ ### Check Current Configuration
141
+ ```powershell
142
+ # View remote URLs
143
+ git remote -v
144
+
145
+ # View detailed remote info
146
+ git remote show origin
147
+
148
+ # View git config
149
+ git config --local -l
150
+ ```
151
+
152
+ ---
153
+
154
+ ## Quick Reference
155
+
156
+ | Task | Command |
157
+ |------|---------|
158
+ | View remotes | `git remote -v` |
159
+ | Update URL (HTTPS) | `git remote set-url origin https://github.com/USER/REPO.git` |
160
+ | Update URL (SSH) | `git remote set-url origin git@github.com:USER/REPO.git` |
161
+ | Remove remote | `git remote remove origin` |
162
+ | Add remote | `git remote add origin <URL>` |
163
+ | Push to remote | `git push -u origin main` |
164
+ | Check remote details | `git remote show origin` |
165
+
166
+ ---
167
+
168
+ ## What You Need to Push
169
+
170
+ Before pushing, make sure you have:
171
+ - ✅ GitHub account created
172
+ - ✅ Repository created on GitHub
173
+ - ✅ Remote URL updated correctly
174
+ - ✅ Local commits ready (already done ✓)
175
+
176
+ ---
177
+
178
+ **What's your GitHub username?** I can help you with the exact commands once you provide it!
__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Capstone Project - Retrieval-Augmented Generation with TRACE Evaluation
3
+
4
+ This application provides a complete RAG system with:
5
+ - Multiple embedding models (Bio-medical BERT models)
6
+ - Various chunking strategies (dense, sparse, hybrid, re-ranking)
7
+ - ChromaDB vector storage
8
+ - Groq LLM integration with rate limiting
9
+ - TRACE evaluation metrics
10
+ - Streamlit chat interface
11
+ - FastAPI REST API
12
+ """
13
+
14
+ __version__ = "1.0.0"
15
+ __author__ = "RAG Capstone Team"
api.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI backend service for RAG application."""
2
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel, Field
5
+ from typing import List, Optional, Dict
6
+ import uvicorn
7
+ from datetime import datetime
8
+ import os
9
+
10
+ from config import settings
11
+ from dataset_loader import RAGBenchLoader
12
+ from vector_store import ChromaDBManager
13
+ from llm_client import GroqLLMClient, RAGPipeline
14
+ from trace_evaluator import TRACEEvaluator
15
+
16
+ # Initialize FastAPI app
17
+ app = FastAPI(
18
+ title="RAG Capstone API",
19
+ description="API for RAG system with TRACE evaluation",
20
+ version="1.0.0"
21
+ )
22
+
23
+ # Add CORS middleware
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ # Global state
33
+ rag_pipeline: Optional[RAGPipeline] = None
34
+ vector_store: Optional[ChromaDBManager] = None
35
+ current_collection: Optional[str] = None
36
+
37
+
38
+ # Request/Response models
39
+ class DatasetLoadRequest(BaseModel):
40
+ """Request model for loading dataset."""
41
+ dataset_name: str = Field(..., description="Name of the dataset")
42
+ num_samples: int = Field(50, description="Number of samples to load")
43
+ chunking_strategy: str = Field("hybrid", description="Chunking strategy")
44
+ chunk_size: int = Field(512, description="Size of chunks")
45
+ overlap: int = Field(50, description="Overlap between chunks")
46
+ embedding_model: str = Field(..., description="Embedding model name")
47
+ llm_model: str = Field("llama-3.1-8b-instant", description="LLM model name")
48
+ groq_api_key: str = Field(..., description="Groq API key")
49
+
50
+
51
+ class QueryRequest(BaseModel):
52
+ """Request model for querying."""
53
+ query: str = Field(..., description="User query")
54
+ n_results: int = Field(5, description="Number of documents to retrieve")
55
+ max_tokens: int = Field(1024, description="Maximum tokens to generate")
56
+ temperature: float = Field(0.7, description="Sampling temperature")
57
+
58
+
59
+ class QueryResponse(BaseModel):
60
+ """Response model for query."""
61
+ query: str
62
+ response: str
63
+ retrieved_documents: List[Dict]
64
+ timestamp: str
65
+
66
+
67
+ class EvaluationRequest(BaseModel):
68
+ """Request model for evaluation."""
69
+ num_samples: int = Field(10, description="Number of test samples")
70
+
71
+
72
+ class CollectionInfo(BaseModel):
73
+ """Collection information model."""
74
+ name: str
75
+ count: int
76
+ metadata: Dict
77
+
78
+
79
+ # API endpoints
80
+ @app.get("/")
81
+ async def root():
82
+ """Root endpoint."""
83
+ return {
84
+ "message": "RAG Capstone API",
85
+ "version": "1.0.0",
86
+ "docs": "/docs"
87
+ }
88
+
89
+
90
+ @app.get("/health")
91
+ async def health_check():
92
+ """Health check endpoint."""
93
+ return {
94
+ "status": "healthy",
95
+ "timestamp": datetime.now().isoformat()
96
+ }
97
+
98
+
99
+ @app.get("/datasets")
100
+ async def list_datasets():
101
+ """List available datasets."""
102
+ return {
103
+ "datasets": settings.ragbench_datasets
104
+ }
105
+
106
+
107
+ @app.get("/models/embedding")
108
+ async def list_embedding_models():
109
+ """List available embedding models."""
110
+ return {
111
+ "embedding_models": settings.embedding_models
112
+ }
113
+
114
+
115
+ @app.get("/models/llm")
116
+ async def list_llm_models():
117
+ """List available LLM models."""
118
+ return {
119
+ "llm_models": settings.llm_models
120
+ }
121
+
122
+
123
+ @app.get("/chunking-strategies")
124
+ async def list_chunking_strategies():
125
+ """List available chunking strategies."""
126
+ return {
127
+ "chunking_strategies": settings.chunking_strategies
128
+ }
129
+
130
+
131
+ @app.get("/collections")
132
+ async def list_collections():
133
+ """List all vector store collections."""
134
+ global vector_store
135
+
136
+ if not vector_store:
137
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
138
+
139
+ collections = vector_store.list_collections()
140
+
141
+ return {
142
+ "collections": collections,
143
+ "count": len(collections)
144
+ }
145
+
146
+
147
+ @app.get("/collections/{collection_name}")
148
+ async def get_collection_info(collection_name: str):
149
+ """Get information about a specific collection."""
150
+ global vector_store
151
+
152
+ if not vector_store:
153
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
154
+
155
+ try:
156
+ stats = vector_store.get_collection_stats(collection_name)
157
+ return stats
158
+ except Exception as e:
159
+ raise HTTPException(status_code=404, detail=f"Collection not found: {str(e)}")
160
+
161
+
162
+ @app.post("/load-dataset")
163
+ async def load_dataset(request: DatasetLoadRequest, background_tasks: BackgroundTasks):
164
+ """Load dataset and create vector collection."""
165
+ global rag_pipeline, vector_store, current_collection
166
+
167
+ try:
168
+ # Initialize dataset loader
169
+ loader = RAGBenchLoader()
170
+
171
+ # Load dataset
172
+ dataset = loader.load_dataset(
173
+ request.dataset_name,
174
+ split="train",
175
+ max_samples=request.num_samples
176
+ )
177
+
178
+ if not dataset:
179
+ raise HTTPException(status_code=400, detail="Failed to load dataset")
180
+
181
+ # Initialize vector store
182
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
183
+
184
+ # Create collection name
185
+ collection_name = f"{request.dataset_name}_{request.chunking_strategy}_{request.embedding_model.split('/')[-1]}"
186
+ collection_name = collection_name.replace("-", "_").replace(".", "_")
187
+
188
+ # Load data into collection
189
+ vector_store.load_dataset_into_collection(
190
+ collection_name=collection_name,
191
+ embedding_model_name=request.embedding_model,
192
+ chunking_strategy=request.chunking_strategy,
193
+ dataset_data=dataset,
194
+ chunk_size=request.chunk_size,
195
+ overlap=request.overlap
196
+ )
197
+
198
+ # Initialize LLM client
199
+ llm_client = GroqLLMClient(
200
+ api_key=request.groq_api_key,
201
+ model_name=request.llm_model,
202
+ max_rpm=settings.groq_rpm_limit,
203
+ rate_limit_delay=settings.rate_limit_delay
204
+ )
205
+
206
+ # Create RAG pipeline
207
+ rag_pipeline = RAGPipeline(llm_client, vector_store)
208
+ current_collection = collection_name
209
+
210
+ return {
211
+ "status": "success",
212
+ "collection_name": collection_name,
213
+ "num_documents": len(dataset),
214
+ "message": f"Collection '{collection_name}' created successfully"
215
+ }
216
+
217
+ except Exception as e:
218
+ raise HTTPException(status_code=500, detail=f"Error loading dataset: {str(e)}")
219
+
220
+
221
+ @app.post("/query", response_model=QueryResponse)
222
+ async def query_rag(request: QueryRequest):
223
+ """Query the RAG system."""
224
+ global rag_pipeline
225
+
226
+ if not rag_pipeline:
227
+ raise HTTPException(
228
+ status_code=400,
229
+ detail="RAG pipeline not initialized. Load a dataset first."
230
+ )
231
+
232
+ try:
233
+ result = rag_pipeline.query(
234
+ query=request.query,
235
+ n_results=request.n_results,
236
+ max_tokens=request.max_tokens,
237
+ temperature=request.temperature
238
+ )
239
+
240
+ result["timestamp"] = datetime.now().isoformat()
241
+
242
+ return result
243
+
244
+ except Exception as e:
245
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
246
+
247
+
248
+ @app.get("/chat-history")
249
+ async def get_chat_history():
250
+ """Get chat history."""
251
+ global rag_pipeline
252
+
253
+ if not rag_pipeline:
254
+ raise HTTPException(
255
+ status_code=400,
256
+ detail="RAG pipeline not initialized. Load a dataset first."
257
+ )
258
+
259
+ return {
260
+ "history": rag_pipeline.get_chat_history()
261
+ }
262
+
263
+
264
+ @app.delete("/chat-history")
265
+ async def clear_chat_history():
266
+ """Clear chat history."""
267
+ global rag_pipeline
268
+
269
+ if not rag_pipeline:
270
+ raise HTTPException(
271
+ status_code=400,
272
+ detail="RAG pipeline not initialized. Load a dataset first."
273
+ )
274
+
275
+ rag_pipeline.clear_history()
276
+
277
+ return {
278
+ "status": "success",
279
+ "message": "Chat history cleared"
280
+ }
281
+
282
+
283
+ @app.post("/evaluate")
284
+ async def run_evaluation(request: EvaluationRequest):
285
+ """Run TRACE evaluation."""
286
+ global rag_pipeline, current_collection
287
+
288
+ if not rag_pipeline:
289
+ raise HTTPException(
290
+ status_code=400,
291
+ detail="RAG pipeline not initialized. Load a dataset first."
292
+ )
293
+
294
+ try:
295
+ # Get dataset name from collection metadata
296
+ collection_metadata = vector_store.current_collection.metadata
297
+ dataset_name = current_collection.split("_")[0] if current_collection else "hotpotqa"
298
+
299
+ # Get test data
300
+ loader = RAGBenchLoader()
301
+ test_data = loader.get_test_data(dataset_name, request.num_samples)
302
+
303
+ # Prepare test cases
304
+ test_cases = []
305
+
306
+ for sample in test_data:
307
+ result = rag_pipeline.query(sample["question"], n_results=5)
308
+
309
+ test_cases.append({
310
+ "query": sample["question"],
311
+ "response": result["response"],
312
+ "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
313
+ "ground_truth": sample.get("answer", "")
314
+ })
315
+
316
+ # Run evaluation
317
+ evaluator = TRACEEvaluator()
318
+ results = evaluator.evaluate_batch(test_cases)
319
+
320
+ return {
321
+ "status": "success",
322
+ "results": results
323
+ }
324
+
325
+ except Exception as e:
326
+ raise HTTPException(status_code=500, detail=f"Error during evaluation: {str(e)}")
327
+
328
+
329
+ @app.delete("/collections/{collection_name}")
330
+ async def delete_collection(collection_name: str):
331
+ """Delete a collection."""
332
+ global vector_store
333
+
334
+ if not vector_store:
335
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
336
+
337
+ try:
338
+ vector_store.delete_collection(collection_name)
339
+ return {
340
+ "status": "success",
341
+ "message": f"Collection '{collection_name}' deleted"
342
+ }
343
+ except Exception as e:
344
+ raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
345
+
346
+
347
+ @app.get("/current-collection")
348
+ async def get_current_collection():
349
+ """Get current collection information."""
350
+ global current_collection, vector_store
351
+
352
+ if not current_collection:
353
+ return {
354
+ "collection": None,
355
+ "message": "No collection loaded"
356
+ }
357
+
358
+ try:
359
+ stats = vector_store.get_collection_stats(current_collection)
360
+ return {
361
+ "collection": current_collection,
362
+ "stats": stats
363
+ }
364
+ except Exception as e:
365
+ raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
366
+
367
+
368
+ if __name__ == "__main__":
369
+ uvicorn.run(
370
+ "api:app",
371
+ host="0.0.0.0",
372
+ port=8000,
373
+ reload=True
374
+ )
chunking_strategies.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chunking strategies for document processing."""
2
+ from typing import List, Dict, Tuple
3
+ from abc import ABC, abstractmethod
4
+ import re
5
+ from rank_bm25 import BM25Okapi
6
+ import numpy as np
7
+
8
+
9
+ class ChunkingStrategy(ABC):
10
+ """Abstract base class for chunking strategies."""
11
+
12
+ @abstractmethod
13
+ def chunk_text(self, text: str, chunk_size: int = 512,
14
+ overlap: int = 50) -> List[str]:
15
+ """Chunk text into smaller pieces.
16
+
17
+ Args:
18
+ text: Input text to chunk
19
+ chunk_size: Maximum size of each chunk
20
+ overlap: Number of characters to overlap between chunks
21
+
22
+ Returns:
23
+ List of text chunks
24
+ """
25
+ pass
26
+
27
+
28
+ class DenseChunking(ChunkingStrategy):
29
+ """Dense chunking strategy - fixed-size chunks with overlap."""
30
+
31
+ def chunk_text(self, text: str, chunk_size: int = 512,
32
+ overlap: int = 50) -> List[str]:
33
+ """Create dense chunks with fixed size and overlap."""
34
+ if not text:
35
+ return []
36
+
37
+ chunks = []
38
+ start = 0
39
+ text_length = len(text)
40
+
41
+ while start < text_length:
42
+ end = start + chunk_size
43
+ chunk = text[start:end]
44
+
45
+ # Try to break at sentence boundary
46
+ if end < text_length:
47
+ last_period = chunk.rfind('.')
48
+ last_newline = chunk.rfind('\n')
49
+ break_point = max(last_period, last_newline)
50
+
51
+ if break_point > chunk_size * 0.5: # At least 50% of chunk size
52
+ chunk = chunk[:break_point + 1]
53
+ end = start + break_point + 1
54
+
55
+ chunks.append(chunk.strip())
56
+ start = end - overlap
57
+
58
+ if start >= text_length:
59
+ break
60
+
61
+ return [c for c in chunks if c] # Remove empty chunks
62
+
63
+
64
+ class SparseChunking(ChunkingStrategy):
65
+ """Sparse chunking strategy - semantic-based chunks (paragraphs/sections)."""
66
+
67
+ def chunk_text(self, text: str, chunk_size: int = 512,
68
+ overlap: int = 50) -> List[str]:
69
+ """Create sparse chunks based on semantic boundaries."""
70
+ if not text:
71
+ return []
72
+
73
+ # Split by double newlines (paragraphs)
74
+ paragraphs = re.split(r'\n\s*\n', text)
75
+ chunks = []
76
+ current_chunk = ""
77
+
78
+ for para in paragraphs:
79
+ para = para.strip()
80
+ if not para:
81
+ continue
82
+
83
+ # If adding this paragraph exceeds chunk_size, save current chunk
84
+ if len(current_chunk) + len(para) > chunk_size and current_chunk:
85
+ chunks.append(current_chunk.strip())
86
+ # Start new chunk with overlap
87
+ if overlap > 0:
88
+ words = current_chunk.split()
89
+ overlap_words = words[-min(overlap // 5, len(words)):]
90
+ current_chunk = " ".join(overlap_words) + " " + para
91
+ else:
92
+ current_chunk = para
93
+ else:
94
+ current_chunk += ("\n\n" if current_chunk else "") + para
95
+
96
+ # Add the last chunk
97
+ if current_chunk:
98
+ chunks.append(current_chunk.strip())
99
+
100
+ return chunks
101
+
102
+
103
+ class HybridChunking(ChunkingStrategy):
104
+ """Hybrid chunking strategy - combines dense and sparse approaches."""
105
+
106
+ def __init__(self):
107
+ self.dense_chunker = DenseChunking()
108
+ self.sparse_chunker = SparseChunking()
109
+
110
+ def chunk_text(self, text: str, chunk_size: int = 512,
111
+ overlap: int = 50) -> List[str]:
112
+ """Create hybrid chunks combining both strategies."""
113
+ if not text:
114
+ return []
115
+
116
+ # First apply sparse chunking to get semantic boundaries
117
+ sparse_chunks = self.sparse_chunker.chunk_text(text, chunk_size * 2, 0)
118
+
119
+ # Then apply dense chunking to each sparse chunk
120
+ all_chunks = []
121
+ for sparse_chunk in sparse_chunks:
122
+ if len(sparse_chunk) > chunk_size:
123
+ dense_chunks = self.dense_chunker.chunk_text(
124
+ sparse_chunk, chunk_size, overlap
125
+ )
126
+ all_chunks.extend(dense_chunks)
127
+ else:
128
+ all_chunks.append(sparse_chunk)
129
+
130
+ return all_chunks
131
+
132
+
133
+ class ReRankingChunking(ChunkingStrategy):
134
+ """Re-ranking chunking strategy - creates chunks and provides relevance scoring."""
135
+
136
+ def __init__(self):
137
+ self.base_chunker = HybridChunking()
138
+ self.bm25 = None
139
+ self.chunks = []
140
+
141
+ def chunk_text(self, text: str, chunk_size: int = 512,
142
+ overlap: int = 50) -> List[str]:
143
+ """Create chunks suitable for re-ranking."""
144
+ self.chunks = self.base_chunker.chunk_text(text, chunk_size, overlap)
145
+
146
+ # Initialize BM25 for re-ranking capability
147
+ tokenized_chunks = [chunk.lower().split() for chunk in self.chunks]
148
+ self.bm25 = BM25Okapi(tokenized_chunks)
149
+
150
+ return self.chunks
151
+
152
+ def rerank_chunks(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
153
+ """Re-rank chunks based on query relevance.
154
+
155
+ Args:
156
+ query: Query string
157
+ top_k: Number of top chunks to return
158
+
159
+ Returns:
160
+ List of (chunk, score) tuples
161
+ """
162
+ if not self.bm25 or not self.chunks:
163
+ return []
164
+
165
+ tokenized_query = query.lower().split()
166
+ scores = self.bm25.get_scores(tokenized_query)
167
+
168
+ # Get top-k chunks with scores
169
+ top_indices = np.argsort(scores)[::-1][:top_k]
170
+ ranked_chunks = [
171
+ (self.chunks[idx], float(scores[idx]))
172
+ for idx in top_indices
173
+ ]
174
+
175
+ return ranked_chunks
176
+
177
+
178
+ class ChunkingFactory:
179
+ """Factory for creating chunking strategy instances."""
180
+
181
+ STRATEGIES = {
182
+ "dense": DenseChunking,
183
+ "sparse": SparseChunking,
184
+ "hybrid": HybridChunking,
185
+ "re-ranking": ReRankingChunking
186
+ }
187
+
188
+ @classmethod
189
+ def create_chunker(cls, strategy: str) -> ChunkingStrategy:
190
+ """Create a chunking strategy instance.
191
+
192
+ Args:
193
+ strategy: Name of the chunking strategy
194
+
195
+ Returns:
196
+ ChunkingStrategy instance
197
+ """
198
+ if strategy not in cls.STRATEGIES:
199
+ raise ValueError(f"Unknown chunking strategy: {strategy}. "
200
+ f"Available: {list(cls.STRATEGIES.keys())}")
201
+
202
+ return cls.STRATEGIES[strategy]()
203
+
204
+ @classmethod
205
+ def get_available_strategies(cls) -> List[str]:
206
+ """Get list of available chunking strategies."""
207
+ return list(cls.STRATEGIES.keys())
cleanup_chroma.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Script to clean up ChromaDB collections and cache."""
3
+
4
+ import shutil
5
+ import os
6
+ from pathlib import Path
7
+
8
+ def cleanup_chroma_db():
9
+ """Clean up ChromaDB collections and cache."""
10
+
11
+ print("=" * 60)
12
+ print("ChromaDB Cleanup Utility")
13
+ print("=" * 60)
14
+
15
+ # First, forcefully delete the chroma_db directory
16
+ chroma_path = Path("./chroma_db")
17
+ if chroma_path.exists():
18
+ print(f"\n🗑️ Removing chroma_db directory: {chroma_path}")
19
+ try:
20
+ shutil.rmtree(chroma_path)
21
+ print(f"✅ Deleted directory: {chroma_path}")
22
+ except Exception as e:
23
+ print(f"❌ Error deleting directory: {e}")
24
+ else:
25
+ print(f"\n✅ chroma_db directory not found: {chroma_path}")
26
+
27
+ # Also check for ChromaDB in .chroma directory (alternative location)
28
+ chroma_alt_path = Path("./.chroma")
29
+ if chroma_alt_path.exists():
30
+ print(f"\n🗑️ Removing .chroma directory: {chroma_alt_path}")
31
+ try:
32
+ shutil.rmtree(chroma_alt_path)
33
+ print(f"✅ Deleted directory: {chroma_alt_path}")
34
+ except Exception as e:
35
+ print(f"❌ Error deleting directory: {e}")
36
+
37
+ # Clear HuggingFace dataset cache (optional)
38
+ response = input("\n🗑️ Clear HuggingFace dataset cache? (y/n): ").lower()
39
+ if response == 'y':
40
+ cache_path = Path.home() / ".cache" / "huggingface" / "datasets"
41
+ if cache_path.exists():
42
+ print(f"🗑️ Removing HF cache: {cache_path}")
43
+ try:
44
+ shutil.rmtree(cache_path)
45
+ print(f"✅ Deleted HF cache: {cache_path}")
46
+ except Exception as e:
47
+ print(f"❌ Error deleting HF cache: {e}")
48
+ else:
49
+ print("ℹ️ HuggingFace cache not found")
50
+
51
+ # Clear ChromaDB chroma cache directory
52
+ response = input("\n🗑️ Clear ChromaDB chroma cache? (y/n): ").lower()
53
+ if response == 'y':
54
+ chroma_cache = Path.home() / ".chroma"
55
+ if chroma_cache.exists():
56
+ print(f"🗑️ Removing ChromaDB cache: {chroma_cache}")
57
+ try:
58
+ shutil.rmtree(chroma_cache)
59
+ print(f"✅ Deleted ChromaDB cache: {chroma_cache}")
60
+ except Exception as e:
61
+ print(f"❌ Error deleting ChromaDB cache: {e}")
62
+
63
+ # Try to use ChromaDBManager if possible
64
+ print("\n📋 Attempting to connect to ChromaDB...")
65
+ try:
66
+ from vector_store import ChromaDBManager
67
+
68
+ manager = ChromaDBManager(persist_directory="./chroma_db")
69
+
70
+ # List existing collections
71
+ collections = manager.list_collections()
72
+ print(f"📊 Found {len(collections)} collection(s):")
73
+ for col in collections:
74
+ print(f" - {col}")
75
+
76
+ # Clear all collections
77
+ if collections:
78
+ print("\n🗑️ Clearing all collections...")
79
+ deleted = manager.clear_all_collections()
80
+ print(f"✅ Deleted {deleted} collection(s)")
81
+ else:
82
+ print("\n✅ No collections to delete")
83
+ except Exception as e:
84
+ print(f"⚠️ Could not connect to ChromaDB via manager: {e}")
85
+ print("ℹ️ This is OK - the directory has been deleted already.")
86
+
87
+ print("\n" + "=" * 60)
88
+ print("✅ Cleanup completed! You can now start fresh.")
89
+ print("=" * 60)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ cleanup_chroma_db()
config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management for RAG Application."""
2
+ from pydantic_settings import BaseSettings
3
+ from typing import Optional
4
+ import os
5
+
6
+
7
+ class Settings(BaseSettings):
8
+ """Application settings."""
9
+
10
+ # API Keys
11
+ groq_api_key: str = ""
12
+
13
+ # ChromaDB
14
+ chroma_persist_directory: str = "./chroma_db"
15
+
16
+ # Rate Limiting
17
+ groq_rpm_limit: int = 30
18
+ rate_limit_delay: float = 2.0
19
+
20
+ # Embedding Models
21
+ embedding_models: list = [
22
+ "sentence-transformers/all-mpnet-base-v2", # Stable, high quality
23
+ "emilyalsentzer/Bio_ClinicalBERT", # Clinical domain
24
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", # Medical domain
25
+ "sentence-transformers/all-MiniLM-L6-v2", # Fast, lightweight
26
+ "sentence-transformers/multilingual-MiniLM-L12-v2", # Multilingual
27
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", # Paraphrase
28
+ "allenai/specter", # Academic papers
29
+ "gemini-embedding-001" # Gemini API
30
+ ]
31
+
32
+ # LLM Models
33
+ llm_models: list = [
34
+ "meta-llama/llama-4-maverick-17b-128e-instruct",
35
+ "llama-3.1-8b-instant",
36
+ "openai/gpt-oss-120b"
37
+ ]
38
+
39
+ # Chunking Strategies
40
+ chunking_strategies: list = ["dense", "sparse", "hybrid", "re-ranking"]
41
+
42
+ # RAG Bench Datasets (from rungalileo/ragbench)
43
+ ragbench_datasets: list = [
44
+ "covidqa",
45
+ "cuad",
46
+ "delucionqa",
47
+ "emanual",
48
+ "expertqa",
49
+ "finqa",
50
+ "hagrid",
51
+ "hotpotqa",
52
+ "msmarco",
53
+ "pubmedqa",
54
+ "tatqa",
55
+ "techqa"
56
+ ]
57
+
58
+ class Config:
59
+ env_file = ".env"
60
+ case_sensitive = False
61
+ extra = "allow"
62
+
63
+
64
+ settings = Settings()
dataset_loader.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loader for RAG Bench datasets."""
2
+ import os
3
+ from typing import List, Dict, Optional
4
+ from datasets import load_dataset
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+
8
+
9
+ class RAGBenchLoader:
10
+ """Load and manage RAG Bench datasets."""
11
+
12
+ SUPPORTED_DATASETS = [
13
+ 'covidqa',
14
+ 'cuad',
15
+ 'delucionqa',
16
+ 'emanual',
17
+ 'expertqa',
18
+ 'finqa',
19
+ 'hagrid',
20
+ 'hotpotqa',
21
+ 'msmarco',
22
+ 'pubmedqa',
23
+ 'tatqa',
24
+ 'techqa'
25
+ ]
26
+
27
+ def __init__(self, cache_dir: str = "./data_cache"):
28
+ """Initialize the dataset loader.
29
+
30
+ Args:
31
+ cache_dir: Directory to cache downloaded datasets
32
+ """
33
+ self.cache_dir = cache_dir
34
+ os.makedirs(cache_dir, exist_ok=True)
35
+
36
+ def load_dataset(self, dataset_name: str, split: str = "test",
37
+ max_samples: Optional[int] = None) -> List[Dict]:
38
+ """Load a RAG Bench dataset from rungalileo/ragbench.
39
+
40
+ Args:
41
+ dataset_name: Name of the dataset to load
42
+ split: Dataset split (train/validation/test)
43
+ max_samples: Maximum number of samples to load
44
+
45
+ Returns:
46
+ List of dictionaries containing dataset samples
47
+ """
48
+ if dataset_name not in self.SUPPORTED_DATASETS:
49
+ raise ValueError(f"Unsupported dataset: {dataset_name}. "
50
+ f"Supported: {self.SUPPORTED_DATASETS}")
51
+
52
+ print(f"Loading {dataset_name} dataset ({split} split) from rungalileo/ragbench...")
53
+
54
+ try:
55
+ # Load from rungalileo/ragbench
56
+ dataset = load_dataset("rungalileo/ragbench", dataset_name, split=split,
57
+ cache_dir=self.cache_dir)
58
+
59
+ processed_data = []
60
+ samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset))))
61
+
62
+ # Process the dataset
63
+ for item in tqdm(samples, desc=f"Processing {dataset_name}"):
64
+ processed_data.append(self._process_ragbench_item(item, dataset_name))
65
+
66
+ print(f"Loaded {len(processed_data)} samples from {dataset_name}")
67
+ return processed_data
68
+
69
+ except Exception as e:
70
+ print(f"Error loading {dataset_name}: {str(e)}")
71
+ print("Falling back to sample data for testing...")
72
+ return self._create_sample_data(dataset_name, max_samples or 10)
73
+
74
+ def _process_ragbench_item(self, item: Dict, dataset_name: str) -> Dict:
75
+ """Process a single RAGBench dataset item into standardized format.
76
+
77
+ Args:
78
+ item: Raw dataset item
79
+ dataset_name: Name of the dataset
80
+
81
+ Returns:
82
+ Processed item dictionary
83
+ """
84
+ # RAGBench datasets typically have: question, documents, answer, and retrieved_contexts
85
+ processed = {
86
+ "question": item.get("question", ""),
87
+ "answer": item.get("answer", ""),
88
+ "context": "", # For embedding and retrieval
89
+ "documents": [], # Store original documents list
90
+ "dataset": dataset_name
91
+ }
92
+
93
+ # Extract documents - RAGBench uses 'documents' as primary source for embeddings
94
+ # Priority: documents > retrieved_contexts > context
95
+ if "documents" in item:
96
+ if isinstance(item["documents"], list):
97
+ processed["documents"] = [str(doc) for doc in item["documents"]]
98
+ processed["context"] = " ".join(processed["documents"])
99
+ else:
100
+ processed["documents"] = [str(item["documents"])]
101
+ processed["context"] = str(item["documents"])
102
+ elif "retrieved_contexts" in item:
103
+ if isinstance(item["retrieved_contexts"], list):
104
+ processed["documents"] = [str(ctx) for ctx in item["retrieved_contexts"]]
105
+ processed["context"] = " ".join(processed["documents"])
106
+ else:
107
+ processed["documents"] = [str(item["retrieved_contexts"])]
108
+ processed["context"] = str(item["retrieved_contexts"])
109
+ elif "context" in item:
110
+ if isinstance(item["context"], list):
111
+ processed["documents"] = [str(ctx) for ctx in item["context"]]
112
+ processed["context"] = " ".join(processed["documents"])
113
+ else:
114
+ processed["documents"] = [str(item["context"])]
115
+ processed["context"] = str(item["context"])
116
+
117
+ # Store additional metadata if available
118
+ if "metadata" in item:
119
+ processed["metadata"] = item["metadata"]
120
+
121
+ return processed
122
+
123
+ def load_all_datasets(self, split: str = "test", max_samples: Optional[int] = None) -> Dict[str, List[Dict]]:
124
+ """Load all RAGBench datasets.
125
+
126
+ Args:
127
+ split: Dataset split to load
128
+ max_samples: Maximum samples per dataset
129
+
130
+ Returns:
131
+ Dictionary mapping dataset names to their data
132
+ """
133
+ all_data = {}
134
+ for dataset_name in self.SUPPORTED_DATASETS:
135
+ print(f"\n{'='*50}")
136
+ print(f"Loading {dataset_name}...")
137
+ print(f"{'='*50}")
138
+ try:
139
+ all_data[dataset_name] = self.load_dataset(dataset_name, split, max_samples)
140
+ except Exception as e:
141
+ print(f"Failed to load {dataset_name}: {str(e)}")
142
+ all_data[dataset_name] = []
143
+
144
+ return all_data
145
+
146
+ def _create_sample_data(self, dataset_name: str, num_samples: int) -> List[Dict]:
147
+ """Create sample data for testing when actual dataset is unavailable."""
148
+ sample_data = []
149
+ for i in range(num_samples):
150
+ # Create multiple sample documents per question
151
+ sample_docs = [
152
+ f"Document 1: This is the first sample document {i+1} for {dataset_name} dataset. "
153
+ f"It contains relevant information to answer the question.",
154
+ f"Document 2: This is the second sample document {i+1} providing additional context. "
155
+ f"It includes more details about the topic.",
156
+ f"Document 3: This is the third sample document {i+1} with supplementary information."
157
+ ]
158
+
159
+ sample_data.append({
160
+ "question": f"Sample question {i+1} for {dataset_name}?",
161
+ "answer": f"Sample answer {i+1}",
162
+ "documents": sample_docs,
163
+ "context": " ".join(sample_docs), # Combined for backward compatibility
164
+ "dataset": dataset_name
165
+ })
166
+ return sample_data
167
+
168
+ def get_test_data(self, dataset_name: str, num_samples: int = 100) -> List[Dict]:
169
+ """Get test data for TRACE evaluation.
170
+
171
+ Args:
172
+ dataset_name: Name of the dataset
173
+ num_samples: Number of test samples
174
+
175
+ Returns:
176
+ List of test samples
177
+ """
178
+ return self.load_dataset(dataset_name, split="test", max_samples=num_samples)
docker-compose.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ streamlit:
5
+ build: .
6
+ ports:
7
+ - "8501:8501"
8
+ environment:
9
+ - GROQ_API_KEY=${GROQ_API_KEY}
10
+ - CHROMA_PERSIST_DIRECTORY=/app/chroma_db
11
+ volumes:
12
+ - ./chroma_db:/app/chroma_db
13
+ - ./data_cache:/app/data_cache
14
+ command: streamlit run streamlit_app.py --server.port=8501 --server.address=0.0.0.0
15
+
16
+ api:
17
+ build: .
18
+ ports:
19
+ - "8000:8000"
20
+ environment:
21
+ - GROQ_API_KEY=${GROQ_API_KEY}
22
+ - CHROMA_PERSIST_DIRECTORY=/app/chroma_db
23
+ volumes:
24
+ - ./chroma_db:/app/chroma_db
25
+ - ./data_cache:/app/data_cache
26
+ command: uvicorn api:app --host 0.0.0.0 --port 8000 --reload
embedding_models.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding models for document vectorization."""
2
+ from typing import List, Optional
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import os
9
+
10
+
11
+ class EmbeddingModel:
12
+ """Base class for embedding models."""
13
+
14
+ def __init__(self, model_name: str, device: Optional[str] = None):
15
+ """Initialize embedding model.
16
+
17
+ Args:
18
+ model_name: Name/path of the model
19
+ device: Device to run model on (cuda/cpu)
20
+ """
21
+ self.model_name = model_name
22
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
23
+ self.model = None
24
+ self.tokenizer = None
25
+
26
+ def load_model(self):
27
+ """Load the embedding model."""
28
+ raise NotImplementedError
29
+
30
+ def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
31
+ """Embed a list of documents.
32
+
33
+ Args:
34
+ texts: List of texts to embed
35
+ batch_size: Batch size for processing
36
+
37
+ Returns:
38
+ Numpy array of embeddings
39
+ """
40
+ raise NotImplementedError
41
+
42
+ def embed_query(self, query: str) -> np.ndarray:
43
+ """Embed a single query.
44
+
45
+ Args:
46
+ query: Query text
47
+
48
+ Returns:
49
+ Numpy array of embedding
50
+ """
51
+ return self.embed_documents([query])[0]
52
+
53
+
54
+ class SentenceTransformerEmbedding(EmbeddingModel):
55
+ """Sentence Transformer based embedding model."""
56
+
57
+ def load_model(self):
58
+ """Load sentence transformer model."""
59
+ print(f"Loading SentenceTransformer model: {self.model_name}")
60
+ try:
61
+ self.model = SentenceTransformer(self.model_name, device=self.device)
62
+ print(f"Model loaded successfully on {self.device}")
63
+ except Exception as e:
64
+ print(f"Error loading model {self.model_name}: {str(e)}")
65
+ print("Falling back to default model...")
66
+ self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
67
+
68
+ def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
69
+ """Embed documents using sentence transformer."""
70
+ if self.model is None:
71
+ self.load_model()
72
+
73
+ embeddings = []
74
+ for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
75
+ batch = texts[i:i + batch_size]
76
+ batch_embeddings = self.model.encode(
77
+ batch,
78
+ convert_to_numpy=True,
79
+ show_progress_bar=False,
80
+ batch_size=batch_size
81
+ )
82
+ embeddings.append(batch_embeddings)
83
+
84
+ return np.vstack(embeddings) if embeddings else np.array([])
85
+
86
+
87
+ class BioMedicalEmbedding(EmbeddingModel):
88
+ """Bio-medical BERT based embedding model."""
89
+
90
+ def load_model(self):
91
+ """Load bio-medical BERT model."""
92
+ print(f"Loading Bio-Medical model: {self.model_name}")
93
+ try:
94
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
95
+ self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
96
+ self.model.eval()
97
+ print(f"Model loaded successfully on {self.device}")
98
+ except Exception as e:
99
+ print(f"Error loading model {self.model_name}: {str(e)}")
100
+ print("Falling back to default model...")
101
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
102
+ self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device)
103
+ self.model.eval()
104
+
105
+ def mean_pooling(self, model_output, attention_mask):
106
+ """Apply mean pooling to get sentence embeddings."""
107
+ token_embeddings = model_output[0]
108
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
109
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
110
+ input_mask_expanded.sum(1), min=1e-9
111
+ )
112
+
113
+ def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
114
+ """Embed documents using bio-medical BERT."""
115
+ if self.model is None:
116
+ self.load_model()
117
+
118
+ embeddings = []
119
+
120
+ with torch.no_grad():
121
+ for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
122
+ batch = texts[i:i + batch_size]
123
+
124
+ # Tokenize
125
+ encoded_input = self.tokenizer(
126
+ batch,
127
+ padding=True,
128
+ truncation=True,
129
+ max_length=512,
130
+ return_tensors='pt'
131
+ ).to(self.device)
132
+
133
+ # Get embeddings
134
+ model_output = self.model(**encoded_input)
135
+
136
+ # Apply mean pooling
137
+ batch_embeddings = self.mean_pooling(
138
+ model_output,
139
+ encoded_input['attention_mask']
140
+ )
141
+
142
+ # Normalize
143
+ batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
144
+
145
+ embeddings.append(batch_embeddings.cpu().numpy())
146
+
147
+ return np.vstack(embeddings) if embeddings else np.array([])
148
+
149
+
150
+ class GeminiEmbedding(EmbeddingModel):
151
+ """Gemini embedding model using Google AI API."""
152
+
153
+ def load_model(self):
154
+ """Load Gemini embedding model."""
155
+ print(f"Initializing Gemini embedding model: {self.model_name}")
156
+ try:
157
+ import google.generativeai as genai
158
+ api_key = os.getenv("GEMINI_API_KEY")
159
+ if not api_key:
160
+ raise ValueError("GEMINI_API_KEY environment variable not set")
161
+ genai.configure(api_key=api_key)
162
+ self.model = genai
163
+ print(f"Gemini model initialized successfully")
164
+ except Exception as e:
165
+ print(f"Error loading Gemini model: {str(e)}")
166
+ print("Falling back to default model...")
167
+ self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
168
+
169
+ def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
170
+ """Embed documents using Gemini API."""
171
+ if self.model is None:
172
+ self.load_model()
173
+
174
+ embeddings = []
175
+
176
+ # Gemini API has rate limits, process with delays
177
+ for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
178
+ batch = texts[i:i + batch_size]
179
+
180
+ for text in batch:
181
+ try:
182
+ if hasattr(self.model, 'embed_content'):
183
+ result = self.model.embed_content(
184
+ model="models/embedding-001",
185
+ content=text,
186
+ task_type="retrieval_document"
187
+ )
188
+ embeddings.append(result['embedding'])
189
+ else:
190
+ # Fallback if Gemini not available
191
+ from sentence_transformers import SentenceTransformer
192
+ fallback_model = SentenceTransformer('all-MiniLM-L6-v2')
193
+ emb = fallback_model.encode([text])[0]
194
+ embeddings.append(emb)
195
+ except Exception as e:
196
+ print(f"Error embedding text: {str(e)}")
197
+ # Use zero vector as fallback
198
+ embeddings.append(np.zeros(768))
199
+
200
+ return np.array(embeddings)
201
+
202
+
203
+ class EmbeddingFactory:
204
+ """Factory for creating embedding model instances."""
205
+
206
+ # Map model names to their types
207
+ MODEL_TYPES = {
208
+ "sentence-transformers/all-mpnet-base-v2": "sentence-transformer", # Stable, well-supported
209
+ "emilyalsentzer/Bio_ClinicalBERT": "biomedical", # Clinical domain
210
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": "biomedical", # Medical domain
211
+ "sentence-transformers/all-MiniLM-L6-v2": "sentence-transformer", # Fast, lightweight
212
+ "sentence-transformers/multilingual-MiniLM-L12-v2": "sentence-transformer", # Multilingual
213
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformer", # Paraphrase
214
+ "allenai/specter": "biomedical", # Academic paper embeddings
215
+ "gemini-embedding-001": "gemini" # Gemini API
216
+ }
217
+
218
+ @classmethod
219
+ def create_embedding_model(cls, model_name: str, device: Optional[str] = None) -> EmbeddingModel:
220
+ """Create an embedding model instance.
221
+
222
+ Args:
223
+ model_name: Name of the embedding model
224
+ device: Device to run model on
225
+
226
+ Returns:
227
+ EmbeddingModel instance
228
+ """
229
+ model_type = cls.MODEL_TYPES.get(model_name, "sentence-transformer")
230
+
231
+ if model_type == "gemini":
232
+ return GeminiEmbedding(model_name, device)
233
+ elif model_type == "biomedical":
234
+ return BioMedicalEmbedding(model_name, device)
235
+ else:
236
+ return SentenceTransformerEmbedding(model_name, device)
237
+
238
+ @classmethod
239
+ def get_available_models(cls) -> List[str]:
240
+ """Get list of available embedding models."""
241
+ return list(cls.MODEL_TYPES.keys())
242
+
243
+ @classmethod
244
+ def get_model_info(cls, model_name: str) -> dict:
245
+ """Get information about a specific model.
246
+
247
+ Args:
248
+ model_name: Name of the model
249
+
250
+ Returns:
251
+ Dictionary with model information
252
+ """
253
+ info = {
254
+ "sentence-transformers/all-mpnet-base-v2": {
255
+ "description": "High-quality, general-purpose sentence embeddings (384d)",
256
+ "dimension": 768,
257
+ "type": "sentence-transformer",
258
+ "note": "Recommended for general use"
259
+ },
260
+ "emilyalsentzer/Bio_ClinicalBERT": {
261
+ "description": "Clinical BERT for biomedical and clinical text",
262
+ "dimension": 768,
263
+ "type": "biomedical"
264
+ },
265
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": {
266
+ "description": "PubMedBERT for biomedical and medical text",
267
+ "dimension": 768,
268
+ "type": "biomedical"
269
+ },
270
+ "sentence-transformers/all-MiniLM-L6-v2": {
271
+ "description": "Fast, lightweight sentence embeddings",
272
+ "dimension": 384,
273
+ "type": "sentence-transformer",
274
+ "note": "Good for speed-sensitive applications"
275
+ },
276
+ "sentence-transformers/multilingual-MiniLM-L12-v2": {
277
+ "description": "Fast multilingual sentence embeddings",
278
+ "dimension": 384,
279
+ "type": "sentence-transformer",
280
+ "note": "Supports 50+ languages"
281
+ },
282
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": {
283
+ "description": "Multilingual paraphrase embeddings",
284
+ "dimension": 384,
285
+ "type": "sentence-transformer",
286
+ "note": "Good for paraphrase detection"
287
+ },
288
+ "allenai/specter": {
289
+ "description": "Embeddings for academic papers and citations",
290
+ "dimension": 768,
291
+ "type": "biomedical",
292
+ "note": "Optimized for scientific literature"
293
+ },
294
+ "gemini-embedding-001": {
295
+ "description": "Google Gemini embedding model via API",
296
+ "dimension": 768,
297
+ "type": "gemini",
298
+ "url": "https://ai.google.dev/gemini-api/docs/embeddings",
299
+ "note": "Requires GEMINI_API_KEY environment variable"
300
+ }
301
+ }
302
+ return info.get(model_name, {"description": "Unknown model", "dimension": 768})
303
+
304
+ @classmethod
305
+ def get_embedding_dimension(cls, model_name: str) -> int:
306
+ """Get embedding dimension for a model.
307
+
308
+ Args:
309
+ model_name: Name of the model
310
+
311
+ Returns:
312
+ Embedding dimension
313
+ """
314
+ # Default dimensions (adjust based on actual models)
315
+ dimensions = {
316
+ "sentence-transformers/all-mpnet-base-v2": 768,
317
+ "emilyalsentzer/Bio_ClinicalBERT": 768,
318
+ "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": 768,
319
+ "sentence-transformers/all-MiniLM-L6-v2": 384,
320
+ "sentence-transformers/multilingual-MiniLM-L12-v2": 384,
321
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384,
322
+ "allenai/specter": 768,
323
+ "gemini-embedding-001": 768
324
+ }
325
+ return dimensions.get(model_name, 768)
example.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script demonstrating how to use the RAG system programmatically.
3
+ """
4
+ import os
5
+ from config import settings
6
+ from dataset_loader import RAGBenchLoader
7
+ from vector_store import ChromaDBManager
8
+ from llm_client import GroqLLMClient, RAGPipeline
9
+ from trace_evaluator import TRACEEvaluator
10
+
11
+
12
+ def main():
13
+ """Example usage of RAG system."""
14
+
15
+ # Set your API key
16
+ api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here"
17
+
18
+ if api_key == "your_api_key_here":
19
+ print("Please set your GROQ_API_KEY in .env file or environment variable")
20
+ return
21
+
22
+ print("=" * 50)
23
+ print("RAG System Example")
24
+ print("=" * 50)
25
+
26
+ # 1. Load dataset
27
+ print("\n1. Loading dataset...")
28
+ loader = RAGBenchLoader()
29
+ dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20)
30
+ print(f"Loaded {len(dataset)} samples")
31
+
32
+ # 2. Create vector store and collection
33
+ print("\n2. Creating vector store...")
34
+ vector_store = ChromaDBManager()
35
+
36
+ collection_name = "example_collection"
37
+ embedding_model = "emilyalsentzer/Bio_ClinicalBERT"
38
+ chunking_strategy = "hybrid"
39
+
40
+ print(f"Loading data into collection with {chunking_strategy} chunking...")
41
+ vector_store.load_dataset_into_collection(
42
+ collection_name=collection_name,
43
+ embedding_model_name=embedding_model,
44
+ chunking_strategy=chunking_strategy,
45
+ dataset_data=dataset,
46
+ chunk_size=512,
47
+ overlap=50
48
+ )
49
+
50
+ # 3. Initialize LLM client
51
+ print("\n3. Initializing LLM client...")
52
+ llm_client = GroqLLMClient(
53
+ api_key=api_key,
54
+ model_name="llama-3.1-8b-instant",
55
+ max_rpm=30,
56
+ rate_limit_delay=2.0
57
+ )
58
+
59
+ # 4. Create RAG pipeline
60
+ print("\n4. Creating RAG pipeline...")
61
+ rag = RAGPipeline(llm_client, vector_store)
62
+
63
+ # 5. Query the system
64
+ print("\n5. Querying the system...")
65
+ queries = [
66
+ "What is machine learning?",
67
+ "How does neural network work?",
68
+ "What is deep learning?"
69
+ ]
70
+
71
+ for i, query in enumerate(queries, 1):
72
+ print(f"\n--- Query {i}: {query} ---")
73
+ result = rag.query(query, n_results=3)
74
+
75
+ print(f"Response: {result['response']}")
76
+ print(f"\nRetrieved {len(result['retrieved_documents'])} documents:")
77
+ for j, doc in enumerate(result['retrieved_documents'], 1):
78
+ print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):")
79
+ print(f"{doc['document'][:200]}...")
80
+
81
+ # 6. Run evaluation
82
+ print("\n6. Running TRACE evaluation...")
83
+ evaluator = TRACEEvaluator(llm_client)
84
+
85
+ # Prepare test cases
86
+ test_cases = []
87
+ test_samples = loader.get_test_data("hotpotqa", num_samples=5)
88
+
89
+ for sample in test_samples:
90
+ result = rag.query(sample["question"], n_results=5)
91
+ test_cases.append({
92
+ "query": sample["question"],
93
+ "response": result["response"],
94
+ "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
95
+ "ground_truth": sample.get("answer", "")
96
+ })
97
+
98
+ results = evaluator.evaluate_batch(test_cases)
99
+
100
+ print("\nTRACE Evaluation Results:")
101
+ print(f"Utilization: {results['utilization']:.3f}")
102
+ print(f"Relevance: {results['relevance']:.3f}")
103
+ print(f"Adherence: {results['adherence']:.3f}")
104
+ print(f"Completeness: {results['completeness']:.3f}")
105
+ print(f"Average: {results['average']:.3f}")
106
+
107
+ # 7. View chat history
108
+ print("\n7. Chat History:")
109
+ history = rag.get_chat_history()
110
+ print(f"Total conversations: {len(history)}")
111
+
112
+ print("\n" + "=" * 50)
113
+ print("Example completed successfully!")
114
+ print("=" * 50)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
llm_client.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Groq LLM integration with rate limiting."""
2
+ from typing import List, Dict, Optional, AsyncIterator
3
+ import time
4
+ from groq import Groq
5
+ import asyncio
6
+ from datetime import datetime, timedelta
7
+ from collections import deque
8
+ import os
9
+
10
+
11
+ class RateLimiter:
12
+ """Rate limiter for API calls."""
13
+
14
+ def __init__(self, max_requests_per_minute: int = 30):
15
+ """Initialize rate limiter.
16
+
17
+ Args:
18
+ max_requests_per_minute: Maximum requests allowed per minute
19
+ """
20
+ self.max_requests = max_requests_per_minute
21
+ self.request_times = deque()
22
+ self.lock = asyncio.Lock()
23
+
24
+ async def acquire(self):
25
+ """Acquire permission to make a request."""
26
+ async with self.lock:
27
+ now = datetime.now()
28
+
29
+ # Remove requests older than 1 minute
30
+ while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1):
31
+ self.request_times.popleft()
32
+
33
+ # If at limit, wait
34
+ if len(self.request_times) >= self.max_requests:
35
+ # Calculate how long to wait
36
+ oldest_request = self.request_times[0]
37
+ wait_time = 60 - (now - oldest_request).total_seconds()
38
+
39
+ if wait_time > 0:
40
+ print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
41
+ await asyncio.sleep(wait_time)
42
+ # Recursive call after waiting
43
+ return await self.acquire()
44
+
45
+ # Record this request
46
+ self.request_times.append(now)
47
+
48
+ def acquire_sync(self):
49
+ """Synchronous version of acquire."""
50
+ now = datetime.now()
51
+
52
+ # Remove requests older than 1 minute
53
+ while self.request_times and (now - self.request_times[0]) > timedelta(minutes=1):
54
+ self.request_times.popleft()
55
+
56
+ # If at limit, wait
57
+ if len(self.request_times) >= self.max_requests:
58
+ oldest_request = self.request_times[0]
59
+ wait_time = 60 - (now - oldest_request).total_seconds()
60
+
61
+ if wait_time > 0:
62
+ print(f"Rate limit reached. Waiting {wait_time:.2f} seconds...")
63
+ time.sleep(wait_time)
64
+ return self.acquire_sync()
65
+
66
+ # Record this request
67
+ self.request_times.append(now)
68
+
69
+
70
+ class GroqLLMClient:
71
+ """Client for Groq LLM API with rate limiting."""
72
+
73
+ def __init__(
74
+ self,
75
+ api_key: str,
76
+ model_name: str = "llama-3.1-8b-instant",
77
+ max_rpm: int = 30,
78
+ rate_limit_delay: float = 2.0
79
+ ):
80
+ """Initialize Groq client.
81
+
82
+ Args:
83
+ api_key: Groq API key
84
+ model_name: Name of the LLM model
85
+ max_rpm: Maximum requests per minute
86
+ rate_limit_delay: Additional delay between requests (seconds)
87
+ """
88
+ self.client = Groq(api_key=api_key)
89
+ self.model_name = model_name
90
+ self.rate_limiter = RateLimiter(max_rpm)
91
+ self.rate_limit_delay = rate_limit_delay
92
+
93
+ # Available models
94
+ self.available_models = [
95
+ "meta-llama/llama-4-maverick-17b-128e-instruct",
96
+ "llama-3.1-8b-instant",
97
+ "openai/gpt-oss-120b"
98
+ ]
99
+
100
+ def set_model(self, model_name: str):
101
+ """Set the LLM model.
102
+
103
+ Args:
104
+ model_name: Name of the model
105
+ """
106
+ if model_name not in self.available_models:
107
+ print(f"Warning: {model_name} not in available models. Using anyway...")
108
+ self.model_name = model_name
109
+
110
+ def generate(
111
+ self,
112
+ prompt: str,
113
+ max_tokens: int = 1024,
114
+ temperature: float = 0.7,
115
+ system_prompt: Optional[str] = None
116
+ ) -> str:
117
+ """Generate text using Groq LLM.
118
+
119
+ Args:
120
+ prompt: Input prompt
121
+ max_tokens: Maximum tokens to generate
122
+ temperature: Sampling temperature
123
+ system_prompt: System prompt
124
+
125
+ Returns:
126
+ Generated text
127
+ """
128
+ # Apply rate limiting
129
+ self.rate_limiter.acquire_sync()
130
+
131
+ # Prepare messages
132
+ messages = []
133
+ if system_prompt:
134
+ messages.append({
135
+ "role": "system",
136
+ "content": system_prompt
137
+ })
138
+ messages.append({
139
+ "role": "user",
140
+ "content": prompt
141
+ })
142
+
143
+ try:
144
+ # Make API call
145
+ response = self.client.chat.completions.create(
146
+ model=self.model_name,
147
+ messages=messages,
148
+ max_tokens=max_tokens,
149
+ temperature=temperature
150
+ )
151
+
152
+ # Add delay
153
+ time.sleep(self.rate_limit_delay)
154
+
155
+ return response.choices[0].message.content
156
+
157
+ except Exception as e:
158
+ print(f"Error generating response: {str(e)}")
159
+ return f"Error: {str(e)}"
160
+
161
+ async def generate_async(
162
+ self,
163
+ prompt: str,
164
+ max_tokens: int = 1024,
165
+ temperature: float = 0.7,
166
+ system_prompt: Optional[str] = None
167
+ ) -> str:
168
+ """Asynchronous version of generate.
169
+
170
+ Args:
171
+ prompt: Input prompt
172
+ max_tokens: Maximum tokens to generate
173
+ temperature: Sampling temperature
174
+ system_prompt: System prompt
175
+
176
+ Returns:
177
+ Generated text
178
+ """
179
+ # Apply rate limiting
180
+ await self.rate_limiter.acquire()
181
+
182
+ # Prepare messages
183
+ messages = []
184
+ if system_prompt:
185
+ messages.append({
186
+ "role": "system",
187
+ "content": system_prompt
188
+ })
189
+ messages.append({
190
+ "role": "user",
191
+ "content": prompt
192
+ })
193
+
194
+ try:
195
+ # Make API call (synchronous client used in async context)
196
+ response = self.client.chat.completions.create(
197
+ model=self.model_name,
198
+ messages=messages,
199
+ max_tokens=max_tokens,
200
+ temperature=temperature
201
+ )
202
+
203
+ # Add delay
204
+ await asyncio.sleep(self.rate_limit_delay)
205
+
206
+ return response.choices[0].message.content
207
+
208
+ except Exception as e:
209
+ print(f"Error generating response: {str(e)}")
210
+ return f"Error: {str(e)}"
211
+
212
+ def generate_with_context(
213
+ self,
214
+ query: str,
215
+ context_documents: List[str],
216
+ max_tokens: int = 1024,
217
+ temperature: float = 0.7
218
+ ) -> str:
219
+ """Generate response with retrieved context.
220
+
221
+ Args:
222
+ query: User query
223
+ context_documents: List of retrieved documents
224
+ max_tokens: Maximum tokens to generate
225
+ temperature: Sampling temperature
226
+
227
+ Returns:
228
+ Generated response
229
+ """
230
+ # Build context
231
+ context = "\n\n".join([
232
+ f"Document {i+1}: {doc}"
233
+ for i, doc in enumerate(context_documents)
234
+ ])
235
+
236
+ # Build prompt
237
+ prompt = f"""Answer the following question based on the provided context.
238
+
239
+ Context:
240
+ {context}
241
+
242
+ Question: {query}
243
+
244
+ Answer:"""
245
+
246
+ system_prompt = "You are a helpful AI assistant. Answer questions based on the provided context. If the answer is not in the context, say so."
247
+
248
+ return self.generate(prompt, max_tokens, temperature, system_prompt)
249
+
250
+ def batch_generate(
251
+ self,
252
+ prompts: List[str],
253
+ max_tokens: int = 1024,
254
+ temperature: float = 0.7,
255
+ system_prompt: Optional[str] = None
256
+ ) -> List[str]:
257
+ """Generate responses for multiple prompts.
258
+
259
+ Args:
260
+ prompts: List of prompts
261
+ max_tokens: Maximum tokens to generate
262
+ temperature: Sampling temperature
263
+ system_prompt: System prompt
264
+
265
+ Returns:
266
+ List of generated responses
267
+ """
268
+ responses = []
269
+ for i, prompt in enumerate(prompts):
270
+ print(f"Processing prompt {i+1}/{len(prompts)}")
271
+ response = self.generate(prompt, max_tokens, temperature, system_prompt)
272
+ responses.append(response)
273
+
274
+ return responses
275
+
276
+
277
+ class RAGPipeline:
278
+ """Complete RAG pipeline with LLM and vector store."""
279
+
280
+ def __init__(
281
+ self,
282
+ llm_client: GroqLLMClient,
283
+ vector_store_manager
284
+ ):
285
+ """Initialize RAG pipeline.
286
+
287
+ Args:
288
+ llm_client: Groq LLM client
289
+ vector_store_manager: ChromaDB manager
290
+ """
291
+ self.llm = llm_client
292
+ self.vector_store = vector_store_manager
293
+ self.chat_history = []
294
+
295
+ def query(
296
+ self,
297
+ query: str,
298
+ n_results: int = 5,
299
+ max_tokens: int = 1024,
300
+ temperature: float = 0.7
301
+ ) -> Dict:
302
+ """Query the RAG system.
303
+
304
+ Args:
305
+ query: User query
306
+ n_results: Number of documents to retrieve
307
+ max_tokens: Maximum tokens to generate
308
+ temperature: Sampling temperature
309
+
310
+ Returns:
311
+ Dictionary with response and retrieved documents
312
+ """
313
+ # Retrieve documents
314
+ retrieved_docs = self.vector_store.get_retrieved_documents(query, n_results)
315
+
316
+ # Extract document texts
317
+ doc_texts = [doc["document"] for doc in retrieved_docs]
318
+
319
+ # Generate response
320
+ response = self.llm.generate_with_context(
321
+ query,
322
+ doc_texts,
323
+ max_tokens,
324
+ temperature
325
+ )
326
+
327
+ # Store in chat history
328
+ self.chat_history.append({
329
+ "query": query,
330
+ "response": response,
331
+ "retrieved_docs": retrieved_docs,
332
+ "timestamp": datetime.now().isoformat()
333
+ })
334
+
335
+ return {
336
+ "query": query,
337
+ "response": response,
338
+ "retrieved_documents": retrieved_docs
339
+ }
340
+
341
+ def get_chat_history(self) -> List[Dict]:
342
+ """Get chat history.
343
+
344
+ Returns:
345
+ List of chat history entries
346
+ """
347
+ return self.chat_history
348
+
349
+ def clear_history(self):
350
+ """Clear chat history."""
351
+ self.chat_history = []
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Dependencies
2
+ fastapi==0.109.0
3
+ uvicorn[standard]==0.27.0
4
+ streamlit==1.31.0
5
+ python-dotenv==1.0.0
6
+
7
+ # LLM and AI
8
+ groq>=0.11.0
9
+ openai==1.12.0
10
+ google-generativeai>=0.3.0
11
+
12
+ # Embeddings and Vector Store
13
+ sentence-transformers==2.7.0
14
+ transformers==4.40.2
15
+ torch>=2.0.0
16
+ chromadb==0.5.23
17
+
18
+ # Data Processing
19
+ pandas==2.2.0
20
+ numpy==1.26.3
21
+ datasets==2.16.1
22
+
23
+ # RAG and Retrieval
24
+ langchain==0.1.6
25
+ langchain-community==0.0.19
26
+ langchain-groq==0.0.1
27
+
28
+ # Evaluation
29
+ ragas==0.1.4
30
+ rank-bm25==0.2.2
31
+
32
+ # Utilities
33
+ pydantic==2.6.0
34
+ pydantic-settings==2.1.0
35
+ tenacity==8.2.3
36
+ aiohttp==3.9.3
37
+ tqdm==4.66.1
38
+
39
+ # Deployment
40
+ gunicorn==21.2.0
run.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick start script to run the RAG application.
3
+ """
4
+ import subprocess
5
+ import sys
6
+ import os
7
+
8
+
9
+ def check_dependencies():
10
+ """Check if required dependencies are installed."""
11
+ try:
12
+ import streamlit
13
+ import fastapi
14
+ import groq
15
+ return True
16
+ except ImportError:
17
+ return False
18
+
19
+
20
+ def install_dependencies():
21
+ """Install dependencies from requirements.txt."""
22
+ print("Installing dependencies...")
23
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
24
+ print("Dependencies installed successfully!")
25
+
26
+
27
+ def check_env_file():
28
+ """Check if .env file exists."""
29
+ if not os.path.exists(".env"):
30
+ print("\n⚠️ Warning: .env file not found!")
31
+ print("Creating .env from .env.example...")
32
+ if os.path.exists(".env.example"):
33
+ with open(".env.example", "r") as src:
34
+ with open(".env", "w") as dst:
35
+ dst.write(src.read())
36
+ print("✅ .env file created. Please edit it and add your Groq API key.")
37
+ else:
38
+ print("❌ .env.example not found. Please create .env manually.")
39
+ return False
40
+ return True
41
+
42
+
43
+ def run_streamlit():
44
+ """Run the Streamlit application."""
45
+ print("\n🚀 Starting Streamlit application...")
46
+ print("📱 Open your browser to: http://localhost:8501")
47
+ subprocess.run([sys.executable, "-m", "streamlit", "run", "streamlit_app.py"])
48
+
49
+
50
+ def run_api():
51
+ """Run the FastAPI application."""
52
+ print("\n🚀 Starting FastAPI server...")
53
+ print("📱 API available at: http://localhost:8000")
54
+ print("📚 API docs at: http://localhost:8000/docs")
55
+ subprocess.run([sys.executable, "api.py"])
56
+
57
+
58
+ def main():
59
+ """Main function."""
60
+ print("=" * 50)
61
+ print("RAG Capstone Project - Quick Start")
62
+ print("=" * 50)
63
+
64
+ # Check dependencies
65
+ if not check_dependencies():
66
+ print("\n📦 Installing dependencies...")
67
+ install_dependencies()
68
+
69
+ # Check .env file
70
+ env_exists = check_env_file()
71
+
72
+ if not env_exists:
73
+ print("\n❌ Please configure your .env file before running the application.")
74
+ return
75
+
76
+ # Ask user what to run
77
+ print("\nWhat would you like to run?")
78
+ print("1. Streamlit Chat Interface (recommended)")
79
+ print("2. FastAPI Backend")
80
+ print("3. Both (requires two terminals)")
81
+
82
+ choice = input("\nEnter your choice (1-3): ").strip()
83
+
84
+ if choice == "1":
85
+ run_streamlit()
86
+ elif choice == "2":
87
+ run_api()
88
+ elif choice == "3":
89
+ print("\n📌 To run both:")
90
+ print("Terminal 1: python api.py")
91
+ print("Terminal 2: streamlit run streamlit_app.py")
92
+ print("\nStarting Streamlit in this terminal...")
93
+ run_streamlit()
94
+ else:
95
+ print("Invalid choice. Exiting.")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
streamlit_app.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit chat interface for RAG application."""
2
+ import streamlit as st
3
+ import sys
4
+ import os
5
+ from datetime import datetime
6
+ import json
7
+ import pandas as pd
8
+ from typing import Optional
9
+ import warnings
10
+
11
+ # Suppress warnings
12
+ warnings.filterwarnings('ignore')
13
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14
+
15
+ # Add parent directory to path
16
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ from config import settings
19
+ from dataset_loader import RAGBenchLoader
20
+ from vector_store import ChromaDBManager
21
+ from llm_client import GroqLLMClient, RAGPipeline
22
+ from trace_evaluator import TRACEEvaluator
23
+ from embedding_models import EmbeddingFactory
24
+ from chunking_strategies import ChunkingFactory
25
+
26
+
27
+ # Page configuration
28
+ st.set_page_config(
29
+ page_title="RAG Capstone Project",
30
+ page_icon="🤖",
31
+ layout="wide"
32
+ )
33
+
34
+ # Initialize session state
35
+ if "chat_history" not in st.session_state:
36
+ st.session_state.chat_history = []
37
+
38
+ if "rag_pipeline" not in st.session_state:
39
+ st.session_state.rag_pipeline = None
40
+
41
+ if "vector_store" not in st.session_state:
42
+ st.session_state.vector_store = None
43
+
44
+ if "collection_loaded" not in st.session_state:
45
+ st.session_state.collection_loaded = False
46
+
47
+ if "evaluation_results" not in st.session_state:
48
+ st.session_state.evaluation_results = None
49
+
50
+ if "dataset_size" not in st.session_state:
51
+ st.session_state.dataset_size = 10000
52
+
53
+ if "current_dataset" not in st.session_state:
54
+ st.session_state.current_dataset = None
55
+
56
+ if "current_llm" not in st.session_state:
57
+ st.session_state.current_llm = settings.llm_models[1]
58
+
59
+ if "selected_collection" not in st.session_state:
60
+ st.session_state.selected_collection = None
61
+
62
+ if "available_collections" not in st.session_state:
63
+ st.session_state.available_collections = []
64
+
65
+
66
+ def get_available_collections():
67
+ """Get list of available collections from ChromaDB."""
68
+ try:
69
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
70
+ collections = vector_store.list_collections()
71
+ return collections
72
+ except Exception as e:
73
+ print(f"Error getting collections: {e}")
74
+ return []
75
+
76
+
77
+ def main():
78
+ """Main Streamlit application."""
79
+ st.title("🤖 RAG Capstone Project")
80
+ st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation")
81
+
82
+ # Get available collections at startup
83
+ available_collections = get_available_collections()
84
+ st.session_state.available_collections = available_collections
85
+
86
+ # Sidebar for configuration
87
+ with st.sidebar:
88
+ st.header("Configuration")
89
+
90
+ # API Key input
91
+ groq_api_key = st.text_input(
92
+ "Groq API Key",
93
+ type="password",
94
+ value=settings.groq_api_key or "",
95
+ help="Enter your Groq API key"
96
+ )
97
+
98
+ st.divider()
99
+
100
+ # Option 1: Use existing collection
101
+ if available_collections:
102
+ st.subheader("📚 Existing Collections")
103
+ st.write(f"Found {len(available_collections)} collection(s)")
104
+
105
+ selected_collection = st.selectbox(
106
+ "Or select existing collection:",
107
+ available_collections,
108
+ key="collection_selector"
109
+ )
110
+
111
+ if st.button("📖 Load Existing Collection", type="secondary"):
112
+ if not groq_api_key:
113
+ st.error("Please enter your Groq API key")
114
+ else:
115
+ load_existing_collection(groq_api_key, selected_collection)
116
+
117
+ st.divider()
118
+
119
+ # Option 2: Create new collection
120
+ st.subheader("🆕 Create New Collection")
121
+
122
+ # Dataset selection
123
+ st.subheader("1. Dataset Selection")
124
+ dataset_name = st.selectbox(
125
+ "Choose Dataset",
126
+ settings.ragbench_datasets,
127
+ index=0
128
+ )
129
+
130
+ # Get dataset size dynamically
131
+ if st.button("🔍 Check Dataset Size", key="check_size"):
132
+ with st.spinner("Checking dataset size..."):
133
+ try:
134
+ from datasets import load_dataset
135
+ import os
136
+
137
+ # Load dataset with download_mode to avoid cache issues
138
+ st.info(f"Fetching dataset info for '{dataset_name}'...")
139
+ ds = load_dataset(
140
+ "rungalileo/ragbench",
141
+ dataset_name,
142
+ split="train",
143
+ trust_remote_code=True,
144
+ download_mode="force_redownload" # Force fresh download to avoid cache corruption
145
+ )
146
+ dataset_size = len(ds)
147
+
148
+ st.session_state.dataset_size = dataset_size
149
+ st.session_state.current_dataset = dataset_name
150
+ st.success(f"✅ Dataset '{dataset_name}' has {dataset_size:,} samples available")
151
+ except Exception as e:
152
+ st.error(f"❌ Error: {str(e)}")
153
+ st.exception(e)
154
+ st.warning(f"Could not determine dataset size. Using default of 10,000.")
155
+ st.session_state.dataset_size = 10000
156
+ st.session_state.current_dataset = dataset_name
157
+
158
+ # Use stored dataset size or default
159
+ max_samples_available = st.session_state.get('dataset_size', 10000)
160
+
161
+ st.caption(f"Max available samples: {max_samples_available:,}")
162
+
163
+ num_samples = st.slider(
164
+ "Number of samples",
165
+ min_value=10,
166
+ max_value=max_samples_available,
167
+ value=min(100, max_samples_available),
168
+ step=50 if max_samples_available > 1000 else 10,
169
+ help="Adjust slider to select number of samples"
170
+ )
171
+
172
+ load_all_samples = st.checkbox(
173
+ "Load all available samples",
174
+ value=False,
175
+ help="Override slider and load entire dataset"
176
+ )
177
+
178
+ st.divider()
179
+
180
+ # Chunking strategy
181
+ st.subheader("2. Chunking Strategy")
182
+ chunking_strategy = st.selectbox(
183
+ "Choose Chunking Strategy",
184
+ settings.chunking_strategies,
185
+ index=0
186
+ )
187
+
188
+ chunk_size = st.slider(
189
+ "Chunk Size",
190
+ min_value=256,
191
+ max_value=1024,
192
+ value=512,
193
+ step=128
194
+ )
195
+
196
+ overlap = st.slider(
197
+ "Overlap",
198
+ min_value=0,
199
+ max_value=200,
200
+ value=50,
201
+ step=10
202
+ )
203
+
204
+ st.divider()
205
+
206
+ # Embedding model
207
+ st.subheader("3. Embedding Model")
208
+ embedding_model = st.selectbox(
209
+ "Choose Embedding Model",
210
+ settings.embedding_models,
211
+ index=0
212
+ )
213
+
214
+ st.divider()
215
+
216
+ # LLM model selection for new collection
217
+ st.subheader("4. LLM Model")
218
+ llm_model = st.selectbox(
219
+ "Choose LLM",
220
+ settings.llm_models,
221
+ index=1
222
+ )
223
+
224
+ st.divider()
225
+
226
+ # Load data button
227
+ if st.button("🚀 Load Data & Create Collection", type="primary"):
228
+ if not groq_api_key:
229
+ st.error("Please enter your Groq API key")
230
+ else:
231
+ # Use None for num_samples if loading all data
232
+ samples_to_load = None if load_all_samples else num_samples
233
+ load_and_create_collection(
234
+ groq_api_key,
235
+ dataset_name,
236
+ samples_to_load,
237
+ chunking_strategy,
238
+ chunk_size,
239
+ overlap,
240
+ embedding_model,
241
+ llm_model
242
+ )
243
+
244
+ # Main content area
245
+ if not st.session_state.collection_loaded:
246
+ st.info("👈 Please configure and load a dataset from the sidebar to begin")
247
+
248
+ # Show instructions
249
+ with st.expander("📖 How to Use", expanded=True):
250
+ st.markdown("""
251
+ 1. **Enter your Groq API Key** in the sidebar
252
+ 2. **Select a dataset** from RAG Bench
253
+ 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking)
254
+ 4. **Select an embedding model** for document vectorization
255
+ 5. **Choose an LLM model** for response generation
256
+ 6. **Click "Load Data & Create Collection"** to initialize
257
+ 7. **Start chatting** in the chat interface
258
+ 8. **View retrieved documents** and evaluation metrics
259
+ 9. **Run TRACE evaluation** on test data
260
+ """)
261
+
262
+ # Show available options
263
+ col1, col2 = st.columns(2)
264
+
265
+ with col1:
266
+ st.subheader("📊 Available Datasets")
267
+ for ds in settings.ragbench_datasets:
268
+ st.markdown(f"- {ds}")
269
+
270
+ with col2:
271
+ st.subheader("🤖 Available Models")
272
+ st.markdown("**Embedding Models:**")
273
+ for em in settings.embedding_models:
274
+ st.markdown(f"- {em}")
275
+
276
+ st.markdown("**LLM Models:**")
277
+ for lm in settings.llm_models:
278
+ st.markdown(f"- {lm}")
279
+
280
+ else:
281
+ # Create tabs for different functionalities
282
+ tab1, tab2, tab3 = st.tabs(["💬 Chat", "📊 Evaluation", "📜 History"])
283
+
284
+ with tab1:
285
+ chat_interface()
286
+
287
+ with tab2:
288
+ evaluation_interface()
289
+
290
+ with tab3:
291
+ history_interface()
292
+
293
+
294
+ def load_existing_collection(api_key: str, collection_name: str):
295
+ """Load an existing collection from ChromaDB."""
296
+ with st.spinner(f"Loading collection '{collection_name}'..."):
297
+ try:
298
+ # Initialize vector store and get collection
299
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
300
+ vector_store.get_collection(collection_name)
301
+
302
+ # Prompt for LLM selection
303
+ st.session_state.current_llm = st.selectbox(
304
+ "Select LLM for this collection:",
305
+ settings.llm_models,
306
+ key=f"llm_selector_{collection_name}"
307
+ )
308
+
309
+ # Initialize LLM client
310
+ st.info("Initializing LLM client...")
311
+ llm_client = GroqLLMClient(
312
+ api_key=api_key,
313
+ model_name=st.session_state.current_llm,
314
+ max_rpm=settings.groq_rpm_limit,
315
+ rate_limit_delay=settings.rate_limit_delay
316
+ )
317
+
318
+ # Create RAG pipeline with correct parameter names
319
+ st.info("Creating RAG pipeline...")
320
+ rag_pipeline = RAGPipeline(
321
+ llm_client=llm_client,
322
+ vector_store_manager=vector_store
323
+ )
324
+
325
+ # Store in session state
326
+ st.session_state.vector_store = vector_store
327
+ st.session_state.rag_pipeline = rag_pipeline
328
+ st.session_state.collection_loaded = True
329
+ st.session_state.current_collection = collection_name
330
+ st.session_state.selected_collection = collection_name
331
+ st.session_state.groq_api_key = api_key
332
+
333
+ st.success(f"✅ Collection '{collection_name}' loaded successfully!")
334
+ st.rerun()
335
+
336
+ except Exception as e:
337
+ st.error(f"Error loading collection: {str(e)}")
338
+ st.exception(e)
339
+
340
+
341
+ def load_and_create_collection(
342
+ api_key: str,
343
+ dataset_name: str,
344
+ num_samples: Optional[int],
345
+ chunking_strategy: str,
346
+ chunk_size: int,
347
+ overlap: int,
348
+ embedding_model: str,
349
+ llm_model: str
350
+ ):
351
+ """Load dataset and create vector collection."""
352
+ with st.spinner("Loading dataset and creating collection..."):
353
+ try:
354
+ # Initialize dataset loader
355
+ loader = RAGBenchLoader()
356
+
357
+ # Load dataset
358
+ if num_samples is None:
359
+ st.info(f"Loading {dataset_name} dataset (all available samples)...")
360
+ else:
361
+ st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...")
362
+ dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
363
+ st.info(f"Loading {dataset_name} dataset...")
364
+ dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
365
+
366
+ if not dataset:
367
+ st.error("Failed to load dataset")
368
+ return
369
+
370
+ # Initialize vector store
371
+ st.info("Initializing vector store...")
372
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
373
+
374
+ # Create collection name
375
+ collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}"
376
+ collection_name = collection_name.replace("-", "_").replace(".", "_")
377
+
378
+ # Delete existing collection with same name (if exists)
379
+ existing_collections = vector_store.list_collections()
380
+ if collection_name in existing_collections:
381
+ st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...")
382
+ vector_store.delete_collection(collection_name)
383
+ st.info("Old collection deleted. Creating new one...")
384
+
385
+ # Load data into collection
386
+ st.info(f"Creating collection with {chunking_strategy} chunking...")
387
+ vector_store.load_dataset_into_collection(
388
+ collection_name=collection_name,
389
+ embedding_model_name=embedding_model,
390
+ chunking_strategy=chunking_strategy,
391
+ dataset_data=dataset,
392
+ chunk_size=chunk_size,
393
+ overlap=overlap
394
+ )
395
+
396
+ # Initialize LLM client
397
+ st.info("Initializing LLM client...")
398
+ llm_client = GroqLLMClient(
399
+ api_key=api_key,
400
+ model_name=llm_model,
401
+ max_rpm=settings.groq_rpm_limit,
402
+ rate_limit_delay=settings.rate_limit_delay
403
+ )
404
+
405
+ # Create RAG pipeline with correct parameter names
406
+ rag_pipeline = RAGPipeline(
407
+ llm_client=llm_client,
408
+ vector_store_manager=vector_store
409
+ )
410
+
411
+ # Store in session state
412
+ st.session_state.vector_store = vector_store
413
+ st.session_state.rag_pipeline = rag_pipeline
414
+ st.session_state.collection_loaded = True
415
+ st.session_state.current_collection = collection_name
416
+ st.session_state.dataset_name = dataset_name
417
+ st.session_state.dataset = dataset
418
+
419
+ st.success(f"✅ Collection '{collection_name}' created successfully!")
420
+ st.rerun()
421
+
422
+ except Exception as e:
423
+ st.error(f"Error: {str(e)}")
424
+
425
+
426
+ def chat_interface():
427
+ """Chat interface tab."""
428
+ st.subheader("💬 Chat Interface")
429
+
430
+ # Check if collection is loaded
431
+ if not st.session_state.collection_loaded:
432
+ st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.")
433
+ st.info("""
434
+ Steps:
435
+ 1. Select a dataset from the dropdown
436
+ 2. Click "Load Data & Create Collection" button
437
+ 3. Wait for the collection to be created
438
+ 4. Then you can start chatting
439
+ """)
440
+ return
441
+
442
+ # Display collection info and LLM selector
443
+ col1, col2, col3 = st.columns([2, 2, 1])
444
+ with col1:
445
+ st.info(f"📚 Collection: {st.session_state.current_collection}")
446
+
447
+ with col2:
448
+ # LLM selector for chat
449
+ selected_llm = st.selectbox(
450
+ "Select LLM for chat:",
451
+ settings.llm_models,
452
+ index=settings.llm_models.index(st.session_state.current_llm),
453
+ key="chat_llm_selector"
454
+ )
455
+
456
+ if selected_llm != st.session_state.current_llm:
457
+ st.session_state.current_llm = selected_llm
458
+ # Recreate RAG pipeline with new LLM
459
+ llm_client = GroqLLMClient(
460
+ api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "",
461
+ model_name=selected_llm,
462
+ max_rpm=settings.groq_rpm_limit,
463
+ rate_limit_delay=settings.rate_limit_delay
464
+ )
465
+ st.session_state.rag_pipeline.llm_client = llm_client
466
+
467
+ with col3:
468
+ if st.button("🗑️ Clear History"):
469
+ st.session_state.chat_history = []
470
+ st.session_state.rag_pipeline.clear_history()
471
+ st.rerun()
472
+
473
+ # Chat container
474
+ chat_container = st.container()
475
+
476
+ # Display chat history
477
+ with chat_container:
478
+ for chat_idx, entry in enumerate(st.session_state.chat_history):
479
+ # User message
480
+ with st.chat_message("user"):
481
+ st.write(entry["query"])
482
+
483
+ # Assistant message
484
+ with st.chat_message("assistant"):
485
+ st.write(entry["response"])
486
+
487
+ # Show retrieved documents in expander
488
+ with st.expander("📄 Retrieved Documents"):
489
+ for doc_idx, doc in enumerate(entry["retrieved_documents"]):
490
+ st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
491
+ st.text_area(
492
+ f"doc_{chat_idx}_{doc_idx}",
493
+ value=doc["document"],
494
+ height=100,
495
+ key=f"doc_area_{chat_idx}_{doc_idx}",
496
+ label_visibility="collapsed"
497
+ )
498
+ if doc.get("metadata"):
499
+ st.caption(f"Metadata: {doc['metadata']}")
500
+
501
+ # Chat input
502
+ query = st.chat_input("Ask a question...")
503
+
504
+ if query:
505
+ # Check if collection exists
506
+ if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection:
507
+ st.error("❌ No data loaded. Please load a dataset first using the configuration panel.")
508
+ st.stop()
509
+
510
+ # Add user message
511
+ with chat_container:
512
+ with st.chat_message("user"):
513
+ st.write(query)
514
+
515
+ # Generate response
516
+ with st.spinner("Generating response..."):
517
+ try:
518
+ result = st.session_state.rag_pipeline.query(query)
519
+ except Exception as e:
520
+ st.error(f"❌ Error querying: {str(e)}")
521
+ st.info("Please load a dataset and create a collection first.")
522
+ st.stop()
523
+
524
+ # Add assistant message
525
+ with chat_container:
526
+ with st.chat_message("assistant"):
527
+ st.write(result["response"])
528
+
529
+ # Show retrieved documents
530
+ with st.expander("📄 Retrieved Documents"):
531
+ for doc_idx, doc in enumerate(result["retrieved_documents"]):
532
+ st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
533
+ st.text_area(
534
+ f"doc_current_{doc_idx}",
535
+ value=doc["document"],
536
+ height=100,
537
+ key=f"doc_current_area_{doc_idx}",
538
+ label_visibility="collapsed"
539
+ )
540
+ if doc.get("metadata"):
541
+ st.caption(f"Metadata: {doc['metadata']}")
542
+
543
+ # Store in history
544
+ st.session_state.chat_history.append(result)
545
+ st.rerun()
546
+
547
+
548
+ def evaluation_interface():
549
+ """Evaluation interface tab."""
550
+ st.subheader("📊 TRACE Evaluation")
551
+
552
+ # Check if collection is loaded
553
+ if not st.session_state.collection_loaded:
554
+ st.warning("⚠️ No data loaded. Please load a collection first.")
555
+ return
556
+
557
+ # LLM selector for evaluation
558
+ col1, col2 = st.columns([3, 1])
559
+ with col1:
560
+ selected_llm = st.selectbox(
561
+ "Select LLM for evaluation:",
562
+ settings.llm_models,
563
+ index=settings.llm_models.index(st.session_state.current_llm),
564
+ key="eval_llm_selector"
565
+ )
566
+
567
+ st.markdown("""
568
+ Run TRACE evaluation metrics on test data:
569
+ - **Utilization**: How well the system uses retrieved documents
570
+ - **Relevance**: Relevance of retrieved documents to the query
571
+ - **Adherence**: How well the response adheres to the retrieved context
572
+ - **Completeness**: How complete the response is in answering the query
573
+ """)
574
+
575
+ num_test_samples = st.slider(
576
+ "Number of test samples",
577
+ min_value=5,
578
+ max_value=50,
579
+ value=10,
580
+ step=5
581
+ )
582
+
583
+ if st.button("🔬 Run Evaluation", type="primary"):
584
+ # Use selected LLM for evaluation
585
+ run_evaluation(num_test_samples, selected_llm)
586
+
587
+ # Display results
588
+ if st.session_state.evaluation_results:
589
+ results = st.session_state.evaluation_results
590
+
591
+ st.success("✅ Evaluation Complete!")
592
+
593
+ # Display aggregate scores
594
+ col1, col2, col3, col4, col5 = st.columns(5)
595
+
596
+ with col1:
597
+ st.metric("📊 Utilization", f"{results['utilization']:.3f}")
598
+ with col2:
599
+ st.metric("🎯 Relevance", f"{results['relevance']:.3f}")
600
+ with col3:
601
+ st.metric("✅ Adherence", f"{results['adherence']:.3f}")
602
+ with col4:
603
+ st.metric("📝 Completeness", f"{results['completeness']:.3f}")
604
+ with col5:
605
+ st.metric("⭐ Average", f"{results['average']:.3f}")
606
+
607
+ # Detailed results
608
+ with st.expander("📋 Detailed Results"):
609
+ df = pd.DataFrame(results["individual_scores"])
610
+ st.dataframe(df, use_container_width=True)
611
+
612
+ # Download results
613
+ results_json = json.dumps(results, indent=2)
614
+ st.download_button(
615
+ label="💾 Download Results (JSON)",
616
+ data=results_json,
617
+ file_name=f"trace_evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
618
+ mime="application/json"
619
+ )
620
+
621
+
622
+ def run_evaluation(num_samples: int, selected_llm: str = None):
623
+ """Run TRACE evaluation."""
624
+ with st.spinner(f"Running evaluation on {num_samples} samples..."):
625
+ try:
626
+ # Use selected LLM if provided
627
+ if selected_llm and selected_llm != st.session_state.current_llm:
628
+ st.info(f"Switching to {selected_llm} for evaluation...")
629
+ groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else ""
630
+ eval_llm_client = GroqLLMClient(
631
+ api_key=groq_api_key,
632
+ model_name=selected_llm,
633
+ max_rpm=settings.groq_rpm_limit,
634
+ rate_limit_delay=settings.rate_limit_delay
635
+ )
636
+ # Temporarily replace LLM client
637
+ original_llm = st.session_state.rag_pipeline.llm_client
638
+ st.session_state.rag_pipeline.llm_client = eval_llm_client
639
+
640
+ # Get test data
641
+ loader = RAGBenchLoader()
642
+ test_data = loader.get_test_data(
643
+ st.session_state.dataset_name,
644
+ num_samples
645
+ )
646
+
647
+ # Prepare test cases
648
+ test_cases = []
649
+
650
+ progress_bar = st.progress(0)
651
+
652
+ for i, sample in enumerate(test_data):
653
+ # Query the RAG system
654
+ result = st.session_state.rag_pipeline.query(
655
+ sample["question"],
656
+ n_results=5
657
+ )
658
+
659
+ # Prepare test case
660
+ test_cases.append({
661
+ "query": sample["question"],
662
+ "response": result["response"],
663
+ "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
664
+ "ground_truth": sample.get("answer", "")
665
+ })
666
+
667
+ # Update progress
668
+ progress_bar.progress((i + 1) / num_samples)
669
+
670
+ # Run evaluation
671
+ evaluator = TRACEEvaluator()
672
+ results = evaluator.evaluate_batch(test_cases)
673
+
674
+ st.session_state.evaluation_results = results
675
+
676
+ # Restore original LLM if it was switched
677
+ if selected_llm and selected_llm != st.session_state.current_llm:
678
+ st.session_state.rag_pipeline.llm_client = original_llm
679
+
680
+ except Exception as e:
681
+ st.error(f"Error during evaluation: {str(e)}")
682
+
683
+
684
+ def history_interface():
685
+ """History interface tab."""
686
+ st.subheader("📜 Chat History")
687
+
688
+ if not st.session_state.chat_history:
689
+ st.info("No chat history yet. Start a conversation in the Chat tab!")
690
+ return
691
+
692
+ # Export history
693
+ col1, col2 = st.columns([3, 1])
694
+ with col2:
695
+ history_json = json.dumps(st.session_state.chat_history, indent=2)
696
+ st.download_button(
697
+ label="💾 Export History",
698
+ data=history_json,
699
+ file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
700
+ mime="application/json"
701
+ )
702
+
703
+ # Display history
704
+ for i, entry in enumerate(st.session_state.chat_history):
705
+ with st.expander(f"💬 Conversation {i+1}: {entry['query'][:50]}..."):
706
+ st.markdown(f"**Query:** {entry['query']}")
707
+ st.markdown(f"**Response:** {entry['response']}")
708
+ st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}")
709
+
710
+ st.markdown("**Retrieved Documents:**")
711
+ for j, doc in enumerate(entry["retrieved_documents"]):
712
+ st.text_area(
713
+ f"Document {j+1}",
714
+ value=doc["document"],
715
+ height=100,
716
+ key=f"history_doc_{i}_{j}"
717
+ )
718
+
719
+
720
+ if __name__ == "__main__":
721
+ main()
trace_evaluator.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TRACE evaluation metrics for RAG systems.
2
+
3
+ TRACE Metrics:
4
+ - uTilization: How well the system uses retrieved documents
5
+ - Relevance: Relevance of retrieved documents to the query
6
+ - Adherence: How well the response adheres to the retrieved context
7
+ - Completeness: How complete the response is in answering the query
8
+ """
9
+ from typing import List, Dict, Optional
10
+ import numpy as np
11
+ from dataclasses import dataclass
12
+ import re
13
+ from collections import Counter
14
+
15
+
16
+ @dataclass
17
+ class TRACEScores:
18
+ """Container for TRACE evaluation scores."""
19
+ utilization: float
20
+ relevance: float
21
+ adherence: float
22
+ completeness: float
23
+
24
+ def to_dict(self) -> Dict:
25
+ """Convert to dictionary."""
26
+ return {
27
+ "utilization": self.utilization,
28
+ "relevance": self.relevance,
29
+ "adherence": self.adherence,
30
+ "completeness": self.completeness,
31
+ "average": self.average()
32
+ }
33
+
34
+ def average(self) -> float:
35
+ """Calculate average score."""
36
+ return (self.utilization + self.relevance +
37
+ self.adherence + self.completeness) / 4
38
+
39
+
40
+ class TRACEEvaluator:
41
+ """TRACE evaluation metrics for RAG systems."""
42
+
43
+ def __init__(self, llm_client=None):
44
+ """Initialize TRACE evaluator.
45
+
46
+ Args:
47
+ llm_client: Optional LLM client for LLM-based evaluation
48
+ """
49
+ self.llm_client = llm_client
50
+
51
+ def evaluate(
52
+ self,
53
+ query: str,
54
+ response: str,
55
+ retrieved_documents: List[str],
56
+ ground_truth: Optional[str] = None
57
+ ) -> TRACEScores:
58
+ """Evaluate a RAG response using TRACE metrics.
59
+
60
+ Args:
61
+ query: User query
62
+ response: Generated response
63
+ retrieved_documents: List of retrieved documents
64
+ ground_truth: Optional ground truth answer
65
+
66
+ Returns:
67
+ TRACEScores object
68
+ """
69
+ utilization = self._compute_utilization(response, retrieved_documents)
70
+ relevance = self._compute_relevance(query, retrieved_documents)
71
+ adherence = self._compute_adherence(response, retrieved_documents)
72
+ completeness = self._compute_completeness(query, response, ground_truth)
73
+
74
+ return TRACEScores(
75
+ utilization=utilization,
76
+ relevance=relevance,
77
+ adherence=adherence,
78
+ completeness=completeness
79
+ )
80
+
81
+ def _compute_utilization(
82
+ self,
83
+ response: str,
84
+ retrieved_documents: List[str]
85
+ ) -> float:
86
+ """Compute utilization score.
87
+
88
+ Measures how well the system uses retrieved documents.
89
+ Score based on:
90
+ - Number of documents that contributed to the response
91
+ - Proportion of retrieved documents used
92
+
93
+ Args:
94
+ response: Generated response
95
+ retrieved_documents: List of retrieved documents
96
+
97
+ Returns:
98
+ Utilization score (0-1)
99
+ """
100
+ if not retrieved_documents or not response:
101
+ return 0.0
102
+
103
+ response_lower = response.lower()
104
+ response_words = set(self._tokenize(response_lower))
105
+
106
+ # Count how many documents contributed
107
+ docs_used = 0
108
+ total_overlap = 0
109
+
110
+ for doc in retrieved_documents:
111
+ doc_lower = doc.lower()
112
+ doc_words = set(self._tokenize(doc_lower))
113
+
114
+ # Check for significant overlap
115
+ overlap = len(response_words & doc_words)
116
+ if overlap > 5: # Threshold for significant contribution
117
+ docs_used += 1
118
+ total_overlap += overlap
119
+
120
+ # Score based on proportion of documents used
121
+ proportion_used = docs_used / len(retrieved_documents)
122
+
123
+ # Also consider depth of utilization
124
+ avg_overlap = total_overlap / len(retrieved_documents) if retrieved_documents else 0
125
+ depth_score = min(avg_overlap / 20, 1.0) # Normalize
126
+
127
+ # Combined score
128
+ utilization_score = 0.6 * proportion_used + 0.4 * depth_score
129
+
130
+ return min(utilization_score, 1.0)
131
+
132
+ def _compute_relevance(
133
+ self,
134
+ query: str,
135
+ retrieved_documents: List[str]
136
+ ) -> float:
137
+ """Compute relevance score.
138
+
139
+ Measures relevance of retrieved documents to the query.
140
+ Uses lexical overlap and keyword matching.
141
+
142
+ Args:
143
+ query: User query
144
+ retrieved_documents: List of retrieved documents
145
+
146
+ Returns:
147
+ Relevance score (0-1)
148
+ """
149
+ if not retrieved_documents or not query:
150
+ return 0.0
151
+
152
+ query_lower = query.lower()
153
+ query_words = set(self._tokenize(query_lower))
154
+ query_keywords = self._extract_keywords(query_lower)
155
+
156
+ relevance_scores = []
157
+
158
+ for doc in retrieved_documents:
159
+ doc_lower = doc.lower()
160
+ doc_words = set(self._tokenize(doc_lower))
161
+
162
+ # Lexical overlap
163
+ overlap = len(query_words & doc_words)
164
+ overlap_score = overlap / len(query_words) if query_words else 0
165
+
166
+ # Keyword matching
167
+ keyword_matches = sum(1 for kw in query_keywords if kw in doc_lower)
168
+ keyword_score = keyword_matches / len(query_keywords) if query_keywords else 0
169
+
170
+ # Combined relevance for this document
171
+ doc_relevance = 0.5 * overlap_score + 0.5 * keyword_score
172
+ relevance_scores.append(doc_relevance)
173
+
174
+ # Average relevance across documents
175
+ return np.mean(relevance_scores)
176
+
177
+ def _compute_adherence(
178
+ self,
179
+ response: str,
180
+ retrieved_documents: List[str]
181
+ ) -> float:
182
+ """Compute adherence score.
183
+
184
+ Measures how well the response adheres to the retrieved context.
185
+ Higher score means response is grounded in the documents.
186
+
187
+ Args:
188
+ response: Generated response
189
+ retrieved_documents: List of retrieved documents
190
+
191
+ Returns:
192
+ Adherence score (0-1)
193
+ """
194
+ if not retrieved_documents or not response:
195
+ return 0.0
196
+
197
+ # Combine all documents
198
+ combined_docs = " ".join(retrieved_documents).lower()
199
+ doc_words = set(self._tokenize(combined_docs))
200
+
201
+ # Analyze response
202
+ response_lower = response.lower()
203
+ response_sentences = self._split_sentences(response_lower)
204
+
205
+ adherence_scores = []
206
+
207
+ for sentence in response_sentences:
208
+ sentence_words = set(self._tokenize(sentence))
209
+
210
+ # Check what proportion of sentence words appear in documents
211
+ if sentence_words:
212
+ grounded_words = len(sentence_words & doc_words)
213
+ sentence_adherence = grounded_words / len(sentence_words)
214
+ adherence_scores.append(sentence_adherence)
215
+
216
+ # Average adherence across sentences
217
+ return np.mean(adherence_scores) if adherence_scores else 0.0
218
+
219
+ def _compute_completeness(
220
+ self,
221
+ query: str,
222
+ response: str,
223
+ ground_truth: Optional[str] = None
224
+ ) -> float:
225
+ """Compute completeness score.
226
+
227
+ Measures how complete the response is in answering the query.
228
+
229
+ Args:
230
+ query: User query
231
+ response: Generated response
232
+ ground_truth: Optional ground truth answer
233
+
234
+ Returns:
235
+ Completeness score (0-1)
236
+ """
237
+ if not response or not query:
238
+ return 0.0
239
+
240
+ # Query analysis
241
+ query_lower = query.lower()
242
+
243
+ # Check for question types and expected components
244
+ is_what = any(w in query_lower for w in ["what", "which"])
245
+ is_when = "when" in query_lower
246
+ is_where = "where" in query_lower
247
+ is_who = "who" in query_lower
248
+ is_why = "why" in query_lower
249
+ is_how = "how" in query_lower
250
+
251
+ response_lower = response.lower()
252
+
253
+ # Basic completeness checks
254
+ completeness_factors = []
255
+
256
+ # Length check (not too short)
257
+ min_length = 50
258
+ length_score = min(len(response) / min_length, 1.0)
259
+ completeness_factors.append(length_score)
260
+
261
+ # Check for appropriate response type
262
+ if is_when and any(w in response_lower for w in ["year", "date", "time", "century"]):
263
+ completeness_factors.append(1.0)
264
+ elif is_where and any(w in response_lower for w in ["location", "place", "country", "city"]):
265
+ completeness_factors.append(1.0)
266
+ elif is_who and any(w in response_lower for w in ["person", "people", "name"]):
267
+ completeness_factors.append(1.0)
268
+
269
+ # If ground truth available, compare
270
+ if ground_truth:
271
+ gt_lower = ground_truth.lower()
272
+ gt_words = set(self._tokenize(gt_lower))
273
+ response_words = set(self._tokenize(response_lower))
274
+
275
+ # Check overlap with ground truth
276
+ overlap = len(gt_words & response_words)
277
+ gt_score = overlap / len(gt_words) if gt_words else 0
278
+ completeness_factors.append(gt_score)
279
+
280
+ # Average all factors
281
+ return np.mean(completeness_factors) if completeness_factors else 0.5
282
+
283
+ def _tokenize(self, text: str) -> List[str]:
284
+ """Tokenize text into words."""
285
+ # Remove punctuation and split
286
+ text = re.sub(r'[^\w\s]', ' ', text)
287
+ words = text.split()
288
+ # Filter out very short words and common stop words
289
+ stop_words = {"a", "an", "the", "is", "are", "was", "were", "in", "on", "at", "to", "for"}
290
+ return [w for w in words if len(w) > 2 and w not in stop_words]
291
+
292
+ def _extract_keywords(self, text: str) -> List[str]:
293
+ """Extract keywords from text."""
294
+ words = self._tokenize(text)
295
+ # Simple keyword extraction - words that appear in query
296
+ # In production, use TF-IDF or similar
297
+ word_freq = Counter(words)
298
+ # Return words that appear at least once
299
+ return list(word_freq.keys())
300
+
301
+ def _split_sentences(self, text: str) -> List[str]:
302
+ """Split text into sentences."""
303
+ # Simple sentence splitting
304
+ sentences = re.split(r'[.!?]+', text)
305
+ return [s.strip() for s in sentences if s.strip()]
306
+
307
+ def evaluate_batch(
308
+ self,
309
+ test_data: List[Dict]
310
+ ) -> Dict:
311
+ """Evaluate multiple test cases.
312
+
313
+ Args:
314
+ test_data: List of test cases, each containing:
315
+ - query: User query
316
+ - response: Generated response
317
+ - retrieved_documents: Retrieved documents
318
+ - ground_truth: Ground truth answer (optional)
319
+
320
+ Returns:
321
+ Dictionary with aggregated scores
322
+ """
323
+ all_scores = []
324
+
325
+ for i, test_case in enumerate(test_data):
326
+ print(f"Evaluating test case {i+1}/{len(test_data)}")
327
+
328
+ scores = self.evaluate(
329
+ query=test_case.get("query", ""),
330
+ response=test_case.get("response", ""),
331
+ retrieved_documents=test_case.get("retrieved_documents", []),
332
+ ground_truth=test_case.get("ground_truth")
333
+ )
334
+
335
+ all_scores.append(scores)
336
+
337
+ # Aggregate scores
338
+ avg_utilization = np.mean([s.utilization for s in all_scores])
339
+ avg_relevance = np.mean([s.relevance for s in all_scores])
340
+ avg_adherence = np.mean([s.adherence for s in all_scores])
341
+ avg_completeness = np.mean([s.completeness for s in all_scores])
342
+
343
+ return {
344
+ "utilization": float(avg_utilization),
345
+ "relevance": float(avg_relevance),
346
+ "adherence": float(avg_adherence),
347
+ "completeness": float(avg_completeness),
348
+ "average": float((avg_utilization + avg_relevance +
349
+ avg_adherence + avg_completeness) / 4),
350
+ "num_samples": len(test_data),
351
+ "individual_scores": [s.to_dict() for s in all_scores]
352
+ }
vector_store.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChromaDB integration for vector storage and retrieval."""
2
+ from typing import List, Dict, Optional, Tuple
3
+ import chromadb
4
+ from chromadb.config import Settings
5
+ import uuid
6
+ import os
7
+ from embedding_models import EmbeddingFactory, EmbeddingModel
8
+ from chunking_strategies import ChunkingFactory
9
+ import json
10
+
11
+
12
+ class ChromaDBManager:
13
+ """Manager for ChromaDB operations."""
14
+
15
+ def __init__(self, persist_directory: str = "./chroma_db"):
16
+ """Initialize ChromaDB manager.
17
+
18
+ Args:
19
+ persist_directory: Directory to persist ChromaDB data
20
+ """
21
+ self.persist_directory = persist_directory
22
+ os.makedirs(persist_directory, exist_ok=True)
23
+
24
+ # Initialize ChromaDB client with is_persistent=True to use persistent storage
25
+ try:
26
+ self.client = chromadb.PersistentClient(
27
+ path=persist_directory,
28
+ settings=Settings(
29
+ anonymized_telemetry=False,
30
+ allow_reset=True # Allow reset if needed
31
+ )
32
+ )
33
+ except Exception as e:
34
+ print(f"Warning: Could not create persistent client: {e}")
35
+ print("Falling back to regular client...")
36
+ self.client = chromadb.Client(Settings(
37
+ persist_directory=persist_directory,
38
+ anonymized_telemetry=False,
39
+ allow_reset=True
40
+ ))
41
+
42
+ self.embedding_model = None
43
+ self.current_collection = None
44
+
45
+ def reconnect(self):
46
+ """Reconnect to ChromaDB in case of connection loss."""
47
+ try:
48
+ self.client = chromadb.PersistentClient(
49
+ path=self.persist_directory,
50
+ settings=Settings(
51
+ anonymized_telemetry=False,
52
+ allow_reset=True
53
+ )
54
+ )
55
+ print("✅ Reconnected to ChromaDB")
56
+ except Exception as e:
57
+ print(f"Error reconnecting: {e}")
58
+
59
+ def create_collection(
60
+ self,
61
+ collection_name: str,
62
+ embedding_model_name: str,
63
+ metadata: Optional[Dict] = None
64
+ ) -> chromadb.Collection:
65
+ """Create a new collection.
66
+
67
+ Args:
68
+ collection_name: Name of the collection
69
+ embedding_model_name: Name of the embedding model
70
+ metadata: Additional metadata for the collection
71
+
72
+ Returns:
73
+ ChromaDB collection
74
+ """
75
+ # Delete if exists
76
+ try:
77
+ self.client.delete_collection(collection_name)
78
+ except:
79
+ pass
80
+
81
+ # Create embedding model
82
+ self.embedding_model = EmbeddingFactory.create_embedding_model(embedding_model_name)
83
+ self.embedding_model.load_model()
84
+
85
+ # Create collection with metadata
86
+ collection_metadata = {
87
+ "embedding_model": embedding_model_name,
88
+ "hnsw:space": "cosine"
89
+ }
90
+ if metadata:
91
+ collection_metadata.update(metadata)
92
+
93
+ self.current_collection = self.client.create_collection(
94
+ name=collection_name,
95
+ metadata=collection_metadata
96
+ )
97
+
98
+ print(f"Created collection: {collection_name}")
99
+ return self.current_collection
100
+
101
+ def get_collection(self, collection_name: str) -> chromadb.Collection:
102
+ """Get an existing collection.
103
+
104
+ Args:
105
+ collection_name: Name of the collection
106
+
107
+ Returns:
108
+ ChromaDB collection
109
+ """
110
+ self.current_collection = self.client.get_collection(collection_name)
111
+
112
+ # Load embedding model from metadata
113
+ metadata = self.current_collection.metadata
114
+ if "embedding_model" in metadata:
115
+ self.embedding_model = EmbeddingFactory.create_embedding_model(
116
+ metadata["embedding_model"]
117
+ )
118
+ self.embedding_model.load_model()
119
+
120
+ return self.current_collection
121
+
122
+ def list_collections(self) -> List[str]:
123
+ """List all collections.
124
+
125
+ Returns:
126
+ List of collection names
127
+ """
128
+ collections = self.client.list_collections()
129
+ return [col.name for col in collections]
130
+
131
+ def clear_all_collections(self) -> int:
132
+ """Delete all collections from the database.
133
+
134
+ Returns:
135
+ Number of collections deleted
136
+ """
137
+ collections = self.list_collections()
138
+ count = 0
139
+
140
+ for collection_name in collections:
141
+ try:
142
+ self.client.delete_collection(collection_name)
143
+ print(f"Deleted collection: {collection_name}")
144
+ count += 1
145
+ except Exception as e:
146
+ print(f"Error deleting collection {collection_name}: {e}")
147
+
148
+ self.current_collection = None
149
+ self.embedding_model = None
150
+ print(f"✅ Cleared {count} collections")
151
+ return count
152
+
153
+ def delete_collection(self, collection_name: str) -> bool:
154
+ """Delete a specific collection.
155
+
156
+ Args:
157
+ collection_name: Name of the collection to delete
158
+
159
+ Returns:
160
+ True if deleted successfully, False otherwise
161
+ """
162
+ try:
163
+ self.client.delete_collection(collection_name)
164
+ if self.current_collection and self.current_collection.name == collection_name:
165
+ self.current_collection = None
166
+ self.embedding_model = None
167
+ print(f"✅ Deleted collection: {collection_name}")
168
+ return True
169
+ except Exception as e:
170
+ print(f"❌ Error deleting collection: {e}")
171
+ return False
172
+
173
+ def add_documents(
174
+ self,
175
+ documents: List[str],
176
+ metadatas: Optional[List[Dict]] = None,
177
+ ids: Optional[List[str]] = None,
178
+ batch_size: int = 100
179
+ ):
180
+ """Add documents to the current collection.
181
+
182
+ Args:
183
+ documents: List of document texts
184
+ metadatas: List of metadata dictionaries
185
+ ids: List of document IDs
186
+ batch_size: Batch size for processing
187
+ """
188
+ if not self.current_collection:
189
+ raise ValueError("No collection selected. Create or get a collection first.")
190
+
191
+ if not self.embedding_model:
192
+ raise ValueError("No embedding model loaded.")
193
+
194
+ # Generate IDs if not provided
195
+ if ids is None:
196
+ ids = [str(uuid.uuid4()) for _ in documents]
197
+
198
+ # Generate default metadata if not provided
199
+ if metadatas is None:
200
+ metadatas = [{"index": i} for i in range(len(documents))]
201
+
202
+ # Process in batches
203
+ total_docs = len(documents)
204
+ print(f"Adding {total_docs} documents to collection...")
205
+
206
+ for i in range(0, total_docs, batch_size):
207
+ batch_docs = documents[i:i + batch_size]
208
+ batch_ids = ids[i:i + batch_size]
209
+ batch_metadatas = metadatas[i:i + batch_size]
210
+
211
+ # Generate embeddings
212
+ embeddings = self.embedding_model.embed_documents(batch_docs)
213
+
214
+ # Add to collection
215
+ self.current_collection.add(
216
+ documents=batch_docs,
217
+ embeddings=embeddings.tolist(),
218
+ metadatas=batch_metadatas,
219
+ ids=batch_ids
220
+ )
221
+
222
+ print(f"Added batch {i//batch_size + 1}/{(total_docs-1)//batch_size + 1}")
223
+
224
+ print(f"Successfully added {total_docs} documents")
225
+
226
+ def load_dataset_into_collection(
227
+ self,
228
+ collection_name: str,
229
+ embedding_model_name: str,
230
+ chunking_strategy: str,
231
+ dataset_data: List[Dict],
232
+ chunk_size: int = 512,
233
+ overlap: int = 50
234
+ ):
235
+ """Load a dataset into a new collection with chunking.
236
+
237
+ Args:
238
+ collection_name: Name for the new collection
239
+ embedding_model_name: Embedding model to use
240
+ chunking_strategy: Chunking strategy to use
241
+ dataset_data: List of dataset samples
242
+ chunk_size: Size of chunks
243
+ overlap: Overlap between chunks
244
+ """
245
+ # Create collection
246
+ self.create_collection(
247
+ collection_name,
248
+ embedding_model_name,
249
+ metadata={
250
+ "chunking_strategy": chunking_strategy,
251
+ "chunk_size": chunk_size,
252
+ "overlap": overlap
253
+ }
254
+ )
255
+
256
+ # Get chunker
257
+ chunker = ChunkingFactory.create_chunker(chunking_strategy)
258
+
259
+ # Process documents
260
+ all_chunks = []
261
+ all_metadatas = []
262
+
263
+ print(f"Processing {len(dataset_data)} documents with {chunking_strategy} chunking...")
264
+
265
+ for idx, sample in enumerate(dataset_data):
266
+ # Use 'documents' list if available, otherwise fall back to 'context'
267
+ documents = sample.get("documents", [])
268
+
269
+ # If documents is empty, use context as fallback
270
+ if not documents:
271
+ context = sample.get("context", "")
272
+ if context:
273
+ documents = [context]
274
+
275
+ if not documents:
276
+ continue
277
+
278
+ # Process each document separately for better granularity
279
+ for doc_idx, document in enumerate(documents):
280
+ if not document or not str(document).strip():
281
+ continue
282
+
283
+ # Chunk each document
284
+ chunks = chunker.chunk_text(str(document), chunk_size, overlap)
285
+
286
+ # Create metadata for each chunk
287
+ for chunk_idx, chunk in enumerate(chunks):
288
+ all_chunks.append(chunk)
289
+ all_metadatas.append({
290
+ "doc_id": idx,
291
+ "doc_idx": doc_idx, # Track which document within the sample
292
+ "chunk_id": chunk_idx,
293
+ "question": sample.get("question", ""),
294
+ "answer": sample.get("answer", ""),
295
+ "dataset": sample.get("dataset", ""),
296
+ "total_docs": len(documents)
297
+ })
298
+
299
+ # Add all chunks to collection
300
+ self.add_documents(all_chunks, all_metadatas)
301
+
302
+ print(f"Loaded {len(all_chunks)} chunks from {len(dataset_data)} samples")
303
+
304
+ def query(
305
+ self,
306
+ query_text: str,
307
+ n_results: int = 5,
308
+ filter_metadata: Optional[Dict] = None
309
+ ) -> Dict:
310
+ """Query the collection.
311
+
312
+ Args:
313
+ query_text: Query text
314
+ n_results: Number of results to return
315
+ filter_metadata: Metadata filter
316
+
317
+ Returns:
318
+ Query results
319
+ """
320
+ if not self.current_collection:
321
+ raise ValueError("No collection selected.")
322
+
323
+ if not self.embedding_model:
324
+ raise ValueError("No embedding model loaded.")
325
+
326
+ # Generate query embedding
327
+ query_embedding = self.embedding_model.embed_query(query_text)
328
+
329
+ # Query collection with retry logic
330
+ try:
331
+ results = self.current_collection.query(
332
+ query_embeddings=[query_embedding.tolist()],
333
+ n_results=n_results,
334
+ where=filter_metadata
335
+ )
336
+ except Exception as e:
337
+ if "default_tenant" in str(e):
338
+ print("Warning: Lost connection to ChromaDB, reconnecting...")
339
+ self.reconnect()
340
+ # Try again after reconnecting
341
+ results = self.current_collection.query(
342
+ query_embeddings=[query_embedding.tolist()],
343
+ n_results=n_results,
344
+ where=filter_metadata
345
+ )
346
+ else:
347
+ raise
348
+
349
+ return results
350
+
351
+ def get_retrieved_documents(
352
+ self,
353
+ query_text: str,
354
+ n_results: int = 5
355
+ ) -> List[Dict]:
356
+ """Get retrieved documents with metadata.
357
+
358
+ Args:
359
+ query_text: Query text
360
+ n_results: Number of results
361
+
362
+ Returns:
363
+ List of retrieved documents with metadata
364
+ """
365
+ results = self.query(query_text, n_results)
366
+
367
+ retrieved_docs = []
368
+ for i in range(len(results['documents'][0])):
369
+ retrieved_docs.append({
370
+ "document": results['documents'][0][i],
371
+ "metadata": results['metadatas'][0][i],
372
+ "distance": results['distances'][0][i] if 'distances' in results else None
373
+ })
374
+
375
+ return retrieved_docs
376
+
377
+ def delete_collection(self, collection_name: str):
378
+ """Delete a collection.
379
+
380
+ Args:
381
+ collection_name: Name of collection to delete
382
+ """
383
+ try:
384
+ self.client.delete_collection(collection_name)
385
+ print(f"Deleted collection: {collection_name}")
386
+ except Exception as e:
387
+ print(f"Error deleting collection: {str(e)}")
388
+
389
+ def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict:
390
+ """Get statistics for a collection.
391
+
392
+ Args:
393
+ collection_name: Name of collection (uses current if None)
394
+
395
+ Returns:
396
+ Dictionary with collection statistics
397
+ """
398
+ if collection_name:
399
+ collection = self.client.get_collection(collection_name)
400
+ elif self.current_collection:
401
+ collection = self.current_collection
402
+ else:
403
+ raise ValueError("No collection specified or selected")
404
+
405
+ count = collection.count()
406
+ metadata = collection.metadata
407
+
408
+ return {
409
+ "name": collection.name,
410
+ "count": count,
411
+ "metadata": metadata
412
+ }