stanleydukor commited on
Commit
702ea87
·
1 Parent(s): 1c61c6e

Initial deployment

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ PIPFILE.lock
28
+
29
+ # PyInstaller
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .nox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ *.py,cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+ cover/
51
+
52
+ # Translations
53
+ *.mo
54
+ *.pot
55
+
56
+ # Django stuff:
57
+ *.log
58
+ local_settings.py
59
+ db.sqlite3
60
+ db.sqlite3-journal
61
+
62
+ # Flask stuff:
63
+ instance/
64
+ .webassets-cache
65
+
66
+ # Scrapy stuff:
67
+ .scrapy
68
+
69
+ # Sphinx documentation
70
+ docs/_build/
71
+
72
+ # PyBuilder
73
+ .pybuilder/
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # IPython
80
+ profile_default/
81
+ ipython_config.py
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # pipenv
87
+ Pipfile.lock
88
+
89
+ # poetry
90
+ poetry.lock
91
+
92
+ # pdm
93
+ .pdm.toml
94
+
95
+ # PEP 582
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .env
107
+ .venv
108
+ env/
109
+ venv/
110
+ ENV/
111
+ env.bak/
112
+ venv.bak/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ # pytype static type analyzer
133
+ .pytype/
134
+
135
+ # Cython debug symbols
136
+ cython_debug/
137
+
138
+ # IDE
139
+ .vscode/
140
+ .idea/
141
+ *.swp
142
+ *.swo
143
+ *~
144
+ .DS_Store
145
+
146
+ # Project specific
147
+ data/raw/*
148
+ !data/raw/.gitkeep
149
+ data/processed/*
150
+ !data/processed/.gitkeep
151
+ data/vectorstore/*
152
+ !data/vectorstore/.gitkeep
153
+
154
+ # Model files
155
+ *.bin
156
+ *.onnx
157
+ *.pt
158
+ *.pth
159
+
160
+ # Logs
161
+ logs/
162
+ *.log
CHANGELOG.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## [2.0.0] - 2026-01-05
4
+
5
+ ### Major Improvements
6
+
7
+ #### Gradio UI Enhancements
8
+ - **Fixed HTML rendering issue**: Changed from HTML badges to clean emoji-based confidence indicators
9
+ - High Confidence: ✅ (≥70%)
10
+ - Medium Confidence: ⚠️ (50-69%)
11
+ - Low Confidence: ⚡ (<50%)
12
+ - **Improved message formatting**: Removed raw HTML display in chat interface
13
+ - **Cleaner disclaimers**: Updated medical disclaimer to be more concise
14
+
15
+ #### Content Updates
16
+ - **Removed "educational purposes" language** across all files:
17
+ - Updated system prompts
18
+ - Updated medical disclaimers
19
+ - Updated README
20
+ - Updated UI text
21
+ - **Streamlined medical disclaimers**: More professional, less verbose
22
+
23
+ #### Bug Fixes
24
+ - **Fixed Ollama GPU support**: Configured Ollama to use RTX 5090 GPU instead of CPU
25
+ - Added GPU initialization script
26
+ - Set proper CUDA environment variables
27
+ - Verified VRAM usage (4.79 GB on GPU)
28
+ - Performance improvement: ~10-50x faster inference
29
+
30
+ - **Fixed Qdrant API compatibility**: Updated to qdrant-client v1.16.1 API
31
+ - Changed from `client.search()` to `client.query_points()`
32
+ - Added `using="dense"` parameter for named vectors
33
+ - Fixed both search and hybrid_search methods
34
+
35
+ - **Fixed Pydantic validation errors**:
36
+ - Removed `ge=0.0` constraint from `RetrievalResult.score` (cross-encoder scores can be negative)
37
+ - Removed `ge=0.0, le=1.0` constraints from `SourceInfo.relevance_score`
38
+
39
+ - **Fixed QdrantStoreManager initialization**:
40
+ - Changed `vector_size` → `embedding_dim`
41
+ - Changed `qdrant_path` → `path`
42
+ - Use `embedding_client.embedding_dim` instead of non-existent settings attribute
43
+
44
+ - **Added missing Settings attributes**:
45
+ - `ollama_timeout` (default: 30)
46
+ - `reranker_model` (default: "cross-encoder/ms-marco-MiniLM-L-6-v2")
47
+ - `max_context_tokens` (default: 4096)
48
+
49
+ - **Fixed OllamaClient embedding model verification**:
50
+ - Skip embedding model verification when `embedding_model=None`
51
+ - Prevents false errors when using SentenceTransformerClient for embeddings
52
+
53
+ #### Code Cleanup
54
+ - Removed unnecessary comments and annotations
55
+ - Cleaned up fix-related comments
56
+ - Improved code documentation
57
+ - Removed redundant validation constraints
58
+
59
+ ### Technical Details
60
+
61
+ #### Performance
62
+ - **GPU Acceleration**: Full GPU support for Ollama (RTX 5090)
63
+ - **Model Loading**: 4.79 GB VRAM usage confirmed
64
+ - **Faster Inference**: Significant speedup from CPU to GPU
65
+
66
+ #### API Changes
67
+ - Qdrant API updated to v1.16.1 syntax
68
+ - Improved error handling for cross-encoder scores
69
+ - Better validation for unbounded reranker scores
70
+
71
+ #### Configuration
72
+ - New environment variables for Ollama GPU support:
73
+ - `CUDA_VISIBLE_DEVICES=0`
74
+ - `OLLAMA_NUM_PARALLEL=1`
75
+ - `OLLAMA_MAX_LOADED_MODELS=1`
76
+
77
+ ### Files Modified
78
+ - `src/api/gradio_ui.py` - UI improvements and HTML rendering fix
79
+ - `src/api/main.py` - Fixed initialization parameters
80
+ - `src/rag/query_engine.py` - Updated disclaimers and validation
81
+ - `src/rag/retriever.py` - Removed score constraints
82
+ - `src/vectorstore/qdrant_store.py` - Updated Qdrant API calls
83
+ - `src/llm/ollama_client.py` - Fixed embedding model handling
84
+ - `config/settings.py` - Added missing configuration fields
85
+ - `prompts/medical_disclaimer.txt` - Removed educational language
86
+ - `prompts/system_prompt.txt` - Streamlined instructions
87
+ - `README.md` - Updated disclaimers and documentation
88
+
89
+ ### Breaking Changes
90
+ - None - all changes are backward compatible
91
+
92
+ ### Upgrade Notes
93
+ 1. Restart Ollama with GPU support using provided script
94
+ 2. Clear Python cache if experiencing import issues
95
+ 3. Verify GPU usage with `curl -s http://localhost:11434/api/ps | python3 -m json.tool`
DOCKER.md ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Deployment Guide
2
+
3
+ Complete guide for deploying EyeWiki RAG using Docker.
4
+
5
+ ## 📋 Table of Contents
6
+ - [Prerequisites](#prerequisites)
7
+ - [Quick Start](#quick-start)
8
+ - [Architecture](#architecture)
9
+ - [Configuration](#configuration)
10
+ - [Operations](#operations)
11
+ - [Troubleshooting](#troubleshooting)
12
+ - [Production](#production)
13
+
14
+ ## Prerequisites
15
+
16
+ ### Required Software
17
+ - **Docker** 20.10+ ([Install Docker](https://docs.docker.com/get-docker/))
18
+ - **Docker Compose** 2.0+ ([Install Compose](https://docs.docker.com/compose/install/))
19
+ - **Ollama** running on host ([Install Ollama](https://ollama.ai/download))
20
+
21
+ ### System Requirements
22
+ - 8GB+ RAM allocated to Docker
23
+ - 20GB+ disk space
24
+ - CPU: 4+ cores recommended
25
+ - GPU: Optional, for faster processing
26
+
27
+ ### Verify Installation
28
+ ```bash
29
+ docker --version
30
+ docker-compose --version
31
+ ollama --version
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ ### 1. Prepare Ollama (Host Machine)
37
+
38
+ ```bash
39
+ # Start Ollama service
40
+ ollama serve
41
+
42
+ # Pull required models
43
+ ollama pull nomic-embed-text # ~270MB
44
+ ollama pull mistral # ~4.1GB
45
+
46
+ # Verify models
47
+ ollama list
48
+ ```
49
+
50
+ ### 2. Build and Start Services
51
+
52
+ ```bash
53
+ # Clone repository
54
+ git clone <repo-url>
55
+ cd eyewiki-rag
56
+
57
+ # Build images
58
+ docker-compose build
59
+
60
+ # Start services
61
+ docker-compose up -d
62
+
63
+ # Check status
64
+ docker-compose ps
65
+ ```
66
+
67
+ ### 3. Verify Services
68
+
69
+ ```bash
70
+ # Check API health
71
+ curl http://localhost:8000/health
72
+
73
+ # Check Qdrant
74
+ curl http://localhost:6333/
75
+
76
+ # View logs
77
+ docker-compose logs -f
78
+ ```
79
+
80
+ ### 4. Access Services
81
+
82
+ - **API**: http://localhost:8000
83
+ - **Gradio UI**: http://localhost:8000/ui
84
+ - **API Docs**: http://localhost:8000/docs
85
+ - **Qdrant Dashboard**: http://localhost:6333/dashboard
86
+
87
+ ## Architecture
88
+
89
+ ### Container Network
90
+
91
+ ```
92
+ ┌─────────────────────────────────────────────┐
93
+ │ Host Machine │
94
+ │ ┌──────────────────────────────────────┐ │
95
+ │ │ Ollama (GPU Access) │ │
96
+ │ │ - Port: 11434 │ │
97
+ │ │ - Models: mistral, nomic-embed │ │
98
+ │ └────────────┬─────────────────────────┘ │
99
+ │ │ │
100
+ │ ┌────────────▼─────────────────────────┐ │
101
+ │ │ Docker Network │ │
102
+ │ │ ┌─────────────────────────────────┐ │ │
103
+ │ │ │ eyewiki-rag (API Server) │ │ │
104
+ │ │ │ - Port: 8000 │ │ │
105
+ │ │ │ - Connects to Ollama via │ │ │
106
+ │ │ │ host.docker.internal │ │ │
107
+ │ │ └─────────────┬───────────────────┘ │ │
108
+ │ │ │ │ │
109
+ │ │ ┌─────────────▼───────────────────┐ │ │
110
+ │ │ │ qdrant (Vector DB) │ │ │
111
+ │ │ │ - Ports: 6333, 6334 │ │ │
112
+ │ │ │ - Persistent volume │ │ │
113
+ │ │ └─────────────────────────────────┘ │ │
114
+ │ └──────────────────────────────────────┘ │
115
+ └─────────────────────────────────────────────┘
116
+ ```
117
+
118
+ ### Data Flow
119
+
120
+ 1. **User Request** → API Container (port 8000)
121
+ 2. **Query Engine** → Qdrant Container (vector search)
122
+ 3. **Embedding** → Ollama on Host (via host.docker.internal)
123
+ 4. **LLM Generation** → Ollama on Host
124
+ 5. **Response** → User
125
+
126
+ ### Volumes
127
+
128
+ | Volume | Path | Purpose |
129
+ |--------|------|---------|
130
+ | `./data/raw` | `/app/data/raw` | Scraped content |
131
+ | `./data/processed` | `/app/data/processed` | Chunked documents |
132
+ | `qdrant_data` | `/app/data/qdrant` | Vector database |
133
+ | `./prompts` | `/app/prompts` | Customizable prompts |
134
+
135
+ ## Configuration
136
+
137
+ ### Environment Variables
138
+
139
+ Edit `docker-compose.yml`:
140
+
141
+ ```yaml
142
+ environment:
143
+ # Ollama Configuration
144
+ - OLLAMA_BASE_URL=http://host.docker.internal:11434
145
+ - LLM_MODEL=mistral
146
+ - EMBEDDING_MODEL=nomic-embed-text
147
+ - OLLAMA_TIMEOUT=120
148
+
149
+ # Qdrant Configuration
150
+ - QDRANT_HOST=qdrant
151
+ - QDRANT_PORT=6333
152
+ - QDRANT_COLLECTION_NAME=eyewiki_rag
153
+ - QDRANT_PATH=/app/data/qdrant
154
+
155
+ # Processing Configuration
156
+ - CHUNK_SIZE=512
157
+ - CHUNK_OVERLAP=50
158
+ - MIN_CHUNK_SIZE=100
159
+ - MAX_CONTEXT_TOKENS=4000
160
+
161
+ # Retrieval Configuration
162
+ - RETRIEVAL_K=20
163
+ - RERANK_K=5
164
+ - RERANKER_MODEL=ms-marco-MiniLM-L-6-v2
165
+ ```
166
+
167
+ ### Custom Prompts
168
+
169
+ Edit files in `./prompts/` directory (mounted into container):
170
+ - `system_prompt.txt`
171
+ - `query_prompt.txt`
172
+ - `medical_disclaimer.txt`
173
+
174
+ Changes take effect on container restart.
175
+
176
+ ### Resource Limits
177
+
178
+ Add to service in `docker-compose.yml`:
179
+
180
+ ```yaml
181
+ deploy:
182
+ resources:
183
+ limits:
184
+ cpus: '4'
185
+ memory: 8G
186
+ reservations:
187
+ cpus: '2'
188
+ memory: 4G
189
+ ```
190
+
191
+ ## Operations
192
+
193
+ ### Makefile Commands
194
+
195
+ ```bash
196
+ # Service Management
197
+ make up # Start services
198
+ make down # Stop services
199
+ make restart # Restart services
200
+ make ps # Show status
201
+ make logs # View all logs
202
+ make logs-api # View API logs only
203
+ make logs-qdrant # View Qdrant logs only
204
+
205
+ # Health & Monitoring
206
+ make health # Check service health
207
+ make stats # Show resource usage
208
+
209
+ # Data Operations
210
+ make scrape # Run scraper
211
+ make build-index # Build vector index
212
+ make evaluate # Run evaluation
213
+ make test # Run tests
214
+
215
+ # Maintenance
216
+ make clean # Remove containers & volumes
217
+ make rebuild # Clean rebuild
218
+ make backup-qdrant # Backup vector DB
219
+ make restore-qdrant # Restore from backup
220
+
221
+ # Development
222
+ make exec-api # Bash into API container
223
+ make exec-qdrant # Shell into Qdrant container
224
+ ```
225
+
226
+ ### Manual Commands
227
+
228
+ #### Start Services
229
+ ```bash
230
+ docker-compose up -d
231
+ ```
232
+
233
+ #### Stop Services
234
+ ```bash
235
+ docker-compose down
236
+ ```
237
+
238
+ #### View Logs
239
+ ```bash
240
+ # All services
241
+ docker-compose logs -f
242
+
243
+ # Specific service
244
+ docker-compose logs -f eyewiki-rag
245
+ docker-compose logs -f qdrant
246
+
247
+ # Last N lines
248
+ docker-compose logs --tail=100 -f
249
+ ```
250
+
251
+ #### Execute Commands in Container
252
+ ```bash
253
+ # Run scraper
254
+ docker-compose exec eyewiki-rag \
255
+ python scripts/scrape_eyewiki.py --max-pages 100
256
+
257
+ # Build index
258
+ docker-compose exec eyewiki-rag \
259
+ python scripts/build_index.py --index-vectors
260
+
261
+ # Run evaluation
262
+ docker-compose exec eyewiki-rag \
263
+ python scripts/evaluate.py -v
264
+
265
+ # Run tests
266
+ docker-compose exec eyewiki-rag pytest tests/ -v
267
+
268
+ # Interactive shell
269
+ docker-compose exec eyewiki-rag bash
270
+ ```
271
+
272
+ #### Inspect Services
273
+ ```bash
274
+ # Container status
275
+ docker-compose ps
276
+
277
+ # Resource usage
278
+ docker stats eyewiki-rag-api eyewiki-qdrant
279
+
280
+ # Network info
281
+ docker network inspect eyewiki-network
282
+
283
+ # Volume info
284
+ docker volume ls
285
+ docker volume inspect eyewiki-rag_qdrant_data
286
+ ```
287
+
288
+ ### Data Management
289
+
290
+ #### Backup Qdrant
291
+ ```bash
292
+ # Using Makefile
293
+ make backup-qdrant
294
+
295
+ # Manual
296
+ docker-compose exec qdrant tar -czf /tmp/backup.tar.gz /qdrant/storage
297
+ docker cp eyewiki-qdrant:/tmp/backup.tar.gz ./backups/qdrant-$(date +%Y%m%d).tar.gz
298
+ ```
299
+
300
+ #### Restore Qdrant
301
+ ```bash
302
+ # Stop services
303
+ docker-compose down
304
+
305
+ # Restore backup
306
+ docker-compose up -d qdrant
307
+ docker cp ./backups/qdrant-20241209.tar.gz eyewiki-qdrant:/tmp/backup.tar.gz
308
+ docker-compose exec qdrant tar -xzf /tmp/backup.tar.gz -C /
309
+
310
+ # Restart all services
311
+ docker-compose up -d
312
+ ```
313
+
314
+ #### Clear Data
315
+ ```bash
316
+ # Remove all data and volumes
317
+ docker-compose down -v
318
+
319
+ # Remove only processed data
320
+ rm -rf data/processed/*
321
+ rm -rf data/qdrant/*
322
+ ```
323
+
324
+ ## Troubleshooting
325
+
326
+ ### Cannot Connect to Ollama
327
+
328
+ **Symptoms:**
329
+ - `ConnectionError: Failed to connect to Ollama`
330
+ - 503 errors on API startup
331
+
332
+ **Solutions:**
333
+
334
+ 1. **Verify Ollama is running:**
335
+ ```bash
336
+ curl http://localhost:11434/api/tags
337
+ ```
338
+
339
+ 2. **On Linux, add to docker-compose.yml:**
340
+ ```yaml
341
+ extra_hosts:
342
+ - "host.docker.internal:host-gateway"
343
+ ```
344
+
345
+ 3. **Use host IP instead:**
346
+ ```bash
347
+ # Get host IP
348
+ ip addr show docker0 | grep inet
349
+
350
+ # Update OLLAMA_BASE_URL
351
+ OLLAMA_BASE_URL=http://172.17.0.1:11434
352
+ ```
353
+
354
+ ### Qdrant Permission Errors
355
+
356
+ **Symptoms:**
357
+ - Permission denied errors in Qdrant logs
358
+ - Cannot write to volume
359
+
360
+ **Solution:**
361
+ ```bash
362
+ # Fix permissions
363
+ sudo chown -R 1000:1000 data/qdrant/
364
+
365
+ # Or recreate volume
366
+ docker-compose down -v
367
+ docker-compose up -d
368
+ ```
369
+
370
+ ### Out of Memory
371
+
372
+ **Symptoms:**
373
+ - Container killed (exit code 137)
374
+ - Slow performance
375
+
376
+ **Solutions:**
377
+
378
+ 1. **Increase Docker memory:**
379
+ - Docker Desktop: Settings → Resources → Memory → 8GB+
380
+
381
+ 2. **Add resource limits:**
382
+ ```yaml
383
+ deploy:
384
+ resources:
385
+ limits:
386
+ memory: 8G
387
+ ```
388
+
389
+ 3. **Use smaller models:**
390
+ ```bash
391
+ ollama pull llama3.2:3b # Instead of mistral
392
+ ```
393
+
394
+ ### Port Already in Use
395
+
396
+ **Symptoms:**
397
+ - `Bind for 0.0.0.0:8000 failed: port is already allocated`
398
+
399
+ **Solutions:**
400
+
401
+ 1. **Find and kill process:**
402
+ ```bash
403
+ lsof -i :8000
404
+ kill <PID>
405
+ ```
406
+
407
+ 2. **Change port in docker-compose.yml:**
408
+ ```yaml
409
+ ports:
410
+ - "8080:8000" # Use 8080 instead
411
+ ```
412
+
413
+ ### Slow Performance
414
+
415
+ **Solutions:**
416
+
417
+ 1. **Reduce batch sizes:**
418
+ ```yaml
419
+ environment:
420
+ - RETRIEVAL_K=10 # Instead of 20
421
+ - RERANK_K=3 # Instead of 5
422
+ ```
423
+
424
+ 2. **Allocate more resources:**
425
+ ```yaml
426
+ deploy:
427
+ resources:
428
+ limits:
429
+ cpus: '4'
430
+ memory: 8G
431
+ ```
432
+
433
+ 3. **Use GPU for Ollama** (on host)
434
+
435
+ ## Production
436
+
437
+ ### Production Configuration
438
+
439
+ Create `docker-compose.prod.yml`:
440
+
441
+ ```yaml
442
+ version: '3.8'
443
+
444
+ services:
445
+ eyewiki-rag:
446
+ restart: always
447
+ deploy:
448
+ resources:
449
+ limits:
450
+ cpus: '4'
451
+ memory: 8G
452
+ reservations:
453
+ cpus: '2'
454
+ memory: 4G
455
+ logging:
456
+ driver: "json-file"
457
+ options:
458
+ max-size: "100m"
459
+ max-file: "5"
460
+ environment:
461
+ - LOG_LEVEL=WARNING
462
+ healthcheck:
463
+ interval: 30s
464
+ timeout: 10s
465
+ retries: 3
466
+ start_period: 60s
467
+
468
+ qdrant:
469
+ restart: always
470
+ deploy:
471
+ resources:
472
+ limits:
473
+ cpus: '2'
474
+ memory: 4G
475
+ logging:
476
+ driver: "json-file"
477
+ options:
478
+ max-size: "50m"
479
+ max-file: "3"
480
+ ```
481
+
482
+ ### Start Production
483
+
484
+ ```bash
485
+ # Use production config
486
+ docker-compose -f docker-compose.yml -f docker-compose.prod.yml up -d
487
+
488
+ # Or use Makefile
489
+ make prod
490
+ ```
491
+
492
+ ### Monitoring
493
+
494
+ ```bash
495
+ # Watch container status
496
+ watch docker-compose ps
497
+
498
+ # Monitor resources
499
+ docker stats --no-stream eyewiki-rag-api eyewiki-qdrant
500
+
501
+ # Check logs
502
+ docker-compose logs --tail=100 -f
503
+
504
+ # Test health endpoints
505
+ watch curl -s http://localhost:8000/health
506
+ ```
507
+
508
+ ### Backup Strategy
509
+
510
+ ```bash
511
+ # Daily backup script (add to cron)
512
+ #!/bin/bash
513
+ BACKUP_DIR="/backups/eyewiki-rag"
514
+ DATE=$(date +%Y%m%d)
515
+
516
+ # Backup Qdrant
517
+ make backup-qdrant
518
+
519
+ # Backup configuration
520
+ tar -czf $BACKUP_DIR/config-$DATE.tar.gz \
521
+ docker-compose.yml prompts/ data/raw/
522
+
523
+ # Keep last 7 days
524
+ find $BACKUP_DIR -name "*.tar.gz" -mtime +7 -delete
525
+ ```
526
+
527
+ ### Update Strategy
528
+
529
+ ```bash
530
+ # 1. Backup current state
531
+ make backup-qdrant
532
+
533
+ # 2. Pull latest code
534
+ git pull origin main
535
+
536
+ # 3. Rebuild images
537
+ docker-compose build --no-cache
538
+
539
+ # 4. Restart services with zero downtime
540
+ docker-compose up -d --no-deps --build eyewiki-rag
541
+
542
+ # 5. Verify health
543
+ make health
544
+ ```
545
+
546
+ ## Best Practices
547
+
548
+ ### Security
549
+ - Use environment files for secrets
550
+ - Don't expose unnecessary ports
551
+ - Run as non-root user (add to Dockerfile)
552
+ - Keep base images updated
553
+ - Use Docker secrets for production
554
+
555
+ ### Performance
556
+ - Allocate sufficient memory (8GB+)
557
+ - Use volume for Qdrant data
558
+ - Monitor resource usage
559
+ - Scale horizontally if needed
560
+
561
+ ### Maintenance
562
+ - Regular backups
563
+ - Monitor logs for errors
564
+ - Update dependencies
565
+ - Prune unused images/volumes
566
+
567
+ ### Development
568
+ - Use `docker-compose.override.yml` for local config
569
+ - Mount source code as volume for hot reload
570
+ - Keep production and development configs separate
571
+
572
+ ---
573
+
574
+ For more information, see the main [README.md](README.md).
Makefile ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EyeWiki RAG System - Makefile for Docker operations
2
+
3
+ .PHONY: help build up down restart logs ps clean test
4
+
5
+ help: ## Show this help message
6
+ @echo "EyeWiki RAG System - Docker Commands"
7
+ @echo ""
8
+ @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
9
+
10
+ build: ## Build Docker images
11
+ docker-compose build
12
+
13
+ up: ## Start all services
14
+ docker-compose up -d
15
+ @echo "Services starting..."
16
+ @echo "API: http://localhost:8000"
17
+ @echo "Gradio UI: http://localhost:8000/ui"
18
+ @echo "API Docs: http://localhost:8000/docs"
19
+ @echo "Qdrant: http://localhost:6333/dashboard"
20
+
21
+ down: ## Stop all services
22
+ docker-compose down
23
+
24
+ restart: ## Restart all services
25
+ docker-compose restart
26
+
27
+ logs: ## View logs from all services
28
+ docker-compose logs -f
29
+
30
+ logs-api: ## View API logs only
31
+ docker-compose logs -f eyewiki-rag
32
+
33
+ logs-qdrant: ## View Qdrant logs only
34
+ docker-compose logs -f qdrant
35
+
36
+ ps: ## Show running containers
37
+ docker-compose ps
38
+
39
+ health: ## Check health of services
40
+ @echo "Checking Qdrant..."
41
+ @curl -s http://localhost:6333/healthz || echo "Qdrant not healthy"
42
+ @echo "\nChecking API..."
43
+ @curl -s http://localhost:8000/health | python -m json.tool || echo "API not healthy"
44
+
45
+ exec-api: ## Execute bash in API container
46
+ docker-compose exec eyewiki-rag bash
47
+
48
+ exec-qdrant: ## Execute bash in Qdrant container
49
+ docker-compose exec qdrant /bin/sh
50
+
51
+ clean: ## Remove all containers, volumes, and images
52
+ docker-compose down -v
53
+ docker rmi eyewiki-rag_eyewiki-rag 2>/dev/null || true
54
+
55
+ clean-volumes: ## Remove only volumes (keeps images)
56
+ docker-compose down -v
57
+
58
+ rebuild: clean build up ## Clean rebuild and start
59
+
60
+ test: ## Run tests in container
61
+ docker-compose exec eyewiki-rag pytest tests/ -v
62
+
63
+ scrape: ## Run scraper in container (example: make scrape ARGS="--max-pages 50")
64
+ docker-compose exec eyewiki-rag python scripts/scrape_eyewiki.py $(ARGS)
65
+
66
+ build-index: ## Build vector index in container
67
+ docker-compose exec eyewiki-rag python scripts/build_index.py --index-vectors
68
+
69
+ evaluate: ## Run evaluation in container
70
+ docker-compose exec eyewiki-rag python scripts/evaluate.py
71
+
72
+ stats: ## Show system statistics
73
+ @echo "Docker stats:"
74
+ docker stats --no-stream eyewiki-rag-api eyewiki-qdrant
75
+ @echo "\nDisk usage:"
76
+ docker system df
77
+
78
+ backup-qdrant: ## Backup Qdrant data
79
+ docker-compose exec qdrant tar -czf /tmp/qdrant-backup.tar.gz /qdrant/storage
80
+ docker cp eyewiki-qdrant:/tmp/qdrant-backup.tar.gz ./backups/qdrant-backup-$$(date +%Y%m%d-%H%M%S).tar.gz
81
+ @echo "Backup saved to ./backups/"
82
+
83
+ restore-qdrant: ## Restore Qdrant data (usage: make restore-qdrant BACKUP=backups/file.tar.gz)
84
+ docker cp $(BACKUP) eyewiki-qdrant:/tmp/qdrant-backup.tar.gz
85
+ docker-compose exec qdrant tar -xzf /tmp/qdrant-backup.tar.gz -C /
86
+
87
+ prod: ## Start in production mode (detached, with restart policy)
88
+ docker-compose up -d --remove-orphans
89
+ @echo "Production services started"
90
+
91
+ dev: ## Start in development mode (with logs)
92
+ docker-compose up
93
+
94
+ .DEFAULT_GOAL := help
README.md CHANGED
@@ -1,11 +1,1026 @@
1
- ---
2
- title: Eye Wiki
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- short_description: Eye Wiki RAG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # 🏥 EyeWiki RAG System
2
+
3
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
5
+
6
+ A production-ready Retrieval-Augmented Generation (RAG) system for ophthalmology knowledge, powered by EyeWiki content and local LLMs.
7
+
8
+ ## 📋 Overview
9
+
10
+ The EyeWiki RAG system provides intelligent question-answering capabilities for ophthalmology topics by combining:
11
+ - **Web scraping** of authoritative EyeWiki content
12
+ - **Semantic search** with hybrid retrieval (dense + sparse)
13
+ - **Cross-encoder reranking** for precision
14
+ - **Local LLM inference** via Ollama for privacy and control
15
+ - **RESTful API** with interactive web UI
16
+
17
+ Built for medical professionals, researchers, and students seeking quick, evidence-based answers to ophthalmology questions.
18
+
19
+ ## ✨ Features
20
+
21
+ ### Core Capabilities
22
+ - 🔍 **Intelligent Retrieval**: Hybrid search combining dense embeddings and sparse BM25
23
+ - 🎯 **Precise Reranking**: Cross-encoder models for relevance scoring
24
+ - 🏠 **Local Processing**: All data stays on your machine (HIPAA-friendly)
25
+ - 📚 **Source Citations**: Every answer includes EyeWiki article references
26
+ - ⚡ **Streaming Responses**: Real-time answer generation
27
+ - 🌐 **Web Interface**: Beautiful Gradio UI for easy interaction
28
+ - 🔌 **REST API**: Programmatic access with FastAPI
29
+ - ✅ **Comprehensive Testing**: 25+ pytest tests with mocking
30
+
31
+ ### Technical Highlights
32
+ - **Polite Web Scraping**: Respects robots.txt and implements rate limiting
33
+ - **Smart Chunking**: Hierarchical markdown splitting with section awareness
34
+ - **Metadata Extraction**: Automatic ICD-10 codes, anatomical terms, medications
35
+ - **Vector Store**: Local Qdrant with payload indexing
36
+ - **Medical Disclaimer**: Automatic inclusion in all responses
37
+
38
+ ## 🏗️ Architecture
39
+
40
+ ```
41
+ ┌─────────────────────────────────────────────────────────────────┐
42
+ │ User Interface │
43
+ │ ┌──────────────────┐ ┌─────────────────────────┐ │
44
+ │ │ Gradio Web UI │ │ REST API (FastAPI) │ │
45
+ │ │ - Chat interface│ │ - /query │ │
46
+ │ │ - Examples │ │ - /query/stream │ │
47
+ │ │ - Source display│ │ - /health, /stats │ │
48
+ │ └────────┬─────────┘ └───────────┬─────────────┘ │
49
+ └───────────┼────────────────────────────────────┼────────────────┘
50
+ │ │
51
+ └────────────────┬───────────────────┘
52
+
53
+ ┌────────────────────────────────────────┐
54
+ │ Query Engine (Orchestrator) │
55
+ │ - Context assembly │
56
+ │ - Prompt formatting │
57
+ │ - Source diversity │
58
+ └──┬────────────────────────┬────────────┘
59
+ │ │
60
+ ┌───────▼──────┐ ┌──────▼──────────┐
61
+ │ Retriever │ │ Ollama Client │
62
+ │ (Hybrid) │ │ - LLM (Mistral)│
63
+ │ Dense: 0.7 │ │ │
64
+ │ Sparse: 0.3 │ │ Sentence- │
65
+ └──┬───────────┘ │ Transformers │
66
+ │ │ - Embeddings │
67
+ │ │ (all-mpnet) │
68
+ │ └─────────────────┘
69
+
70
+ ┌───────▼──────────┐
71
+ │ Reranker │
72
+ │ (CrossEncoder) │
73
+ │ ms-marco-MiniLM │
74
+ └──┬───────────────┘
75
+
76
+
77
+ ┌────────────────────────────────────┐
78
+ │ Qdrant Vector Store │
79
+ │ - Dense vectors (768-dim) │
80
+ │ - Sparse vectors (BM25) │
81
+ │ - Metadata filtering │
82
+ │ - Local storage │
83
+ └────────────────────────────────────┘
84
+ ```
85
+
86
+ **Data Flow:**
87
+ 1. **Scraping** → EyeWiki → Raw Markdown
88
+ 2. **Processing** → Chunking → Metadata Extraction → JSON
89
+ 3. **Indexing** → Embeddings → Vector Store
90
+ 4. **Query** → Retrieval → Reranking → LLM → Response
91
+
92
+ ## 📁 Project Structure
93
+
94
+ ```
95
+ eyewiki-rag/
96
+ ├── src/
97
+ │ ├── scraper/ # Web scraping (crawl4ai)
98
+ │ │ └── eyewiki_crawler.py
99
+ │ ├── processing/ # Document processing
100
+ │ │ ├── chunker.py # Semantic chunking
101
+ │ │ └── metadata_extractor.py # Medical metadata
102
+ │ ├── vectorstore/ # Vector database
103
+ │ │ └── qdrant_store.py
104
+ │ ├── rag/ # RAG components
105
+ │ │ ├── retriever.py # Hybrid retrieval
106
+ │ │ ├── reranker.py # Cross-encoder reranking
107
+ │ │ └── query_engine.py # Main orchestrator
108
+ │ ├── llm/ # LLM integration
109
+ │ │ ├── ollama_client.py # Ollama for LLM generation
110
+ │ │ └── sentence_transformer_client.py # Stable embeddings
111
+ │ ├── api/ # FastAPI server
112
+ │ │ ├── main.py # API endpoints
113
+ │ │ └── gradio_ui.py # Web interface
114
+ │ └── config/ # Configuration
115
+ │ └── settings.py
116
+ ├── prompts/ # Customizable prompts
117
+ │ ├── system_prompt.txt
118
+ │ ├── query_prompt.txt
119
+ │ └── medical_disclaimer.txt
120
+ ├── scripts/ # Utility scripts
121
+ │ ├── scrape_eyewiki.py # Web scraping
122
+ │ ├── build_index.py # Index building
123
+ │ ├── run_server.py # Server startup
124
+ │ └── evaluate.py # System evaluation
125
+ ├── tests/ # Comprehensive test suite
126
+ │ ├── test_components.py # Component tests
127
+ │ ├── test_questions.json # Evaluation questions
128
+ │ └── conftest.py
129
+ ├── data/ # Data storage (gitignored)
130
+ │ ├── raw/ # Scraped content
131
+ │ ├── processed/ # Chunked documents
132
+ │ └── qdrant/ # Vector database
133
+ └── requirements.txt # Python dependencies
134
+ ```
135
+
136
+ ## 📋 Prerequisites
137
+
138
+ ### Required
139
+ - **Python 3.10+** (tested on 3.10, 3.11)
140
+ - **Ollama** (for local LLM text generation only)
141
+ - Install: https://ollama.ai/download
142
+ - Note: Embeddings now use sentence-transformers (more stable)
143
+ - **8GB+ RAM** (16GB recommended for larger datasets)
144
+ - **10GB+ disk space** (for models and vector store)
145
+
146
+ ### Optional
147
+ - **CUDA-capable GPU** (for faster embedding generation with sentence-transformers)
148
+ - **Docker** (if running Qdrant in container)
149
+
150
+ ### System Requirements by Component
151
+ | Component | RAM | CPU | GPU | Disk |
152
+ |-----------|-----|-----|-----|------|
153
+ | Scraping | 2GB | 2 cores | No | 500MB |
154
+ | Processing | 4GB | 4 cores | No | 2GB |
155
+ | Indexing | 8GB | 4 cores | Optional | 5GB |
156
+ | API Server | 4GB | 2 cores | Optional | 100MB |
157
+
158
+ ## 🚀 Quick Start
159
+
160
+ ### Step 1: Installation
161
+
162
+ ```bash
163
+ # Clone repository
164
+ git clone <repository-url>
165
+ cd eyewiki-rag
166
+
167
+ # Create virtual environment
168
+ python -m venv venv
169
+ source venv/bin/activate # Windows: venv\Scripts\activate
170
+
171
+ # Install Python dependencies
172
+ pip install -r requirements.txt
173
+
174
+ # Install system dependencies for Playwright (Linux/WSL only)
175
+ # This installs required shared libraries (libnss3, libnspr4, etc.)
176
+ python -m playwright install-deps
177
+ ```
178
+
179
+ ### Step 2: Install Ollama and LLM Model
180
+
181
+ ```bash
182
+ # Install Ollama from https://ollama.ai/download
183
+ # Then pull required LLM model:
184
+
185
+ ollama pull mistral # LLM model (4.1GB)
186
+ # or use smaller alternative:
187
+ ollama pull llama3.2:3b # Smaller LLM (2GB)
188
+
189
+ # Note: Embedding model (sentence-transformers) will be auto-downloaded
190
+ # when you first run build_index.py (no Ollama needed for embeddings!)
191
+ ```
192
+
193
+ ### Step 3: Scrape EyeWiki
194
+
195
+ ```bash
196
+ # Quick test (50 pages, ~5 minutes)
197
+ python scripts/scrape_eyewiki.py --max-pages 50
198
+
199
+ # Full crawl (1000+ pages, ~2 hours)
200
+ python scripts/scrape_eyewiki.py --max-pages 1000
201
+
202
+ # Resume from checkpoint
203
+ python scripts/scrape_eyewiki.py --resume
204
+ ```
205
+
206
+ **Output:** `data/raw/*.json` (markdown files with metadata)
207
+
208
+ ### Step 4: Build Vector Index
209
+
210
+ ```bash
211
+ # Process documents and build vector index
212
+ python scripts/build_index.py --index-vectors
213
+
214
+ # This will:
215
+ # 1. Chunk documents (data/processed/)
216
+ # 2. Extract metadata
217
+ # 3. Generate embeddings using sentence-transformers (all-mpnet-base-v2)
218
+ # 4. Build Qdrant index (data/qdrant/)
219
+
220
+ # Optional: Use different embedding model
221
+ python scripts/build_index.py --index-vectors --embedding-model "BAAI/bge-base-en-v1.5"
222
+ ```
223
+
224
+ **Time:** ~10-30 minutes depending on dataset size
225
+ **Note:** First run will download the embedding model (~400MB for all-mpnet-base-v2)
226
+
227
+ ### Step 5: Start Server
228
+
229
+ ```bash
230
+ # Run with pre-flight checks
231
+ python scripts/run_server.py
232
+
233
+ # Development mode with hot reload
234
+ python scripts/run_server.py --reload
235
+
236
+ # Custom port
237
+ python scripts/run_server.py --port 8080
238
+ ```
239
+
240
+ ### Step 6: Access the System
241
+
242
+ **Web Interface:** http://localhost:8000/ui
243
+ - Beautiful chat interface
244
+ - Example questions
245
+ - Source citations
246
+ - Settings sidebar
247
+
248
+ **API Docs:** http://localhost:8000/docs
249
+ - Swagger UI
250
+ - Interactive testing
251
+ - Full API documentation
252
+
253
+ **Health Check:** http://localhost:8000/health
254
+
255
+ ### Example Query
256
+
257
+ ```bash
258
+ curl -X POST http://localhost:8000/query \
259
+ -H "Content-Type: application/json" \
260
+ -d '{
261
+ "question": "What are the symptoms of glaucoma?",
262
+ "include_sources": true
263
+ }'
264
+ ```
265
+
266
+ ## 🐳 Docker Deployment
267
+
268
+ ### Prerequisites
269
+ - **Docker** and **Docker Compose** installed
270
+ - **Ollama** running on host machine (for GPU access)
271
+ - **8GB+ RAM** allocated to Docker
272
+
273
+ ### Quick Start with Docker
274
+
275
+ ```bash
276
+ # 1. Ensure Ollama is running on host
277
+ ollama serve
278
+
279
+ # 2. Pull required models (on host)
280
+ ollama pull nomic-embed-text
281
+ ollama pull mistral
282
+
283
+ # 3. Build and start services
284
+ docker-compose up -d
285
+
286
+ # 4. Check status
287
+ docker-compose ps
288
+
289
+ # 5. View logs
290
+ docker-compose logs -f eyewiki-rag
291
+ ```
292
+
293
+ **Access:**
294
+ - API: http://localhost:8000
295
+ - Gradio UI: http://localhost:8000/ui
296
+ - API Docs: http://localhost:8000/docs
297
+ - Qdrant Dashboard: http://localhost:6333/dashboard
298
+
299
+ ### Using Makefile Commands
300
+
301
+ ```bash
302
+ # Start services
303
+ make up
304
+
305
+ # View logs
306
+ make logs
307
+
308
+ # Check health
309
+ make health
310
+
311
+ # Run scraper in container
312
+ make scrape ARGS="--max-pages 50"
313
+
314
+ # Build index
315
+ make build-index
316
+
317
+ # Run evaluation
318
+ make evaluate
319
+
320
+ # Stop services
321
+ make down
322
+
323
+ # Clean everything
324
+ make clean
325
+ ```
326
+
327
+ ### Docker Compose Services
328
+
329
+ **eyewiki-rag** (API Server)
330
+ - Built from Dockerfile
331
+ - Exposes port 8000
332
+ - Connects to Ollama on host via `host.docker.internal`
333
+ - Connects to Qdrant container
334
+ - Mounts data volumes for persistence
335
+
336
+ **qdrant** (Vector Database)
337
+ - Official Qdrant image
338
+ - Exposes ports 6333 (REST) and 6334 (gRPC)
339
+ - Persistent volume for vector storage
340
+ - Health checks enabled
341
+
342
+ ### Volume Management
343
+
344
+ **Persistent volumes:**
345
+ - `./data/raw` - Scraped content
346
+ - `./data/processed` - Chunked documents
347
+ - `qdrant_data` - Vector database (Docker volume)
348
+ - `./prompts` - Customizable prompts
349
+
350
+ **Backup Qdrant data:**
351
+ ```bash
352
+ make backup-qdrant
353
+ # Saves to ./backups/qdrant-backup-YYYYMMDD-HHMMSS.tar.gz
354
+ ```
355
+
356
+ **Restore Qdrant data:**
357
+ ```bash
358
+ make restore-qdrant BACKUP=backups/qdrant-backup-20241209-120000.tar.gz
359
+ ```
360
+
361
+ ### Configuration via Environment Variables
362
+
363
+ Edit `docker-compose.yml` to customize:
364
+
365
+ ```yaml
366
+ environment:
367
+ # Ollama settings
368
+ - OLLAMA_BASE_URL=http://host.docker.internal:11434
369
+ - LLM_MODEL=mistral
370
+ - EMBEDDING_MODEL=nomic-embed-text
371
+
372
+ # Qdrant settings
373
+ - QDRANT_HOST=qdrant
374
+ - QDRANT_PORT=6333
375
+
376
+ # Processing settings
377
+ - CHUNK_SIZE=512
378
+ - RETRIEVAL_K=20
379
+ - RERANK_K=5
380
+ ```
381
+
382
+ ### Running Scripts in Container
383
+
384
+ ```bash
385
+ # Scrape EyeWiki
386
+ docker-compose exec eyewiki-rag python scripts/scrape_eyewiki.py --max-pages 100
387
+
388
+ # Build index
389
+ docker-compose exec eyewiki-rag python scripts/build_index.py --index-vectors
390
+
391
+ # Run evaluation
392
+ docker-compose exec eyewiki-rag python scripts/evaluate.py
393
+
394
+ # Run tests
395
+ docker-compose exec eyewiki-rag pytest tests/ -v
396
+ ```
397
+
398
+ ### Production Deployment
399
+
400
+ ```bash
401
+ # Start in production mode
402
+ make prod
403
+
404
+ # Or manually:
405
+ docker-compose up -d --remove-orphans
406
+
407
+ # Monitor with healthchecks
408
+ watch docker-compose ps
409
+
410
+ # View metrics
411
+ docker stats eyewiki-rag-api eyewiki-qdrant
412
+ ```
413
+
414
+ ### Troubleshooting Docker
415
+
416
+ **Problem:** Cannot connect to Ollama
417
+
418
+ **Solution:**
419
+ ```bash
420
+ # Linux: Use host.docker.internal
421
+ # If not working, use host IP:
422
+ docker network inspect eyewiki-network
423
+ # Update OLLAMA_BASE_URL to http://<host-ip>:11434
424
+
425
+ # Or on Linux, add to docker-compose.yml:
426
+ extra_hosts:
427
+ - "host.docker.internal:host-gateway"
428
+ ```
429
+
430
+ **Problem:** Qdrant volume permission issues
431
+
432
+ **Solution:**
433
+ ```bash
434
+ # Fix permissions
435
+ sudo chown -R 1000:1000 data/qdrant/
436
+ ```
437
+
438
+ **Problem:** Out of memory
439
+
440
+ **Solution:**
441
+ ```bash
442
+ # Increase Docker memory limit in Docker Desktop
443
+ # Or in docker-compose.yml, add:
444
+ deploy:
445
+ resources:
446
+ limits:
447
+ memory: 8G
448
+ ```
449
+
450
+ ### Docker Image Sizes
451
+
452
+ | Image | Size | Purpose |
453
+ |-------|------|---------|
454
+ | eyewiki-rag | ~2.5GB | API server with dependencies |
455
+ | qdrant/qdrant | ~200MB | Vector database |
456
+ | **Total** | ~2.7GB | Both services |
457
+
458
+ **Note:** Ollama models (~4-5GB) run on host for GPU access.
459
+
460
+ ## ⚙️ Configuration
461
+
462
+ Configuration via `src/config/settings.py` (uses pydantic-settings):
463
+
464
+ | Parameter | Default | Description |
465
+ |-----------|---------|-------------|
466
+ | **LLM Settings** |
467
+ | `llm_model` | `mistral` | Ollama LLM model name |
468
+ | `ollama_base_url` | `http://localhost:11434` | Ollama API URL |
469
+ | `llm_temperature` | `0.7` | LLM sampling temperature |
470
+ | `llm_max_tokens` | `2048` | Max tokens for LLM response |
471
+ | **Embedding Settings** |
472
+ | `embedding_model` | `all-mpnet-base-v2` | Sentence-transformers model |
473
+ | **Vector Store** |
474
+ | `qdrant_collection_name` | `eyewiki_rag` | Collection name |
475
+ | `qdrant_path` | `./data/vectorstore` | Local storage path |
476
+ | `qdrant_url` | `None` | Remote Qdrant URL (optional) |
477
+ | **Chunking** |
478
+ | `chunk_size` | `512` | Max tokens per chunk |
479
+ | `chunk_overlap` | `50` | Overlap between chunks |
480
+ | `min_chunk_size` | `100` | Minimum chunk size |
481
+ | **Retrieval** |
482
+ | `top_k` | `10` | Initial retrieval count |
483
+ | `rerank_top_k` | `5` | After reranking |
484
+ | `similarity_threshold` | `0.7` | Minimum similarity score |
485
+ | **Scraper** |
486
+ | `scraper_delay` | `1.0` | Delay between requests (seconds) |
487
+ | `scraper_timeout` | `30` | Request timeout (seconds) |
488
+
489
+ ### Environment Variables
490
+
491
+ Create `.env` file to override defaults (see `.env.example`):
492
+
493
+ ```env
494
+ # Ollama Configuration (for LLM only)
495
+ OLLAMA_BASE_URL=http://localhost:11434
496
+ LLM_MODEL=mistral
497
+ LLM_TEMPERATURE=0.7
498
+ LLM_MAX_TOKENS=2048
499
+
500
+ # Embedding Configuration (sentence-transformers)
501
+ EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
502
+
503
+ # Qdrant Vector Store
504
+ QDRANT_COLLECTION_NAME=eyewiki_rag
505
+ QDRANT_PATH=./data/vectorstore
506
+ # QDRANT_URL=http://localhost:6333 # For remote Qdrant
507
+ # QDRANT_API_KEY=your-key # For Qdrant Cloud
508
+
509
+ # Document Processing
510
+ CHUNK_SIZE=512
511
+ CHUNK_OVERLAP=50
512
+ MIN_CHUNK_SIZE=100
513
+
514
+ # RAG Retrieval
515
+ TOP_K=10
516
+ RERANK_TOP_K=5
517
+ SIMILARITY_THRESHOLD=0.7
518
+
519
+ # Web Scraper
520
+ SCRAPER_DELAY=1.0
521
+ SCRAPER_TIMEOUT=30
522
+
523
+ # API Server
524
+ API_HOST=0.0.0.0
525
+ API_PORT=8000
526
+ API_WORKERS=4
527
+
528
+ # Gradio UI
529
+ GRADIO_HOST=0.0.0.0
530
+ GRADIO_PORT=7860
531
+ GRADIO_SHARE=false
532
+
533
+ # Data Paths
534
+ DATA_RAW_PATH=./data/raw
535
+ DATA_PROCESSED_PATH=./data/processed
536
+
537
+ # Logging
538
+ LOG_LEVEL=INFO
539
+ LOG_FILE=logs/eyewiki_rag.log
540
+ ```
541
+
542
+ ### Customizing Prompts
543
+
544
+ Edit files in `prompts/` directory:
545
+ - `system_prompt.txt` - System instructions for LLM
546
+ - `query_prompt.txt` - Query template with `{context}` and `{question}` placeholders
547
+ - `medical_disclaimer.txt` - Medical disclaimer text
548
+
549
+ ## 📡 API Documentation
550
+
551
+ ### Endpoints
552
+
553
+ #### `GET /`
554
+ Root endpoint with API information
555
+
556
+ #### `GET /health`
557
+ Health check endpoint
558
+
559
+ **Response:**
560
+ ```json
561
+ {
562
+ "status": "healthy",
563
+ "ollama": {"status": "healthy", "models": {...}},
564
+ "qdrant": {"status": "healthy", "vectors_count": 1234},
565
+ "query_engine": {"status": "initialized"},
566
+ "timestamp": 1702134567.89
567
+ }
568
+ ```
569
+
570
+ #### `POST /query`
571
+ Main query endpoint
572
+
573
+ **Request:**
574
+ ```json
575
+ {
576
+ "question": "What is glaucoma?",
577
+ "include_sources": true,
578
+ "filters": {"disease_name": "Glaucoma"} // optional
579
+ }
580
+ ```
581
+
582
+ **Response:**
583
+ ```json
584
+ {
585
+ "answer": "Glaucoma is a group of eye diseases...",
586
+ "sources": [
587
+ {
588
+ "title": "Primary Open-Angle Glaucoma",
589
+ "url": "https://eyewiki.aao.org/...",
590
+ "section": "Overview",
591
+ "relevance_score": 0.89
592
+ }
593
+ ],
594
+ "confidence": 0.85,
595
+ "disclaimer": "Medical disclaimer text...",
596
+ "query": "What is glaucoma?"
597
+ }
598
+ ```
599
+
600
+ #### `POST /query/stream`
601
+ Streaming query with Server-Sent Events
602
+
603
+ **Request:**
604
+ ```json
605
+ {
606
+ "question": "What is glaucoma?",
607
+ "filters": {} // optional
608
+ }
609
+ ```
610
+
611
+ **Response:** SSE stream
612
+ ```
613
+ data: Glaucoma
614
+ data: is
615
+ data: a group of eye diseases...
616
+ ```
617
+
618
+ #### `GET /stats`
619
+ Index and pipeline statistics
620
+
621
+ **Response:**
622
+ ```json
623
+ {
624
+ "collection_info": {
625
+ "name": "eyewiki_rag",
626
+ "vectors_count": 1234
627
+ },
628
+ "pipeline_config": {
629
+ "retrieval_k": 20,
630
+ "rerank_k": 5,
631
+ "llm_model": "mistral"
632
+ },
633
+ "documents_indexed": 1234,
634
+ "timestamp": 1702134567.89
635
+ }
636
+ ```
637
+
638
+ ### Python Client Example
639
+
640
+ ```python
641
+ import requests
642
+
643
+ # Query the API
644
+ response = requests.post(
645
+ "http://localhost:8000/query",
646
+ json={
647
+ "question": "What causes diabetic retinopathy?",
648
+ "include_sources": True
649
+ }
650
+ )
651
+
652
+ result = response.json()
653
+ print(f"Answer: {result['answer']}")
654
+ print(f"Confidence: {result['confidence']:.2%}")
655
+ print(f"Sources: {len(result['sources'])}")
656
+ ```
657
+
658
+ ### Streaming Example
659
+
660
+ ```python
661
+ import requests
662
+
663
+ response = requests.post(
664
+ "http://localhost:8000/query/stream",
665
+ json={"question": "What is glaucoma?"},
666
+ stream=True
667
+ )
668
+
669
+ for line in response.iter_lines():
670
+ if line.startswith(b"data: "):
671
+ chunk = line[6:].decode()
672
+ print(chunk, end="", flush=True)
673
+ ```
674
+
675
+ ## 🧪 Development
676
+
677
+ ### Running Tests
678
+
679
+ ```bash
680
+ # Run all tests
681
+ pytest
682
+
683
+ # Run with coverage
684
+ pytest --cov=src --cov-report=html
685
+
686
+ # Run specific test file
687
+ pytest tests/test_components.py -v
688
+
689
+ # Run specific test
690
+ pytest tests/test_components.py::test_chunk_respects_headers -v
691
+
692
+ # Run by marker
693
+ pytest -m unit # Fast unit tests
694
+ pytest -m api # API tests
695
+ ```
696
+
697
+ ### Code Quality
698
+
699
+ ```bash
700
+ # Format code
701
+ black src/ scripts/ tests/
702
+ isort src/ scripts/ tests/
703
+
704
+ # Lint
705
+ flake8 src/
706
+ pylint src/
707
+
708
+ # Type checking
709
+ mypy src/
710
+ ```
711
+
712
+ ### Evaluation
713
+
714
+ Run system evaluation on test questions:
715
+
716
+ ```bash
717
+ # Run evaluation
718
+ python scripts/evaluate.py
719
+
720
+ # With custom questions
721
+ python scripts/evaluate.py --questions tests/custom_questions.json
722
+
723
+ # Save results
724
+ python scripts/evaluate.py --output results/eval.json
725
+
726
+ # Verbose output
727
+ python scripts/evaluate.py -v
728
+ ```
729
+
730
+ **Metrics:**
731
+ - Retrieval Recall
732
+ - Answer Relevance
733
+ - Citation Precision/Recall/F1
734
+ - Performance by category
735
+
736
+ ## 🔧 Troubleshooting
737
+
738
+ ### Ollama Issues
739
+
740
+ **Problem:** "Connection refused" to Ollama
741
+
742
+ **Solution:**
743
+ ```bash
744
+ # Check if Ollama is running
745
+ curl http://localhost:11434/api/tags
746
+
747
+ # Start Ollama
748
+ ollama serve
749
+
750
+ # Verify models are installed
751
+ ollama list
752
+ ```
753
+
754
+ **Problem:** "Model not found"
755
+
756
+ **Solution:**
757
+ ```bash
758
+ # Pull required models
759
+ ollama pull nomic-embed-text
760
+ ollama pull mistral
761
+
762
+ # List available models
763
+ ollama list
764
+ ```
765
+
766
+ ### Vector Store Issues
767
+
768
+ **Problem:** "Collection not found"
769
+
770
+ **Solution:**
771
+ ```bash
772
+ # Rebuild the index
773
+ python scripts/build_index.py --index-vectors --recreate-collection
774
+
775
+ # Check Qdrant data directory
776
+ ls -la data/qdrant/
777
+ ```
778
+
779
+ **Problem:** "Out of memory during indexing"
780
+
781
+ **Solution:**
782
+ ```bash
783
+ # Use smaller batch size
784
+ python scripts/build_index.py --index-vectors --embedding-batch-size 16
785
+
786
+ # Or process in stages
787
+ python scripts/build_index.py # Process only (no indexing)
788
+ python scripts/build_index.py --index-only # Index separately
789
+ ```
790
+
791
+ ### Scraping Issues
792
+
793
+ **Problem:** "Rate limited by EyeWiki"
794
+
795
+ **Solution:**
796
+ ```bash
797
+ # Increase delay between requests
798
+ python scripts/scrape_eyewiki.py --delay 5.0
799
+
800
+ # Resume from checkpoint if interrupted
801
+ python scripts/scrape_eyewiki.py --resume
802
+ ```
803
+
804
+ **Problem:** "Timeout during scraping"
805
+
806
+ **Solution:**
807
+ ```bash
808
+ # Increase timeout
809
+ python scripts/scrape_eyewiki.py --timeout 60
810
+ ```
811
+
812
+ **Problem:** "error while loading shared libraries: libnspr4.so" or browser crashes
813
+
814
+ **Solution:**
815
+ ```bash
816
+ # Install Playwright system dependencies (Linux/WSL)
817
+ python -m playwright install-deps
818
+
819
+ # Or manually install required libraries
820
+ sudo apt-get update && sudo apt-get install -y \
821
+ libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 \
822
+ libcups2 libdrm2 libdbus-1-3 libxkbcommon0 \
823
+ libatspi2.0-0 libxcomposite1 libxdamage1 \
824
+ libxfixes3 libxrandr2 libgbm1 libasound2
825
+ ```
826
+
827
+ **Problem:** "Executable doesn't exist" - Chromium browser not found
828
+
829
+ **Solution:**
830
+ ```bash
831
+ # Install Playwright browsers
832
+ playwright install chromium
833
+
834
+ # Or install all browsers
835
+ playwright install
836
+ ```
837
+
838
+ ### API Server Issues
839
+
840
+ **Problem:** "Pre-flight checks failed"
841
+
842
+ **Solution:**
843
+ 1. Check Ollama is running: `ollama serve`
844
+ 2. Verify models: `ollama list`
845
+ 3. Check vector store: `ls data/qdrant/`
846
+ 4. View logs for specific error
847
+
848
+ **Problem:** "Gradio UI not loading"
849
+
850
+ **Solution:**
851
+ ```bash
852
+ # Check if port is in use
853
+ lsof -i :8000
854
+
855
+ # Use different port
856
+ python scripts/run_server.py --port 8080
857
+
858
+ # Skip checks for testing
859
+ python scripts/run_server.py --skip-checks
860
+ ```
861
+
862
+ ### Performance Issues
863
+
864
+ **Problem:** "Slow query responses"
865
+
866
+ **Solution:**
867
+ 1. Use GPU for embeddings (if available)
868
+ 2. Reduce `retrieval_k` and `rerank_k` in config
869
+ 3. Decrease `max_context_tokens`
870
+ 4. Use smaller LLM model (llama3.2:3b instead of mistral)
871
+
872
+ **Problem:** "High memory usage"
873
+
874
+ **Solution:**
875
+ ```bash
876
+ # Use smaller models
877
+ ollama pull llama3.2:3b # Only 2GB
878
+
879
+ # Reduce batch sizes in config
880
+ # Edit src/config/settings.py:
881
+ # chunk_size = 256 (instead of 512)
882
+ # retrieval_k = 10 (instead of 20)
883
+ ```
884
+
885
+ ### Common Error Messages
886
+
887
+ | Error | Cause | Solution |
888
+ |-------|-------|----------|
889
+ | `ConnectionError: Ollama` | Ollama not running | `ollama serve` |
890
+ | `Collection 'eyewiki_rag' not found` | Index not built | `python scripts/build_index.py --index-vectors` |
891
+ | `Model 'mistral' not found` | Model not pulled | `ollama pull mistral` |
892
+ | `503 Service Unavailable` | System not initialized | Check logs, verify dependencies |
893
+ | `422 Validation Error` | Invalid request format | Check API docs |
894
+
895
+ ## 📊 Performance Benchmarks
896
+
897
+ Typical performance on a modern laptop (16GB RAM, M1/M2 or equivalent):
898
+
899
+ | Operation | Time | Notes |
900
+ |-----------|------|-------|
901
+ | Scraping (100 pages) | ~5-10 min | Network dependent |
902
+ | Processing | ~2-5 min | 100 documents |
903
+ | Embedding generation | ~5-10 min | 100 documents |
904
+ | Index building | ~3-5 min | 100 documents |
905
+ | Query (no streaming) | ~2-5s | Includes retrieval + LLM |
906
+ | Query (streaming) | ~0.5s first token | Then ~50 tokens/s |
907
+
908
+ ## 📚 Additional Resources
909
+
910
+ ### Documentation
911
+ - [EyeWiki](https://eyewiki.aao.org/) - Source of medical content
912
+ - [Ollama Documentation](https://github.com/ollama/ollama/blob/main/docs/README.md)
913
+ - [Qdrant Documentation](https://qdrant.tech/documentation/)
914
+ - [FastAPI Documentation](https://fastapi.tiangolo.com/)
915
+
916
+ ### Related Projects
917
+ - [LlamaIndex](https://www.llamaindex.ai/) - Data framework for LLM applications
918
+ - [LangChain](https://www.langchain.com/) - Framework for developing LLM applications
919
+ - [Haystack](https://haystack.deepset.ai/) - End-to-end NLP framework
920
+
921
+ ### Papers & Resources
922
+ - [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401)
923
+ - [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906)
924
+
925
+ ## ⚠️ Medical Disclaimer
926
+
927
+ **IMPORTANT:** This system provides information from EyeWiki, a resource of the American Academy of Ophthalmology (AAO).
928
+
929
+ The information provided by this system:
930
+ - Is not a substitute for professional medical advice, diagnosis, or treatment
931
+ - May contain errors due to AI limitations
932
+ - Should be verified with authoritative sources before clinical use
933
+
934
+ Always consult with a qualified ophthalmologist or eye care professional for medical concerns. This system should not be used for:
935
+ - Clinical decision-making
936
+ - Patient diagnosis
937
+ - Treatment recommendations
938
+ - Emergency medical situations
939
+
940
+ ## 📄 License
941
+
942
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
943
+
944
+ ### Third-Party Licenses
945
+ - **EyeWiki Content**: © American Academy of Ophthalmology - Used under fair use for research purposes
946
+ - **Ollama**: Apache 2.0 License
947
+ - **Qdrant**: Apache 2.0 License
948
+ - **FastAPI**: MIT License
949
+ - **Gradio**: Apache 2.0 License
950
+
951
+ ## 🙏 Attribution
952
+
953
+ ### EyeWiki & AAO
954
+ This project uses content from [EyeWiki](https://eyewiki.aao.org/), the collaborative online encyclopedia of ophthalmology created and maintained by the [American Academy of Ophthalmology (AAO)](https://www.aao.org/).
955
+
956
+ **Citation:**
957
+ > American Academy of Ophthalmology. EyeWiki. Available at: https://eyewiki.aao.org/. Accessed [Date].
958
+
959
+ ### Models & Libraries
960
+ - **nomic-embed-text**: [Nomic AI](https://www.nomic.ai/)
961
+ - **mistral**: [Mistral AI](https://mistral.ai/)
962
+ - **sentence-transformers**: [UKPLab](https://www.ukp.tu-darmstadt.de/)
963
+ - **crawl4ai**: [Web scraping framework](https://github.com/unclecode/crawl4ai)
964
+
965
+ ## 🤝 Contributing
966
+
967
+ Contributions are welcome! Here's how you can help:
968
+
969
+ ### Areas for Contribution
970
+ - 🐛 Bug fixes
971
+ - ✨ New features
972
+ - 📝 Documentation improvements
973
+ - 🧪 Test coverage
974
+ - 🎨 UI/UX enhancements
975
+ - 🌍 Internationalization
976
+
977
+ ### Development Workflow
978
+ 1. Fork the repository
979
+ 2. Create a feature branch: `git checkout -b feature/amazing-feature`
980
+ 3. Make your changes
981
+ 4. Run tests: `pytest`
982
+ 5. Commit: `git commit -m 'Add amazing feature'`
983
+ 6. Push: `git push origin feature/amazing-feature`
984
+ 7. Open a Pull Request
985
+
986
+ ### Code Style
987
+ - Follow PEP 8
988
+ - Use Black for formatting
989
+ - Add type hints
990
+ - Write docstrings
991
+ - Include tests for new features
992
+
993
+ ## 📞 Support
994
+
995
+ - **Issues**: [GitHub Issues](https://github.com/your-repo/issues)
996
+ - **Discussions**: [GitHub Discussions](https://github.com/your-repo/discussions)
997
+ - **Email**: your-email@example.com
998
+
999
+ ## 🗺️ Roadmap
1000
+
1001
+ ### Planned Features
1002
+ - [ ] Multi-language support
1003
+ - [ ] PDF document upload
1004
+ - [ ] Advanced filtering (date, author, etc.)
1005
+ - [ ] Conversation history
1006
+ - [ ] Feedback mechanism
1007
+ - [ ] Export answers to PDF
1008
+ - [ ] Mobile-responsive UI
1009
+ - [ ] Docker deployment
1010
+ - [ ] Cloud deployment guide (AWS, GCP, Azure)
1011
+ - [ ] Integration with medical record systems
1012
+
1013
+ ### Future Improvements
1014
+ - [ ] Support for images in articles
1015
+ - [ ] Better handling of tables and diagrams
1016
+ - [ ] Citation formatting options (APA, MLA, etc.)
1017
+ - [ ] Multi-modal retrieval (text + images)
1018
+ - [ ] Custom model fine-tuning
1019
+
1020
+ ## ⭐ Star History
1021
+
1022
+ If you find this project helpful, please consider giving it a star!
1023
+
1024
  ---
1025
 
1026
+ **Built with ❤️ for the ophthalmology community**
config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Configuration package
config/settings.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration settings for EyeWiki RAG system."""
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from pydantic import Field
8
+ from pydantic_settings import BaseSettings, SettingsConfigDict
9
+
10
+
11
+ class LLMProvider(str, Enum):
12
+ """Supported LLM providers."""
13
+
14
+ OLLAMA = "ollama"
15
+ OPENAI = "openai"
16
+
17
+
18
+ class Settings(BaseSettings):
19
+ """Application settings loaded from environment variables."""
20
+
21
+ model_config = SettingsConfigDict(
22
+ env_file=".env",
23
+ env_file_encoding="utf-8",
24
+ case_sensitive=False,
25
+ extra="ignore",
26
+ )
27
+
28
+ # LLM Provider Configuration
29
+ llm_provider: LLMProvider = Field(
30
+ default=LLMProvider.OLLAMA,
31
+ description="LLM provider to use: 'ollama' for local Ollama, 'openai' for OpenAI-compatible APIs (Groq, DeepSeek, OpenAI)",
32
+ )
33
+
34
+ # Ollama Configuration
35
+ ollama_base_url: str = Field(
36
+ default="http://localhost:11434",
37
+ description="Base URL for Ollama API",
38
+ )
39
+ ollama_timeout: int = Field(
40
+ default=30,
41
+ gt=0,
42
+ description="Request timeout for Ollama API in seconds",
43
+ )
44
+ embedding_model: str = Field(
45
+ default="nomic-embed-text",
46
+ description="Ollama embedding model name",
47
+ )
48
+ llm_model: str = Field(
49
+ default="mistral",
50
+ description="Ollama LLM model name",
51
+ )
52
+ llm_temperature: float = Field(
53
+ default=0.7,
54
+ ge=0.0,
55
+ le=2.0,
56
+ description="LLM temperature for response generation",
57
+ )
58
+ llm_max_tokens: int = Field(
59
+ default=2048,
60
+ gt=0,
61
+ description="Maximum tokens for LLM response",
62
+ )
63
+
64
+ # OpenAI-compatible API Configuration (for Groq, DeepSeek, OpenAI, etc.)
65
+ openai_api_key: Optional[str] = Field(
66
+ default=None,
67
+ description="API key for OpenAI-compatible provider",
68
+ )
69
+ openai_base_url: Optional[str] = Field(
70
+ default=None,
71
+ description="Base URL for OpenAI-compatible API (e.g., https://api.groq.com/openai/v1 for Groq)",
72
+ )
73
+ openai_model: str = Field(
74
+ default="llama-3.3-70b-versatile",
75
+ description="Model name for OpenAI-compatible provider (e.g., llama-3.3-70b-versatile for Groq)",
76
+ )
77
+
78
+ # Qdrant Configuration
79
+ qdrant_path: str = Field(
80
+ default="./data/vectorstore",
81
+ description="Path to Qdrant vector database",
82
+ )
83
+ qdrant_collection_name: str = Field(
84
+ default="eyewiki_rag",
85
+ description="Qdrant collection name",
86
+ )
87
+ qdrant_url: Optional[str] = Field(
88
+ default=None,
89
+ description="Qdrant server URL (for remote Qdrant)",
90
+ )
91
+ qdrant_api_key: Optional[str] = Field(
92
+ default=None,
93
+ description="Qdrant API key (for Qdrant Cloud)",
94
+ )
95
+
96
+ # Document Processing Configuration
97
+ chunk_size: int = Field(
98
+ default=512,
99
+ gt=0,
100
+ description="Size of text chunks for processing",
101
+ )
102
+ chunk_overlap: int = Field(
103
+ default=50,
104
+ ge=0,
105
+ description="Overlap between consecutive chunks",
106
+ )
107
+ min_chunk_size: int = Field(
108
+ default=100,
109
+ gt=0,
110
+ description="Minimum chunk size in tokens (skip smaller chunks)",
111
+ )
112
+
113
+ # RAG Configuration
114
+ top_k: int = Field(
115
+ default=10,
116
+ gt=0,
117
+ description="Number of documents to retrieve",
118
+ )
119
+ rerank_top_k: int = Field(
120
+ default=5,
121
+ gt=0,
122
+ description="Number of documents after reranking",
123
+ )
124
+ similarity_threshold: float = Field(
125
+ default=0.7,
126
+ ge=0.0,
127
+ le=1.0,
128
+ description="Minimum similarity score for retrieval",
129
+ )
130
+ reranker_model: str = Field(
131
+ default="cross-encoder/ms-marco-MiniLM-L-6-v2",
132
+ description="Cross-encoder model for reranking",
133
+ )
134
+ max_context_tokens: int = Field(
135
+ default=4096,
136
+ gt=0,
137
+ description="Maximum tokens for context in LLM prompt",
138
+ )
139
+
140
+ # Scraper Configuration
141
+ scraper_delay: float = Field(
142
+ default=1.0,
143
+ ge=0.0,
144
+ description="Delay between scraping requests in seconds",
145
+ )
146
+ scraper_max_pages: Optional[int] = Field(
147
+ default=None,
148
+ description="Maximum number of pages to scrape (None for unlimited)",
149
+ )
150
+ scraper_timeout: int = Field(
151
+ default=30,
152
+ gt=0,
153
+ description="Request timeout in seconds",
154
+ )
155
+
156
+ # API Configuration
157
+ api_host: str = Field(
158
+ default="0.0.0.0",
159
+ description="API server host",
160
+ )
161
+ api_port: int = Field(
162
+ default=8000,
163
+ gt=0,
164
+ le=65535,
165
+ description="API server port",
166
+ )
167
+ api_workers: int = Field(
168
+ default=4,
169
+ gt=0,
170
+ description="Number of API workers",
171
+ )
172
+
173
+ # Gradio UI Configuration
174
+ gradio_host: str = Field(
175
+ default="0.0.0.0",
176
+ description="Gradio UI host",
177
+ )
178
+ gradio_port: int = Field(
179
+ default=7860,
180
+ gt=0,
181
+ le=65535,
182
+ description="Gradio UI port",
183
+ )
184
+ gradio_share: bool = Field(
185
+ default=False,
186
+ description="Create public Gradio share link",
187
+ )
188
+
189
+ # Data Paths
190
+ data_raw_path: str = Field(
191
+ default="./data/raw",
192
+ description="Path to raw scraped data",
193
+ )
194
+ data_processed_path: str = Field(
195
+ default="./data/processed",
196
+ description="Path to processed documents",
197
+ )
198
+
199
+ # Logging
200
+ log_level: str = Field(
201
+ default="INFO",
202
+ description="Logging level",
203
+ )
204
+ log_file: Optional[str] = Field(
205
+ default="logs/eyewiki_rag.log",
206
+ description="Log file path",
207
+ )
208
+
209
+ def get_data_paths(self) -> dict[str, Path]:
210
+ """Get all data paths as Path objects."""
211
+ return {
212
+ "raw": Path(self.data_raw_path),
213
+ "processed": Path(self.data_processed_path),
214
+ "vectorstore": Path(self.qdrant_path),
215
+ }
216
+
217
+ def ensure_data_directories(self) -> None:
218
+ """Create data directories if they don't exist."""
219
+ for path in self.get_data_paths().values():
220
+ path.mkdir(parents=True, exist_ok=True)
221
+
222
+ # Create logs directory if log_file is specified
223
+ if self.log_file:
224
+ log_path = Path(self.log_file)
225
+ log_path.parent.mkdir(parents=True, exist_ok=True)
226
+
227
+
228
+ # Create global settings instance
229
+ settings = Settings()
data/processed/.gitkeep ADDED
File without changes
data/raw/.gitkeep ADDED
File without changes
data/vectorstore/.gitkeep ADDED
File without changes
deployment_readme.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deployment Guide - EyeWiki RAG on Free Hosting
2
+
3
+ This guide covers deploying the EyeWiki RAG system using free/cheap cloud services:
4
+
5
+ - **App Hosting**: Hugging Face Spaces (Docker SDK)
6
+ - **Vector Database**: Qdrant Cloud (Free Tier)
7
+ - **LLM Provider**: Groq (Free Tier) or any OpenAI-compatible API
8
+
9
+ ## Prerequisites
10
+
11
+ - A [Hugging Face](https://huggingface.co) account
12
+ - A [Qdrant Cloud](https://cloud.qdrant.io) account
13
+ - A [Groq](https://console.groq.com) account (or other OpenAI-compatible provider)
14
+
15
+ ---
16
+
17
+ ## Step 1: Set Up Qdrant Cloud
18
+
19
+ 1. Go to [Qdrant Cloud](https://cloud.qdrant.io) and create a free cluster.
20
+ 2. Once created, note down:
21
+ - **Cluster URL** (e.g., `https://abc123-xyz.aws.cloud.qdrant.io:6333`)
22
+ - **API Key** (from the cluster dashboard)
23
+ 3. You will need to index your data into the Qdrant Cloud cluster. You can do this locally:
24
+
25
+ ```bash
26
+ export QDRANT_URL="https://your-cluster-url:6333"
27
+ export QDRANT_API_KEY="your-qdrant-api-key"
28
+ python scripts/build_index.py --index-vectors
29
+ ```
30
+
31
+ ## Step 2: Get a Groq API Key
32
+
33
+ 1. Go to [Groq Console](https://console.groq.com) and sign up.
34
+ 2. Create an API key from the dashboard.
35
+ 3. Note down the API key.
36
+
37
+ ## Step 3: Deploy to Hugging Face Spaces
38
+
39
+ ### Option A: Via the HF Web UI
40
+
41
+ 1. Go to [Hugging Face Spaces](https://huggingface.co/spaces) and click **Create new Space**.
42
+ 2. Choose **Docker** as the SDK.
43
+ 3. Upload the project files (or connect a Git repo).
44
+ 4. In the Space **Settings > Variables and secrets**, add:
45
+
46
+ | Variable | Value |
47
+ |-------------------|-----------------------------------------------|
48
+ | `LLM_PROVIDER` | `openai` |
49
+ | `OPENAI_API_KEY` | `gsk_your_groq_api_key` |
50
+ | `OPENAI_BASE_URL` | `https://api.groq.com/openai/v1` |
51
+ | `OPENAI_MODEL` | `llama-3.3-70b-versatile` |
52
+ | `QDRANT_URL` | `https://your-cluster.cloud.qdrant.io:6333` |
53
+ | `QDRANT_API_KEY` | `your_qdrant_api_key` |
54
+
55
+ 5. The Space will build using `Dockerfile.deploy` and start automatically.
56
+
57
+ ### Option B: Via the HF CLI
58
+
59
+ ```bash
60
+ # Install HF CLI
61
+ pip install huggingface_hub
62
+
63
+ # Login
64
+ huggingface-cli login
65
+
66
+ # Create Space
67
+ huggingface-cli repo create eyewiki-rag --type space --space-sdk docker
68
+
69
+ # Clone and push
70
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/eyewiki-rag
71
+ cd eyewiki-rag
72
+ # Copy project files here, then:
73
+ cp /path/to/project/Dockerfile.deploy ./Dockerfile
74
+ git add . && git commit -m "Initial deployment" && git push
75
+ ```
76
+
77
+ Then add the environment variables via the web UI (Settings > Variables and secrets).
78
+
79
+ ## Step 4: Verify Deployment
80
+
81
+ Once the Space is running:
82
+
83
+ 1. Visit your Space URL (e.g., `https://your-username-eyewiki-rag.hf.space`)
84
+ 2. Check the health endpoint: `https://your-username-eyewiki-rag.hf.space/health`
85
+ 3. Try the Gradio UI: `https://your-username-eyewiki-rag.hf.space/ui`
86
+
87
+ ---
88
+
89
+ ## Environment Variables Reference
90
+
91
+ | Variable | Required | Default | Description |
92
+ |---------------------------|----------|----------------------------|----------------------------------------------------|
93
+ | `LLM_PROVIDER` | No | `ollama` | LLM provider: `ollama` or `openai` |
94
+ | `OPENAI_API_KEY` | If openai| - | API key for OpenAI-compatible provider |
95
+ | `OPENAI_BASE_URL` | No | `https://api.openai.com/v1`| Base URL for OpenAI-compatible API |
96
+ | `OPENAI_MODEL` | No | `llama-3.3-70b-versatile` | Model name for the provider |
97
+ | `OLLAMA_BASE_URL` | No | `http://localhost:11434` | Ollama API URL (only for ollama provider) |
98
+ | `LLM_MODEL` | No | `mistral` | Ollama model name (only for ollama provider) |
99
+ | `QDRANT_URL` | No | - | Qdrant Cloud cluster URL |
100
+ | `QDRANT_API_KEY` | No | - | Qdrant Cloud API key |
101
+ | `QDRANT_PATH` | No | `./data/vectorstore` | Local Qdrant path (if not using cloud) |
102
+ | `QDRANT_COLLECTION_NAME` | No | `eyewiki_rag` | Qdrant collection name |
103
+ | `EMBEDDING_MODEL` | No | `nomic-embed-text` | Sentence-transformer embedding model |
104
+ | `API_PORT` | No | `8000` | API server port |
105
+
106
+ ---
107
+
108
+ ## Provider Examples
109
+
110
+ ### Groq (Free Tier)
111
+
112
+ ```env
113
+ LLM_PROVIDER=openai
114
+ OPENAI_API_KEY=gsk_your_key_here
115
+ OPENAI_BASE_URL=https://api.groq.com/openai/v1
116
+ OPENAI_MODEL=llama-3.3-70b-versatile
117
+ ```
118
+
119
+ ### OpenAI
120
+
121
+ ```env
122
+ LLM_PROVIDER=openai
123
+ OPENAI_API_KEY=sk-your_key_here
124
+ OPENAI_MODEL=gpt-4o-mini
125
+ ```
126
+
127
+ ### DeepSeek
128
+
129
+ ```env
130
+ LLM_PROVIDER=openai
131
+ OPENAI_API_KEY=your_key_here
132
+ OPENAI_BASE_URL=https://api.deepseek.com/v1
133
+ OPENAI_MODEL=deepseek-chat
134
+ ```
135
+
136
+ ### Local Ollama (Default)
137
+
138
+ ```env
139
+ LLM_PROVIDER=ollama
140
+ OLLAMA_BASE_URL=http://localhost:11434
141
+ LLM_MODEL=mistral
142
+ ```
143
+
144
+ ---
145
+
146
+ ## Troubleshooting
147
+
148
+ - **Space fails to build**: Check that `Dockerfile.deploy` is renamed to `Dockerfile` in the Space repo.
149
+ - **Model download slow on startup**: The embedding model (`all-mpnet-base-v2`) downloads on first run. Subsequent restarts use the cached version.
150
+ - **Qdrant connection errors**: Verify your `QDRANT_URL` includes the port (`:6333`) and the API key is correct.
151
+ - **LLM errors**: Check that your API key is valid and the model name is supported by your provider.
docker-compose.yml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EyeWiki RAG System - Docker Compose Configuration
2
+ version: '3.8'
3
+
4
+ services:
5
+ # Qdrant vector database
6
+ qdrant:
7
+ image: qdrant/qdrant:latest
8
+ container_name: eyewiki-qdrant
9
+ ports:
10
+ - "6333:6333" # REST API
11
+ - "6334:6334" # gRPC (optional)
12
+ volumes:
13
+ - qdrant_data:/qdrant/storage
14
+ environment:
15
+ - QDRANT__SERVICE__GRPC_PORT=6334
16
+ networks:
17
+ - eyewiki-network
18
+ restart: unless-stopped
19
+ healthcheck:
20
+ test: ["CMD", "curl", "-f", "http://localhost:6333/"]
21
+ interval: 30s
22
+ timeout: 10s
23
+ retries: 3
24
+ start_period: 40s
25
+
26
+ # EyeWiki RAG API
27
+ eyewiki-rag:
28
+ build:
29
+ context: .
30
+ dockerfile: Dockerfile
31
+ container_name: eyewiki-rag-api
32
+ ports:
33
+ - "8000:8000"
34
+ volumes:
35
+ # Mount data directories for persistence
36
+ - ./data/raw:/app/data/raw
37
+ - ./data/processed:/app/data/processed
38
+ - qdrant_data:/app/data/qdrant
39
+ # Mount prompts for easy customization
40
+ - ./prompts:/app/prompts
41
+ environment:
42
+ # Ollama on host (access via host.docker.internal)
43
+ - OLLAMA_BASE_URL=http://host.docker.internal:11434
44
+ - LLM_MODEL=mistral
45
+ - EMBEDDING_MODEL=nomic-embed-text
46
+
47
+ # Qdrant service
48
+ - QDRANT_HOST=qdrant
49
+ - QDRANT_PORT=6333
50
+ - QDRANT_COLLECTION_NAME=eyewiki_rag
51
+ - QDRANT_PATH=/app/data/qdrant
52
+
53
+ # Processing settings
54
+ - CHUNK_SIZE=512
55
+ - CHUNK_OVERLAP=50
56
+ - MAX_CONTEXT_TOKENS=4000
57
+
58
+ # Retrieval settings
59
+ - RETRIEVAL_K=20
60
+ - RERANK_K=5
61
+ networks:
62
+ - eyewiki-network
63
+ depends_on:
64
+ qdrant:
65
+ condition: service_healthy
66
+ restart: unless-stopped
67
+ healthcheck:
68
+ test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
69
+ interval: 30s
70
+ timeout: 10s
71
+ retries: 3
72
+ start_period: 60s
73
+
74
+ networks:
75
+ eyewiki-network:
76
+ driver: bridge
77
+
78
+ volumes:
79
+ qdrant_data:
80
+ driver: local
plan/implementation_plan.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation Plan - EyeWiki RAG Deployment
2
+
3
+ This plan outlines the steps to prepare the EyeWiki RAG application for deployment on free/cheap hosting providers (specifically Hugging Face Spaces + Groq + Qdrant Cloud), by decoupling the local Ollama dependency.
4
+
5
+ ## User Review Required
6
+
7
+ > [!IMPORTANT]
8
+ > **LLM Provider Switch**: The deployment will support switching from local Ollama to "OpenAI-compatible" APIs (like Groq, DeepSeek, or OpenAI itself). This requires an API key for the chosen provider.
9
+
10
+ > [!NOTE]
11
+ > **Hosting Choice**: The recommended "free" stack is **Hugging Face Spaces (Docker)** for the app, **Qdrant Cloud (Free Tier)** for the vector DB, and **Groq (Free Tier)** for the LLM.
12
+
13
+ ## Proposed Changes
14
+
15
+ ### Configuration
16
+ #### [MODIFY] [settings.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/config/settings.py)
17
+ - Add `llm_provider` field (enum: "ollama", "openai").
18
+ - Add `openai_api_key`, `openai_base_url`, `openai_model` fields.
19
+
20
+ ### LLM Abstraction
21
+ #### [NEW] [llm_client.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/llm/llm_client.py)
22
+ - Define `LLMClient` abstract base class/protocol with `generate` and `stream_generate` methods.
23
+
24
+ #### [MODIFY] [ollama_client.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/llm/ollama_client.py)
25
+ - Implement `LLMClient` interface.
26
+
27
+ #### [NEW] [openai_client.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/llm/openai_client.py)
28
+ - Implement `LLMClient` using `openai` python package.
29
+ - Support standard OpenAI API and compatible endpoints (Groq).
30
+
31
+ ### Application Logic
32
+ #### [MODIFY] [query_engine.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/rag/query_engine.py)
33
+ - Update type hints to use abstract `LLMClient`.
34
+
35
+ #### [MODIFY] [main.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/api/main.py)
36
+ - Instantiate appropriate client based on `settings.llm_provider`.
37
+ - Update lifecycle events.
38
+
39
+ #### [MODIFY] [run_server.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/scripts/run_server.py)
40
+ - Modify pre-flight checks to only check Ollama if `llm_provider_is_ollama`.
41
+ - Add checks for API keys if provider is OpenAI/Groq.
42
+
43
+ ### Deployment Configuration
44
+ #### [NEW] [Dockerfile.deploy](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/Dockerfile.deploy)
45
+ - Optimized Dockerfile for Hugging Face Spaces (non-root user, specific cache directories).
46
+
47
+ #### [NEW] [deployment_readme.md](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/deployment_readme.md)
48
+ - Step-by-step guide for deploying to HF Spaces and setting up Qdrant Cloud.
49
+
50
+ #### [MODIFY] [requirements.txt](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/requirements.txt)
51
+ - Add `openai>=1.0.0`.
52
+
53
+ ## Verification Plan
54
+
55
+ ### Automated Tests
56
+ - Run existing tests to ensure no regression: `pytest tests/`
57
+ - *Note:* New client tests would require mocking OpenAI API, which might be out of scope for a "test deployment", but we will verify the code compiles and runs.
58
+
59
+ ### Manual Verification
60
+ 1. **Local Test (Ollama)**: Run server with `LLM_PROVIDER=ollama` and verify standard functionality.
61
+ 2. **Local Test (Mock/Groq)**: Run server with `LLM_PROVIDER=openai` and a valid API key (or mock) to verify the switch works.
62
+ 3. **Deployment Build**: Build the `Dockerfile.deploy` locally to ensure it builds correctly.
prompts/medical_disclaimer.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ **Medical Disclaimer:** This information is sourced from EyeWiki, a resource of the American Academy of Ophthalmology (AAO). It is not a substitute for professional medical advice, diagnosis, or treatment. AI systems can make errors. Always consult with a qualified ophthalmologist or eye care professional for medical concerns and verify any critical information with authoritative sources.
prompts/query_prompt.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are answering a question using information from the EyeWiki medical knowledge base.
2
+
3
+ CONTEXT FROM EYEWIKI:
4
+ {context}
5
+
6
+ ---
7
+
8
+ QUESTION: {question}
9
+
10
+ INSTRUCTIONS:
11
+ 1. Answer the question using ONLY the information provided in the context above
12
+ 2. Cite sources for all claims using the format: [Source: Article Title]
13
+ 3. If the context does not contain enough information to fully answer the question, clearly state: "The provided sources do not contain sufficient information about [specific aspect]"
14
+ 4. Organize your answer with:
15
+ - A direct answer to the question (1-2 sentences)
16
+ - Supporting details from the sources with citations
17
+ - Any relevant additional context from the sources
18
+ 5. Use clear medical terminology with explanations for technical terms
19
+ 6. Do NOT make up or infer information beyond what is explicitly stated in the context
20
+
21
+ ANSWER:
prompts/system_prompt.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an expert ophthalmology knowledge assistant powered by the EyeWiki medical database. Your role is to provide accurate, evidence-based information about eye diseases, conditions, treatments, and procedures.
2
+
3
+ CRITICAL GUIDELINES:
4
+
5
+ 1. CONTEXT-ONLY RESPONSES
6
+ - Base ALL answers strictly on the provided context from EyeWiki articles
7
+ - NEVER make up, infer, or add information that is not explicitly in the context
8
+ - If the context does not contain enough information to answer a question, clearly state this
9
+ - Do not use general knowledge or information from other sources
10
+
11
+ 2. SOURCE CITATION
12
+ - Always cite the specific EyeWiki article when referencing information
13
+ - Use the format: [Source: Article Title] or "According to [Article Title]..."
14
+ - When multiple sources support a point, cite all relevant sources
15
+ - Include section names when specific information comes from a particular section
16
+
17
+ 3. RESPONSE STRUCTURE
18
+ Your answers should follow this format:
19
+
20
+ a) Direct Answer
21
+ - Begin with a clear, concise answer to the specific question
22
+ - Use 1-2 sentences to address the core query
23
+
24
+ b) Supporting Details
25
+ - Provide relevant details, definitions, and explanations from the sources
26
+ - Use proper medical terminology, but include clear explanations for complex terms
27
+ - Organize information logically (e.g., causes, symptoms, diagnosis, treatment)
28
+
29
+ c) Additional Context (when appropriate)
30
+ - Include related information that provides valuable context
31
+ - Mention important considerations, risk factors, or variations
32
+ - Connect concepts to help understanding
33
+
34
+ d) Limitations
35
+ - If the context is incomplete, specify what information is missing
36
+ - Acknowledge when a question requires clinical judgment or patient-specific evaluation
37
+
38
+ 4. MEDICAL TERMINOLOGY
39
+ - Use accurate medical terminology as it appears in the sources
40
+ - Immediately follow technical terms with clear explanations in parentheses
41
+ - Example: "trabecular meshwork (the eye's drainage system)"
42
+ - Balance professional precision with accessibility
43
+
44
+ 5. UNCERTAINTY AND LIMITATIONS
45
+ When you cannot fully answer a question:
46
+ - Explicitly state: "The provided sources do not contain sufficient information about..."
47
+ - Offer what partial information IS available
48
+ - Suggest what type of information would be needed for a complete answer
49
+ - NEVER guess or extrapolate beyond what the sources explicitly state
50
+
51
+ 6. CLINICAL CONSULTATION REMINDER
52
+ - For questions about specific symptoms, diagnosis, or treatment decisions, remind users to consult a qualified eye care professional
53
+ - Emphasize that individual cases vary and require professional medical evaluation
54
+ - Do not provide specific medical advice for individual situations
55
+
56
+ 7. RESPONSE QUALITY
57
+ - Be thorough but concise - avoid unnecessary verbosity
58
+ - Use clear section headers for longer responses
59
+ - Present information in a logical, easy-to-follow structure
60
+ - Use bullet points or numbered lists when appropriate for clarity
61
+ - Maintain a professional yet approachable tone
62
+
63
+ 8. ACCURACY PRIORITIES
64
+ - Accuracy is more important than completeness
65
+ - It is better to say "I don't have enough information" than to speculate
66
+ - When sources conflict or present multiple perspectives, present all views and cite each
67
+ - Distinguish between established facts and areas of ongoing research or debate
68
+
69
+ EXAMPLE RESPONSE PATTERNS:
70
+
71
+ Good Response:
72
+ "Primary open-angle glaucoma (POAG) is characterized by progressive optic nerve damage and visual field loss [Source: Primary Open-Angle Glaucoma]. The primary risk factor is elevated intraocular pressure (IOP), which occurs when the eye's drainage system (trabecular meshwork) becomes less efficient at draining aqueous humor [Source: Glaucoma Pathophysiology]..."
73
+
74
+ Poor Response:
75
+ "Glaucoma is usually treated with eye drops, and most patients do well with treatment."
76
+ (No citations, no source verification, making general claims)
77
+
78
+ When Uncertain:
79
+ "The provided sources discuss glaucoma treatment options including medications and surgery [Source: Glaucoma Management], but do not contain specific information about the long-term success rates you're asking about. For detailed statistics on treatment outcomes, you would need additional clinical research data."
80
+
81
+ REMEMBER:
82
+ - You are a knowledge assistant, not a medical professional
83
+ - Your purpose is to provide information, not to diagnose or prescribe
84
+ - Every piece of information should be traceable to the provided sources
85
+ - Professional consultation is irreplaceable for medical care
86
+
87
+ Maintain these standards in every response to ensure users receive accurate, well-sourced, and appropriately contextualized medical information.
pytest.ini ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ # Pytest configuration file
3
+
4
+ # Test discovery patterns
5
+ python_files = test_*.py
6
+ python_classes = Test*
7
+ python_functions = test_*
8
+
9
+ # Test paths
10
+ testpaths = tests
11
+
12
+ # Output options
13
+ addopts =
14
+ -v
15
+ --strict-markers
16
+ --tb=short
17
+ --disable-warnings
18
+
19
+ # Markers
20
+ markers =
21
+ unit: Unit tests (fast, isolated)
22
+ integration: Integration tests (may be slow)
23
+ api: API tests (requires server components)
24
+
25
+ # Minimum Python version
26
+ minversion = 3.8
requirements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web Scraping
2
+ crawl4ai>=0.3.0
3
+ beautifulsoup4>=4.12.0
4
+ markdownify>=0.11.0
5
+
6
+ # RAG Framework
7
+ llama-index>=0.10.0
8
+ llama-index-vector-stores-qdrant>=0.2.0
9
+ llama-index-embeddings-ollama>=0.1.0
10
+ llama-index-llms-ollama>=0.1.0
11
+
12
+ # Vector Storage
13
+ qdrant-client>=1.7.0
14
+
15
+ # Embeddings & Reranking
16
+ sentence-transformers>=2.2.0 # For stable embeddings and cross-encoder reranking
17
+ torch>=2.0.0 # Required by sentence-transformers
18
+
19
+ # API Server
20
+ fastapi>=0.104.0
21
+ uvicorn[standard]>=0.24.0
22
+
23
+ # UI
24
+ gradio>=4.0.0
25
+
26
+ # Configuration
27
+ python-dotenv>=1.0.0
28
+ pydantic>=2.0.0
29
+ pydantic-settings>=2.0.0
30
+
31
+ # CLI Output & Progress
32
+ rich>=13.0.0
33
+ tqdm>=4.66.0
34
+
35
+ # OpenAI-compatible API
36
+ openai>=1.0.0
37
+
38
+ # Utilities
39
+ requests>=2.31.0
40
+ aiohttp>=3.9.0
41
+
42
+ # Development
43
+ pytest>=7.4.0
44
+ pytest-asyncio>=0.21.0
45
+ black>=23.11.0
46
+ isort>=5.12.0
47
+ flake8>=6.1.0
scripts/build_index.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Build index by processing raw markdown files into semantic chunks with metadata."""
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ import traceback
8
+ from pathlib import Path
9
+ from typing import Dict, List
10
+
11
+ from tqdm import tqdm
12
+
13
+ # Add parent directory to path
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ from src.processing.chunker import SemanticChunker, ChunkNode
17
+ from src.processing.metadata_extractor import MetadataExtractor
18
+ from src.vectorstore.qdrant_store import QdrantStoreManager
19
+ from src.llm.sentence_transformer_client import SentenceTransformerClient
20
+ from config.settings import settings
21
+ from rich.console import Console
22
+ from rich.panel import Panel
23
+ from rich.table import Table
24
+
25
+
26
+ def parse_args():
27
+ """Parse command line arguments."""
28
+ parser = argparse.ArgumentParser(
29
+ description="Process raw EyeWiki markdown into semantic chunks with medical metadata",
30
+ formatter_class=argparse.RawDescriptionHelpFormatter,
31
+ epilog="""
32
+ Examples:
33
+ # Just process files (no vector indexing)
34
+ python scripts/build_index.py
35
+
36
+ # Process AND build vector index
37
+ python scripts/build_index.py --index-vectors
38
+
39
+ # Only build vector index from existing processed files
40
+ python scripts/build_index.py --index-only
41
+
42
+ # Process with custom directories
43
+ python scripts/build_index.py --input-dir ./my_raw --output-dir ./my_processed
44
+
45
+ # Force rebuild with fresh Qdrant collection
46
+ python scripts/build_index.py --rebuild --index-vectors --recreate-collection
47
+
48
+ # Process only files matching pattern
49
+ python scripts/build_index.py --pattern "Glaucoma*.md" --index-vectors
50
+
51
+ # Custom chunking and embedding parameters
52
+ python scripts/build_index.py --chunk-size 1024 --embedding-batch-size 64 --index-vectors
53
+ """,
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--input-dir",
58
+ type=str,
59
+ default=None,
60
+ help=f"Input directory with raw markdown files (default: {settings.data_raw_path})",
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--output-dir",
65
+ type=str,
66
+ default=None,
67
+ help=f"Output directory for processed chunks (default: {settings.data_processed_path})",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--rebuild",
72
+ action="store_true",
73
+ help="Force rebuild even if output files exist",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--pattern",
78
+ type=str,
79
+ default="*.md",
80
+ help="Glob pattern for files to process (default: *.md)",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--chunk-size",
85
+ type=int,
86
+ default=None,
87
+ help=f"Chunk size in tokens (default: {settings.chunk_size})",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--chunk-overlap",
92
+ type=int,
93
+ default=None,
94
+ help=f"Chunk overlap in tokens (default: {settings.chunk_overlap})",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--min-chunk-size",
99
+ type=int,
100
+ default=None,
101
+ help=f"Minimum chunk size in tokens (default: {settings.min_chunk_size})",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--verbose",
106
+ "-v",
107
+ action="store_true",
108
+ help="Enable verbose output with detailed error messages",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--index-vectors",
113
+ action="store_true",
114
+ help="Build vector index in Qdrant after processing",
115
+ )
116
+
117
+ parser.add_argument(
118
+ "--index-only",
119
+ action="store_true",
120
+ help="Skip processing, only build vector index from existing processed files",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--recreate-collection",
125
+ action="store_true",
126
+ help="Recreate Qdrant collection (deletes existing data)",
127
+ )
128
+
129
+ parser.add_argument(
130
+ "--embedding-batch-size",
131
+ type=int,
132
+ default=32,
133
+ help="Batch size for embedding generation (default: 32)",
134
+ )
135
+
136
+ parser.add_argument(
137
+ "--embedding-model",
138
+ type=str,
139
+ default="sentence-transformers/all-mpnet-base-v2",
140
+ help="Sentence transformer model name (default: all-mpnet-base-v2)",
141
+ )
142
+
143
+ return parser.parse_args()
144
+
145
+
146
+ def print_banner(console: Console):
147
+ """Print welcome banner."""
148
+ banner = """
149
+ [bold cyan]EyeWiki Index Builder[/bold cyan]
150
+ [dim]Processing pipeline: Markdown � Metadata Extraction � Semantic Chunking � JSON[/dim]
151
+ """
152
+ console.print(Panel(banner, border_style="cyan"))
153
+
154
+
155
+ def load_markdown_file(md_file: Path) -> tuple[str, Dict]:
156
+ """
157
+ Load markdown content and corresponding JSON metadata.
158
+
159
+ Args:
160
+ md_file: Path to markdown file
161
+
162
+ Returns:
163
+ Tuple of (content, metadata)
164
+
165
+ Raises:
166
+ FileNotFoundError: If JSON metadata file not found
167
+ ValueError: If content is empty or metadata is invalid
168
+ """
169
+ # Read markdown content
170
+ with open(md_file, "r", encoding="utf-8") as f:
171
+ content = f.read()
172
+
173
+ if not content.strip():
174
+ raise ValueError("Empty markdown content")
175
+
176
+ # Look for corresponding JSON metadata
177
+ json_file = md_file.with_suffix(".json")
178
+ if not json_file.exists():
179
+ raise FileNotFoundError(f"Metadata file not found: {json_file}")
180
+
181
+ # Read metadata
182
+ with open(json_file, "r", encoding="utf-8") as f:
183
+ metadata = json.load(f)
184
+
185
+ if not isinstance(metadata, dict):
186
+ raise ValueError("Invalid metadata format (must be dict)")
187
+
188
+ return content, metadata
189
+
190
+
191
+ def process_file(
192
+ md_file: Path,
193
+ output_dir: Path,
194
+ chunker: SemanticChunker,
195
+ extractor: MetadataExtractor,
196
+ rebuild: bool = False,
197
+ verbose: bool = False,
198
+ ) -> Dict:
199
+ """
200
+ Process a single markdown file through the pipeline.
201
+
202
+ Pipeline:
203
+ 1. Load markdown and metadata
204
+ 2. Extract medical metadata
205
+ 3. Chunk document
206
+ 4. Save chunks to JSON
207
+
208
+ Args:
209
+ md_file: Path to markdown file
210
+ output_dir: Output directory for chunks
211
+ chunker: SemanticChunker instance
212
+ extractor: MetadataExtractor instance
213
+ rebuild: Force rebuild even if output exists
214
+ verbose: Enable verbose error output
215
+
216
+ Returns:
217
+ Dictionary with processing results and statistics
218
+ """
219
+ result = {
220
+ "file": md_file.name,
221
+ "status": "pending",
222
+ "chunks_created": 0,
223
+ "total_tokens": 0,
224
+ "error": None,
225
+ }
226
+
227
+ output_file = output_dir / f"{md_file.stem}_chunks.json"
228
+
229
+ # Check if output already exists
230
+ if output_file.exists() and not rebuild:
231
+ result["status"] = "skipped"
232
+ result["error"] = "Output already exists (use --rebuild to force)"
233
+ return result
234
+
235
+ try:
236
+ # Step 1: Load file
237
+ content, metadata = load_markdown_file(md_file)
238
+
239
+ # Step 2: Extract medical metadata
240
+ enhanced_metadata = extractor.extract(content, metadata)
241
+
242
+ # Step 3: Chunk document
243
+ chunks = chunker.chunk_document(content, enhanced_metadata)
244
+
245
+ if not chunks:
246
+ result["status"] = "skipped"
247
+ result["error"] = "No chunks created (content too small or filtered)"
248
+ return result
249
+
250
+ # Step 4: Save chunks to JSON
251
+ output_dir.mkdir(parents=True, exist_ok=True)
252
+ with open(output_file, "w", encoding="utf-8") as f:
253
+ chunk_dicts = [chunk.to_dict() for chunk in chunks]
254
+ json.dump(chunk_dicts, f, indent=2, ensure_ascii=False)
255
+
256
+ # Update result
257
+ result["status"] = "success"
258
+ result["chunks_created"] = len(chunks)
259
+ result["total_tokens"] = sum(chunk.token_count for chunk in chunks)
260
+
261
+ except FileNotFoundError as e:
262
+ result["status"] = "error"
263
+ result["error"] = f"File not found: {e}"
264
+ if verbose:
265
+ result["traceback"] = traceback.format_exc()
266
+
267
+ except ValueError as e:
268
+ result["status"] = "error"
269
+ result["error"] = f"Invalid data: {e}"
270
+ if verbose:
271
+ result["traceback"] = traceback.format_exc()
272
+
273
+ except Exception as e:
274
+ result["status"] = "error"
275
+ result["error"] = f"Unexpected error: {e}"
276
+ if verbose:
277
+ result["traceback"] = traceback.format_exc()
278
+
279
+ return result
280
+
281
+
282
+ def print_statistics(results: List[Dict], console: Console):
283
+ """
284
+ Print processing statistics.
285
+
286
+ Args:
287
+ results: List of processing results
288
+ console: Rich console for output
289
+ """
290
+ # Calculate statistics
291
+ total_files = len(results)
292
+ successful = sum(1 for r in results if r["status"] == "success")
293
+ skipped = sum(1 for r in results if r["status"] == "skipped")
294
+ errors = sum(1 for r in results if r["status"] == "error")
295
+
296
+ total_chunks = sum(r["chunks_created"] for r in results)
297
+ total_tokens = sum(r["total_tokens"] for r in results)
298
+
299
+ avg_chunks = total_chunks / successful if successful > 0 else 0
300
+ avg_tokens_per_chunk = total_tokens / total_chunks if total_chunks > 0 else 0
301
+ avg_tokens_per_doc = total_tokens / successful if successful > 0 else 0
302
+
303
+ # Create statistics table
304
+ table = Table(title="Processing Statistics", border_style="green")
305
+ table.add_column("Metric", style="cyan", justify="left")
306
+ table.add_column("Value", style="white", justify="right")
307
+
308
+ table.add_row("Total Files", f"{total_files:,}")
309
+ table.add_row("Successfully Processed", f"{successful:,}")
310
+ table.add_row("Skipped", f"{skipped:,}")
311
+ table.add_row("Errors", f"{errors:,}")
312
+ table.add_row("", "") # Separator
313
+ table.add_row("Total Chunks Created", f"{total_chunks:,}")
314
+ table.add_row("Total Tokens", f"{total_tokens:,}")
315
+ table.add_row("", "") # Separator
316
+ table.add_row("Avg Chunks per Document", f"{avg_chunks:.1f}")
317
+ table.add_row("Avg Tokens per Chunk", f"{avg_tokens_per_chunk:.1f}")
318
+ table.add_row("Avg Tokens per Document", f"{avg_tokens_per_doc:.1f}")
319
+
320
+ console.print("\n")
321
+ console.print(table)
322
+
323
+ # Show error details if any
324
+ error_results = [r for r in results if r["status"] == "error"]
325
+ if error_results:
326
+ console.print("\n[yellow]Error Details:[/yellow]")
327
+ for i, result in enumerate(error_results[:10], 1):
328
+ console.print(f" {i}. [red]{result['file']}[/red]")
329
+ console.print(f" [dim]{result['error']}[/dim]")
330
+ if "traceback" in result:
331
+ console.print(f" [dim]{result['traceback']}[/dim]")
332
+
333
+ if len(error_results) > 10:
334
+ console.print(f" [dim]... and {len(error_results) - 10} more errors[/dim]")
335
+
336
+ # Show skipped details if any
337
+ skip_results = [r for r in results if r["status"] == "skipped"]
338
+ if skip_results and len(skip_results) <= 5:
339
+ console.print("\n[yellow]Skipped Files:[/yellow]")
340
+ for i, result in enumerate(skip_results, 1):
341
+ console.print(f" {i}. {result['file']}: {result['error']}")
342
+
343
+
344
+ def load_processed_chunks(processed_dir: Path, console: Console) -> List[ChunkNode]:
345
+ """
346
+ Load all processed chunks from JSON files.
347
+
348
+ Args:
349
+ processed_dir: Directory containing processed chunk JSON files
350
+ console: Rich console for output
351
+
352
+ Returns:
353
+ List of ChunkNode objects
354
+ """
355
+ chunk_files = list(processed_dir.glob("*_chunks.json"))
356
+
357
+ if not chunk_files:
358
+ console.print(f"[yellow]No processed chunk files found in {processed_dir}[/yellow]")
359
+ return []
360
+
361
+ all_chunks = []
362
+
363
+ console.print(f"\n[cyan]Loading processed chunks from {len(chunk_files)} files...[/cyan]")
364
+
365
+ with tqdm(chunk_files, desc="Loading chunks", unit="file") as pbar:
366
+ for chunk_file in pbar:
367
+ try:
368
+ with open(chunk_file, "r", encoding="utf-8") as f:
369
+ chunk_dicts = json.load(f)
370
+
371
+ # Convert dicts to ChunkNode objects
372
+ for chunk_dict in chunk_dicts:
373
+ chunk = ChunkNode.from_dict(chunk_dict)
374
+ all_chunks.append(chunk)
375
+
376
+ pbar.set_postfix({"total_chunks": len(all_chunks)})
377
+
378
+ except Exception as e:
379
+ console.print(f"[red]Error loading {chunk_file.name}: {e}[/red]")
380
+
381
+ console.print(f"[green]✓[/green] Loaded {len(all_chunks):,} chunks")
382
+ return all_chunks
383
+
384
+
385
+ def build_vector_index(
386
+ chunks: List[ChunkNode],
387
+ embedding_client: SentenceTransformerClient,
388
+ qdrant_manager: QdrantStoreManager,
389
+ batch_size: int,
390
+ console: Console,
391
+ ) -> Dict:
392
+ """
393
+ Build vector index by generating embeddings and inserting into Qdrant.
394
+
395
+ Args:
396
+ chunks: List of ChunkNode objects
397
+ embedding_client: SentenceTransformerClient for stable embeddings
398
+ qdrant_manager: QdrantStoreManager for vector storage
399
+ batch_size: Batch size for embedding generation
400
+ console: Rich console for output
401
+
402
+ Returns:
403
+ Dictionary with indexing statistics
404
+ """
405
+ if not chunks:
406
+ console.print("[yellow]No chunks to index[/yellow]")
407
+ return {"chunks_indexed": 0, "time_taken": 0}
408
+
409
+ console.print(f"\n[bold cyan]Building Vector Index[/bold cyan]")
410
+ console.print(f"Chunks to index: {len(chunks):,}")
411
+ console.print(f"Embedding batch size: {batch_size}")
412
+
413
+ import time
414
+ start_time = time.time()
415
+
416
+ # Extract text content for embedding
417
+ texts = [chunk.content for chunk in chunks]
418
+
419
+ # Generate embeddings with progress bar
420
+ console.print("\n[cyan]Generating embeddings...[/cyan]")
421
+ try:
422
+ embeddings = embedding_client.embed_batch(
423
+ texts=texts,
424
+ batch_size=batch_size,
425
+ show_progress=True,
426
+ )
427
+ except Exception as e:
428
+ console.print(f"[red]Failed to generate embeddings: {e}[/red]")
429
+ raise
430
+
431
+ # Insert into Qdrant
432
+ console.print("\n[cyan]Inserting into Qdrant...[/cyan]")
433
+ try:
434
+ num_added = qdrant_manager.add_documents(
435
+ chunks=chunks,
436
+ dense_embeddings=embeddings,
437
+ )
438
+ except Exception as e:
439
+ console.print(f"[red]Failed to insert into Qdrant: {e}[/red]")
440
+ raise
441
+
442
+ elapsed_time = time.time() - start_time
443
+
444
+ # Get collection info
445
+ try:
446
+ collection_info = qdrant_manager.get_collection_info()
447
+ except Exception as e:
448
+ console.print(f"[yellow]Could not get collection info: {e}[/yellow]")
449
+ collection_info = {}
450
+
451
+ stats = {
452
+ "chunks_indexed": num_added,
453
+ "time_taken": elapsed_time,
454
+ "chunks_per_second": num_added / elapsed_time if elapsed_time > 0 else 0,
455
+ "collection_info": collection_info,
456
+ }
457
+
458
+ return stats
459
+
460
+
461
+ def print_index_statistics(stats: Dict, console: Console):
462
+ """
463
+ Print vector indexing statistics.
464
+
465
+ Args:
466
+ stats: Statistics dictionary
467
+ console: Rich console for output
468
+ """
469
+ table = Table(title="Vector Index Statistics", border_style="green")
470
+ table.add_column("Metric", style="cyan", justify="left")
471
+ table.add_column("Value", style="white", justify="right")
472
+
473
+ table.add_row("Chunks Indexed", f"{stats['chunks_indexed']:,}")
474
+ table.add_row("Time Taken", f"{stats['time_taken']:.1f}s")
475
+ table.add_row("Chunks/Second", f"{stats['chunks_per_second']:.1f}")
476
+
477
+ if "collection_info" in stats and stats["collection_info"]:
478
+ info = stats["collection_info"]
479
+ table.add_row("", "") # Separator
480
+ table.add_row("Collection Name", info.get("name", "N/A"))
481
+ table.add_row("Total Vectors", f"{info.get('vectors_count', 0):,}")
482
+ table.add_row("Total Points", f"{info.get('points_count', 0):,}")
483
+ table.add_row("Status", info.get("status", "N/A"))
484
+
485
+ console.print("\n")
486
+ console.print(table)
487
+
488
+
489
+ def main():
490
+ """Main entry point for index building."""
491
+ args = parse_args()
492
+ console = Console()
493
+
494
+ # Print banner
495
+ print_banner(console)
496
+
497
+ # Prepare directories
498
+ input_dir = Path(args.input_dir) if args.input_dir else Path(settings.data_raw_path)
499
+ output_dir = Path(args.output_dir) if args.output_dir else Path(settings.data_processed_path)
500
+
501
+ # Check mode
502
+ index_only = args.index_only
503
+ should_index = args.index_vectors or args.index_only
504
+
505
+ # Print mode
506
+ if index_only:
507
+ console.print("[cyan]Mode:[/cyan] Index only (skip processing)")
508
+ elif should_index:
509
+ console.print("[cyan]Mode:[/cyan] Process and build vector index")
510
+ else:
511
+ console.print("[cyan]Mode:[/cyan] Process only (no vector indexing)")
512
+
513
+ # Validate input directory (only needed if not index-only)
514
+ if not index_only and not input_dir.exists():
515
+ console.print(f"[bold red]Error: Input directory does not exist: {input_dir}[/bold red]")
516
+ return 1
517
+
518
+ # Validate output directory exists (needed for index-only)
519
+ if index_only and not output_dir.exists():
520
+ console.print(f"[bold red]Error: Output directory does not exist: {output_dir}[/bold red]")
521
+ console.print("[yellow]Please run processing first without --index-only[/yellow]")
522
+ return 1
523
+
524
+ # Print configuration
525
+ if not index_only:
526
+ # Find all markdown files
527
+ md_files = list(input_dir.glob(args.pattern))
528
+
529
+ if not md_files:
530
+ console.print(f"[yellow]No files matching pattern '{args.pattern}' found in {input_dir}[/yellow]")
531
+ return 0
532
+
533
+ console.print(f"[cyan]Input directory:[/cyan] {input_dir}")
534
+ console.print(f"[cyan]Output directory:[/cyan] {output_dir}")
535
+ console.print(f"[cyan]Files found:[/cyan] {len(md_files)}")
536
+ console.print(f"[cyan]Pattern:[/cyan] {args.pattern}")
537
+ console.print(f"[cyan]Rebuild mode:[/cyan] {'Yes' if args.rebuild else 'No'}")
538
+ else:
539
+ console.print(f"[cyan]Processed directory:[/cyan] {output_dir}")
540
+
541
+ # Initialize components (only if processing)
542
+ results = []
543
+
544
+ if not index_only:
545
+ chunker = SemanticChunker(
546
+ chunk_size=args.chunk_size if args.chunk_size is not None else settings.chunk_size,
547
+ chunk_overlap=args.chunk_overlap if args.chunk_overlap is not None else settings.chunk_overlap,
548
+ min_chunk_size=args.min_chunk_size if args.min_chunk_size is not None else settings.min_chunk_size,
549
+ )
550
+
551
+ extractor = MetadataExtractor()
552
+
553
+ console.print(f"[cyan]Chunk size:[/cyan] {chunker.chunk_size} tokens")
554
+ console.print(f"[cyan]Chunk overlap:[/cyan] {chunker.chunk_overlap} tokens")
555
+ console.print(f"[cyan]Min chunk size:[/cyan] {chunker.min_chunk_size} tokens")
556
+ console.print()
557
+
558
+ # Process files with progress bar
559
+ console.print("[bold cyan]Processing Files...[/bold cyan]\n")
560
+
561
+ with tqdm(
562
+ total=len(md_files),
563
+ desc="Processing",
564
+ unit="file",
565
+ ncols=100,
566
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
567
+ ) as pbar:
568
+
569
+ for md_file in md_files:
570
+ # Update progress bar description
571
+ pbar.set_description(f"Processing {md_file.name[:30]:30}")
572
+
573
+ # Process file
574
+ result = process_file(
575
+ md_file=md_file,
576
+ output_dir=output_dir,
577
+ chunker=chunker,
578
+ extractor=extractor,
579
+ rebuild=args.rebuild,
580
+ verbose=args.verbose,
581
+ )
582
+
583
+ results.append(result)
584
+
585
+ # Update progress bar postfix with running stats
586
+ successful = sum(1 for r in results if r["status"] == "success")
587
+ chunks = sum(r["chunks_created"] for r in results)
588
+ pbar.set_postfix({"success": successful, "chunks": chunks})
589
+
590
+ pbar.update(1)
591
+
592
+ # Print statistics
593
+ print_statistics(results, console)
594
+
595
+ # Check processing status
596
+ successful = sum(1 for r in results if r["status"] == "success")
597
+ errors = sum(1 for r in results if r["status"] == "error")
598
+
599
+ console.print()
600
+ if errors == 0 and successful > 0:
601
+ console.print("[bold green]Processing completed successfully![/bold green]")
602
+ console.print(f"[green]Processed files saved to: {output_dir}[/green]")
603
+ elif successful > 0:
604
+ console.print("[bold yellow]Processing completed with some errors.[/bold yellow]")
605
+ console.print(f"[yellow]Processed files saved to: {output_dir}[/yellow]")
606
+ else:
607
+ console.print("[bold red]Processing failed - no files were processed successfully.[/bold red]")
608
+ if not should_index:
609
+ return 1
610
+
611
+ # Vector indexing phase
612
+ if should_index:
613
+ try:
614
+ # Initialize embedding client with sentence-transformers
615
+ console.print("\n[bold cyan]Initializing Sentence Transformers Client...[/bold cyan]")
616
+ try:
617
+ embedding_client = SentenceTransformerClient(model_name=args.embedding_model)
618
+ model_info = embedding_client.get_model_info()
619
+ console.print(f"[green]✓[/green] Loaded model: {model_info['model_name']}")
620
+ console.print(f"[green]✓[/green] Device: {model_info['device']}")
621
+ console.print(f"[green]✓[/green] Embedding dimension: {model_info['embedding_dim']}")
622
+ except Exception as e:
623
+ console.print(f"[bold red]Failed to initialize Sentence Transformers: {e}[/bold red]")
624
+ console.print("[yellow]Install sentence-transformers: pip install sentence-transformers torch[/yellow]")
625
+ return 1
626
+
627
+ # Initialize Qdrant store
628
+ console.print("\n[bold cyan]Initializing Qdrant Store...[/bold cyan]")
629
+ try:
630
+ qdrant_manager = QdrantStoreManager()
631
+ qdrant_manager.initialize_collection(recreate=args.recreate_collection)
632
+ except Exception as e:
633
+ console.print(f"[bold red]Failed to initialize Qdrant: {e}[/bold red]")
634
+ return 1
635
+
636
+ # Load processed chunks
637
+ chunks = load_processed_chunks(output_dir, console)
638
+
639
+ if not chunks:
640
+ console.print("[yellow]No chunks to index. Please process documents first.[/yellow]")
641
+ return 0
642
+
643
+ # Build vector index
644
+ try:
645
+ index_stats = build_vector_index(
646
+ chunks=chunks,
647
+ embedding_client=embedding_client,
648
+ qdrant_manager=qdrant_manager,
649
+ batch_size=args.embedding_batch_size,
650
+ console=console,
651
+ )
652
+
653
+ # Print index statistics
654
+ print_index_statistics(index_stats, console)
655
+
656
+ console.print("\n[bold green]Vector indexing completed successfully![/bold green]")
657
+
658
+ except Exception as e:
659
+ console.print(f"\n[bold red]Vector indexing failed: {e}[/bold red]")
660
+ if args.verbose:
661
+ traceback.print_exc()
662
+ return 1
663
+
664
+ except KeyboardInterrupt:
665
+ console.print("\n[yellow]Indexing interrupted by user (Ctrl+C)[/yellow]")
666
+ return 130
667
+
668
+ return 0
669
+
670
+
671
+ if __name__ == "__main__":
672
+ try:
673
+ exit_code = main()
674
+ sys.exit(exit_code)
675
+ except KeyboardInterrupt:
676
+ console = Console()
677
+ console.print("\n[yellow]Process interrupted by user (Ctrl+C)[/yellow]")
678
+ sys.exit(130)
679
+ except Exception as e:
680
+ console = Console()
681
+ console.print(f"\n[bold red]Fatal error: {e}[/bold red]")
682
+ traceback.print_exc()
683
+ sys.exit(1)
scripts/evaluate.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluation script for EyeWiki RAG system.
4
+
5
+ Evaluates the system on a set of test questions and measures:
6
+ - Retrieval recall (relevant sources retrieved)
7
+ - Answer relevance (expected topics covered)
8
+ - Source citation accuracy
9
+
10
+ Usage:
11
+ python scripts/evaluate.py
12
+ python scripts/evaluate.py --questions tests/custom_questions.json
13
+ python scripts/evaluate.py --output results/eval_results.json
14
+ """
15
+
16
+ import argparse
17
+ import json
18
+ import sys
19
+ import time
20
+ from pathlib import Path
21
+ from typing import Dict, List, Any
22
+
23
+ from rich.console import Console
24
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
25
+ from rich.table import Table
26
+ from rich.panel import Panel
27
+
28
+ # Add project root to path
29
+ project_root = Path(__file__).parent.parent
30
+ sys.path.insert(0, str(project_root))
31
+
32
+ from config.settings import Settings
33
+ from src.llm.ollama_client import OllamaClient
34
+ from src.rag.query_engine import EyeWikiQueryEngine
35
+ from src.rag.reranker import CrossEncoderReranker
36
+ from src.rag.retriever import HybridRetriever
37
+ from src.vectorstore.qdrant_store import QdrantStoreManager
38
+
39
+
40
+ console = Console()
41
+
42
+
43
+ # ============================================================================
44
+ # Evaluation Metrics
45
+ # ============================================================================
46
+
47
+ def calculate_retrieval_recall(
48
+ retrieved_sources: List[str],
49
+ expected_sources: List[str],
50
+ ) -> float:
51
+ """
52
+ Calculate retrieval recall.
53
+
54
+ Recall = (# of expected sources retrieved) / (# of expected sources)
55
+
56
+ Args:
57
+ retrieved_sources: List of retrieved source titles
58
+ expected_sources: List of expected source titles
59
+
60
+ Returns:
61
+ Recall score (0-1)
62
+ """
63
+ if not expected_sources:
64
+ return 1.0
65
+
66
+ # Normalize for case-insensitive matching
67
+ retrieved_lower = {s.lower() for s in retrieved_sources}
68
+ expected_lower = {s.lower() for s in expected_sources}
69
+
70
+ # Count matches (allow partial matching)
71
+ matches = 0
72
+ for expected in expected_lower:
73
+ for retrieved in retrieved_lower:
74
+ # Check if expected source name is in retrieved source or vice versa
75
+ if expected in retrieved or retrieved in expected:
76
+ matches += 1
77
+ break
78
+
79
+ recall = matches / len(expected_sources) if expected_sources else 0.0
80
+ return recall
81
+
82
+
83
+ def calculate_answer_relevance(
84
+ answer: str,
85
+ expected_topics: List[str],
86
+ ) -> float:
87
+ """
88
+ Calculate answer relevance based on topic coverage.
89
+
90
+ Relevance = (# of expected topics found) / (# of expected topics)
91
+
92
+ Args:
93
+ answer: Generated answer text
94
+ expected_topics: List of expected topic keywords
95
+
96
+ Returns:
97
+ Relevance score (0-1)
98
+ """
99
+ if not expected_topics:
100
+ return 1.0
101
+
102
+ answer_lower = answer.lower()
103
+
104
+ # Count how many expected topics appear in answer
105
+ topics_found = sum(1 for topic in expected_topics if topic.lower() in answer_lower)
106
+
107
+ relevance = topics_found / len(expected_topics) if expected_topics else 0.0
108
+ return relevance
109
+
110
+
111
+ def calculate_citation_accuracy(
112
+ answer: str,
113
+ cited_sources: List[str],
114
+ expected_sources: List[str],
115
+ ) -> Dict[str, float]:
116
+ """
117
+ Calculate citation accuracy metrics.
118
+
119
+ Args:
120
+ answer: Generated answer text
121
+ cited_sources: Sources returned by system
122
+ expected_sources: Expected sources
123
+
124
+ Returns:
125
+ Dictionary with citation metrics
126
+ """
127
+ # Check if answer contains explicit citations
128
+ has_citations = "[Source:" in answer or "According to" in answer
129
+
130
+ # Calculate precision and recall
131
+ if cited_sources and expected_sources:
132
+ cited_set = {s.lower() for s in cited_sources}
133
+ expected_set = {s.lower() for s in expected_sources}
134
+
135
+ # Allow partial matching
136
+ true_positives = 0
137
+ for cited in cited_set:
138
+ for expected in expected_set:
139
+ if expected in cited or cited in expected:
140
+ true_positives += 1
141
+ break
142
+
143
+ precision = true_positives / len(cited_sources) if cited_sources else 0.0
144
+ recall = true_positives / len(expected_sources) if expected_sources else 0.0
145
+
146
+ # F1 score
147
+ f1 = (
148
+ 2 * (precision * recall) / (precision + recall)
149
+ if (precision + recall) > 0
150
+ else 0.0
151
+ )
152
+ else:
153
+ precision = 0.0
154
+ recall = 0.0
155
+ f1 = 0.0
156
+
157
+ return {
158
+ "has_explicit_citations": has_citations,
159
+ "precision": precision,
160
+ "recall": recall,
161
+ "f1": f1,
162
+ }
163
+
164
+
165
+ # ============================================================================
166
+ # Question Evaluation
167
+ # ============================================================================
168
+
169
+ def evaluate_question(
170
+ question_data: Dict[str, Any],
171
+ query_engine: EyeWikiQueryEngine,
172
+ ) -> Dict[str, Any]:
173
+ """
174
+ Evaluate a single question.
175
+
176
+ Args:
177
+ question_data: Question data with expected answers
178
+ query_engine: Query engine instance
179
+
180
+ Returns:
181
+ Evaluation results
182
+ """
183
+ question_id = question_data["id"]
184
+ question = question_data["question"]
185
+ expected_topics = question_data["expected_topics"]
186
+ expected_sources = question_data["expected_sources"]
187
+
188
+ # Query the system
189
+ start_time = time.time()
190
+ try:
191
+ response = query_engine.query(
192
+ question=question,
193
+ include_sources=True,
194
+ )
195
+ query_time = time.time() - start_time
196
+
197
+ # Extract retrieved sources
198
+ retrieved_sources = [s.title for s in response.sources]
199
+
200
+ # Calculate metrics
201
+ retrieval_recall = calculate_retrieval_recall(
202
+ retrieved_sources, expected_sources
203
+ )
204
+
205
+ answer_relevance = calculate_answer_relevance(
206
+ response.answer, expected_topics
207
+ )
208
+
209
+ citation_metrics = calculate_citation_accuracy(
210
+ response.answer, retrieved_sources, expected_sources
211
+ )
212
+
213
+ # Detailed topic analysis
214
+ topics_found = [
215
+ topic for topic in expected_topics if topic.lower() in response.answer.lower()
216
+ ]
217
+ topics_missing = [
218
+ topic
219
+ for topic in expected_topics
220
+ if topic.lower() not in response.answer.lower()
221
+ ]
222
+
223
+ # Source analysis
224
+ sources_retrieved = []
225
+ sources_missing = []
226
+
227
+ for expected in expected_sources:
228
+ found = False
229
+ for retrieved in retrieved_sources:
230
+ if expected.lower() in retrieved.lower() or retrieved.lower() in expected.lower():
231
+ sources_retrieved.append(expected)
232
+ found = True
233
+ break
234
+ if not found:
235
+ sources_missing.append(expected)
236
+
237
+ result = {
238
+ "id": question_id,
239
+ "question": question,
240
+ "category": question_data.get("category", "unknown"),
241
+ "answer": response.answer,
242
+ "confidence": response.confidence,
243
+ "query_time": query_time,
244
+ "metrics": {
245
+ "retrieval_recall": retrieval_recall,
246
+ "answer_relevance": answer_relevance,
247
+ "citation_precision": citation_metrics["precision"],
248
+ "citation_recall": citation_metrics["recall"],
249
+ "citation_f1": citation_metrics["f1"],
250
+ },
251
+ "details": {
252
+ "retrieved_sources": retrieved_sources,
253
+ "expected_sources": expected_sources,
254
+ "sources_retrieved": sources_retrieved,
255
+ "sources_missing": sources_missing,
256
+ "topics_found": topics_found,
257
+ "topics_missing": topics_missing,
258
+ "has_explicit_citations": citation_metrics["has_explicit_citations"],
259
+ },
260
+ "success": True,
261
+ }
262
+
263
+ except Exception as e:
264
+ result = {
265
+ "id": question_id,
266
+ "question": question,
267
+ "category": question_data.get("category", "unknown"),
268
+ "error": str(e),
269
+ "query_time": time.time() - start_time,
270
+ "success": False,
271
+ }
272
+
273
+ return result
274
+
275
+
276
+ # ============================================================================
277
+ # Aggregate Analysis
278
+ # ============================================================================
279
+
280
+ def calculate_aggregate_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
281
+ """
282
+ Calculate aggregate metrics across all questions.
283
+
284
+ Args:
285
+ results: List of evaluation results
286
+
287
+ Returns:
288
+ Aggregate metrics
289
+ """
290
+ successful_results = [r for r in results if r["success"]]
291
+
292
+ if not successful_results:
293
+ return {"error": "No successful evaluations"}
294
+
295
+ # Average metrics
296
+ avg_retrieval_recall = sum(
297
+ r["metrics"]["retrieval_recall"] for r in successful_results
298
+ ) / len(successful_results)
299
+
300
+ avg_answer_relevance = sum(
301
+ r["metrics"]["answer_relevance"] for r in successful_results
302
+ ) / len(successful_results)
303
+
304
+ avg_citation_precision = sum(
305
+ r["metrics"]["citation_precision"] for r in successful_results
306
+ ) / len(successful_results)
307
+
308
+ avg_citation_recall = sum(
309
+ r["metrics"]["citation_recall"] for r in successful_results
310
+ ) / len(successful_results)
311
+
312
+ avg_citation_f1 = sum(
313
+ r["metrics"]["citation_f1"] for r in successful_results
314
+ ) / len(successful_results)
315
+
316
+ avg_confidence = sum(r["confidence"] for r in successful_results) / len(
317
+ successful_results
318
+ )
319
+
320
+ avg_query_time = sum(r["query_time"] for r in successful_results) / len(
321
+ successful_results
322
+ )
323
+
324
+ # Citation statistics
325
+ citations_present = sum(
326
+ 1 for r in successful_results if r["details"]["has_explicit_citations"]
327
+ )
328
+
329
+ # Category breakdown
330
+ categories = {}
331
+ for result in successful_results:
332
+ category = result["category"]
333
+ if category not in categories:
334
+ categories[category] = {
335
+ "count": 0,
336
+ "retrieval_recall": 0,
337
+ "answer_relevance": 0,
338
+ }
339
+ categories[category]["count"] += 1
340
+ categories[category]["retrieval_recall"] += result["metrics"]["retrieval_recall"]
341
+ categories[category]["answer_relevance"] += result["metrics"]["answer_relevance"]
342
+
343
+ # Average by category
344
+ for category, data in categories.items():
345
+ count = data["count"]
346
+ data["retrieval_recall"] /= count
347
+ data["answer_relevance"] /= count
348
+
349
+ return {
350
+ "total_questions": len(results),
351
+ "successful": len(successful_results),
352
+ "failed": len(results) - len(successful_results),
353
+ "metrics": {
354
+ "retrieval_recall": avg_retrieval_recall,
355
+ "answer_relevance": avg_answer_relevance,
356
+ "citation_precision": avg_citation_precision,
357
+ "citation_recall": avg_citation_recall,
358
+ "citation_f1": avg_citation_f1,
359
+ "avg_confidence": avg_confidence,
360
+ "avg_query_time": avg_query_time,
361
+ "citation_rate": citations_present / len(successful_results),
362
+ },
363
+ "by_category": categories,
364
+ }
365
+
366
+
367
+ # ============================================================================
368
+ # Output Functions
369
+ # ============================================================================
370
+
371
+ def print_question_result(result: Dict[str, Any]):
372
+ """Print result for a single question."""
373
+ if not result["success"]:
374
+ console.print(
375
+ f"\n[red]✗ {result['id']}: {result['question']}[/red]",
376
+ f"[red]Error: {result['error']}[/red]",
377
+ )
378
+ return
379
+
380
+ # Create metrics table
381
+ table = Table(show_header=False, box=None, padding=(0, 1))
382
+ table.add_column(style="cyan")
383
+ table.add_column(style="yellow")
384
+
385
+ metrics = result["metrics"]
386
+ table.add_row("Retrieval Recall", f"{metrics['retrieval_recall']:.2%}")
387
+ table.add_row("Answer Relevance", f"{metrics['answer_relevance']:.2%}")
388
+ table.add_row("Citation F1", f"{metrics['citation_f1']:.2%}")
389
+ table.add_row("Confidence", f"{result['confidence']:.2%}")
390
+ table.add_row("Query Time", f"{result['query_time']:.2f}s")
391
+
392
+ # Determine overall status
393
+ avg_score = (metrics["retrieval_recall"] + metrics["answer_relevance"]) / 2
394
+ if avg_score >= 0.8:
395
+ status = "[green]✓ PASS[/green]"
396
+ elif avg_score >= 0.6:
397
+ status = "[yellow]~ PARTIAL[/yellow]"
398
+ else:
399
+ status = "[red]✗ FAIL[/red]"
400
+
401
+ console.print(f"\n{status} [bold]{result['id']}:[/bold] {result['question']}")
402
+ console.print(table)
403
+
404
+ # Print missing items
405
+ details = result["details"]
406
+ if details["topics_missing"]:
407
+ console.print(
408
+ f" [dim]Missing topics: {', '.join(details['topics_missing'])}[/dim]"
409
+ )
410
+ if details["sources_missing"]:
411
+ console.print(
412
+ f" [dim]Missing sources: {', '.join(details['sources_missing'])}[/dim]"
413
+ )
414
+
415
+
416
+ def print_aggregate_results(aggregate: Dict[str, Any]):
417
+ """Print aggregate results."""
418
+ console.print("\n")
419
+ console.print(
420
+ Panel.fit(
421
+ "[bold cyan]Evaluation Summary[/bold cyan]",
422
+ border_style="cyan",
423
+ )
424
+ )
425
+
426
+ # Overall metrics table
427
+ table = Table(show_header=True, header_style="bold magenta")
428
+ table.add_column("Metric", style="cyan")
429
+ table.add_column("Score", style="yellow", justify="right")
430
+ table.add_column("Grade", style="green", justify="center")
431
+
432
+ metrics = aggregate["metrics"]
433
+
434
+ def get_grade(score: float) -> str:
435
+ if score >= 0.9:
436
+ return "[green]A[/green]"
437
+ elif score >= 0.8:
438
+ return "[green]B[/green]"
439
+ elif score >= 0.7:
440
+ return "[yellow]C[/yellow]"
441
+ elif score >= 0.6:
442
+ return "[yellow]D[/yellow]"
443
+ else:
444
+ return "[red]F[/red]"
445
+
446
+ table.add_row(
447
+ "Retrieval Recall",
448
+ f"{metrics['retrieval_recall']:.2%}",
449
+ get_grade(metrics["retrieval_recall"]),
450
+ )
451
+ table.add_row(
452
+ "Answer Relevance",
453
+ f"{metrics['answer_relevance']:.2%}",
454
+ get_grade(metrics["answer_relevance"]),
455
+ )
456
+ table.add_row(
457
+ "Citation Precision",
458
+ f"{metrics['citation_precision']:.2%}",
459
+ get_grade(metrics["citation_precision"]),
460
+ )
461
+ table.add_row(
462
+ "Citation Recall",
463
+ f"{metrics['citation_recall']:.2%}",
464
+ get_grade(metrics["citation_recall"]),
465
+ )
466
+ table.add_row(
467
+ "Citation F1",
468
+ f"{metrics['citation_f1']:.2%}",
469
+ get_grade(metrics["citation_f1"]),
470
+ )
471
+
472
+ console.print(table)
473
+
474
+ # Statistics
475
+ console.print(f"\n[bold]Statistics:[/bold]")
476
+ console.print(
477
+ f" Total Questions: {aggregate['total_questions']}",
478
+ f" Successful: [green]{aggregate['successful']}[/green]",
479
+ f" Failed: [red]{aggregate['failed']}[/red]",
480
+ f" Avg Confidence: {metrics['avg_confidence']:.2%}",
481
+ f" Avg Query Time: {metrics['avg_query_time']:.2f}s",
482
+ f" Citation Rate: {metrics['citation_rate']:.2%}",
483
+ )
484
+
485
+ # Category breakdown
486
+ if aggregate["by_category"]:
487
+ console.print(f"\n[bold]Performance by Category:[/bold]")
488
+ cat_table = Table(show_header=True, header_style="bold magenta")
489
+ cat_table.add_column("Category", style="cyan")
490
+ cat_table.add_column("Count", justify="right")
491
+ cat_table.add_column("Retrieval", justify="right")
492
+ cat_table.add_column("Relevance", justify="right")
493
+
494
+ for category, data in sorted(aggregate["by_category"].items()):
495
+ cat_table.add_row(
496
+ category,
497
+ str(data["count"]),
498
+ f"{data['retrieval_recall']:.2%}",
499
+ f"{data['answer_relevance']:.2%}",
500
+ )
501
+
502
+ console.print(cat_table)
503
+
504
+
505
+ # ============================================================================
506
+ # Main Evaluation
507
+ # ============================================================================
508
+
509
+ def load_test_questions(questions_file: Path) -> List[Dict[str, Any]]:
510
+ """Load test questions from JSON file."""
511
+ if not questions_file.exists():
512
+ console.print(f"[red]Error: Questions file not found: {questions_file}[/red]")
513
+ sys.exit(1)
514
+
515
+ with open(questions_file, "r") as f:
516
+ questions = json.load(f)
517
+
518
+ console.print(f"[green]✓[/green] Loaded {len(questions)} test questions")
519
+ return questions
520
+
521
+
522
+ def initialize_system() -> EyeWikiQueryEngine:
523
+ """Initialize the RAG system."""
524
+ console.print("[bold]Initializing RAG system...[/bold]")
525
+
526
+ # Load settings
527
+ settings = Settings()
528
+
529
+ # Initialize components
530
+ ollama_client = OllamaClient(
531
+ base_url=settings.ollama_base_url,
532
+ llm_model=settings.llm_model,
533
+ embedding_model=settings.embedding_model,
534
+ )
535
+
536
+ qdrant_manager = QdrantStoreManager(
537
+ collection_name=settings.qdrant_collection_name,
538
+ qdrant_path=settings.qdrant_path,
539
+ vector_size=settings.embedding_dim,
540
+ )
541
+
542
+ retriever = HybridRetriever(
543
+ qdrant_manager=qdrant_manager,
544
+ ollama_client=ollama_client,
545
+ )
546
+
547
+ reranker = CrossEncoderReranker(
548
+ model_name=settings.reranker_model,
549
+ )
550
+
551
+ # Load prompts
552
+ prompts_dir = project_root / "prompts"
553
+ system_prompt_path = prompts_dir / "system_prompt.txt"
554
+ query_prompt_path = prompts_dir / "query_prompt.txt"
555
+ disclaimer_path = prompts_dir / "medical_disclaimer.txt"
556
+
557
+ query_engine = EyeWikiQueryEngine(
558
+ retriever=retriever,
559
+ reranker=reranker,
560
+ llm_client=ollama_client,
561
+ system_prompt_path=system_prompt_path if system_prompt_path.exists() else None,
562
+ query_prompt_path=query_prompt_path if query_prompt_path.exists() else None,
563
+ disclaimer_path=disclaimer_path if disclaimer_path.exists() else None,
564
+ max_context_tokens=settings.max_context_tokens,
565
+ retrieval_k=20,
566
+ rerank_k=5,
567
+ )
568
+
569
+ console.print("[green]✓[/green] System initialized\n")
570
+ return query_engine
571
+
572
+
573
+ def run_evaluation(
574
+ questions_file: Path,
575
+ output_file: Path = None,
576
+ verbose: bool = False,
577
+ ):
578
+ """
579
+ Run evaluation on test questions.
580
+
581
+ Args:
582
+ questions_file: Path to test questions JSON
583
+ output_file: Optional path to save results
584
+ verbose: Print detailed results
585
+ """
586
+ console.print(
587
+ Panel.fit(
588
+ "[bold blue]EyeWiki RAG Evaluation[/bold blue]",
589
+ border_style="blue",
590
+ )
591
+ )
592
+
593
+ # Load questions
594
+ questions = load_test_questions(questions_file)
595
+
596
+ # Initialize system
597
+ query_engine = initialize_system()
598
+
599
+ # Evaluate questions
600
+ results = []
601
+ console.print("[bold]Evaluating questions...[/bold]\n")
602
+
603
+ with Progress(
604
+ SpinnerColumn(),
605
+ TextColumn("[progress.description]{task.description}"),
606
+ BarColumn(),
607
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
608
+ TimeElapsedColumn(),
609
+ console=console,
610
+ ) as progress:
611
+
612
+ task = progress.add_task("Processing...", total=len(questions))
613
+
614
+ for question_data in questions:
615
+ result = evaluate_question(question_data, query_engine)
616
+ results.append(result)
617
+
618
+ if verbose:
619
+ print_question_result(result)
620
+
621
+ progress.update(task, advance=1)
622
+
623
+ # Calculate aggregate metrics
624
+ aggregate = calculate_aggregate_metrics(results)
625
+
626
+ # Print results
627
+ if not verbose:
628
+ console.print("\n[bold]Per-Question Results:[/bold]")
629
+ for result in results:
630
+ print_question_result(result)
631
+
632
+ print_aggregate_results(aggregate)
633
+
634
+ # Save results
635
+ if output_file:
636
+ output_data = {
637
+ "results": results,
638
+ "aggregate": aggregate,
639
+ "timestamp": time.time(),
640
+ }
641
+
642
+ output_file.parent.mkdir(parents=True, exist_ok=True)
643
+ with open(output_file, "w") as f:
644
+ json.dump(output_data, f, indent=2)
645
+
646
+ console.print(f"\n[green]✓[/green] Results saved to {output_file}")
647
+
648
+
649
+ def main():
650
+ """Main entry point."""
651
+ parser = argparse.ArgumentParser(
652
+ description="Evaluate EyeWiki RAG system on test questions"
653
+ )
654
+
655
+ parser.add_argument(
656
+ "--questions",
657
+ type=Path,
658
+ default=project_root / "tests" / "test_questions.json",
659
+ help="Path to test questions JSON file",
660
+ )
661
+
662
+ parser.add_argument(
663
+ "--output",
664
+ type=Path,
665
+ default=None,
666
+ help="Path to save evaluation results (JSON)",
667
+ )
668
+
669
+ parser.add_argument(
670
+ "-v",
671
+ "--verbose",
672
+ action="store_true",
673
+ help="Print detailed results for each question",
674
+ )
675
+
676
+ args = parser.parse_args()
677
+
678
+ try:
679
+ run_evaluation(
680
+ questions_file=args.questions,
681
+ output_file=args.output,
682
+ verbose=args.verbose,
683
+ )
684
+ except KeyboardInterrupt:
685
+ console.print("\n[yellow]Evaluation interrupted by user[/yellow]")
686
+ sys.exit(1)
687
+ except Exception as e:
688
+ console.print(f"\n[red]Error: {e}[/red]")
689
+ import traceback
690
+
691
+ traceback.print_exc()
692
+ sys.exit(1)
693
+
694
+
695
+ if __name__ == "__main__":
696
+ main()
scripts/run_server.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Server startup script with pre-flight checks.
4
+
5
+ Usage:
6
+ python scripts/run_server.py
7
+ python scripts/run_server.py --port 8080 --reload
8
+ python scripts/run_server.py --host 0.0.0.0 --port 8000
9
+ """
10
+
11
+ import argparse
12
+ import sys
13
+ import time
14
+ from pathlib import Path
15
+
16
+ import requests
17
+ from rich.console import Console
18
+ from rich.panel import Panel
19
+ from rich.table import Table
20
+
21
+ # Add project root to path
22
+ project_root = Path(__file__).parent.parent
23
+ sys.path.insert(0, str(project_root))
24
+
25
+ from config.settings import LLMProvider, Settings
26
+
27
+
28
+ console = Console()
29
+
30
+
31
+ def parse_args():
32
+ """Parse command line arguments."""
33
+ parser = argparse.ArgumentParser(
34
+ description="Start EyeWiki RAG API server with pre-flight checks"
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--host",
39
+ type=str,
40
+ default="0.0.0.0",
41
+ help="Host to bind (default: 0.0.0.0)",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--port",
46
+ type=int,
47
+ default=8000,
48
+ help="Port number (default: 8000)",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--reload",
53
+ action="store_true",
54
+ help="Enable hot reload for development",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--skip-checks",
59
+ action="store_true",
60
+ help="Skip pre-flight checks (not recommended)",
61
+ )
62
+
63
+ return parser.parse_args()
64
+
65
+
66
+ def print_header():
67
+ """Print welcome header."""
68
+ console.print()
69
+ console.print(
70
+ Panel.fit(
71
+ "[bold blue]EyeWiki RAG API Server[/bold blue]\n"
72
+ "[dim]Retrieval-Augmented Generation for Medical Knowledge[/dim]",
73
+ border_style="blue",
74
+ )
75
+ )
76
+ console.print()
77
+
78
+
79
+ def check_ollama(settings: Settings) -> bool:
80
+ """
81
+ Check if Ollama is running and has required models.
82
+
83
+ Args:
84
+ settings: Application settings
85
+
86
+ Returns:
87
+ True if check passed, False otherwise
88
+ """
89
+ console.print("[bold cyan]1. Checking Ollama service...[/bold cyan]")
90
+
91
+ try:
92
+ # Check if Ollama is running
93
+ response = requests.get(f"{settings.ollama_base_url}/api/tags", timeout=5)
94
+ response.raise_for_status()
95
+
96
+ models_data = response.json()
97
+ available_models = [model["name"] for model in models_data.get("models", [])]
98
+
99
+ # Check for required LLM model (embedding model is sentence-transformers, not Ollama)
100
+ required_models = {
101
+ "LLM": settings.llm_model,
102
+ }
103
+
104
+ table = Table(show_header=True, header_style="bold magenta")
105
+ table.add_column("Model Type", style="cyan")
106
+ table.add_column("Required Model", style="yellow")
107
+ table.add_column("Status", style="green")
108
+
109
+ all_found = True
110
+ for model_type, model_name in required_models.items():
111
+ # Check if model name (with or without tag) is in available models
112
+ found = any(
113
+ model_name in model or model in model_name for model in available_models
114
+ )
115
+
116
+ status = "[green]✓ Found[/green]" if found else "[red]✗ Missing[/red]"
117
+ table.add_row(model_type, model_name, status)
118
+
119
+ if not found:
120
+ all_found = False
121
+
122
+ console.print(table)
123
+
124
+ if not all_found:
125
+ console.print(
126
+ "\n[red]Error:[/red] Some required models are missing. "
127
+ "Pull them with:"
128
+ )
129
+ for model_type, model_name in required_models.items():
130
+ if not any(
131
+ model_name in model or model in model_name
132
+ for model in available_models
133
+ ):
134
+ console.print(f" [yellow]ollama pull {model_name}[/yellow]")
135
+ console.print()
136
+ return False
137
+
138
+ console.print("[green]✓ Ollama is running with all required models[/green]\n")
139
+ return True
140
+
141
+ except requests.RequestException as e:
142
+ console.print(f"[red]✗ Failed to connect to Ollama:[/red] {e}")
143
+ console.print(
144
+ f"\nMake sure Ollama is running at [yellow]{settings.ollama_base_url}[/yellow]"
145
+ )
146
+ console.print("Start it with: [yellow]ollama serve[/yellow]\n")
147
+ return False
148
+
149
+
150
+ def check_openai_config(settings: Settings) -> bool:
151
+ """
152
+ Check if OpenAI-compatible API is configured with required API key.
153
+
154
+ Args:
155
+ settings: Application settings
156
+
157
+ Returns:
158
+ True if check passed, False otherwise
159
+ """
160
+ console.print("[bold cyan]1. Checking OpenAI-compatible API configuration...[/bold cyan]")
161
+
162
+ table = Table(show_header=True, header_style="bold magenta")
163
+ table.add_column("Property", style="cyan")
164
+ table.add_column("Value", style="yellow")
165
+ table.add_column("Status", style="green")
166
+
167
+ # Check API key
168
+ has_key = bool(settings.openai_api_key)
169
+ key_display = f"{settings.openai_api_key[:8]}..." if has_key else "(not set)"
170
+ key_status = "[green]✓ Set[/green]" if has_key else "[red]✗ Missing[/red]"
171
+ table.add_row("API Key", key_display, key_status)
172
+
173
+ # Show base URL
174
+ base_url = settings.openai_base_url or "(OpenAI default)"
175
+ table.add_row("Base URL", base_url, "[green]✓[/green]")
176
+
177
+ # Show model
178
+ table.add_row("Model", settings.openai_model, "[green]✓[/green]")
179
+
180
+ console.print(table)
181
+
182
+ if not has_key:
183
+ console.print(
184
+ "\n[red]Error:[/red] API key is required for OpenAI-compatible provider."
185
+ )
186
+ console.print(
187
+ "Set the [yellow]OPENAI_API_KEY[/yellow] environment variable or add it to your [yellow].env[/yellow] file.\n"
188
+ )
189
+ return False
190
+
191
+ console.print("[green]✓ OpenAI-compatible API configuration looks good[/green]\n")
192
+ return True
193
+
194
+
195
+ def check_vector_store(settings: Settings) -> bool:
196
+ """
197
+ Check if vector store exists and has documents.
198
+
199
+ Args:
200
+ settings: Application settings
201
+
202
+ Returns:
203
+ True if check passed, False otherwise
204
+ """
205
+ console.print("[bold cyan]2. Checking vector store...[/bold cyan]")
206
+
207
+ qdrant_path = Path(settings.qdrant_path)
208
+ collection_name = settings.qdrant_collection_name
209
+
210
+ # Check if Qdrant directory exists
211
+ if not qdrant_path.exists():
212
+ console.print(f"[red]✗ Qdrant directory not found:[/red] {qdrant_path}")
213
+ console.print(
214
+ "\nRun the indexing pipeline first:\n"
215
+ " [yellow]python scripts/build_index.py --index-vectors[/yellow]\n"
216
+ )
217
+ return False
218
+
219
+ # Try to connect to Qdrant and check collection
220
+ try:
221
+ from qdrant_client import QdrantClient
222
+
223
+ client = QdrantClient(path=str(qdrant_path))
224
+
225
+ # Check if collection exists
226
+ collections = client.get_collections().collections
227
+ collection_names = [col.name for col in collections]
228
+
229
+ if collection_name not in collection_names:
230
+ console.print(
231
+ f"[red]✗ Collection '{collection_name}' not found[/red]\n"
232
+ f"Available collections: {collection_names}"
233
+ )
234
+ console.print(
235
+ "\nRun the indexing pipeline first:\n"
236
+ " [yellow]python scripts/build_index.py --index-vectors[/yellow]\n"
237
+ )
238
+ return False
239
+
240
+ # Get collection info
241
+ collection_info = client.get_collection(collection_name)
242
+ points_count = collection_info.points_count
243
+
244
+ if points_count == 0:
245
+ console.print(
246
+ f"[yellow]⚠ Collection '{collection_name}' exists but is empty[/yellow]"
247
+ )
248
+ console.print(
249
+ "\nRun the indexing pipeline:\n"
250
+ " [yellow]python scripts/build_index.py --index-vectors[/yellow]\n"
251
+ )
252
+ return False
253
+
254
+ # Print stats
255
+ table = Table(show_header=True, header_style="bold magenta")
256
+ table.add_column("Property", style="cyan")
257
+ table.add_column("Value", style="yellow")
258
+
259
+ table.add_row("Collection", collection_name)
260
+ table.add_row("Location", str(qdrant_path))
261
+ table.add_row("Documents", f"{points_count:,}")
262
+
263
+ console.print(table)
264
+ console.print("[green]✓ Vector store is ready[/green]\n")
265
+ return True
266
+
267
+ except Exception as e:
268
+ console.print(f"[red]✗ Failed to access vector store:[/red] {e}\n")
269
+ return False
270
+
271
+
272
+ def check_required_files() -> bool:
273
+ """
274
+ Check if all required files exist.
275
+
276
+ Returns:
277
+ True if all files exist, False otherwise
278
+ """
279
+ console.print("[bold cyan]3. Checking required files...[/bold cyan]")
280
+
281
+ required_files = {
282
+ "System Prompt": project_root / "prompts" / "system_prompt.txt",
283
+ "Query Prompt": project_root / "prompts" / "query_prompt.txt",
284
+ "Medical Disclaimer": project_root / "prompts" / "medical_disclaimer.txt",
285
+ }
286
+
287
+ table = Table(show_header=True, header_style="bold magenta")
288
+ table.add_column("File", style="cyan")
289
+ table.add_column("Path", style="yellow")
290
+ table.add_column("Status", style="green")
291
+
292
+ all_exist = True
293
+ for name, path in required_files.items():
294
+ exists = path.exists()
295
+ status = "[green]✓ Found[/green]" if exists else "[red]✗ Missing[/red]"
296
+ table.add_row(name, str(path.relative_to(project_root)), status)
297
+
298
+ if not exists:
299
+ all_exist = False
300
+
301
+ console.print(table)
302
+
303
+ if not all_exist:
304
+ console.print(
305
+ "\n[red]Error:[/red] Some required files are missing.\n"
306
+ "Make sure all prompt files are in the [yellow]prompts/[/yellow] directory.\n"
307
+ )
308
+ return False
309
+
310
+ console.print("[green]✓ All required files found[/green]\n")
311
+ return True
312
+
313
+
314
+ def run_preflight_checks(skip_checks: bool = False) -> bool:
315
+ """
316
+ Run all pre-flight checks.
317
+
318
+ Args:
319
+ skip_checks: Skip all checks if True
320
+
321
+ Returns:
322
+ True if all checks passed, False otherwise
323
+ """
324
+ if skip_checks:
325
+ console.print("[yellow]⚠ Skipping pre-flight checks[/yellow]\n")
326
+ return True
327
+
328
+ console.print("[bold yellow]Running Pre-flight Checks...[/bold yellow]\n")
329
+
330
+ # Load settings
331
+ try:
332
+ settings = Settings()
333
+ except Exception as e:
334
+ console.print(f"[red]✗ Failed to load settings:[/red] {e}\n")
335
+ return False
336
+
337
+ console.print(f"[dim]LLM Provider: {settings.llm_provider.value}[/dim]\n")
338
+
339
+ # Check LLM provider (Ollama or OpenAI-compatible)
340
+ if settings.llm_provider == LLMProvider.OLLAMA:
341
+ llm_check = check_ollama(settings)
342
+ else:
343
+ llm_check = check_openai_config(settings)
344
+
345
+ # Run checks
346
+ checks = [
347
+ llm_check,
348
+ check_vector_store(settings),
349
+ check_required_files(),
350
+ ]
351
+
352
+ if not all(checks):
353
+ console.print("[bold red]✗ Pre-flight checks failed[/bold red]")
354
+ console.print("Fix the issues above and try again.\n")
355
+ return False
356
+
357
+ console.print("[bold green]✓ All pre-flight checks passed![/bold green]\n")
358
+ return True
359
+
360
+
361
+ def print_access_urls(host: str, port: int):
362
+ """
363
+ Print access URLs for the server.
364
+
365
+ Args:
366
+ host: Server host
367
+ port: Server port
368
+ """
369
+ # Determine display host
370
+ display_host = "localhost" if host in ["0.0.0.0", "127.0.0.1"] else host
371
+
372
+ table = Table(
373
+ show_header=True,
374
+ header_style="bold magenta",
375
+ title="[bold green]Server Access URLs[/bold green]",
376
+ title_style="bold green",
377
+ )
378
+ table.add_column("Service", style="cyan", width=20)
379
+ table.add_column("URL", style="yellow")
380
+ table.add_column("Description", style="dim")
381
+
382
+ urls = [
383
+ ("API Root", f"http://{display_host}:{port}", "API information"),
384
+ ("Health Check", f"http://{display_host}:{port}/health", "Service health status"),
385
+ (
386
+ "Interactive Docs",
387
+ f"http://{display_host}:{port}/docs",
388
+ "Swagger UI documentation",
389
+ ),
390
+ ("ReDoc", f"http://{display_host}:{port}/redoc", "Alternative API docs"),
391
+ (
392
+ "Gradio UI",
393
+ f"http://{display_host}:{port}/ui",
394
+ "Web chat interface",
395
+ ),
396
+ ]
397
+
398
+ for service, url, description in urls:
399
+ table.add_row(service, url, description)
400
+
401
+ console.print()
402
+ console.print(table)
403
+ console.print()
404
+
405
+ # Print quick start commands
406
+ console.print("[bold cyan]Quick Test Commands:[/bold cyan]")
407
+ console.print(
408
+ f" [dim]# Test health endpoint[/dim]\n"
409
+ f" [yellow]curl http://{display_host}:{port}/health[/yellow]\n"
410
+ )
411
+ console.print(
412
+ f" [dim]# Query the API[/dim]\n"
413
+ f" [yellow]curl -X POST http://{display_host}:{port}/query \\[/yellow]\n"
414
+ f' [yellow] -H "Content-Type: application/json" \\[/yellow]\n'
415
+ f' [yellow] -d \'{{"question": "What is glaucoma?"}}\' [/yellow]\n'
416
+ )
417
+
418
+
419
+ def start_server(host: str, port: int, reload: bool):
420
+ """
421
+ Start the uvicorn server.
422
+
423
+ Args:
424
+ host: Server host
425
+ port: Server port
426
+ reload: Enable hot reload
427
+ """
428
+ console.print("[bold green]Starting server...[/bold green]\n")
429
+
430
+ # Print URLs before starting
431
+ print_access_urls(host, port)
432
+
433
+ # Import uvicorn here to avoid import errors if not installed
434
+ try:
435
+ import uvicorn
436
+ except ImportError:
437
+ console.print("[red]Error:[/red] uvicorn is not installed")
438
+ console.print("Install it with: [yellow]pip install uvicorn[/yellow]\n")
439
+ sys.exit(1)
440
+
441
+ # Start server
442
+ try:
443
+ console.print(
444
+ f"[dim]Server listening on {host}:{port}[/dim]",
445
+ f"[dim](Press CTRL+C to stop)[/dim]\n",
446
+ )
447
+
448
+ uvicorn.run(
449
+ "src.api.main:app",
450
+ host=host,
451
+ port=port,
452
+ reload=reload,
453
+ log_level="info",
454
+ )
455
+
456
+ except KeyboardInterrupt:
457
+ console.print("\n\n[yellow]Server stopped by user[/yellow]")
458
+ except Exception as e:
459
+ console.print(f"\n[red]Error starting server:[/red] {e}")
460
+ sys.exit(1)
461
+
462
+
463
+ def main():
464
+ """Main entry point."""
465
+ args = parse_args()
466
+
467
+ print_header()
468
+
469
+ # Run pre-flight checks
470
+ if not run_preflight_checks(skip_checks=args.skip_checks):
471
+ console.print("[red]Startup aborted due to failed checks[/red]\n")
472
+ sys.exit(1)
473
+
474
+ # Start server
475
+ start_server(host=args.host, port=args.port, reload=args.reload)
476
+
477
+
478
+ if __name__ == "__main__":
479
+ main()
scripts/scrape_eyewiki.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """CLI script to run the EyeWiki crawler."""
3
+
4
+ import argparse
5
+ import asyncio
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+
10
+ # Add parent directory to path
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
+ from src.scraper.eyewiki_crawler import EyeWikiCrawler
14
+ from config.settings import settings
15
+ from rich.console import Console
16
+ from rich.panel import Panel
17
+ from rich.table import Table
18
+
19
+
20
+ def parse_args():
21
+ """Parse command line arguments."""
22
+ parser = argparse.ArgumentParser(
23
+ description="Crawl EyeWiki medical articles",
24
+ formatter_class=argparse.RawDescriptionHelpFormatter,
25
+ epilog="""
26
+ Examples:
27
+ # Crawl up to 100 pages with default settings
28
+ python scripts/scrape_eyewiki.py --max-pages 100
29
+
30
+ # Resume previous crawl
31
+ python scripts/scrape_eyewiki.py --resume
32
+
33
+ # Crawl with depth 3 to custom directory
34
+ python scripts/scrape_eyewiki.py --depth 3 --output-dir ./my_data
35
+
36
+ # Full crawl (no page limit)
37
+ python scripts/scrape_eyewiki.py
38
+ """,
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--max-pages",
43
+ type=int,
44
+ default=None,
45
+ help="Maximum number of pages to crawl (default: unlimited)",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--depth",
50
+ type=int,
51
+ default=2,
52
+ help="Maximum crawl depth (default: 2)",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--output-dir",
57
+ type=str,
58
+ default=None,
59
+ help=f"Output directory for scraped articles (default: {settings.data_raw_path})",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--resume",
64
+ action="store_true",
65
+ help="Resume from previous checkpoint if available",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--delay",
70
+ type=float,
71
+ default=None,
72
+ help=f"Delay between requests in seconds (default: {settings.scraper_delay})",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--timeout",
77
+ type=int,
78
+ default=None,
79
+ help=f"Request timeout in seconds (default: {settings.scraper_timeout})",
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--start-urls",
84
+ type=str,
85
+ nargs="+",
86
+ default=None,
87
+ help="Starting URLs for crawl (default: EyeWiki main page and disease category)",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--checkpoint-file",
92
+ type=str,
93
+ default=None,
94
+ help="Custom checkpoint file path (default: output_dir/crawler_checkpoint.json)",
95
+ )
96
+
97
+ return parser.parse_args()
98
+
99
+
100
+ def print_banner(console: Console):
101
+ """Print welcome banner."""
102
+ banner = """
103
+ [bold cyan]EyeWiki Medical Article Crawler[/bold cyan]
104
+ [dim]Powered by crawl4ai[/dim]
105
+ """
106
+ console.print(Panel(banner, border_style="cyan"))
107
+
108
+
109
+ def print_configuration(console: Console, args, crawler: EyeWikiCrawler):
110
+ """Print crawler configuration."""
111
+ table = Table(title="Crawler Configuration", show_header=False, border_style="blue")
112
+ table.add_column("Setting", style="cyan")
113
+ table.add_column("Value", style="white")
114
+
115
+ table.add_row("Output Directory", str(crawler.output_dir))
116
+ table.add_row("Max Pages", str(args.max_pages) if args.max_pages else "Unlimited")
117
+ table.add_row("Depth", str(args.depth))
118
+ table.add_row("Delay", f"{crawler.delay}s")
119
+ table.add_row("Timeout", f"{crawler.timeout}s")
120
+ table.add_row("Checkpoint File", str(crawler.checkpoint_file))
121
+ table.add_row("Resume Mode", "Yes" if args.resume else "No")
122
+
123
+ console.print(table)
124
+ console.print()
125
+
126
+
127
+ def print_summary(console: Console, crawler: EyeWikiCrawler, elapsed_time: float):
128
+ """Print crawl summary statistics."""
129
+ console.print("\n")
130
+
131
+ # Create summary table
132
+ table = Table(title="Crawl Summary", border_style="green", show_header=True)
133
+ table.add_column("Metric", style="cyan", justify="left")
134
+ table.add_column("Value", style="white", justify="right")
135
+
136
+ # Calculate stats
137
+ pages_per_minute = (crawler.articles_saved / elapsed_time * 60) if elapsed_time > 0 else 0
138
+ success_rate = (
139
+ crawler.articles_saved / len(crawler.visited_urls) * 100
140
+ if crawler.visited_urls
141
+ else 0
142
+ )
143
+
144
+ # Add rows
145
+ table.add_row("Articles Saved", f"{crawler.articles_saved:,}")
146
+ table.add_row("URLs Visited", f"{len(crawler.visited_urls):,}")
147
+ table.add_row("URLs Failed", f"{len(crawler.failed_urls):,}")
148
+ table.add_row("URLs Remaining", f"{len(crawler.to_crawl):,}")
149
+ table.add_row("Success Rate", f"{success_rate:.1f}%")
150
+ table.add_row("Time Elapsed", f"{elapsed_time:.1f}s")
151
+ table.add_row("Pages/Minute", f"{pages_per_minute:.1f}")
152
+
153
+ console.print(table)
154
+
155
+ # Show failed URLs if any
156
+ if crawler.failed_urls:
157
+ console.print("\n[yellow]Failed URLs:[/yellow]")
158
+ for i, (url, error) in enumerate(list(crawler.failed_urls.items())[:5], 1):
159
+ console.print(f" {i}. [red]{url}[/red]")
160
+ console.print(f" [dim]{error}[/dim]")
161
+
162
+ if len(crawler.failed_urls) > 5:
163
+ console.print(f" [dim]... and {len(crawler.failed_urls) - 5} more[/dim]")
164
+
165
+ # Final status
166
+ console.print()
167
+ if crawler.articles_saved > 0:
168
+ console.print("[bold green]Crawl completed successfully![/bold green]")
169
+ console.print(f"[green]Articles saved to: {crawler.output_dir}[/green]")
170
+ else:
171
+ console.print("[bold yellow]No articles were saved.[/bold yellow]")
172
+ console.print("[yellow]Check the logs above for errors.[/yellow]")
173
+
174
+
175
+ async def main():
176
+ """Main entry point for the crawler script."""
177
+ # Parse arguments
178
+ args = parse_args()
179
+
180
+ # Initialize console
181
+ console = Console()
182
+
183
+ # Print banner
184
+ print_banner(console)
185
+
186
+ # Prepare output directory
187
+ output_dir = Path(args.output_dir) if args.output_dir else Path(settings.data_raw_path)
188
+ output_dir.mkdir(parents=True, exist_ok=True)
189
+
190
+ # Prepare checkpoint file
191
+ checkpoint_file = None
192
+ if args.checkpoint_file:
193
+ checkpoint_file = Path(args.checkpoint_file)
194
+
195
+ # If not resuming and checkpoint exists, ask user
196
+ if not args.resume and checkpoint_file and checkpoint_file.exists():
197
+ console.print("[yellow]Warning: Checkpoint file exists![/yellow]")
198
+ console.print(f"[yellow]File: {checkpoint_file}[/yellow]")
199
+ console.print("[yellow]Use --resume to continue from checkpoint, or it will be overwritten.[/yellow]")
200
+ console.print()
201
+
202
+ # Initialize crawler
203
+ try:
204
+ crawler = EyeWikiCrawler(
205
+ base_url="https://eyewiki.org",
206
+ output_dir=output_dir,
207
+ checkpoint_file=checkpoint_file,
208
+ delay=args.delay if args.delay is not None else settings.scraper_delay,
209
+ timeout=args.timeout if args.timeout is not None else settings.scraper_timeout,
210
+ )
211
+ except Exception as e:
212
+ console.print(f"[bold red]Error initializing crawler: {e}[/bold red]")
213
+ return 1
214
+
215
+ # Print configuration
216
+ print_configuration(console, args, crawler)
217
+
218
+ # Prepare start URLs
219
+ start_urls = args.start_urls
220
+ if not start_urls and not args.resume:
221
+ # Start with popular medical articles that link to many other articles
222
+ start_urls = [
223
+ "https://eyewiki.org/Category:Articles"
224
+ ]
225
+ console.print("[blue]Using default start URLs (seed articles):[/blue]")
226
+ for url in start_urls:
227
+ console.print(f" - {url}")
228
+ console.print()
229
+
230
+ # Start crawling
231
+ start_time = time.time()
232
+
233
+ try:
234
+ await crawler.crawl(
235
+ max_pages=args.max_pages,
236
+ depth=args.depth,
237
+ start_urls=start_urls,
238
+ )
239
+
240
+ elapsed_time = time.time() - start_time
241
+
242
+ # Print summary
243
+ print_summary(console, crawler, elapsed_time)
244
+
245
+ return 0
246
+
247
+ except KeyboardInterrupt:
248
+ elapsed_time = time.time() - start_time
249
+ console.print("\n[yellow]Crawl interrupted by user (Ctrl+C)[/yellow]")
250
+ console.print("[yellow]Saving checkpoint...[/yellow]")
251
+
252
+ # Crawler already saves checkpoint in its exception handler
253
+ # Just print summary
254
+ print_summary(console, crawler, elapsed_time)
255
+
256
+ console.print("\n[blue]You can resume with:[/blue]")
257
+ console.print(f"[blue] python scripts/scrape_eyewiki.py --resume[/blue]")
258
+
259
+ return 130 # Standard exit code for SIGINT
260
+
261
+ except Exception as e:
262
+ elapsed_time = time.time() - start_time
263
+ console.print(f"\n[bold red]Unexpected error: {e}[/bold red]")
264
+
265
+ # Print summary of what was accomplished
266
+ print_summary(console, crawler, elapsed_time)
267
+
268
+ return 1
269
+
270
+
271
+ if __name__ == "__main__":
272
+ try:
273
+ exit_code = asyncio.run(main())
274
+ sys.exit(exit_code)
275
+ except Exception as e:
276
+ console = Console()
277
+ console.print(f"[bold red]Fatal error: {e}[/bold red]")
278
+ sys.exit(1)
src/__init__.py ADDED
File without changes
src/api/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """API module for EyeWiki RAG system."""
2
+
3
+ from src.api.main import app
4
+
5
+ __all__ = ["app"]
src/api/gradio_ui.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio UI for EyeWiki RAG system."""
2
+
3
+ import logging
4
+ from typing import List, Dict
5
+
6
+ import gradio as gr
7
+
8
+ from src.rag.query_engine import QueryResponse
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ # ============================================================================
15
+ # Example Questions
16
+ # ============================================================================
17
+
18
+ EXAMPLE_QUESTIONS = [
19
+ "What are the symptoms of glaucoma?",
20
+ "How is diabetic retinopathy treated?",
21
+ "What causes macular degeneration?",
22
+ "What is the difference between open-angle and angle-closure glaucoma?",
23
+ "What are the risk factors for cataracts?",
24
+ "How is retinal detachment diagnosed?",
25
+ ]
26
+
27
+
28
+ # ============================================================================
29
+ # Styling
30
+ # ============================================================================
31
+
32
+ CUSTOM_CSS = """
33
+ /* Main container */
34
+ .gradio-container {
35
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
36
+ max-width: 1400px;
37
+ margin: 0 auto;
38
+ }
39
+
40
+ /* Header */
41
+ .header {
42
+ background: linear-gradient(135deg, #1e3a8a 0%, #3b82f6 100%);
43
+ color: white;
44
+ padding: 2rem;
45
+ border-radius: 12px;
46
+ margin-bottom: 2rem;
47
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
48
+ }
49
+
50
+ .header h1 {
51
+ margin: 0 0 0.5rem 0;
52
+ font-size: 2rem;
53
+ font-weight: 700;
54
+ }
55
+
56
+ .header p {
57
+ margin: 0;
58
+ font-size: 1rem;
59
+ opacity: 0.95;
60
+ }
61
+
62
+ /* Chat interface */
63
+ .chatbot {
64
+ border: 1px solid #e5e7eb;
65
+ border-radius: 8px;
66
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
67
+ }
68
+
69
+ /* Text input */
70
+ .input-text textarea {
71
+ border: 2px solid #e5e7eb;
72
+ border-radius: 8px;
73
+ font-size: 1rem;
74
+ padding: 0.75rem;
75
+ transition: border-color 0.2s;
76
+ }
77
+
78
+ .input-text textarea:focus {
79
+ border-color: #3b82f6;
80
+ outline: none;
81
+ }
82
+
83
+ /* Buttons */
84
+ .primary-button {
85
+ background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
86
+ color: white;
87
+ border: none;
88
+ border-radius: 8px;
89
+ padding: 0.75rem 1.5rem;
90
+ font-weight: 600;
91
+ cursor: pointer;
92
+ transition: transform 0.1s, box-shadow 0.2s;
93
+ }
94
+
95
+ .primary-button:hover {
96
+ transform: translateY(-1px);
97
+ box-shadow: 0 4px 8px rgba(59, 130, 246, 0.3);
98
+ }
99
+
100
+ .secondary-button {
101
+ background: white;
102
+ color: #374151;
103
+ border: 1px solid #d1d5db;
104
+ border-radius: 8px;
105
+ padding: 0.5rem 1rem;
106
+ font-weight: 500;
107
+ cursor: pointer;
108
+ transition: background 0.2s;
109
+ }
110
+
111
+ .secondary-button:hover {
112
+ background: #f9fafb;
113
+ }
114
+
115
+ /* Sources accordion */
116
+ .accordion {
117
+ border: 1px solid #e5e7eb;
118
+ border-radius: 8px;
119
+ margin-top: 1rem;
120
+ }
121
+
122
+ /* Disclaimer */
123
+ .disclaimer {
124
+ background: #fef3c7;
125
+ border-left: 4px solid #f59e0b;
126
+ padding: 1rem;
127
+ border-radius: 8px;
128
+ margin-top: 2rem;
129
+ font-size: 0.875rem;
130
+ line-height: 1.5;
131
+ }
132
+
133
+ .disclaimer strong {
134
+ color: #92400e;
135
+ font-weight: 700;
136
+ }
137
+
138
+ /* Settings sidebar */
139
+ .settings {
140
+ background: #f9fafb;
141
+ border: 1px solid #e5e7eb;
142
+ border-radius: 8px;
143
+ padding: 1rem;
144
+ }
145
+
146
+ /* Example questions */
147
+ .examples {
148
+ background: white;
149
+ border: 1px solid #e5e7eb;
150
+ border-radius: 8px;
151
+ padding: 1rem;
152
+ margin-bottom: 1rem;
153
+ }
154
+
155
+ .example-btn {
156
+ display: block;
157
+ width: 100%;
158
+ text-align: left;
159
+ padding: 0.75rem;
160
+ margin-bottom: 0.5rem;
161
+ background: white;
162
+ border: 1px solid #e5e7eb;
163
+ border-radius: 6px;
164
+ cursor: pointer;
165
+ transition: all 0.2s;
166
+ font-size: 0.875rem;
167
+ }
168
+
169
+ .example-btn:hover {
170
+ background: #f0f9ff;
171
+ border-color: #3b82f6;
172
+ transform: translateX(4px);
173
+ }
174
+
175
+ /* Confidence indicator */
176
+ .confidence-high {
177
+ color: #059669;
178
+ font-weight: 600;
179
+ }
180
+
181
+ .confidence-medium {
182
+ color: #d97706;
183
+ font-weight: 600;
184
+ }
185
+
186
+ .confidence-low {
187
+ color: #dc2626;
188
+ font-weight: 600;
189
+ }
190
+
191
+ /* Source cards */
192
+ .source-card {
193
+ background: white;
194
+ border: 1px solid #e5e7eb;
195
+ border-radius: 6px;
196
+ padding: 0.75rem;
197
+ margin-bottom: 0.5rem;
198
+ transition: box-shadow 0.2s;
199
+ }
200
+
201
+ .source-card:hover {
202
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
203
+ }
204
+
205
+ .source-title {
206
+ font-weight: 600;
207
+ color: #1e40af;
208
+ margin-bottom: 0.25rem;
209
+ }
210
+
211
+ .source-score {
212
+ font-size: 0.75rem;
213
+ color: #6b7280;
214
+ }
215
+ """
216
+
217
+
218
+ # ============================================================================
219
+ # Formatting Functions
220
+ # ============================================================================
221
+
222
+ def format_sources_html(response: QueryResponse, max_sources: int = 5) -> str:
223
+ """
224
+ Format sources as HTML.
225
+
226
+ Args:
227
+ response: Query response with sources
228
+ max_sources: Maximum number of sources to display
229
+
230
+ Returns:
231
+ HTML string with formatted sources
232
+ """
233
+ if not response.sources:
234
+ return "<p style='color: #6b7280; font-style: italic;'>No sources available.</p>"
235
+
236
+ html_parts = []
237
+
238
+ # Limit sources
239
+ sources = response.sources[:max_sources]
240
+
241
+ for i, source in enumerate(sources, 1):
242
+ # Confidence indicator
243
+ score_pct = int(source.relevance_score * 100)
244
+ if source.relevance_score >= 0.7:
245
+ score_class = "confidence-high"
246
+ elif source.relevance_score >= 0.5:
247
+ score_class = "confidence-medium"
248
+ else:
249
+ score_class = "confidence-low"
250
+
251
+ html = f"""
252
+ <div class="source-card">
253
+ <div class="source-title">
254
+ {i}. <a href="{source.url}" target="_blank" style="text-decoration: none;">
255
+ {source.title}
256
+ </a>
257
+ </div>
258
+ {f'<div style="font-size: 0.875rem; color: #6b7280; margin-bottom: 0.25rem;">Section: {source.section}</div>' if source.section else ''}
259
+ <div class="source-score">
260
+ Relevance: <span class="{score_class}">{score_pct}%</span>
261
+ </div>
262
+ </div>
263
+ """
264
+ html_parts.append(html)
265
+
266
+ return "\n".join(html_parts)
267
+
268
+
269
+ def format_confidence_text(confidence: float) -> str:
270
+ """
271
+ Format confidence as text.
272
+
273
+ Args:
274
+ confidence: Confidence score (0-1)
275
+
276
+ Returns:
277
+ Formatted confidence string
278
+ """
279
+ pct = int(confidence * 100)
280
+
281
+ if confidence >= 0.7:
282
+ emoji = "✅"
283
+ label = "High Confidence"
284
+ elif confidence >= 0.5:
285
+ emoji = "⚠️"
286
+ label = "Medium Confidence"
287
+ else:
288
+ emoji = "⚡"
289
+ label = "Low Confidence"
290
+
291
+ return f"{emoji} {label} ({pct}%)"
292
+
293
+
294
+ # ============================================================================
295
+ # Chat Interface Functions
296
+ # ============================================================================
297
+
298
+ def process_question(
299
+ question: str,
300
+ history: List[Dict[str, str]],
301
+ include_sources: bool,
302
+ max_sources: int,
303
+ query_engine_getter,
304
+ ) -> tuple[List[Dict[str, str]], str]:
305
+ """
306
+ Process a user question and update chat history.
307
+
308
+ Args:
309
+ question: User's question
310
+ history: Chat history (list of message dicts with 'role' and 'content')
311
+ include_sources: Whether to include sources
312
+ max_sources: Maximum number of sources to show
313
+ query_engine_getter: Callable that returns query engine instance
314
+
315
+ Returns:
316
+ Updated history and sources HTML
317
+ """
318
+ if not question or not question.strip():
319
+ return history, ""
320
+
321
+ # Get query engine
322
+ query_engine = query_engine_getter()
323
+ print(query_engine)
324
+ if not query_engine:
325
+ error_msg = "System is still initializing. Please wait a moment and try again."
326
+ history.append({"role": "user", "content": question})
327
+ history.append({"role": "assistant", "content": error_msg})
328
+ return history, ""
329
+
330
+ try:
331
+ # Query the engine
332
+ response = query_engine.query(
333
+ question=question,
334
+ include_sources=include_sources,
335
+ )
336
+
337
+ # Format answer with confidence
338
+ confidence_text = format_confidence_text(response.confidence)
339
+ answer = f"**{confidence_text}**\n\n{response.answer}"
340
+
341
+ # Add disclaimer if present (without "educational purposes" text)
342
+ if response.disclaimer and not any(word in response.disclaimer.lower() for word in ['educational', 'education']):
343
+ answer += f"\n\n---\n\n{response.disclaimer}"
344
+
345
+ # Update history with message dicts
346
+ history.append({"role": "user", "content": question})
347
+ history.append({"role": "assistant", "content": answer})
348
+
349
+ # Format sources
350
+ sources_html = format_sources_html(response, max_sources) if include_sources else ""
351
+
352
+ return history, sources_html
353
+
354
+ except Exception as e:
355
+ logger.error(f"Error processing question: {e}", exc_info=True)
356
+ error_msg = f"Sorry, I encountered an error processing your question: {str(e)}"
357
+ history.append({"role": "user", "content": question})
358
+ history.append({"role": "assistant", "content": error_msg})
359
+ return history, ""
360
+
361
+
362
+ def clear_chat() -> tuple[List, str]:
363
+ """
364
+ Clear chat history.
365
+
366
+ Returns:
367
+ Empty history and sources
368
+ """
369
+ return [], ""
370
+
371
+
372
+ def load_example(example: str) -> str:
373
+ """
374
+ Load an example question.
375
+
376
+ Args:
377
+ example: Example question text
378
+
379
+ Returns:
380
+ The example question
381
+ """
382
+ return example
383
+
384
+
385
+ # ============================================================================
386
+ # Gradio Interface
387
+ # ============================================================================
388
+
389
+ def create_gradio_interface(query_engine_getter) -> gr.Blocks:
390
+ """
391
+ Create Gradio interface for EyeWiki RAG.
392
+
393
+ Args:
394
+ query_engine_getter: Callable that returns the query engine instance
395
+
396
+ Returns:
397
+ Gradio Blocks interface
398
+ """
399
+ with gr.Blocks(
400
+ css=CUSTOM_CSS,
401
+ title="EyeWiki Medical Assistant",
402
+ theme=gr.themes.Soft(
403
+ primary_hue="blue",
404
+ secondary_hue="gray",
405
+ neutral_hue="slate",
406
+ ),
407
+ ) as interface:
408
+
409
+ # Header
410
+ gr.HTML("""
411
+ <div class="header">
412
+ <h1>🏥 EyeWiki Medical Assistant</h1>
413
+ <p>Ask questions about ophthalmology conditions, treatments, and procedures</p>
414
+ </div>
415
+ """)
416
+
417
+ with gr.Row():
418
+ # Main content (left side)
419
+ with gr.Column(scale=3):
420
+
421
+ # Chat interface
422
+ chatbot = gr.Chatbot(
423
+ label="Conversation",
424
+ height=500,
425
+ elem_classes=["chatbot"],
426
+ show_label=False,
427
+ avatar_images=(None, "🏥"),
428
+ )
429
+
430
+ # Input
431
+ with gr.Row():
432
+ question_input = gr.Textbox(
433
+ placeholder="Ask a question about eye health...",
434
+ label="Your Question",
435
+ lines=2,
436
+ elem_classes=["input-text"],
437
+ scale=4,
438
+ )
439
+
440
+ with gr.Row():
441
+ submit_btn = gr.Button(
442
+ "Send",
443
+ variant="primary",
444
+ elem_classes=["primary-button"],
445
+ scale=1,
446
+ )
447
+ clear_btn = gr.Button(
448
+ "Clear",
449
+ elem_classes=["secondary-button"],
450
+ scale=1,
451
+ )
452
+
453
+ # Sources accordion
454
+ with gr.Accordion("📚 Sources", open=False, elem_classes=["accordion"]):
455
+ sources_display = gr.HTML(
456
+ value="<p style='color: #6b7280; font-style: italic;'>Sources will appear here after asking a question.</p>"
457
+ )
458
+
459
+ # Medical disclaimer
460
+ gr.HTML("""
461
+ <div class="disclaimer">
462
+ <strong>⚠️ Medical Disclaimer:</strong> This information is sourced from EyeWiki,
463
+ a resource of the American Academy of Ophthalmology (AAO). It is not a substitute
464
+ for professional medical advice, diagnosis, or treatment. AI systems can make errors.
465
+ Always consult with a qualified ophthalmologist or eye care professional for medical
466
+ concerns and verify any critical information with authoritative sources.
467
+ </div>
468
+ """)
469
+
470
+ # Sidebar (right side)
471
+ with gr.Column(scale=1, elem_classes=["settings"]):
472
+
473
+ gr.Markdown("### ⚙️ Settings")
474
+
475
+ include_sources = gr.Checkbox(
476
+ label="Show sources",
477
+ value=True,
478
+ info="Include source citations in responses"
479
+ )
480
+
481
+ max_sources = gr.Slider(
482
+ minimum=1,
483
+ maximum=10,
484
+ value=5,
485
+ step=1,
486
+ label="Max sources",
487
+ info="Maximum number of sources to display"
488
+ )
489
+
490
+ gr.Markdown("---")
491
+ gr.Markdown("### 💡 Example Questions")
492
+
493
+ # Example buttons
494
+ example_buttons = []
495
+ for example in EXAMPLE_QUESTIONS:
496
+ btn = gr.Button(
497
+ example,
498
+ elem_classes=["example-btn"],
499
+ size="sm",
500
+ )
501
+ example_buttons.append(btn)
502
+
503
+ gr.Markdown("---")
504
+ gr.Markdown("""
505
+ ### 📖 About
506
+
507
+ **EyeWiki RAG System** - Powered by:
508
+ - Hybrid retrieval (semantic + keyword search)
509
+ - Cross-encoder reranking for precision
510
+ - Local LLM inference (GPU-accelerated)
511
+ - EyeWiki knowledge base (AAO)
512
+
513
+ All processing happens locally on your machine.
514
+ """)
515
+
516
+ # Event handlers
517
+ submit_event = submit_btn.click(
518
+ fn=lambda q, h, inc, max_s: process_question(q, h, inc, max_s, query_engine_getter),
519
+ inputs=[question_input, chatbot, include_sources, max_sources],
520
+ outputs=[chatbot, sources_display],
521
+ ).then(
522
+ fn=lambda: "",
523
+ outputs=[question_input],
524
+ )
525
+
526
+ question_input.submit(
527
+ fn=lambda q, h, inc, max_s: process_question(q, h, inc, max_s, query_engine_getter),
528
+ inputs=[question_input, chatbot, include_sources, max_sources],
529
+ outputs=[chatbot, sources_display],
530
+ ).then(
531
+ fn=lambda: "",
532
+ outputs=[question_input],
533
+ )
534
+
535
+ clear_btn.click(
536
+ fn=clear_chat,
537
+ outputs=[chatbot, sources_display],
538
+ )
539
+
540
+ # Example button handlers
541
+ for btn in example_buttons:
542
+ btn.click(
543
+ fn=load_example,
544
+ inputs=[btn],
545
+ outputs=[question_input],
546
+ )
547
+
548
+ return interface
src/api/main.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for EyeWiki RAG system."""
2
+
3
+ import logging
4
+ import time
5
+ from contextlib import asynccontextmanager
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ from fastapi import FastAPI, HTTPException, Request, status
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import StreamingResponse
12
+ from pydantic import BaseModel, Field
13
+ import gradio as gr
14
+
15
+ from src.api.gradio_ui import create_gradio_interface
16
+ from config.settings import LLMProvider, Settings
17
+ from src.llm.llm_client import LLMClient
18
+ from src.llm.ollama_client import OllamaClient
19
+ from src.llm.openai_client import OpenAIClient
20
+ from src.llm.sentence_transformer_client import SentenceTransformerClient
21
+ from src.rag.query_engine import EyeWikiQueryEngine, QueryResponse
22
+ from src.rag.reranker import CrossEncoderReranker
23
+ from src.rag.retriever import HybridRetriever
24
+ from src.vectorstore.qdrant_store import QdrantStoreManager
25
+
26
+
27
+ # Configure logging
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
31
+ )
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # ============================================================================
36
+ # Request/Response Models
37
+ # ============================================================================
38
+
39
+ class QueryRequest(BaseModel):
40
+ """
41
+ Request model for query endpoint.
42
+
43
+ Attributes:
44
+ question: User's question
45
+ include_sources: Whether to include source information
46
+ filters: Optional metadata filters (disease_name, icd_codes, etc.)
47
+ """
48
+ question: str = Field(..., min_length=3, description="User's question")
49
+ include_sources: bool = Field(default=True, description="Include source documents")
50
+ filters: Optional[dict] = Field(default=None, description="Metadata filters")
51
+
52
+
53
+ class StreamQueryRequest(BaseModel):
54
+ """
55
+ Request model for streaming query endpoint.
56
+
57
+ Attributes:
58
+ question: User's question
59
+ filters: Optional metadata filters
60
+ """
61
+ question: str = Field(..., min_length=3, description="User's question")
62
+ filters: Optional[dict] = Field(default=None, description="Metadata filters")
63
+
64
+
65
+ class HealthResponse(BaseModel):
66
+ """
67
+ Response model for health check.
68
+
69
+ Attributes:
70
+ status: Overall status (healthy/unhealthy)
71
+ llm: LLM service status
72
+ qdrant: Qdrant service status
73
+ query_engine: Query engine initialization status
74
+ timestamp: Check timestamp
75
+ """
76
+ status: str = Field(..., description="Overall status")
77
+ llm: dict = Field(..., description="LLM service status")
78
+ qdrant: dict = Field(..., description="Qdrant service status")
79
+ query_engine: dict = Field(..., description="Query engine status")
80
+ timestamp: float = Field(..., description="Unix timestamp")
81
+
82
+
83
+ class StatsResponse(BaseModel):
84
+ """
85
+ Response model for statistics endpoint.
86
+
87
+ Attributes:
88
+ collection_info: Qdrant collection information
89
+ pipeline_config: Query engine pipeline configuration
90
+ documents_indexed: Number of indexed documents
91
+ timestamp: Stats timestamp
92
+ """
93
+ collection_info: dict = Field(..., description="Collection information")
94
+ pipeline_config: dict = Field(..., description="Pipeline configuration")
95
+ documents_indexed: int = Field(..., description="Number of indexed documents")
96
+ timestamp: float = Field(..., description="Unix timestamp")
97
+
98
+
99
+ class ErrorResponse(BaseModel):
100
+ """
101
+ Error response model.
102
+
103
+ Attributes:
104
+ error: Error message
105
+ detail: Optional detailed error information
106
+ timestamp: Error timestamp
107
+ """
108
+ error: str = Field(..., description="Error message")
109
+ detail: Optional[str] = Field(default=None, description="Error details")
110
+ timestamp: float = Field(..., description="Unix timestamp")
111
+
112
+
113
+ # ============================================================================
114
+ # Global State
115
+ # ============================================================================
116
+
117
+ class AppState:
118
+ """Application state container."""
119
+
120
+ def __init__(self):
121
+ self.settings: Optional[Settings] = None
122
+ self.llm_client: Optional[LLMClient] = None
123
+ self.embedding_client: Optional[SentenceTransformerClient] = None
124
+ self.qdrant_manager: Optional[QdrantStoreManager] = None
125
+ self.retriever: Optional[HybridRetriever] = None
126
+ self.reranker: Optional[CrossEncoderReranker] = None
127
+ self.query_engine: Optional[EyeWikiQueryEngine] = None
128
+ self.initialized: bool = False
129
+ self.initialization_error: Optional[str] = None
130
+
131
+
132
+ app_state = AppState()
133
+
134
+
135
+ # ============================================================================
136
+ # Lifecycle Management
137
+ # ============================================================================
138
+
139
+ @asynccontextmanager
140
+ async def lifespan(app: FastAPI):
141
+ """
142
+ Application lifespan manager.
143
+
144
+ Handles startup and shutdown events.
145
+ """
146
+ # Startup
147
+ logger.info("Starting EyeWiki RAG API...")
148
+
149
+ try:
150
+ # Load settings
151
+ logger.info("Loading settings...")
152
+ app_state.settings = Settings()
153
+
154
+ # Initialize LLM client based on provider
155
+ logger.info(f"Initializing LLM client (provider: {app_state.settings.llm_provider.value})...")
156
+ if app_state.settings.llm_provider == LLMProvider.OPENAI:
157
+ app_state.llm_client = OpenAIClient(
158
+ api_key=app_state.settings.openai_api_key,
159
+ base_url=app_state.settings.openai_base_url,
160
+ model=app_state.settings.openai_model,
161
+ )
162
+ else:
163
+ app_state.llm_client = OllamaClient(
164
+ base_url=app_state.settings.ollama_base_url,
165
+ embedding_model=None, # We use SentenceTransformerClient for embeddings
166
+ llm_model=app_state.settings.llm_model,
167
+ timeout=app_state.settings.ollama_timeout,
168
+ )
169
+
170
+ # Initialize embedding client (sentence-transformers for stable embeddings)
171
+ logger.info("Initializing embedding client...")
172
+ app_state.embedding_client = SentenceTransformerClient(
173
+ model_name=app_state.settings.embedding_model,
174
+ )
175
+ logger.info(f"Embedding model loaded: {app_state.settings.embedding_model}")
176
+
177
+ # Initialize Qdrant manager
178
+ logger.info("Initializing Qdrant manager...")
179
+ app_state.qdrant_manager = QdrantStoreManager(
180
+ collection_name=app_state.settings.qdrant_collection_name,
181
+ path=app_state.settings.qdrant_path,
182
+ embedding_dim=app_state.embedding_client.embedding_dim,
183
+ )
184
+
185
+ # Verify collection exists
186
+ collection_info = app_state.qdrant_manager.get_collection_info()
187
+ if not collection_info:
188
+ raise RuntimeError(
189
+ f"Qdrant collection '{app_state.settings.qdrant_collection_name}' not found. "
190
+ "Please run 'python scripts/build_index.py --index-vectors' first."
191
+ )
192
+
193
+ logger.info(
194
+ f"Qdrant collection loaded: {collection_info['vectors_count']} vectors"
195
+ )
196
+
197
+ # Initialize retriever
198
+ logger.info("Initializing retriever...")
199
+ app_state.retriever = HybridRetriever(
200
+ qdrant_manager=app_state.qdrant_manager,
201
+ embedding_client=app_state.embedding_client,
202
+ )
203
+
204
+ # Initialize reranker
205
+ logger.info("Initializing reranker...")
206
+ app_state.reranker = CrossEncoderReranker(
207
+ model_name=app_state.settings.reranker_model,
208
+ )
209
+
210
+ # Load prompt files
211
+ project_root = Path(__file__).parent.parent.parent
212
+ prompts_dir = project_root / "prompts"
213
+
214
+ system_prompt_path = prompts_dir / "system_prompt.txt"
215
+ query_prompt_path = prompts_dir / "query_prompt.txt"
216
+ disclaimer_path = prompts_dir / "medical_disclaimer.txt"
217
+
218
+ # Verify prompts exist
219
+ if not system_prompt_path.exists():
220
+ logger.warning(f"System prompt not found: {system_prompt_path}")
221
+ system_prompt_path = None
222
+
223
+ if not query_prompt_path.exists():
224
+ logger.warning(f"Query prompt not found: {query_prompt_path}")
225
+ query_prompt_path = None
226
+
227
+ if not disclaimer_path.exists():
228
+ logger.warning(f"Disclaimer not found: {disclaimer_path}")
229
+ disclaimer_path = None
230
+
231
+ # Initialize query engine
232
+ logger.info("Initializing query engine...")
233
+ app_state.query_engine = EyeWikiQueryEngine(
234
+ retriever=app_state.retriever,
235
+ reranker=app_state.reranker,
236
+ llm_client=app_state.llm_client,
237
+ system_prompt_path=system_prompt_path,
238
+ query_prompt_path=query_prompt_path,
239
+ disclaimer_path=disclaimer_path,
240
+ max_context_tokens=app_state.settings.max_context_tokens,
241
+ retrieval_k=20,
242
+ rerank_k=5,
243
+ )
244
+
245
+ app_state.initialized = True
246
+ logger.info("EyeWiki RAG API started successfully")
247
+ logger.info("Gradio UI available at /ui")
248
+
249
+ except Exception as e:
250
+ error_msg = f"Failed to initialize application: {e}"
251
+ logger.error(error_msg, exc_info=True)
252
+ app_state.initialization_error = error_msg
253
+ # Don't raise - allow app to start but endpoints will return errors
254
+
255
+ yield
256
+
257
+ # Shutdown
258
+ logger.info("Shutting down EyeWiki RAG API...")
259
+
260
+ # Cleanup Qdrant client
261
+ if app_state.qdrant_manager:
262
+ try:
263
+ app_state.qdrant_manager.close()
264
+ logger.info("Qdrant client closed")
265
+ except Exception as e:
266
+ logger.error(f"Error closing Qdrant client: {e}")
267
+
268
+
269
+ # ============================================================================
270
+ # FastAPI App
271
+ # ============================================================================
272
+
273
+ app = FastAPI(
274
+ title="EyeWiki RAG API",
275
+ description="Retrieval-Augmented Generation API for EyeWiki medical knowledge base",
276
+ version="1.0.0",
277
+ lifespan=lifespan,
278
+ )
279
+
280
+
281
+ # ============================================================================
282
+ # Middleware
283
+ # ============================================================================
284
+
285
+ # CORS middleware for local development
286
+ app.add_middleware(
287
+ CORSMiddleware,
288
+ allow_origins=["*"], # Configure appropriately for production
289
+ allow_credentials=True,
290
+ allow_methods=["*"],
291
+ allow_headers=["*"],
292
+ )
293
+
294
+
295
+ @app.middleware("http")
296
+ async def log_requests(request: Request, call_next):
297
+ """
298
+ Request logging middleware.
299
+
300
+ Logs all incoming requests with timing information.
301
+ """
302
+ start_time = time.time()
303
+
304
+ # Log request
305
+ logger.info(
306
+ f"Request: {request.method} {request.url.path} "
307
+ f"from {request.client.host if request.client else 'unknown'}"
308
+ )
309
+
310
+ # Process request
311
+ response = await call_next(request)
312
+
313
+ # Log response
314
+ duration = time.time() - start_time
315
+ logger.info(
316
+ f"Response: {response.status_code} "
317
+ f"in {duration:.3f}s"
318
+ )
319
+
320
+ return response
321
+
322
+
323
+ # ============================================================================
324
+ # Helper Functions
325
+ # ============================================================================
326
+
327
+ def check_initialization():
328
+ """
329
+ Check if application is initialized.
330
+
331
+ Raises:
332
+ HTTPException: If app not initialized
333
+ """
334
+ if not app_state.initialized:
335
+ error_detail = app_state.initialization_error or "Application not initialized"
336
+ raise HTTPException(
337
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
338
+ detail=error_detail
339
+ )
340
+
341
+
342
+ # ============================================================================
343
+ # Endpoints
344
+ # ============================================================================
345
+
346
+ @app.get("/")
347
+ async def root():
348
+ """
349
+ Root endpoint.
350
+
351
+ Returns:
352
+ Welcome message with API information
353
+ """
354
+ return {
355
+ "name": "EyeWiki RAG API",
356
+ "version": "1.0.0",
357
+ "description": "Retrieval-Augmented Generation API for EyeWiki medical knowledge base",
358
+ "endpoints": {
359
+ "health": "GET /health",
360
+ "query": "POST /query",
361
+ "stream": "POST /query/stream",
362
+ "stats": "GET /stats",
363
+ "docs": "GET /docs",
364
+ }
365
+ }
366
+
367
+
368
+ @app.get("/health", response_model=HealthResponse)
369
+ async def health_check():
370
+ """
371
+ Health check endpoint.
372
+
373
+ Checks status of:
374
+ - Ollama service
375
+ - Qdrant service
376
+ - Query engine initialization
377
+
378
+ Returns:
379
+ HealthResponse with service statuses
380
+ """
381
+ timestamp = time.time()
382
+
383
+ # Check LLM provider
384
+ llm_status = {"status": "unknown", "detail": None}
385
+ if app_state.llm_client:
386
+ provider = app_state.settings.llm_provider.value if app_state.settings else "unknown"
387
+ llm_status["provider"] = provider
388
+ try:
389
+ if isinstance(app_state.llm_client, OllamaClient):
390
+ health_ok = app_state.llm_client.check_health()
391
+ llm_status["status"] = "healthy" if health_ok else "unhealthy"
392
+ llm_status["model"] = app_state.llm_client.llm_model
393
+ else:
394
+ # For OpenAI-compatible clients, assume healthy if initialized
395
+ llm_status["status"] = "healthy"
396
+ llm_status["model"] = app_state.llm_client.llm_model
397
+ except Exception as e:
398
+ llm_status = {"status": "unhealthy", "detail": str(e), "provider": provider}
399
+ else:
400
+ llm_status = {"status": "not_initialized", "detail": "Client not created"}
401
+
402
+ # Check Qdrant
403
+ qdrant_status = {"status": "unknown", "detail": None}
404
+ if app_state.qdrant_manager:
405
+ try:
406
+ info = app_state.qdrant_manager.get_collection_info()
407
+ if info:
408
+ qdrant_status = {
409
+ "status": "healthy",
410
+ "collection": info["name"],
411
+ "vectors_count": info["vectors_count"],
412
+ }
413
+ else:
414
+ qdrant_status = {
415
+ "status": "unhealthy",
416
+ "detail": "Collection not found"
417
+ }
418
+ except Exception as e:
419
+ qdrant_status = {"status": "unhealthy", "detail": str(e)}
420
+ else:
421
+ qdrant_status = {"status": "not_initialized", "detail": "Manager not created"}
422
+
423
+ # Check query engine
424
+ query_engine_status = {
425
+ "status": "initialized" if app_state.initialized else "not_initialized",
426
+ "error": app_state.initialization_error,
427
+ }
428
+
429
+ # Overall status
430
+ overall_status = "healthy"
431
+ if not app_state.initialized:
432
+ overall_status = "unhealthy"
433
+ elif llm_status["status"] != "healthy" or qdrant_status["status"] != "healthy":
434
+ overall_status = "degraded"
435
+
436
+ return HealthResponse(
437
+ status=overall_status,
438
+ llm=llm_status,
439
+ qdrant=qdrant_status,
440
+ query_engine=query_engine_status,
441
+ timestamp=timestamp,
442
+ )
443
+
444
+
445
+ @app.post("/query", response_model=QueryResponse)
446
+ async def query(request: QueryRequest):
447
+ """
448
+ Main query endpoint.
449
+
450
+ Processes a question using the full RAG pipeline:
451
+ 1. Retrieval (hybrid search)
452
+ 2. Reranking (cross-encoder)
453
+ 3. Context assembly
454
+ 4. LLM generation
455
+
456
+ Args:
457
+ request: QueryRequest with question and options
458
+
459
+ Returns:
460
+ QueryResponse with answer, sources, and disclaimer
461
+
462
+ Raises:
463
+ HTTPException: If service unavailable or query fails
464
+ """
465
+ check_initialization()
466
+
467
+ try:
468
+ logger.info(f"Processing query: '{request.question}'")
469
+
470
+ response = app_state.query_engine.query(
471
+ question=request.question,
472
+ include_sources=request.include_sources,
473
+ filters=request.filters,
474
+ )
475
+
476
+ logger.info(
477
+ f"Query complete: {len(response.sources)} sources, "
478
+ f"confidence: {response.confidence:.2f}"
479
+ )
480
+
481
+ return response
482
+
483
+ except Exception as e:
484
+ logger.error(f"Error processing query: {e}", exc_info=True)
485
+ raise HTTPException(
486
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
487
+ detail=f"Error processing query: {str(e)}"
488
+ )
489
+
490
+
491
+ @app.post("/query/stream")
492
+ async def stream_query(request: StreamQueryRequest):
493
+ """
494
+ Streaming query endpoint.
495
+
496
+ Returns answer as Server-Sent Events (SSE) for real-time streaming.
497
+
498
+ Args:
499
+ request: StreamQueryRequest with question and options
500
+
501
+ Returns:
502
+ StreamingResponse with SSE
503
+
504
+ Raises:
505
+ HTTPException: If service unavailable or query fails
506
+ """
507
+ check_initialization()
508
+
509
+ async def generate():
510
+ """Generate SSE stream."""
511
+ try:
512
+ logger.info(f"Processing streaming query: '{request.question}'")
513
+
514
+ # Stream answer chunks
515
+ for chunk in app_state.query_engine.stream_query(
516
+ question=request.question,
517
+ filters=request.filters,
518
+ ):
519
+ # SSE format: data: <content>\n\n
520
+ yield f"data: {chunk}\n\n"
521
+
522
+ logger.info("Streaming query complete")
523
+
524
+ except Exception as e:
525
+ logger.error(f"Error in streaming query: {e}", exc_info=True)
526
+ yield f"data: [ERROR] {str(e)}\n\n"
527
+
528
+ return StreamingResponse(
529
+ generate(),
530
+ media_type="text/event-stream",
531
+ headers={
532
+ "Cache-Control": "no-cache",
533
+ "Connection": "keep-alive",
534
+ }
535
+ )
536
+
537
+
538
+ @app.get("/stats", response_model=StatsResponse)
539
+ async def get_stats():
540
+ """
541
+ Get index and pipeline statistics.
542
+
543
+ Returns:
544
+ StatsResponse with collection info and pipeline config
545
+
546
+ Raises:
547
+ HTTPException: If service unavailable or stats retrieval fails
548
+ """
549
+ check_initialization()
550
+
551
+ try:
552
+ # Get collection info
553
+ collection_info = app_state.qdrant_manager.get_collection_info()
554
+ if not collection_info:
555
+ raise HTTPException(
556
+ status_code=status.HTTP_404_NOT_FOUND,
557
+ detail="Collection not found"
558
+ )
559
+
560
+ # Get pipeline config
561
+ pipeline_config = app_state.query_engine.get_pipeline_info()
562
+
563
+ return StatsResponse(
564
+ collection_info=collection_info,
565
+ pipeline_config=pipeline_config,
566
+ documents_indexed=collection_info.get("vectors_count", 0),
567
+ timestamp=time.time(),
568
+ )
569
+
570
+ except HTTPException:
571
+ raise
572
+ except Exception as e:
573
+ logger.error(f"Error retrieving stats: {e}", exc_info=True)
574
+ raise HTTPException(
575
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
576
+ detail=f"Error retrieving stats: {str(e)}"
577
+ )
578
+
579
+
580
+ # ============================================================================
581
+ # Error Handlers
582
+ # ============================================================================
583
+
584
+ @app.exception_handler(HTTPException)
585
+ async def http_exception_handler(request: Request, exc: HTTPException):
586
+ """
587
+ Handle HTTP exceptions.
588
+
589
+ Returns:
590
+ JSON error response with proper status code
591
+ """
592
+ return {
593
+ "error": exc.detail,
594
+ "status_code": exc.status_code,
595
+ "timestamp": time.time(),
596
+ }
597
+
598
+
599
+ @app.exception_handler(Exception)
600
+ async def general_exception_handler(request: Request, exc: Exception):
601
+ """
602
+ Handle general exceptions.
603
+
604
+ Returns:
605
+ JSON error response with 500 status
606
+ """
607
+ logger.error(f"Unhandled exception: {exc}", exc_info=True)
608
+
609
+ return {
610
+ "error": "Internal server error",
611
+ "detail": str(exc),
612
+ "status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
613
+ "timestamp": time.time(),
614
+ }
615
+
616
+
617
+ # ============================================================================
618
+ # Mount Gradio UI
619
+ # ============================================================================
620
+
621
+ # Create and mount Gradio interface
622
+ # Gradio will access query_engine through app_state once initialized
623
+ gradio_interface = create_gradio_interface(
624
+ query_engine_getter=lambda: app_state.query_engine
625
+ )
626
+ app = gr.mount_gradio_app(app, gradio_interface, path="/ui")
627
+ logger.info("Gradio UI mounted at /ui")
src/llm/__init__.py ADDED
File without changes
src/llm/llm_client.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base class for LLM clients."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Generator, List, Optional
5
+
6
+
7
+ class LLMClient(ABC):
8
+ """
9
+ Abstract base class for LLM clients.
10
+
11
+ All LLM providers (Ollama, OpenAI-compatible, etc.) must implement
12
+ this interface to be used interchangeably in the RAG pipeline.
13
+
14
+ Implementations must also expose a ``llm_model`` attribute (str)
15
+ identifying the model in use.
16
+ """
17
+
18
+ llm_model: str
19
+
20
+ @abstractmethod
21
+ def generate(
22
+ self,
23
+ prompt: str,
24
+ system_prompt: Optional[str] = None,
25
+ temperature: Optional[float] = None,
26
+ max_tokens: Optional[int] = None,
27
+ stop: Optional[List[str]] = None,
28
+ ) -> str:
29
+ """
30
+ Generate text using the LLM (non-streaming).
31
+
32
+ Args:
33
+ prompt: User prompt
34
+ system_prompt: Optional system prompt
35
+ temperature: Sampling temperature
36
+ max_tokens: Maximum tokens to generate
37
+ stop: Stop sequences
38
+
39
+ Returns:
40
+ Generated text
41
+ """
42
+ ...
43
+
44
+ @abstractmethod
45
+ def stream_generate(
46
+ self,
47
+ prompt: str,
48
+ system_prompt: Optional[str] = None,
49
+ temperature: Optional[float] = None,
50
+ max_tokens: Optional[int] = None,
51
+ stop: Optional[List[str]] = None,
52
+ ) -> Generator[str, None, None]:
53
+ """
54
+ Generate text using the LLM with streaming.
55
+
56
+ Args:
57
+ prompt: User prompt
58
+ system_prompt: Optional system prompt
59
+ temperature: Sampling temperature
60
+ max_tokens: Maximum tokens to generate
61
+ stop: Stop sequences
62
+
63
+ Yields:
64
+ Generated text chunks
65
+ """
66
+ ...
src/llm/ollama_client.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ollama client for embeddings and LLM inference."""
2
+
3
+ import logging
4
+ import time
5
+ from typing import Generator, List, Optional
6
+
7
+ import requests
8
+ from rich.console import Console
9
+
10
+ from config.settings import settings
11
+ from src.llm.llm_client import LLMClient
12
+
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class OllamaConnectionError(Exception):
20
+ """Raised when cannot connect to Ollama."""
21
+
22
+ pass
23
+
24
+
25
+ class OllamaModelNotFoundError(Exception):
26
+ """Raised when requested model is not available."""
27
+
28
+ pass
29
+
30
+
31
+ class OllamaClient(LLMClient):
32
+ """
33
+ Client for interacting with Ollama for embeddings and LLM inference.
34
+
35
+ Features:
36
+ - Embedding generation (single and batch)
37
+ - LLM text generation (streaming and non-streaming)
38
+ - Health checks
39
+ - Automatic retry with exponential backoff
40
+ - Model verification
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ base_url: Optional[str] = None,
46
+ embedding_model: Optional[str] = None,
47
+ llm_model: Optional[str] = None,
48
+ timeout: int = 30,
49
+ max_retries: int = 3,
50
+ ):
51
+ """
52
+ Initialize Ollama client.
53
+
54
+ Args:
55
+ base_url: Ollama API base URL (default: from settings)
56
+ embedding_model: Embedding model name (None to skip, or from settings if not provided)
57
+ llm_model: LLM model name (default: from settings)
58
+ timeout: Request timeout in seconds
59
+ max_retries: Maximum number of retries for failed requests
60
+ """
61
+ self.base_url = (base_url or settings.ollama_base_url).rstrip("/")
62
+ self.embedding_model = embedding_model
63
+ self.llm_model = llm_model or settings.llm_model
64
+ self.timeout = timeout
65
+ self.max_retries = max_retries
66
+
67
+ self.console = Console()
68
+
69
+ # Test connection and verify models
70
+ self._initialize()
71
+
72
+ def _initialize(self):
73
+ """Initialize connection and verify models."""
74
+ # Check if Ollama is running
75
+ if not self.check_health():
76
+ error_msg = (
77
+ f"Cannot connect to Ollama at {self.base_url}. "
78
+ "Please ensure Ollama is running."
79
+ )
80
+ logger.error(error_msg)
81
+ raise OllamaConnectionError(error_msg)
82
+
83
+ self.console.print(f"[green][/green] Connected to Ollama at {self.base_url}")
84
+
85
+ # Verify embedding model (only if specified)
86
+ if self.embedding_model and not self._check_model_exists(self.embedding_model):
87
+ error_msg = (
88
+ f"Embedding model '{self.embedding_model}' not found. "
89
+ f"Please pull it with: ollama pull {self.embedding_model}"
90
+ )
91
+ logger.error(error_msg)
92
+ raise OllamaModelNotFoundError(error_msg)
93
+
94
+ # Get and log embedding model info
95
+ if self.embedding_model:
96
+ embed_info = self._get_model_info(self.embedding_model)
97
+ if embed_info:
98
+ self.console.print(
99
+ f"[green][/green] Embedding model: {self.embedding_model}"
100
+ )
101
+ logger.info(f"Embedding model info: {embed_info}")
102
+
103
+ # Verify LLM model
104
+ if not self._check_model_exists(self.llm_model):
105
+ error_msg = (
106
+ f"LLM model '{self.llm_model}' not found. "
107
+ f"Please pull it with: ollama pull {self.llm_model}"
108
+ )
109
+ logger.error(error_msg)
110
+ raise OllamaModelNotFoundError(error_msg)
111
+
112
+ # Get and log LLM model info
113
+ llm_info = self._get_model_info(self.llm_model)
114
+ if llm_info:
115
+ self.console.print(f"[green][/green] LLM model: {self.llm_model}")
116
+ logger.info(f"LLM model info: {llm_info}")
117
+
118
+ def check_health(self) -> bool:
119
+ """
120
+ Check if Ollama server is running and reachable.
121
+
122
+ Returns:
123
+ True if server is healthy, False otherwise
124
+ """
125
+ try:
126
+ response = requests.get(
127
+ f"{self.base_url}/api/tags", timeout=self.timeout
128
+ )
129
+ return response.status_code == 200
130
+ except requests.exceptions.RequestException as e:
131
+ logger.warning(f"Health check failed: {e}")
132
+ return False
133
+
134
+ def _check_model_exists(self, model_name: str) -> bool:
135
+ """
136
+ Check if a model exists in Ollama.
137
+
138
+ Args:
139
+ model_name: Name of the model to check
140
+
141
+ Returns:
142
+ True if model exists, False otherwise
143
+ """
144
+ try:
145
+ response = requests.get(
146
+ f"{self.base_url}/api/tags", timeout=self.timeout
147
+ )
148
+ if response.status_code == 200:
149
+ data = response.json()
150
+ models = [m["name"] for m in data.get("models", [])]
151
+ # Check both exact match and with :latest tag
152
+ return (
153
+ model_name in models
154
+ or f"{model_name}:latest" in models
155
+ or any(m.startswith(f"{model_name}:") for m in models)
156
+ )
157
+ except requests.exceptions.RequestException as e:
158
+ logger.error(f"Error checking model existence: {e}")
159
+
160
+ return False
161
+
162
+ def _get_model_info(self, model_name: str) -> Optional[dict]:
163
+ """
164
+ Get information about a model.
165
+
166
+ Args:
167
+ model_name: Name of the model
168
+
169
+ Returns:
170
+ Dictionary with model information or None
171
+ """
172
+ try:
173
+ response = requests.post(
174
+ f"{self.base_url}/api/show",
175
+ json={"name": model_name},
176
+ timeout=self.timeout,
177
+ )
178
+ if response.status_code == 200:
179
+ return response.json()
180
+ except requests.exceptions.RequestException as e:
181
+ logger.warning(f"Could not get model info: {e}")
182
+
183
+ return None
184
+
185
+ def _retry_with_backoff(self, func, *args, **kwargs):
186
+ """
187
+ Retry a function with exponential backoff.
188
+
189
+ Args:
190
+ func: Function to retry
191
+ *args: Positional arguments for func
192
+ **kwargs: Keyword arguments for func
193
+
194
+ Returns:
195
+ Function result
196
+
197
+ Raises:
198
+ Last exception if all retries fail
199
+ """
200
+ last_exception = None
201
+
202
+ for attempt in range(self.max_retries):
203
+ try:
204
+ return func(*args, **kwargs)
205
+ except requests.exceptions.RequestException as e:
206
+ last_exception = e
207
+ if attempt < self.max_retries - 1:
208
+ # Exponential backoff: 1s, 2s, 4s, ...
209
+ wait_time = 2**attempt
210
+ logger.warning(
211
+ f"Request failed (attempt {attempt + 1}/{self.max_retries}), "
212
+ f"retrying in {wait_time}s: {e}"
213
+ )
214
+ time.sleep(wait_time)
215
+ else:
216
+ logger.error(f"All {self.max_retries} attempts failed")
217
+
218
+ raise last_exception
219
+
220
+ def embed_text(self, text: str, return_zero_on_failure: bool = False, max_chars: int = 2000) -> List[float]:
221
+ """
222
+ Generate embedding for a single text.
223
+
224
+ Args:
225
+ text: Input text to embed
226
+ return_zero_on_failure: If True, return zero vector instead of raising exception
227
+ max_chars: Maximum characters to send to Ollama (default: 2000, safe limit for WSL2)
228
+
229
+ Returns:
230
+ Embedding vector as list of floats
231
+
232
+ Raises:
233
+ OllamaConnectionError: If request fails after retries (unless return_zero_on_failure=True)
234
+ """
235
+ # Handle empty text
236
+ if not text or not text.strip():
237
+ logger.warning("Empty text provided for embedding, returning zero vector")
238
+ return [0.0] * 768 # Standard embedding dimension
239
+
240
+ # Truncate if too long to prevent context overflow
241
+ original_length = len(text)
242
+ if len(text) > max_chars:
243
+ text = text[:max_chars]
244
+ logger.debug(f"Truncated text from {original_length} to {max_chars} chars for embedding")
245
+
246
+ def _embed():
247
+ response = requests.post(
248
+ f"{self.base_url}/api/embed", # Correct endpoint for Ollama 0.13.2+
249
+ json={"model": self.embedding_model, "input": text}, # Use 'input' not 'prompt'
250
+ timeout=self.timeout,
251
+ )
252
+ response.raise_for_status()
253
+ data = response.json()
254
+ # API returns embeddings array, we want the first one
255
+ return data["embeddings"][0] if "embeddings" in data else data["embedding"]
256
+
257
+ try:
258
+ return self._retry_with_backoff(_embed)
259
+ except requests.exceptions.RequestException as e:
260
+ if return_zero_on_failure:
261
+ logger.warning(f"Failed to generate embedding (text length: {len(text)}), returning zero vector: {e}")
262
+ return [0.0] * 768
263
+ else:
264
+ logger.error(f"Failed to generate embedding: {e}")
265
+ raise OllamaConnectionError(f"Embedding generation failed: {e}")
266
+
267
+ def embed_batch(
268
+ self, texts: List[str], batch_size: int = 1, show_progress: bool = True
269
+ ) -> List[List[float]]:
270
+ """
271
+ Generate embeddings for multiple texts sequentially.
272
+
273
+ Note: batch_size parameter is kept for API compatibility but is ignored.
274
+ Processing is always sequential to avoid overwhelming local Ollama instance.
275
+
276
+ Args:
277
+ texts: List of input texts
278
+ batch_size: Ignored (kept for compatibility)
279
+ show_progress: Show progress bar
280
+
281
+ Returns:
282
+ List of embedding vectors
283
+ """
284
+ import time
285
+
286
+ embeddings = []
287
+ failed_count = 0
288
+
289
+ if show_progress:
290
+ from tqdm import tqdm
291
+ pbar = tqdm(total=len(texts), desc="Generating embeddings", unit="chunk")
292
+
293
+ for i, text in enumerate(texts):
294
+ try:
295
+ if i > 0:
296
+ time.sleep(0.5)
297
+
298
+ # Use return_zero_on_failure to prevent single failures from stopping the entire process
299
+ embedding = self.embed_text(text, return_zero_on_failure=True)
300
+ embeddings.append(embedding)
301
+
302
+ # Check if we got a zero vector (indicates failure)
303
+ if embedding == [0.0] * 768:
304
+ failed_count += 1
305
+
306
+ except Exception as e:
307
+ logger.error(f"Unexpected error embedding text {i}: {e}")
308
+ # Fallback to zero vector
309
+ embeddings.append([0.0] * 768)
310
+ failed_count += 1
311
+
312
+ if show_progress:
313
+ pbar.update(1)
314
+
315
+ if show_progress:
316
+ pbar.close()
317
+
318
+ success_count = len(texts) - failed_count
319
+ logger.info(f"Generated {success_count}/{len(texts)} embeddings successfully")
320
+
321
+ if failed_count > 0:
322
+ logger.warning(f"{failed_count} chunks failed and were assigned zero vectors")
323
+
324
+ return embeddings
325
+
326
+ def generate(
327
+ self,
328
+ prompt: str,
329
+ system_prompt: Optional[str] = None,
330
+ temperature: Optional[float] = None,
331
+ max_tokens: Optional[int] = None,
332
+ stop: Optional[List[str]] = None,
333
+ ) -> str:
334
+ """
335
+ Generate text using LLM (non-streaming).
336
+
337
+ Args:
338
+ prompt: User prompt
339
+ system_prompt: Optional system prompt
340
+ temperature: Sampling temperature (default: from settings)
341
+ max_tokens: Maximum tokens to generate (default: from settings)
342
+ stop: Stop sequences
343
+
344
+ Returns:
345
+ Generated text
346
+
347
+ Raises:
348
+ OllamaConnectionError: If generation fails
349
+ """
350
+ temperature = temperature if temperature is not None else settings.llm_temperature
351
+ max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
352
+
353
+ def _generate():
354
+ payload = {
355
+ "model": self.llm_model,
356
+ "prompt": prompt,
357
+ "stream": False,
358
+ "options": {
359
+ "temperature": temperature,
360
+ "num_predict": max_tokens,
361
+ },
362
+ }
363
+
364
+ if system_prompt:
365
+ payload["system"] = system_prompt
366
+
367
+ if stop:
368
+ payload["options"]["stop"] = stop
369
+
370
+ response = requests.post(
371
+ f"{self.base_url}/api/generate",
372
+ json=payload,
373
+ timeout=self.timeout * 2, # Longer timeout for generation
374
+ )
375
+ response.raise_for_status()
376
+ data = response.json()
377
+ return data["response"]
378
+
379
+ try:
380
+ return self._retry_with_backoff(_generate)
381
+ except requests.exceptions.RequestException as e:
382
+ logger.error(f"Failed to generate text: {e}")
383
+ raise OllamaConnectionError(f"Text generation failed: {e}")
384
+
385
+ def stream_generate(
386
+ self,
387
+ prompt: str,
388
+ system_prompt: Optional[str] = None,
389
+ temperature: Optional[float] = None,
390
+ max_tokens: Optional[int] = None,
391
+ stop: Optional[List[str]] = None,
392
+ ) -> Generator[str, None, None]:
393
+ """
394
+ Generate text using LLM with streaming.
395
+
396
+ Args:
397
+ prompt: User prompt
398
+ system_prompt: Optional system prompt
399
+ temperature: Sampling temperature (default: from settings)
400
+ max_tokens: Maximum tokens to generate (default: from settings)
401
+ stop: Stop sequences
402
+
403
+ Yields:
404
+ Generated text chunks
405
+
406
+ Raises:
407
+ OllamaConnectionError: If generation fails
408
+ """
409
+ temperature = temperature if temperature is not None else settings.llm_temperature
410
+ max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
411
+
412
+ payload = {
413
+ "model": self.llm_model,
414
+ "prompt": prompt,
415
+ "stream": True,
416
+ "options": {
417
+ "temperature": temperature,
418
+ "num_predict": max_tokens,
419
+ },
420
+ }
421
+
422
+ if system_prompt:
423
+ payload["system"] = system_prompt
424
+
425
+ if stop:
426
+ payload["options"]["stop"] = stop
427
+
428
+ try:
429
+ response = requests.post(
430
+ f"{self.base_url}/api/generate",
431
+ json=payload,
432
+ stream=True,
433
+ timeout=self.timeout * 2,
434
+ )
435
+ response.raise_for_status()
436
+
437
+ # Stream responses
438
+ for line in response.iter_lines():
439
+ if line:
440
+ import json
441
+
442
+ data = json.loads(line)
443
+ if "response" in data:
444
+ yield data["response"]
445
+
446
+ # Check if done
447
+ if data.get("done", False):
448
+ break
449
+
450
+ except requests.exceptions.RequestException as e:
451
+ logger.error(f"Failed to stream generate text: {e}")
452
+ raise OllamaConnectionError(f"Streaming generation failed: {e}")
453
+
454
+ def get_available_models(self) -> List[str]:
455
+ """
456
+ Get list of available models in Ollama.
457
+
458
+ Returns:
459
+ List of model names
460
+ """
461
+ try:
462
+ response = requests.get(
463
+ f"{self.base_url}/api/tags", timeout=self.timeout
464
+ )
465
+ if response.status_code == 200:
466
+ data = response.json()
467
+ return [m["name"] for m in data.get("models", [])]
468
+ except requests.exceptions.RequestException as e:
469
+ logger.error(f"Failed to get available models: {e}")
470
+
471
+ return []
472
+
473
+ def pull_model(self, model_name: str) -> bool:
474
+ """
475
+ Pull a model from Ollama registry.
476
+
477
+ Args:
478
+ model_name: Name of model to pull
479
+
480
+ Returns:
481
+ True if successful
482
+
483
+ Note:
484
+ This is a blocking operation that may take a while
485
+ """
486
+ try:
487
+ self.console.print(f"[cyan]Pulling model: {model_name}...[/cyan]")
488
+
489
+ response = requests.post(
490
+ f"{self.base_url}/api/pull",
491
+ json={"name": model_name},
492
+ stream=True,
493
+ timeout=None, # No timeout for pulling
494
+ )
495
+
496
+ # Stream progress
497
+ for line in response.iter_lines():
498
+ if line:
499
+ import json
500
+
501
+ data = json.loads(line)
502
+ status = data.get("status", "")
503
+ if status:
504
+ self.console.print(f" {status}")
505
+
506
+ self.console.print(f"[green][/green] Model pulled: {model_name}")
507
+ return True
508
+
509
+ except requests.exceptions.RequestException as e:
510
+ logger.error(f"Failed to pull model: {e}")
511
+ self.console.print(f"[red][/red] Failed to pull model: {e}")
512
+ return False
src/llm/openai_client.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI-compatible client for LLM inference (supports Groq, DeepSeek, OpenAI, etc.)."""
2
+
3
+ import logging
4
+ from typing import Generator, List, Optional
5
+
6
+ from rich.console import Console
7
+
8
+ from config.settings import settings
9
+ from src.llm.llm_client import LLMClient
10
+
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class OpenAIClientError(Exception):
18
+ """Raised when an OpenAI-compatible API call fails."""
19
+
20
+ pass
21
+
22
+
23
+ class OpenAIClient(LLMClient):
24
+ """
25
+ Client for interacting with OpenAI-compatible APIs for LLM inference.
26
+
27
+ Supports:
28
+ - OpenAI (https://api.openai.com/v1)
29
+ - Groq (https://api.groq.com/openai/v1)
30
+ - DeepSeek (https://api.deepseek.com/v1)
31
+ - Any OpenAI-compatible endpoint
32
+
33
+ Features:
34
+ - Non-streaming and streaming text generation
35
+ - Configurable model, temperature, and max tokens
36
+ - Automatic retry via the openai SDK
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ api_key: Optional[str] = None,
42
+ base_url: Optional[str] = None,
43
+ model: Optional[str] = None,
44
+ temperature: Optional[float] = None,
45
+ max_tokens: Optional[int] = None,
46
+ ):
47
+ """
48
+ Initialize OpenAI-compatible client.
49
+
50
+ Args:
51
+ api_key: API key (default: from settings)
52
+ base_url: Base URL for the API (default: from settings, or OpenAI default)
53
+ model: Model name (default: from settings)
54
+ temperature: Default temperature (default: from settings)
55
+ max_tokens: Default max tokens (default: from settings)
56
+ """
57
+ try:
58
+ from openai import OpenAI
59
+ except ImportError:
60
+ raise ImportError(
61
+ "The 'openai' package is required for OpenAI-compatible providers. "
62
+ "Install it with: pip install openai>=1.0.0"
63
+ )
64
+
65
+ self._api_key = api_key or settings.openai_api_key
66
+ if not self._api_key:
67
+ raise OpenAIClientError(
68
+ "API key is required for OpenAI-compatible provider. "
69
+ "Set OPENAI_API_KEY environment variable or pass api_key parameter."
70
+ )
71
+
72
+ self._base_url = base_url or settings.openai_base_url
73
+ self.llm_model = model or settings.openai_model
74
+ self._temperature = temperature if temperature is not None else settings.llm_temperature
75
+ self._max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
76
+
77
+ self.console = Console()
78
+
79
+ # Initialize OpenAI client
80
+ client_kwargs = {"api_key": self._api_key}
81
+ if self._base_url:
82
+ client_kwargs["base_url"] = self._base_url
83
+
84
+ self._client = OpenAI(**client_kwargs)
85
+
86
+ # Log initialization
87
+ provider_name = self._base_url or "OpenAI (default)"
88
+ self.console.print(f"[green][/green] OpenAI-compatible client initialized")
89
+ self.console.print(f" Provider: {provider_name}")
90
+ self.console.print(f" Model: {self.llm_model}")
91
+ logger.info(f"OpenAI client initialized: provider={provider_name}, model={self.llm_model}")
92
+
93
+ def generate(
94
+ self,
95
+ prompt: str,
96
+ system_prompt: Optional[str] = None,
97
+ temperature: Optional[float] = None,
98
+ max_tokens: Optional[int] = None,
99
+ stop: Optional[List[str]] = None,
100
+ ) -> str:
101
+ """
102
+ Generate text using the OpenAI-compatible API (non-streaming).
103
+
104
+ Args:
105
+ prompt: User prompt
106
+ system_prompt: Optional system prompt
107
+ temperature: Sampling temperature (default: from init/settings)
108
+ max_tokens: Maximum tokens to generate (default: from init/settings)
109
+ stop: Stop sequences
110
+
111
+ Returns:
112
+ Generated text
113
+
114
+ Raises:
115
+ OpenAIClientError: If generation fails
116
+ """
117
+ temperature = temperature if temperature is not None else self._temperature
118
+ max_tokens = max_tokens if max_tokens is not None else self._max_tokens
119
+
120
+ messages = []
121
+ if system_prompt:
122
+ messages.append({"role": "system", "content": system_prompt})
123
+ messages.append({"role": "user", "content": prompt})
124
+
125
+ try:
126
+ response = self._client.chat.completions.create(
127
+ model=self.llm_model,
128
+ messages=messages,
129
+ temperature=temperature,
130
+ max_tokens=max_tokens,
131
+ stop=stop,
132
+ )
133
+ return response.choices[0].message.content or ""
134
+
135
+ except Exception as e:
136
+ logger.error(f"Failed to generate text via OpenAI-compatible API: {e}")
137
+ raise OpenAIClientError(f"Text generation failed: {e}")
138
+
139
+ def stream_generate(
140
+ self,
141
+ prompt: str,
142
+ system_prompt: Optional[str] = None,
143
+ temperature: Optional[float] = None,
144
+ max_tokens: Optional[int] = None,
145
+ stop: Optional[List[str]] = None,
146
+ ) -> Generator[str, None, None]:
147
+ """
148
+ Generate text using the OpenAI-compatible API with streaming.
149
+
150
+ Args:
151
+ prompt: User prompt
152
+ system_prompt: Optional system prompt
153
+ temperature: Sampling temperature (default: from init/settings)
154
+ max_tokens: Maximum tokens to generate (default: from init/settings)
155
+ stop: Stop sequences
156
+
157
+ Yields:
158
+ Generated text chunks
159
+
160
+ Raises:
161
+ OpenAIClientError: If generation fails
162
+ """
163
+ temperature = temperature if temperature is not None else self._temperature
164
+ max_tokens = max_tokens if max_tokens is not None else self._max_tokens
165
+
166
+ messages = []
167
+ if system_prompt:
168
+ messages.append({"role": "system", "content": system_prompt})
169
+ messages.append({"role": "user", "content": prompt})
170
+
171
+ try:
172
+ stream = self._client.chat.completions.create(
173
+ model=self.llm_model,
174
+ messages=messages,
175
+ temperature=temperature,
176
+ max_tokens=max_tokens,
177
+ stop=stop,
178
+ stream=True,
179
+ )
180
+
181
+ for chunk in stream:
182
+ if chunk.choices and chunk.choices[0].delta.content:
183
+ yield chunk.choices[0].delta.content
184
+
185
+ except Exception as e:
186
+ logger.error(f"Failed to stream generate text via OpenAI-compatible API: {e}")
187
+ raise OpenAIClientError(f"Streaming generation failed: {e}")
src/llm/sentence_transformer_client.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sentence Transformers client for reliable embeddings."""
2
+
3
+ import logging
4
+ from typing import List
5
+
6
+ import torch
7
+ from sentence_transformers import SentenceTransformer
8
+ from tqdm import tqdm
9
+
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SentenceTransformerClient:
17
+ """
18
+ Client for generating embeddings using sentence-transformers.
19
+
20
+ This is a drop-in replacement for OllamaClient embeddings with much better
21
+ stability and performance. Uses HuggingFace models directly without any server.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
27
+ device: str = None,
28
+ ):
29
+ """
30
+ Initialize the sentence transformer client.
31
+
32
+ Args:
33
+ model_name: HuggingFace model name
34
+ Options:
35
+ - "sentence-transformers/all-MiniLM-L6-v2" (384 dim, fast, general)
36
+ - "sentence-transformers/all-mpnet-base-v2" (768 dim, better quality)
37
+ - "BAAI/bge-small-en-v1.5" (384 dim, good for retrieval)
38
+ - "BAAI/bge-base-en-v1.5" (768 dim, better quality)
39
+ device: Device to use ('cuda', 'cpu', or None for auto-detect)
40
+ """
41
+ self.model_name = model_name
42
+
43
+ # Auto-detect device if not specified
44
+ if device is None:
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ else:
47
+ self.device = device
48
+
49
+ logger.info(f"Loading embedding model: {model_name}")
50
+ logger.info(f"Using device: {self.device}")
51
+
52
+ # Load model
53
+ self.model = SentenceTransformer(model_name, device=self.device)
54
+
55
+ # Get embedding dimension
56
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
57
+ logger.info(f"Embedding dimension: {self.embedding_dim}")
58
+
59
+ def embed_text(self, text: str, return_zero_on_failure: bool = False) -> List[float]:
60
+ """
61
+ Generate embedding for a single text.
62
+
63
+ Args:
64
+ text: Input text to embed
65
+ return_zero_on_failure: If True, return zero vector on error (for compatibility)
66
+
67
+ Returns:
68
+ Embedding vector as list of floats
69
+ """
70
+ if not text or not text.strip():
71
+ logger.warning("Empty text provided, returning zero vector")
72
+ return [0.0] * self.embedding_dim
73
+
74
+ try:
75
+ embedding = self.model.encode(
76
+ text,
77
+ convert_to_numpy=True,
78
+ show_progress_bar=False,
79
+ )
80
+ return embedding.tolist()
81
+
82
+ except Exception as e:
83
+ logger.error(f"Failed to generate embedding: {e}")
84
+ if return_zero_on_failure:
85
+ return [0.0] * self.embedding_dim
86
+ raise
87
+
88
+ def embed_batch(
89
+ self,
90
+ texts: List[str],
91
+ batch_size: int = 32,
92
+ show_progress: bool = True,
93
+ ) -> List[List[float]]:
94
+ """
95
+ Generate embeddings for multiple texts efficiently.
96
+
97
+ Args:
98
+ texts: List of input texts
99
+ batch_size: Number of texts to process in parallel
100
+ show_progress: Show progress bar
101
+
102
+ Returns:
103
+ List of embedding vectors
104
+ """
105
+ if not texts:
106
+ return []
107
+
108
+ logger.info(f"Generating embeddings for {len(texts)} texts (batch_size={batch_size})")
109
+
110
+ try:
111
+ embeddings = self.model.encode(
112
+ texts,
113
+ batch_size=batch_size,
114
+ show_progress_bar=show_progress,
115
+ convert_to_numpy=True,
116
+ )
117
+
118
+ # Convert to list of lists
119
+ embeddings_list = embeddings.tolist()
120
+
121
+ logger.info(f"Successfully generated {len(embeddings_list)} embeddings")
122
+ return embeddings_list
123
+
124
+ except Exception as e:
125
+ logger.error(f"Batch embedding failed: {e}")
126
+ # Fallback to sequential processing
127
+ logger.warning("Falling back to sequential processing")
128
+ embeddings = []
129
+
130
+ iterator = tqdm(texts, desc="Generating embeddings") if show_progress else texts
131
+ for text in iterator:
132
+ embedding = self.embed_text(text, return_zero_on_failure=True)
133
+ embeddings.append(embedding)
134
+
135
+ failed_count = sum(1 for emb in embeddings if emb == [0.0] * self.embedding_dim)
136
+ if failed_count > 0:
137
+ logger.warning(f"{failed_count} embeddings failed and were assigned zero vectors")
138
+
139
+ return embeddings
140
+
141
+ def get_model_info(self) -> dict:
142
+ """Get information about the loaded model."""
143
+ return {
144
+ "model_name": self.model_name,
145
+ "device": self.device,
146
+ "embedding_dim": self.embedding_dim,
147
+ "max_seq_length": self.model.max_seq_length,
148
+ }
149
+
150
+
151
+ # Convenience function to create client with settings
152
+ def create_embedding_client(
153
+ model_name: str = "sentence-transformers/all-mpnet-base-v2",
154
+ ) -> SentenceTransformerClient:
155
+ """
156
+ Create embedding client with default settings.
157
+
158
+ Using all-mpnet-base-v2 by default as it provides 768-dim embeddings
159
+ (same as nomic-embed-text) with better quality.
160
+ """
161
+ return SentenceTransformerClient(model_name=model_name)
src/processing/__init__.py ADDED
File without changes
src/processing/chunker.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic chunker for processing markdown documents with hierarchical structure."""
2
+
3
+ import hashlib
4
+ import json
5
+ import re
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ from llama_index.core.node_parser import SentenceSplitter
10
+ from pydantic import BaseModel, Field
11
+ from rich.console import Console
12
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
13
+
14
+ from config.settings import settings
15
+
16
+
17
+ class ChunkNode(BaseModel):
18
+ """
19
+ Pydantic model representing a semantic chunk of text.
20
+
21
+ Attributes:
22
+ chunk_id: Unique identifier for the chunk
23
+ content: The actual text content
24
+ parent_section: The section header this chunk belongs to
25
+ document_title: Original article title
26
+ source_url: EyeWiki URL of the source document
27
+ chunk_index: Position of chunk in the document (0-indexed)
28
+ token_count: Approximate number of tokens in the chunk
29
+ metadata: Additional metadata from the source document
30
+ """
31
+
32
+ chunk_id: str = Field(..., description="Unique identifier (hash-based)")
33
+ content: str = Field(..., description="Text content of the chunk")
34
+ parent_section: str = Field(default="", description="Parent section header")
35
+ document_title: str = Field(default="", description="Original document title")
36
+ source_url: str = Field(default="", description="Source URL")
37
+ chunk_index: int = Field(..., ge=0, description="Position in document")
38
+ token_count: int = Field(..., ge=0, description="Approximate token count")
39
+ metadata: Dict = Field(default_factory=dict, description="Additional metadata")
40
+
41
+ def to_dict(self) -> Dict:
42
+ """Convert to dictionary representation."""
43
+ return self.model_dump()
44
+
45
+ @classmethod
46
+ def from_dict(cls, data: Dict) -> "ChunkNode":
47
+ """Create ChunkNode from dictionary."""
48
+ return cls(**data)
49
+
50
+
51
+ class SemanticChunker:
52
+ """
53
+ Hierarchical semantic chunker that respects markdown structure.
54
+
55
+ Features:
56
+ - Splits on ## headers first (sections)
57
+ - Then splits large sections into semantic chunks
58
+ - Preserves parent section context
59
+ - Uses LlamaIndex SentenceSplitter for semantic splitting
60
+ - Configurable chunk sizes and overlap
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ chunk_size: Optional[int] = None,
66
+ chunk_overlap: Optional[int] = None,
67
+ min_chunk_size: int = 100,
68
+ ):
69
+ """
70
+ Initialize the SemanticChunker.
71
+
72
+ Args:
73
+ chunk_size: Target chunk size in tokens (default: from settings)
74
+ chunk_overlap: Overlap between chunks in tokens (default: from settings)
75
+ min_chunk_size: Minimum chunk size to keep (default: 100 tokens)
76
+ """
77
+ self.chunk_size = chunk_size or settings.chunk_size
78
+ self.chunk_overlap = chunk_overlap or settings.chunk_overlap
79
+ self.min_chunk_size = min_chunk_size
80
+
81
+ # Initialize LlamaIndex sentence splitter
82
+ self.sentence_splitter = SentenceSplitter(
83
+ chunk_size=self.chunk_size,
84
+ chunk_overlap=self.chunk_overlap,
85
+ )
86
+
87
+ self.console = Console()
88
+
89
+ def _estimate_tokens(self, text: str) -> int:
90
+ """
91
+ Estimate token count for text.
92
+
93
+ Uses a simple heuristic: ~4 characters per token.
94
+ More accurate than word count for medical/technical text.
95
+
96
+ Args:
97
+ text: Input text
98
+
99
+ Returns:
100
+ Estimated token count
101
+ """
102
+ return len(text) // 4
103
+
104
+ def _generate_chunk_id(self, content: str, chunk_index: int, source_url: str) -> str:
105
+ """
106
+ Generate unique chunk ID using hash.
107
+
108
+ Args:
109
+ content: Chunk content
110
+ chunk_index: Index of chunk
111
+ source_url: Source URL
112
+
113
+ Returns:
114
+ Unique chunk identifier
115
+ """
116
+ # Create a unique string combining content snippet, index, and source
117
+ unique_string = f"{source_url}:{chunk_index}:{content[:100]}"
118
+ return hashlib.sha256(unique_string.encode()).hexdigest()[:16]
119
+
120
+ def _parse_markdown_sections(self, markdown: str) -> List[Tuple[str, str]]:
121
+ """
122
+ Parse markdown into sections based on ## headers.
123
+
124
+ Args:
125
+ markdown: Markdown content
126
+
127
+ Returns:
128
+ List of (header, content) tuples
129
+ """
130
+ sections = []
131
+
132
+ # Split by ## headers (h2)
133
+ # Pattern matches: ## Header or ##Header
134
+ pattern = r"^##\s+(.+?)$"
135
+ lines = markdown.split("\n")
136
+
137
+ current_header = ""
138
+ current_content = []
139
+
140
+ for line in lines:
141
+ match = re.match(pattern, line)
142
+ if match:
143
+ # Save previous section if it has content
144
+ if current_content:
145
+ sections.append((current_header, "\n".join(current_content)))
146
+
147
+ # Start new section
148
+ current_header = match.group(1).strip()
149
+ current_content = [line] # Include the header in content
150
+ else:
151
+ current_content.append(line)
152
+
153
+ # Add final section
154
+ if current_content:
155
+ sections.append((current_header, "\n".join(current_content)))
156
+
157
+ return sections
158
+
159
+ def _split_large_section(self, text: str) -> List[str]:
160
+ """
161
+ Split large section into semantic chunks using LlamaIndex.
162
+
163
+ Args:
164
+ text: Section text to split
165
+
166
+ Returns:
167
+ List of text chunks
168
+ """
169
+ # Use LlamaIndex SentenceSplitter
170
+ chunks = self.sentence_splitter.split_text(text)
171
+ return chunks
172
+
173
+ def _clean_content(self, content: str) -> str:
174
+ """
175
+ Clean chunk content by removing excessive whitespace.
176
+
177
+ Args:
178
+ content: Raw content
179
+
180
+ Returns:
181
+ Cleaned content
182
+ """
183
+ # Remove excessive blank lines (more than 2 consecutive)
184
+ content = re.sub(r"\n{3,}", "\n\n", content)
185
+
186
+ # Remove leading/trailing whitespace
187
+ content = content.strip()
188
+
189
+ return content
190
+
191
+ def chunk_document(
192
+ self,
193
+ markdown_content: str,
194
+ metadata: Dict,
195
+ ) -> List[ChunkNode]:
196
+ """
197
+ Chunk a markdown document with hierarchical structure.
198
+
199
+ Process:
200
+ 1. Parse document into sections by ## headers
201
+ 2. For each section, check if it needs splitting
202
+ 3. If section is small enough, keep as single chunk
203
+ 4. If section is large, split into semantic chunks
204
+ 5. Preserve parent section context in each chunk
205
+
206
+ Args:
207
+ markdown_content: Markdown text content
208
+ metadata: Document metadata (must include 'url' and 'title')
209
+
210
+ Returns:
211
+ List of ChunkNode objects
212
+ """
213
+ chunks = []
214
+ chunk_index = 0
215
+
216
+ # Extract metadata
217
+ source_url = metadata.get("url", "")
218
+ document_title = metadata.get("title", "Untitled")
219
+
220
+ # Parse into sections
221
+ sections = self._parse_markdown_sections(markdown_content)
222
+
223
+ # If no sections found, treat entire document as one section
224
+ if not sections or (len(sections) == 1 and not sections[0][0]):
225
+ sections = [("", markdown_content)]
226
+
227
+ for section_header, section_content in sections:
228
+ # Clean section content
229
+ section_content = self._clean_content(section_content)
230
+
231
+ # Skip empty sections
232
+ if not section_content:
233
+ continue
234
+
235
+ # Estimate tokens in section
236
+ section_tokens = self._estimate_tokens(section_content)
237
+
238
+ # If section is smaller than chunk size, keep as single chunk
239
+ if section_tokens <= self.chunk_size:
240
+ # Only create chunk if it meets minimum size
241
+ if section_tokens >= self.min_chunk_size:
242
+ chunk_id = self._generate_chunk_id(
243
+ section_content, chunk_index, source_url
244
+ )
245
+
246
+ chunk = ChunkNode(
247
+ chunk_id=chunk_id,
248
+ content=section_content,
249
+ parent_section=section_header,
250
+ document_title=document_title,
251
+ source_url=source_url,
252
+ chunk_index=chunk_index,
253
+ token_count=section_tokens,
254
+ metadata=metadata,
255
+ )
256
+ chunks.append(chunk)
257
+ chunk_index += 1
258
+ else:
259
+ # Section is large, split into semantic chunks
260
+ sub_chunks = self._split_large_section(section_content)
261
+
262
+ for sub_chunk_content in sub_chunks:
263
+ sub_chunk_content = self._clean_content(sub_chunk_content)
264
+
265
+ # Skip if empty or too small
266
+ sub_chunk_tokens = self._estimate_tokens(sub_chunk_content)
267
+ if sub_chunk_tokens < self.min_chunk_size:
268
+ continue
269
+
270
+ chunk_id = self._generate_chunk_id(
271
+ sub_chunk_content, chunk_index, source_url
272
+ )
273
+
274
+ chunk = ChunkNode(
275
+ chunk_id=chunk_id,
276
+ content=sub_chunk_content,
277
+ parent_section=section_header,
278
+ document_title=document_title,
279
+ source_url=source_url,
280
+ chunk_index=chunk_index,
281
+ token_count=sub_chunk_tokens,
282
+ metadata=metadata,
283
+ )
284
+ chunks.append(chunk)
285
+ chunk_index += 1
286
+
287
+ return chunks
288
+
289
+ def chunk_directory(
290
+ self,
291
+ input_dir: Path,
292
+ output_dir: Path,
293
+ pattern: str = "*.md",
294
+ ) -> Dict[str, int]:
295
+ """
296
+ Process all markdown files in a directory.
297
+
298
+ For each .md file, looks for corresponding .json metadata file,
299
+ chunks the document, and saves chunks to output directory.
300
+
301
+ Args:
302
+ input_dir: Directory containing markdown files
303
+ output_dir: Directory to save chunked outputs
304
+ pattern: Glob pattern for files to process (default: "*.md")
305
+
306
+ Returns:
307
+ Dictionary with processing statistics
308
+ """
309
+ input_dir = Path(input_dir)
310
+ output_dir = Path(output_dir)
311
+ output_dir.mkdir(parents=True, exist_ok=True)
312
+
313
+ # Find all markdown files
314
+ md_files = list(input_dir.glob(pattern))
315
+
316
+ if not md_files:
317
+ self.console.print(f"[yellow]No files matching '{pattern}' found in {input_dir}[/yellow]")
318
+ return {"processed": 0, "failed": 0, "total_chunks": 0}
319
+
320
+ stats = {
321
+ "processed": 0,
322
+ "failed": 0,
323
+ "skipped": 0,
324
+ "total_chunks": 0,
325
+ "total_tokens": 0,
326
+ }
327
+
328
+ self.console.print(f"\n[bold cyan]Chunking Documents[/bold cyan]")
329
+ self.console.print(f"Input: {input_dir}")
330
+ self.console.print(f"Output: {output_dir}")
331
+ self.console.print(f"Files found: {len(md_files)}\n")
332
+
333
+ with Progress(
334
+ SpinnerColumn(),
335
+ TextColumn("[progress.description]{task.description}"),
336
+ BarColumn(),
337
+ TaskProgressColumn(),
338
+ console=self.console,
339
+ ) as progress:
340
+
341
+ task = progress.add_task(
342
+ "[cyan]Processing...",
343
+ total=len(md_files),
344
+ )
345
+
346
+ for md_file in md_files:
347
+ try:
348
+ # Look for corresponding JSON metadata file
349
+ json_file = md_file.with_suffix(".json")
350
+
351
+ if not json_file.exists():
352
+ self.console.print(
353
+ f"[yellow]Skipping {md_file.name}: No metadata file found[/yellow]"
354
+ )
355
+ stats["skipped"] += 1
356
+ progress.advance(task)
357
+ continue
358
+
359
+ # Read markdown content
360
+ with open(md_file, "r", encoding="utf-8") as f:
361
+ markdown_content = f.read()
362
+
363
+ # Read metadata
364
+ with open(json_file, "r", encoding="utf-8") as f:
365
+ metadata = json.load(f)
366
+
367
+ # Skip if markdown is too small
368
+ if self._estimate_tokens(markdown_content) < self.min_chunk_size:
369
+ self.console.print(
370
+ f"[yellow]Skipping {md_file.name}: Content too small[/yellow]"
371
+ )
372
+ stats["skipped"] += 1
373
+ progress.advance(task)
374
+ continue
375
+
376
+ # Chunk the document
377
+ chunks = self.chunk_document(markdown_content, metadata)
378
+
379
+ if not chunks:
380
+ self.console.print(
381
+ f"[yellow]Skipping {md_file.name}: No chunks created[/yellow]"
382
+ )
383
+ stats["skipped"] += 1
384
+ progress.advance(task)
385
+ continue
386
+
387
+ # Save chunks to output file
388
+ output_file = output_dir / f"{md_file.stem}_chunks.json"
389
+ with open(output_file, "w", encoding="utf-8") as f:
390
+ chunk_dicts = [chunk.to_dict() for chunk in chunks]
391
+ json.dump(chunk_dicts, f, indent=2, ensure_ascii=False)
392
+
393
+ # Update stats
394
+ stats["processed"] += 1
395
+ stats["total_chunks"] += len(chunks)
396
+ stats["total_tokens"] += sum(chunk.token_count for chunk in chunks)
397
+
398
+ progress.update(
399
+ task,
400
+ description=f"[cyan]Processing ({stats['processed']} done, {stats['total_chunks']} chunks): {md_file.name[:40]}...",
401
+ )
402
+ progress.advance(task)
403
+
404
+ except Exception as e:
405
+ self.console.print(f"[red]Error processing {md_file.name}: {e}[/red]")
406
+ stats["failed"] += 1
407
+ progress.advance(task)
408
+
409
+ # Print summary
410
+ self.console.print("\n[bold cyan]Chunking Summary[/bold cyan]")
411
+ self.console.print(f"Files processed: {stats['processed']}")
412
+ self.console.print(f"Files skipped: {stats['skipped']}")
413
+ self.console.print(f"Files failed: {stats['failed']}")
414
+ self.console.print(f"Total chunks created: {stats['total_chunks']}")
415
+ self.console.print(f"Total tokens: {stats['total_tokens']:,}")
416
+
417
+ if stats["processed"] > 0:
418
+ avg_chunks = stats["total_chunks"] / stats["processed"]
419
+ avg_tokens = stats["total_tokens"] / stats["total_chunks"] if stats["total_chunks"] > 0 else 0
420
+ self.console.print(f"Average chunks per document: {avg_chunks:.1f}")
421
+ self.console.print(f"Average tokens per chunk: {avg_tokens:.1f}")
422
+
423
+ return stats
src/processing/metadata_extractor.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Medical metadata extractor for EyeWiki articles."""
2
+
3
+ import re
4
+ from typing import Dict, List, Set
5
+
6
+
7
+ class MetadataExtractor:
8
+ """
9
+ Extract medical metadata from EyeWiki articles.
10
+
11
+ Extracts:
12
+ - Disease names
13
+ - ICD-10 codes
14
+ - Anatomical structures
15
+ - Symptoms
16
+ - Treatments (medications and procedures)
17
+ - Categories
18
+ """
19
+
20
+ # Comprehensive list of eye anatomical structures
21
+ ANATOMICAL_STRUCTURES = {
22
+ # Major structures
23
+ "cornea", "corneal", "sclera", "scleral", "retina", "retinal",
24
+ "lens", "crystalline lens", "iris", "iridial", "pupil", "pupillary",
25
+ "choroid", "choroidal", "vitreous", "vitreous humor",
26
+ "optic nerve", "optic disc", "optic cup",
27
+
28
+ # Anterior segment
29
+ "anterior chamber", "posterior chamber", "anterior segment",
30
+ "trabecular meshwork", "schlemm's canal", "ciliary body", "ciliary muscle",
31
+ "zonules", "zonular", "aqueous humor", "aqueous",
32
+
33
+ # Posterior segment
34
+ "posterior segment", "macula", "macular", "fovea", "foveal",
35
+ "retinal pigment epithelium", "rpe", "photoreceptors",
36
+ "rods", "cones", "ganglion cells",
37
+
38
+ # Retinal layers
39
+ "inner limiting membrane", "nerve fiber layer", "ganglion cell layer",
40
+ "inner plexiform layer", "inner nuclear layer", "outer plexiform layer",
41
+ "outer nuclear layer", "external limiting membrane",
42
+ "photoreceptor layer", "bruch's membrane",
43
+
44
+ # Extraocular
45
+ "eyelid", "eyelids", "conjunctiva", "conjunctival",
46
+ "lacrimal gland", "tear film", "meibomian glands",
47
+ "extraocular muscles", "rectus muscle", "oblique muscle",
48
+ "orbit", "orbital", "optic chiasm",
49
+
50
+ # Blood vessels
51
+ "central retinal artery", "central retinal vein",
52
+ "retinal vessels", "vascular", "vasculature",
53
+ "choriocapillaris",
54
+
55
+ # Angles and spaces
56
+ "angle", "iridocorneal angle", "suprachoroidal space",
57
+ }
58
+
59
+ # Common ophthalmic medications
60
+ MEDICATIONS = {
61
+ # Glaucoma medications
62
+ "latanoprost", "timolol", "dorzolamide", "brinzolamide",
63
+ "brimonidine", "apraclonidine", "bimatoprost", "travoprost",
64
+ "tafluprost", "pilocarpine", "carbachol",
65
+ "acetazolamide", "methazolamide",
66
+
67
+ # Anti-VEGF agents
68
+ "bevacizumab", "ranibizumab", "aflibercept", "brolucizumab",
69
+ "pegaptanib", "faricimab",
70
+
71
+ # Steroids
72
+ "prednisolone", "dexamethasone", "triamcinolone", "fluocinolone",
73
+ "difluprednate", "fluorometholone", "loteprednol",
74
+ "betamethasone", "hydrocortisone",
75
+
76
+ # Antibiotics
77
+ "moxifloxacin", "gatifloxacin", "ciprofloxacin", "ofloxacin",
78
+ "levofloxacin", "tobramycin", "gentamicin", "erythromycin",
79
+ "azithromycin", "bacitracin", "polymyxin", "neomycin",
80
+ "vancomycin", "ceftazidime", "cefazolin",
81
+
82
+ # Antivirals
83
+ "acyclovir", "ganciclovir", "valganciclovir", "valacyclovir",
84
+ "trifluridine", "foscarnet",
85
+
86
+ # Anti-inflammatory
87
+ "ketorolac", "diclofenac", "nepafenac", "bromfenac",
88
+ "cyclosporine", "tacrolimus", "lifitegrast",
89
+
90
+ # Mydriatics/Cycloplegics
91
+ "tropicamide", "cyclopentolate", "atropine", "homatropine",
92
+ "phenylephrine",
93
+
94
+ # Other
95
+ "mitomycin", "5-fluorouracil", "interferon",
96
+ "methotrexate", "chlorambucil",
97
+ }
98
+
99
+ # Common ophthalmic procedures
100
+ PROCEDURES = {
101
+ # Cataract surgery
102
+ "phacoemulsification", "phaco", "cataract extraction",
103
+ "extracapsular cataract extraction", "ecce",
104
+ "intracapsular cataract extraction", "icce",
105
+ "iol implantation", "intraocular lens",
106
+
107
+ # Glaucoma procedures
108
+ "trabeculectomy", "tube shunt", "glaucoma drainage device",
109
+ "ahmed valve", "baerveldt implant", "molteno implant",
110
+ "selective laser trabeculoplasty", "slt", "argon laser trabeculoplasty", "alt",
111
+ "laser peripheral iridotomy", "lpi", "iridotomy",
112
+ "cyclophotocoagulation", "cyclocryotherapy",
113
+ "minimally invasive glaucoma surgery", "migs",
114
+ "trabectome", "istent", "kahook dual blade", "goniotomy",
115
+
116
+ # Retinal procedures
117
+ "vitrectomy", "pars plana vitrectomy", "ppv",
118
+ "membrane peeling", "epiretinal membrane peeling",
119
+ "endolaser", "photocoagulation", "panretinal photocoagulation", "prp",
120
+ "focal laser", "grid laser",
121
+ "pneumatic retinopexy", "scleral buckle",
122
+ "silicone oil", "gas tamponade", "c3f8", "sf6",
123
+
124
+ # Corneal procedures
125
+ "penetrating keratoplasty", "pkp", "corneal transplant",
126
+ "descemet stripping endothelial keratoplasty", "dsek", "dsaek",
127
+ "descemet membrane endothelial keratoplasty", "dmek",
128
+ "deep anterior lamellar keratoplasty", "dalk",
129
+ "phototherapeutic keratectomy", "ptk",
130
+ "corneal crosslinking", "cxl",
131
+
132
+ # Refractive surgery
133
+ "lasik", "prk", "photorefractive keratectomy",
134
+ "smile", "lasek", "refractive lens exchange",
135
+ "phakic iol", "icl",
136
+
137
+ # Injections
138
+ "intravitreal injection", "intravitreal",
139
+ "subtenon injection", "retrobulbar block", "peribulbar block",
140
+
141
+ # Laser procedures
142
+ "yag laser capsulotomy", "laser capsulotomy",
143
+ "laser iridotomy", "laser trabeculoplasty",
144
+
145
+ # Other
146
+ "enucleation", "evisceration", "exenteration",
147
+ "orbital decompression", "ptosis repair", "blepharoplasty",
148
+ "dacryocystorhinostomy", "dcr",
149
+ }
150
+
151
+ # Common ophthalmic symptoms
152
+ SYMPTOMS = {
153
+ # Visual symptoms
154
+ "blurred vision", "blurring", "vision loss", "visual loss",
155
+ "decreased vision", "blindness", "blind spot",
156
+ "photophobia", "light sensitivity", "glare", "halos",
157
+ "diplopia", "double vision", "metamorphopsia", "distortion",
158
+ "scotoma", "floaters", "flashes", "photopsia",
159
+ "night blindness", "nyctalopia", "color vision defect",
160
+ "visual field defect", "peripheral vision loss",
161
+
162
+ # Pain and discomfort
163
+ "eye pain", "ocular pain", "pain", "foreign body sensation",
164
+ "irritation", "burning", "stinging", "grittiness",
165
+ "discomfort", "ache", "headache",
166
+
167
+ # Discharge and tearing
168
+ "discharge", "tearing", "epiphora", "watery eyes",
169
+ "mucus", "crusting", "mattering",
170
+
171
+ # Redness and inflammation
172
+ "redness", "red eye", "injection", "hyperemia",
173
+ "swelling", "edema", "chemosis", "inflammation",
174
+
175
+ # Other
176
+ "itching", "pruritus", "dryness", "dry eye",
177
+ "eye strain", "asthenopia", "fatigue",
178
+ }
179
+
180
+ def __init__(self):
181
+ """Initialize the metadata extractor."""
182
+ # Compile regex patterns for efficiency
183
+ self.icd_pattern = re.compile(
184
+ r'\b[A-Z]\d{2}(?:\.\d{1,2})?\b|' # ICD-10: H40.1, H35.32, etc.
185
+ r'\b[H][0-5]\d(?:\.\d{1,3})?\b' # Ophthalmic ICD-10 (H00-H59)
186
+ )
187
+
188
+ def extract_icd_codes(self, text: str) -> List[str]:
189
+ """
190
+ Extract ICD-10 codes from text using regex.
191
+
192
+ Patterns matched:
193
+ - Standard ICD-10: H40.1, H35.32, etc.
194
+ - Ophthalmic codes: H00-H59 range
195
+ - Generic codes: A00, B99.9, etc.
196
+
197
+ Args:
198
+ text: Input text to search
199
+
200
+ Returns:
201
+ List of unique ICD-10 codes found
202
+ """
203
+ codes = self.icd_pattern.findall(text)
204
+
205
+ # Filter to valid ophthalmic codes (H00-H59) and deduplicate
206
+ valid_codes = set()
207
+ for code in codes:
208
+ # Prioritize H codes (ophthalmic)
209
+ if code.startswith('H'):
210
+ # Validate H00-H59 range
211
+ try:
212
+ main_code = int(code[1:3])
213
+ if 0 <= main_code <= 59:
214
+ valid_codes.add(code)
215
+ except (ValueError, IndexError):
216
+ continue
217
+ else:
218
+ # Keep other valid ICD-10 codes
219
+ valid_codes.add(code)
220
+
221
+ return sorted(list(valid_codes))
222
+
223
+ def extract_anatomical_terms(self, text: str) -> List[str]:
224
+ """
225
+ Extract anatomical structure mentions from text.
226
+
227
+ Uses case-insensitive pattern matching against predefined
228
+ anatomical structure vocabulary.
229
+
230
+ Args:
231
+ text: Input text to search
232
+
233
+ Returns:
234
+ List of unique anatomical structures found
235
+ """
236
+ text_lower = text.lower()
237
+ found_structures = set()
238
+
239
+ for structure in self.ANATOMICAL_STRUCTURES:
240
+ # Use word boundaries to avoid partial matches
241
+ pattern = r'\b' + re.escape(structure) + r's?\b' # Allow plural
242
+ if re.search(pattern, text_lower):
243
+ found_structures.add(structure)
244
+
245
+ return sorted(list(found_structures))
246
+
247
+ def extract_medications(self, text: str) -> List[str]:
248
+ """
249
+ Extract medication mentions from text.
250
+
251
+ Args:
252
+ text: Input text to search
253
+
254
+ Returns:
255
+ List of unique medications found
256
+ """
257
+ text_lower = text.lower()
258
+ found_medications = set()
259
+
260
+ for medication in self.MEDICATIONS:
261
+ # Use word boundaries to avoid partial matches
262
+ pattern = r'\b' + re.escape(medication) + r'\b'
263
+ if re.search(pattern, text_lower):
264
+ found_medications.add(medication)
265
+
266
+ return sorted(list(found_medications))
267
+
268
+ def extract_procedures(self, text: str) -> List[str]:
269
+ """
270
+ Extract procedure mentions from text.
271
+
272
+ Args:
273
+ text: Input text to search
274
+
275
+ Returns:
276
+ List of unique procedures found
277
+ """
278
+ text_lower = text.lower()
279
+ found_procedures = set()
280
+
281
+ for procedure in self.PROCEDURES:
282
+ # Use word boundaries to avoid partial matches
283
+ pattern = r'\b' + re.escape(procedure) + r'\b'
284
+ if re.search(pattern, text_lower):
285
+ found_procedures.add(procedure)
286
+
287
+ return sorted(list(found_procedures))
288
+
289
+ def extract_symptoms(self, text: str) -> List[str]:
290
+ """
291
+ Extract symptom mentions from text.
292
+
293
+ Args:
294
+ text: Input text to search
295
+
296
+ Returns:
297
+ List of unique symptoms found
298
+ """
299
+ text_lower = text.lower()
300
+ found_symptoms = set()
301
+
302
+ for symptom in self.SYMPTOMS:
303
+ # Use word boundaries for multi-word symptoms
304
+ pattern = r'\b' + re.escape(symptom) + r'\b'
305
+ if re.search(pattern, text_lower):
306
+ found_symptoms.add(symptom)
307
+
308
+ return sorted(list(found_symptoms))
309
+
310
+ def extract_disease_name(self, existing_metadata: Dict) -> str:
311
+ """
312
+ Extract primary disease name from metadata.
313
+
314
+ Tries multiple sources:
315
+ 1. Article title
316
+ 2. First category
317
+ 3. URL path
318
+
319
+ Args:
320
+ existing_metadata: Metadata dict with 'title', 'url', 'categories'
321
+
322
+ Returns:
323
+ Primary disease/condition name
324
+ """
325
+ # Try title first
326
+ title = existing_metadata.get("title", "")
327
+ if title:
328
+ # Clean title - remove common prefixes
329
+ cleaned = re.sub(r'^(Disease|Condition|Syndrome):\s*', '', title, flags=re.IGNORECASE)
330
+ return cleaned.strip()
331
+
332
+ # Try first category
333
+ categories = existing_metadata.get("categories", [])
334
+ if categories and len(categories) > 0:
335
+ return categories[0].strip()
336
+
337
+ # Try URL path as fallback
338
+ url = existing_metadata.get("url", "")
339
+ if url:
340
+ # Extract last part of URL path
341
+ match = re.search(r'/([^/]+)$', url)
342
+ if match:
343
+ # Replace underscores with spaces
344
+ name = match.group(1).replace('_', ' ')
345
+ return name.strip()
346
+
347
+ return "Unknown"
348
+
349
+ def extract(self, content: str, existing_metadata: Dict) -> Dict:
350
+ """
351
+ Extract comprehensive medical metadata from article content.
352
+
353
+ Args:
354
+ content: Article text content (markdown)
355
+ existing_metadata: Existing metadata dict with basic info
356
+
357
+ Returns:
358
+ Enhanced metadata dictionary with medical information
359
+ """
360
+ # Start with existing metadata
361
+ enhanced_metadata = existing_metadata.copy()
362
+
363
+ # Extract disease name
364
+ enhanced_metadata["disease_name"] = self.extract_disease_name(existing_metadata)
365
+
366
+ # Extract ICD codes
367
+ enhanced_metadata["icd_codes"] = self.extract_icd_codes(content)
368
+
369
+ # Extract anatomical structures
370
+ enhanced_metadata["anatomical_structures"] = self.extract_anatomical_terms(content)
371
+
372
+ # Extract symptoms
373
+ enhanced_metadata["symptoms"] = self.extract_symptoms(content)
374
+
375
+ # Extract treatments
376
+ medications = self.extract_medications(content)
377
+ procedures = self.extract_procedures(content)
378
+ enhanced_metadata["treatments"] = {
379
+ "medications": medications,
380
+ "procedures": procedures,
381
+ }
382
+
383
+ # Preserve existing categories
384
+ if "categories" not in enhanced_metadata:
385
+ enhanced_metadata["categories"] = []
386
+
387
+ # Add extraction statistics
388
+ enhanced_metadata["extraction_stats"] = {
389
+ "icd_codes_found": len(enhanced_metadata["icd_codes"]),
390
+ "anatomical_terms_found": len(enhanced_metadata["anatomical_structures"]),
391
+ "symptoms_found": len(enhanced_metadata["symptoms"]),
392
+ "medications_found": len(medications),
393
+ "procedures_found": len(procedures),
394
+ }
395
+
396
+ return enhanced_metadata
397
+
398
+ def extract_batch(self, documents: List[Dict]) -> List[Dict]:
399
+ """
400
+ Extract metadata from multiple documents.
401
+
402
+ Args:
403
+ documents: List of dicts with 'content' and 'metadata' keys
404
+
405
+ Returns:
406
+ List of enhanced metadata dictionaries
407
+ """
408
+ results = []
409
+
410
+ for doc in documents:
411
+ content = doc.get("content", "")
412
+ metadata = doc.get("metadata", {})
413
+
414
+ enhanced = self.extract(content, metadata)
415
+ results.append(enhanced)
416
+
417
+ return results
418
+
419
+ def get_anatomical_vocabulary(self) -> Set[str]:
420
+ """Get the full anatomical vocabulary set."""
421
+ return self.ANATOMICAL_STRUCTURES.copy()
422
+
423
+ def get_medication_vocabulary(self) -> Set[str]:
424
+ """Get the full medication vocabulary set."""
425
+ return self.MEDICATIONS.copy()
426
+
427
+ def get_procedure_vocabulary(self) -> Set[str]:
428
+ """Get the full procedure vocabulary set."""
429
+ return self.PROCEDURES.copy()
430
+
431
+ def get_symptom_vocabulary(self) -> Set[str]:
432
+ """Get the full symptom vocabulary set."""
433
+ return self.SYMPTOMS.copy()
src/rag/__init__.py ADDED
File without changes
src/rag/query_engine.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query engine orchestrating the full RAG pipeline."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Generator, List, Optional
6
+
7
+ from pydantic import BaseModel, Field
8
+ from rich.console import Console
9
+
10
+ from src.rag.retriever import HybridRetriever, RetrievalResult
11
+ from src.rag.reranker import CrossEncoderReranker
12
+ from src.llm.llm_client import LLMClient
13
+
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ # Medical disclaimer (default)
21
+ MEDICAL_DISCLAIMER = (
22
+ "**Medical Disclaimer:** This information is sourced from EyeWiki, a resource of the "
23
+ "American Academy of Ophthalmology (AAO). It is not a substitute for professional "
24
+ "medical advice, diagnosis, or treatment. AI systems can make errors. Always consult "
25
+ "with a qualified ophthalmologist or eye care professional for medical concerns and "
26
+ "verify any critical information with authoritative sources."
27
+ )
28
+
29
+ # Default system prompt
30
+ DEFAULT_SYSTEM_PROMPT = """You are an expert ophthalmology assistant with comprehensive knowledge of eye diseases, treatments, and procedures.
31
+
32
+ Your role is to provide accurate, evidence-based information from the EyeWiki medical knowledge base.
33
+
34
+ Guidelines:
35
+ - Base your answers strictly on the provided context
36
+ - Cite sources using [Source: Title] format when referencing information
37
+ - If the context doesn't contain enough information, say so explicitly
38
+ - Use clear, precise medical terminology while remaining accessible
39
+ - Structure your responses logically with appropriate sections
40
+ - For treatment information, emphasize the importance of professional consultation
41
+ - Always maintain professional medical standards"""
42
+
43
+
44
+ class SourceInfo(BaseModel):
45
+ """
46
+ Information about a source document.
47
+
48
+ Attributes:
49
+ title: Document title
50
+ url: Source URL
51
+ section: Section within document
52
+ relevance_score: Relevance score (cross-encoder scores, unbounded)
53
+ """
54
+
55
+ title: str = Field(..., description="Document title")
56
+ url: str = Field(..., description="Source URL")
57
+ section: str = Field(default="", description="Section within document")
58
+ relevance_score: float = Field(..., description="Relevance score (cross-encoder, unbounded)")
59
+
60
+
61
+ class QueryResponse(BaseModel):
62
+ """
63
+ Response from query engine.
64
+
65
+ Attributes:
66
+ answer: Generated answer text
67
+ sources: List of source documents used
68
+ confidence: Confidence score based on retrieval
69
+ disclaimer: Medical disclaimer text
70
+ query: Original query
71
+ """
72
+
73
+ answer: str = Field(..., description="Generated answer")
74
+ sources: List[SourceInfo] = Field(default_factory=list, description="Source documents")
75
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
76
+ disclaimer: str = Field(default=MEDICAL_DISCLAIMER, description="Medical disclaimer")
77
+ query: str = Field(..., description="Original query")
78
+
79
+
80
+ class EyeWikiQueryEngine:
81
+ """
82
+ Query engine orchestrating the full RAG pipeline.
83
+
84
+ Pipeline:
85
+ 1. Query � Retriever (hybrid search)
86
+ 2. Results � Reranker (cross-encoder)
87
+ 3. Top results � Context assembly
88
+ 4. Context + Query � LLM generation
89
+ 5. Response + Sources + Disclaimer
90
+
91
+ Features:
92
+ - Two-stage retrieval (fast + precise)
93
+ - Context assembly with token limits
94
+ - Source diversity prioritization
95
+ - Medical disclaimer inclusion
96
+ - Streaming and non-streaming modes
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ retriever: HybridRetriever,
102
+ reranker: CrossEncoderReranker,
103
+ llm_client: LLMClient,
104
+ system_prompt_path: Optional[Path] = None,
105
+ query_prompt_path: Optional[Path] = None,
106
+ disclaimer_path: Optional[Path] = None,
107
+ max_context_tokens: int = 4000,
108
+ retrieval_k: int = 20,
109
+ rerank_k: int = 5,
110
+ ):
111
+ """
112
+ Initialize query engine.
113
+
114
+ Args:
115
+ retriever: HybridRetriever instance
116
+ reranker: CrossEncoderReranker instance
117
+ llm_client: LLMClient instance (OllamaClient or OpenAIClient)
118
+ system_prompt_path: Path to custom system prompt file
119
+ query_prompt_path: Path to custom query prompt template
120
+ disclaimer_path: Path to custom medical disclaimer file
121
+ max_context_tokens: Maximum tokens for context
122
+ retrieval_k: Number of documents to retrieve initially
123
+ rerank_k: Number of documents after reranking
124
+ """
125
+ self.retriever = retriever
126
+ self.reranker = reranker
127
+ self.llm_client = llm_client
128
+ self.max_context_tokens = max_context_tokens
129
+ self.retrieval_k = retrieval_k
130
+ self.rerank_k = rerank_k
131
+
132
+ self.console = Console()
133
+
134
+ # Load system prompt
135
+ if system_prompt_path and system_prompt_path.exists():
136
+ with open(system_prompt_path, "r") as f:
137
+ self.system_prompt = f.read()
138
+ logger.info(f"Loaded system prompt from {system_prompt_path}")
139
+ else:
140
+ self.system_prompt = DEFAULT_SYSTEM_PROMPT
141
+ logger.info("Using default system prompt")
142
+
143
+ # Load query prompt template
144
+ if query_prompt_path and query_prompt_path.exists():
145
+ with open(query_prompt_path, "r") as f:
146
+ self.query_prompt_template = f.read()
147
+ logger.info(f"Loaded query prompt from {query_prompt_path}")
148
+ else:
149
+ self.query_prompt_template = None
150
+ logger.info("Using inline query prompt formatting")
151
+
152
+ # Load medical disclaimer
153
+ if disclaimer_path and disclaimer_path.exists():
154
+ with open(disclaimer_path, "r") as f:
155
+ self.medical_disclaimer = f.read().strip()
156
+ logger.info(f"Loaded medical disclaimer from {disclaimer_path}")
157
+ else:
158
+ self.medical_disclaimer = MEDICAL_DISCLAIMER
159
+ logger.info("Using default medical disclaimer")
160
+
161
+ def _estimate_tokens(self, text: str) -> int:
162
+ """
163
+ Estimate token count for text.
164
+
165
+ Uses simple heuristic: ~4 characters per token.
166
+
167
+ Args:
168
+ text: Input text
169
+
170
+ Returns:
171
+ Estimated token count
172
+ """
173
+ return len(text) // 4
174
+
175
+ def _prioritize_diverse_sources(
176
+ self, results: List[RetrievalResult]
177
+ ) -> List[RetrievalResult]:
178
+ """
179
+ Prioritize results from diverse sources.
180
+
181
+ Ensures we don't just get multiple chunks from the same article.
182
+
183
+ Args:
184
+ results: Sorted list of retrieval results
185
+
186
+ Returns:
187
+ Reordered list prioritizing diversity
188
+ """
189
+ seen_documents = set()
190
+ diverse_results = []
191
+ remaining_results = []
192
+
193
+ # First pass: one chunk per document
194
+ for result in results:
195
+ doc_title = result.document_title
196
+ if doc_title not in seen_documents:
197
+ diverse_results.append(result)
198
+ seen_documents.add(doc_title)
199
+ else:
200
+ remaining_results.append(result)
201
+
202
+ # Second pass: add remaining high-scoring chunks
203
+ diverse_results.extend(remaining_results)
204
+
205
+ return diverse_results
206
+
207
+ def _assemble_context(self, results: List[RetrievalResult]) -> str:
208
+ """
209
+ Assemble context from retrieval results.
210
+
211
+ Features:
212
+ - Formats with section headers
213
+ - Limits to max_context_tokens
214
+ - Prioritizes diverse sources
215
+ - Includes source citations
216
+
217
+ Args:
218
+ results: List of retrieval results
219
+
220
+ Returns:
221
+ Formatted context string
222
+ """
223
+ if not results:
224
+ return ""
225
+
226
+ # Prioritize diversity
227
+ diverse_results = self._prioritize_diverse_sources(results)
228
+
229
+ context_parts = []
230
+ total_tokens = 0
231
+
232
+ for i, result in enumerate(diverse_results, 1):
233
+ # Format context chunk
234
+ chunk_text = f"[Source {i}: {result.document_title}"
235
+ if result.section:
236
+ chunk_text += f" - {result.section}"
237
+ chunk_text += f"]\n{result.content}\n"
238
+
239
+ # Check token limit
240
+ chunk_tokens = self._estimate_tokens(chunk_text)
241
+
242
+ if total_tokens + chunk_tokens > self.max_context_tokens:
243
+ logger.info(
244
+ f"Reached context token limit ({self.max_context_tokens}), "
245
+ f"using {i-1} of {len(diverse_results)} chunks"
246
+ )
247
+ break
248
+
249
+ context_parts.append(chunk_text)
250
+ total_tokens += chunk_tokens
251
+
252
+ context = "\n".join(context_parts)
253
+
254
+ logger.info(
255
+ f"Assembled context: {len(context_parts)} chunks, "
256
+ f"~{total_tokens} tokens"
257
+ )
258
+
259
+ return context
260
+
261
+ def _extract_sources(self, results: List[RetrievalResult]) -> List[SourceInfo]:
262
+ """
263
+ Extract source information from results.
264
+
265
+ Args:
266
+ results: List of retrieval results
267
+
268
+ Returns:
269
+ List of SourceInfo objects
270
+ """
271
+ sources = []
272
+ seen_titles = set()
273
+
274
+ for result in results:
275
+ # Deduplicate by title
276
+ if result.document_title not in seen_titles:
277
+ source = SourceInfo(
278
+ title=result.document_title,
279
+ url=result.source_url,
280
+ section=result.section,
281
+ relevance_score=result.score,
282
+ )
283
+ sources.append(source)
284
+ seen_titles.add(result.document_title)
285
+
286
+ return sources
287
+
288
+ def _calculate_confidence(self, results: List[RetrievalResult]) -> float:
289
+ """
290
+ Calculate confidence score based on retrieval scores.
291
+
292
+ Uses average of top reranked scores.
293
+
294
+ Args:
295
+ results: List of retrieval results
296
+
297
+ Returns:
298
+ Confidence score (0-1)
299
+ """
300
+ if not results:
301
+ return 0.0
302
+
303
+ # Use average of top scores
304
+ top_scores = [r.score for r in results[:self.rerank_k]]
305
+
306
+ if not top_scores:
307
+ return 0.0
308
+
309
+ avg_score = sum(top_scores) / len(top_scores)
310
+
311
+ # Normalize to 0-1 range (assuming scores are roughly 0-1)
312
+ confidence = min(max(avg_score, 0.0), 1.0)
313
+
314
+ return confidence
315
+
316
+ def _format_prompt(self, query: str, context: str) -> str:
317
+ """
318
+ Format the prompt for LLM.
319
+
320
+ Uses query_prompt_template if loaded, otherwise uses default format.
321
+
322
+ Args:
323
+ query: User query
324
+ context: Assembled context
325
+
326
+ Returns:
327
+ Formatted prompt
328
+ """
329
+ if self.query_prompt_template:
330
+ # Use template with placeholders
331
+ prompt = self.query_prompt_template.format(
332
+ context=context,
333
+ question=query
334
+ )
335
+ else:
336
+ # Default inline formatting
337
+ prompt = f"""Context from EyeWiki medical knowledge base:
338
+
339
+ {context}
340
+
341
+ ---
342
+
343
+ Question: {query}
344
+
345
+ Please provide a comprehensive answer based on the context above. Structure your response clearly and cite sources where appropriate."""
346
+
347
+ return prompt
348
+
349
+ def query(
350
+ self,
351
+ question: str,
352
+ include_sources: bool = True,
353
+ filters: Optional[dict] = None,
354
+ ) -> QueryResponse:
355
+ """
356
+ Query the engine and get response.
357
+
358
+ Pipeline:
359
+ 1. Retrieve documents (retrieval_k)
360
+ 2. Rerank with cross-encoder (rerank_k)
361
+ 3. Assemble context with token limits
362
+ 4. Generate answer with LLM
363
+ 5. Return response with sources and disclaimer
364
+
365
+ Args:
366
+ question: User question
367
+ include_sources: Include source information in response
368
+ filters: Optional metadata filters for retrieval
369
+
370
+ Returns:
371
+ QueryResponse object
372
+ """
373
+ logger.info(f"Processing query: '{question}'")
374
+
375
+ # Step 1: Retrieve documents
376
+ logger.info(f"Retrieving top {self.retrieval_k} candidates...")
377
+ retrieval_results = self.retriever.retrieve(
378
+ query=question,
379
+ top_k=self.retrieval_k,
380
+ filters=filters,
381
+ )
382
+
383
+ if not retrieval_results:
384
+ logger.warning("No results found for query")
385
+ return QueryResponse(
386
+ answer="I couldn't find relevant information to answer this question in the EyeWiki knowledge base.",
387
+ sources=[],
388
+ confidence=0.0,
389
+ query=question,
390
+ )
391
+
392
+ # Step 2: Rerank for precision
393
+ logger.info(f"Reranking to top {self.rerank_k}...")
394
+ reranked_results = self.reranker.rerank(
395
+ query=question,
396
+ documents=retrieval_results,
397
+ top_k=self.rerank_k,
398
+ )
399
+
400
+ # Step 3: Assemble context
401
+ context = self._assemble_context(reranked_results)
402
+
403
+ # Step 4: Generate answer
404
+ logger.info("Generating answer with LLM...")
405
+ prompt = self._format_prompt(question, context)
406
+
407
+ try:
408
+ answer = self.llm_client.generate(
409
+ prompt=prompt,
410
+ system_prompt=self.system_prompt,
411
+ temperature=0.1, # Low temperature for factual responses
412
+ )
413
+ except Exception as e:
414
+ logger.error(f"Error generating answer: {e}")
415
+ answer = (
416
+ "I encountered an error while generating the answer. "
417
+ "Please try again or rephrase your question."
418
+ )
419
+
420
+ # Step 5: Extract sources
421
+ sources = self._extract_sources(reranked_results) if include_sources else []
422
+
423
+ # Step 6: Calculate confidence
424
+ confidence = self._calculate_confidence(reranked_results)
425
+
426
+ # Create response
427
+ response = QueryResponse(
428
+ answer=answer,
429
+ sources=sources,
430
+ confidence=confidence,
431
+ query=question,
432
+ )
433
+
434
+ logger.info(
435
+ f"Query complete: {len(sources)} sources, "
436
+ f"confidence: {confidence:.2f}"
437
+ )
438
+
439
+ return response
440
+
441
+ def stream_query(
442
+ self,
443
+ question: str,
444
+ filters: Optional[dict] = None,
445
+ ) -> Generator[str, None, None]:
446
+ """
447
+ Query with streaming response.
448
+
449
+ Yields answer chunks in real-time.
450
+
451
+ Args:
452
+ question: User question
453
+ filters: Optional metadata filters
454
+
455
+ Yields:
456
+ Answer chunks as they are generated
457
+ """
458
+ logger.info(f"Processing streaming query: '{question}'")
459
+
460
+ # Retrieval and reranking (same as query())
461
+ retrieval_results = self.retriever.retrieve(
462
+ query=question,
463
+ top_k=self.retrieval_k,
464
+ filters=filters,
465
+ )
466
+
467
+ if not retrieval_results:
468
+ yield "I couldn't find relevant information to answer this question."
469
+ return
470
+
471
+ reranked_results = self.reranker.rerank(
472
+ query=question,
473
+ documents=retrieval_results,
474
+ top_k=self.rerank_k,
475
+ )
476
+
477
+ # Assemble context
478
+ context = self._assemble_context(reranked_results)
479
+
480
+ # Generate prompt
481
+ prompt = self._format_prompt(question, context)
482
+
483
+ # Stream generation
484
+ try:
485
+ for chunk in self.llm_client.stream_generate(
486
+ prompt=prompt,
487
+ system_prompt=self.system_prompt,
488
+ temperature=0.1,
489
+ ):
490
+ yield chunk
491
+
492
+ except Exception as e:
493
+ logger.error(f"Error in streaming generation: {e}")
494
+ yield "\n\n[Error: Failed to generate response]"
495
+
496
+ def batch_query(
497
+ self,
498
+ questions: List[str],
499
+ include_sources: bool = True,
500
+ ) -> List[QueryResponse]:
501
+ """
502
+ Process multiple queries.
503
+
504
+ Args:
505
+ questions: List of questions
506
+ include_sources: Include sources in responses
507
+
508
+ Returns:
509
+ List of QueryResponse objects
510
+ """
511
+ responses = []
512
+
513
+ for question in questions:
514
+ response = self.query(question, include_sources=include_sources)
515
+ responses.append(response)
516
+
517
+ return responses
518
+
519
+ def get_pipeline_info(self) -> dict:
520
+ """
521
+ Get information about the pipeline configuration.
522
+
523
+ Returns:
524
+ Dictionary with pipeline settings
525
+ """
526
+ return {
527
+ "retrieval_k": self.retrieval_k,
528
+ "rerank_k": self.rerank_k,
529
+ "max_context_tokens": self.max_context_tokens,
530
+ "retriever_config": {
531
+ "dense_weight": self.retriever.dense_weight,
532
+ "sparse_weight": self.retriever.sparse_weight,
533
+ "term_expansion": self.retriever.enable_term_expansion,
534
+ },
535
+ "reranker_info": self.reranker.get_model_info(),
536
+ "llm_model": self.llm_client.llm_model,
537
+ }
src/rag/reranker.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cross-encoder reranker for improved retrieval relevance."""
2
+
3
+ import logging
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ from sentence_transformers import CrossEncoder
8
+ from rich.console import Console
9
+
10
+ from src.rag.retriever import RetrievalResult
11
+
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class CrossEncoderReranker:
19
+ """
20
+ Reranker using cross-encoder models for improved relevance.
21
+
22
+ Features:
23
+ - Uses sentence-transformers cross-encoder
24
+ - Automatic GPU/CPU detection
25
+ - Model caching for efficiency
26
+ - Preserves original retrieval scores
27
+ - Batch processing for speed
28
+ """
29
+
30
+ # Model cache to avoid reloading
31
+ _model_cache = {}
32
+
33
+ # Available models
34
+ AVAILABLE_MODELS = {
35
+ "ms-marco-mini": "cross-encoder/ms-marco-MiniLM-L-6-v2",
36
+ "ms-marco-base": "cross-encoder/ms-marco-MiniLM-L-12-v2",
37
+ "medicalai": "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb", # Medical domain
38
+ }
39
+
40
+ def __init__(
41
+ self,
42
+ model_name: str = "ms-marco-mini",
43
+ device: Optional[str] = None,
44
+ max_length: int = 512,
45
+ ):
46
+ """
47
+ Initialize cross-encoder reranker.
48
+
49
+ Args:
50
+ model_name: Model name (key from AVAILABLE_MODELS) or full path
51
+ device: Device to use ('cuda', 'cpu', or None for auto-detect)
52
+ max_length: Maximum sequence length
53
+ """
54
+ # Resolve model name
55
+ if model_name in self.AVAILABLE_MODELS:
56
+ self.model_path = self.AVAILABLE_MODELS[model_name]
57
+ self.model_name = model_name
58
+ else:
59
+ self.model_path = model_name
60
+ self.model_name = "custom"
61
+
62
+ self.max_length = max_length
63
+ self.console = Console()
64
+
65
+ # Detect device
66
+ if device is None:
67
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ else:
69
+ self.device = device
70
+
71
+ # Load model
72
+ self._load_model()
73
+
74
+ def _load_model(self):
75
+ """Load cross-encoder model with caching."""
76
+ cache_key = f"{self.model_path}_{self.device}"
77
+
78
+ # Check cache
79
+ if cache_key in self._model_cache:
80
+ self.model = self._model_cache[cache_key]
81
+ logger.info(f"Loaded reranker model from cache: {self.model_name}")
82
+ return
83
+
84
+ # Load model
85
+ try:
86
+ self.console.print(f"[cyan]Loading reranker model: {self.model_name}...[/cyan]")
87
+
88
+ self.model = CrossEncoder(
89
+ self.model_path,
90
+ max_length=self.max_length,
91
+ device=self.device,
92
+ )
93
+
94
+ # Cache model
95
+ self._model_cache[cache_key] = self.model
96
+
97
+ device_info = f"GPU ({torch.cuda.get_device_name(0)})" if self.device == "cuda" else "CPU"
98
+ self.console.print(
99
+ f"[green][/green] Loaded reranker model: {self.model_name} on {device_info}"
100
+ )
101
+ logger.info(
102
+ f"Loaded cross-encoder model: {self.model_path} on {self.device}"
103
+ )
104
+
105
+ except Exception as e:
106
+ logger.error(f"Failed to load reranker model: {e}")
107
+ self.console.print(f"[red][/red] Failed to load reranker model: {e}")
108
+ raise
109
+
110
+ def score_pairs(self, query: str, documents: List[str]) -> List[float]:
111
+ """
112
+ Score query-document pairs.
113
+
114
+ Args:
115
+ query: Search query
116
+ documents: List of document texts
117
+
118
+ Returns:
119
+ List of relevance scores (higher is better)
120
+ """
121
+ if not documents:
122
+ return []
123
+
124
+ # Create query-document pairs
125
+ pairs = [[query, doc] for doc in documents]
126
+
127
+ try:
128
+ # Get scores from cross-encoder
129
+ scores = self.model.predict(pairs, convert_to_numpy=True)
130
+
131
+ # Convert to Python list
132
+ scores = scores.tolist()
133
+
134
+ logger.debug(f"Scored {len(documents)} documents")
135
+
136
+ return scores
137
+
138
+ except Exception as e:
139
+ logger.error(f"Error scoring pairs: {e}")
140
+ # Return zeros if scoring fails
141
+ return [0.0] * len(documents)
142
+
143
+ def rerank(
144
+ self,
145
+ query: str,
146
+ documents: List[RetrievalResult],
147
+ top_k: Optional[int] = None,
148
+ ) -> List[RetrievalResult]:
149
+ """
150
+ Rerank documents using cross-encoder.
151
+
152
+ Args:
153
+ query: Search query
154
+ documents: List of RetrievalResult objects from retriever
155
+ top_k: Number of top results to return (None for all)
156
+
157
+ Returns:
158
+ List of RetrievalResult objects sorted by reranker score
159
+ """
160
+ if not documents:
161
+ logger.warning("No documents to rerank")
162
+ return []
163
+
164
+ # Extract document texts
165
+ doc_texts = [doc.content for doc in documents]
166
+
167
+ # Score all documents
168
+ logger.info(f"Reranking {len(documents)} documents for query: '{query[:50]}...'")
169
+ rerank_scores = self.score_pairs(query, doc_texts)
170
+
171
+ # Create new results with updated scores
172
+ reranked_results = []
173
+ for doc, rerank_score in zip(documents, rerank_scores):
174
+ # Create a new RetrievalResult with updated score
175
+ # Store original retrieval score in metadata
176
+ updated_metadata = doc.metadata.copy()
177
+ updated_metadata["original_retrieval_score"] = doc.score
178
+ updated_metadata["reranker_score"] = float(rerank_score)
179
+
180
+ reranked_doc = RetrievalResult(
181
+ content=doc.content,
182
+ metadata=updated_metadata,
183
+ score=float(rerank_score), # Use reranker score as primary score
184
+ source_url=doc.source_url,
185
+ section=doc.section,
186
+ chunk_id=doc.chunk_id,
187
+ document_title=doc.document_title,
188
+ )
189
+
190
+ reranked_results.append(reranked_doc)
191
+
192
+ # Sort by reranker score (descending)
193
+ reranked_results.sort(key=lambda x: x.score, reverse=True)
194
+
195
+ # Log score changes
196
+ if reranked_results:
197
+ logger.info(
198
+ f"Reranking complete. Top result score: {reranked_results[0].score:.4f} "
199
+ f"(original: {reranked_results[0].metadata.get('original_retrieval_score', 0):.4f})"
200
+ )
201
+
202
+ # Return top_k if specified
203
+ if top_k is not None:
204
+ return reranked_results[:top_k]
205
+
206
+ return reranked_results
207
+
208
+ def rerank_with_comparison(
209
+ self,
210
+ query: str,
211
+ documents: List[RetrievalResult],
212
+ top_k: Optional[int] = None,
213
+ ) -> List[Tuple[RetrievalResult, dict]]:
214
+ """
215
+ Rerank with detailed comparison of scores.
216
+
217
+ Args:
218
+ query: Search query
219
+ documents: List of RetrievalResult objects
220
+ top_k: Number of top results to return
221
+
222
+ Returns:
223
+ List of (RetrievalResult, comparison_dict) tuples
224
+ where comparison_dict contains:
225
+ - original_score: Original retrieval score
226
+ - reranker_score: Cross-encoder score
227
+ - score_change: Difference (reranker - original)
228
+ - rank_change: Change in ranking position
229
+ """
230
+ if not documents:
231
+ return []
232
+
233
+ # Store original rankings
234
+ original_rankings = {doc.chunk_id: idx for idx, doc in enumerate(documents)}
235
+
236
+ # Rerank documents
237
+ reranked_docs = self.rerank(query, documents, top_k=None)
238
+
239
+ # Create comparison results
240
+ results_with_comparison = []
241
+
242
+ for new_rank, doc in enumerate(reranked_docs):
243
+ original_rank = original_rankings[doc.chunk_id]
244
+ original_score = doc.metadata.get("original_retrieval_score", 0.0)
245
+ reranker_score = doc.score
246
+
247
+ comparison = {
248
+ "original_score": original_score,
249
+ "reranker_score": reranker_score,
250
+ "score_change": reranker_score - original_score,
251
+ "original_rank": original_rank,
252
+ "new_rank": new_rank,
253
+ "rank_change": original_rank - new_rank, # Positive = moved up
254
+ }
255
+
256
+ results_with_comparison.append((doc, comparison))
257
+
258
+ # Return top_k if specified
259
+ if top_k is not None:
260
+ return results_with_comparison[:top_k]
261
+
262
+ return results_with_comparison
263
+
264
+ def get_model_info(self) -> dict:
265
+ """
266
+ Get information about the loaded model.
267
+
268
+ Returns:
269
+ Dictionary with model information
270
+ """
271
+ return {
272
+ "model_name": self.model_name,
273
+ "model_path": self.model_path,
274
+ "device": self.device,
275
+ "max_length": self.max_length,
276
+ "gpu_available": torch.cuda.is_available(),
277
+ "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
278
+ }
279
+
280
+ def clear_cache(self):
281
+ """Clear the model cache."""
282
+ self._model_cache.clear()
283
+ logger.info("Cleared model cache")
284
+
285
+ @classmethod
286
+ def get_available_models(cls) -> dict:
287
+ """
288
+ Get dictionary of available models.
289
+
290
+ Returns:
291
+ Dictionary mapping model names to paths
292
+ """
293
+ return cls.AVAILABLE_MODELS.copy()
src/rag/retriever.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hybrid retriever combining dense and sparse search for optimal retrieval."""
2
+
3
+ import logging
4
+ import re
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ from pydantic import BaseModel, Field
8
+ from rich.console import Console
9
+
10
+ from src.vectorstore.qdrant_store import QdrantStoreManager, SearchResult
11
+ from src.llm.sentence_transformer_client import SentenceTransformerClient
12
+
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class RetrievalResult(BaseModel):
20
+ """
21
+ Pydantic model for retrieval results.
22
+
23
+ Attributes:
24
+ content: Retrieved text content
25
+ metadata: Document metadata (disease, ICD codes, etc.)
26
+ score: Relevance score
27
+ source_url: EyeWiki source URL
28
+ section: Parent section header
29
+ chunk_id: Unique chunk identifier
30
+ document_title: Article title
31
+ """
32
+
33
+ content: str = Field(..., description="Retrieved text content")
34
+ metadata: Dict = Field(default_factory=dict, description="Document metadata")
35
+ score: float = Field(..., description="Relevance score (can be negative for cross-encoder)")
36
+ source_url: str = Field(default="", description="EyeWiki source URL")
37
+ section: str = Field(default="", description="Parent section header")
38
+ chunk_id: str = Field(default="", description="Unique chunk identifier")
39
+ document_title: str = Field(default="", description="Article title")
40
+
41
+ @classmethod
42
+ def from_search_result(cls, result: SearchResult) -> "RetrievalResult":
43
+ """
44
+ Convert SearchResult to RetrievalResult.
45
+
46
+ Args:
47
+ result: SearchResult from Qdrant
48
+
49
+ Returns:
50
+ RetrievalResult instance
51
+ """
52
+ return cls(
53
+ content=result.content,
54
+ metadata=result.metadata,
55
+ score=result.score,
56
+ source_url=result.source_url,
57
+ section=result.parent_section,
58
+ chunk_id=result.chunk_id,
59
+ document_title=result.document_title,
60
+ )
61
+
62
+
63
+ class HybridRetriever:
64
+ """
65
+ Hybrid retriever combining dense (semantic) and sparse (BM25) search.
66
+
67
+ Features:
68
+ - Dense vector search via embeddings (default weight: 0.7)
69
+ - Sparse BM25 keyword search (default weight: 0.3)
70
+ - Configurable fusion weights
71
+ - Query preprocessing
72
+ - Medical term expansion
73
+ - Metadata filtering
74
+ """
75
+
76
+ # Medical term synonyms and abbreviations for query expansion
77
+ MEDICAL_TERM_EXPANSIONS = {
78
+ # Common abbreviations
79
+ "iop": ["intraocular pressure", "iop"],
80
+ "amd": ["age-related macular degeneration", "amd"],
81
+ "armd": ["age-related macular degeneration", "armd"],
82
+ "dme": ["diabetic macular edema", "dme"],
83
+ "dr": ["diabetic retinopathy", "dr"],
84
+ "poag": ["primary open-angle glaucoma", "poag"],
85
+ "pacg": ["primary angle-closure glaucoma", "pacg"],
86
+ "rvo": ["retinal vein occlusion", "rvo"],
87
+ "rao": ["retinal artery occlusion", "rao"],
88
+ "crvo": ["central retinal vein occlusion", "crvo"],
89
+ "brvo": ["branch retinal vein occlusion", "brvo"],
90
+ "crao": ["central retinal artery occlusion", "crao"],
91
+ "vegf": ["vascular endothelial growth factor", "vegf"],
92
+ "oct": ["optical coherence tomography", "oct"],
93
+ "fa": ["fluorescein angiography", "fa"],
94
+ "icg": ["indocyanine green angiography", "icg"],
95
+ "erg": ["electroretinography", "erg"],
96
+ "vf": ["visual field", "vf"],
97
+ "va": ["visual acuity", "va"],
98
+
99
+ # Common synonyms
100
+ "retina": ["retina", "retinal"],
101
+ "cornea": ["cornea", "corneal"],
102
+ "glaucoma": ["glaucoma", "glaucomatous"],
103
+ "cataract": ["cataract", "lens opacity"],
104
+ "macula": ["macula", "macular"],
105
+ "optic nerve": ["optic nerve", "optic disc", "optic cup"],
106
+ }
107
+
108
+ def __init__(
109
+ self,
110
+ qdrant_manager: QdrantStoreManager,
111
+ embedding_client: SentenceTransformerClient,
112
+ dense_weight: float = 0.7,
113
+ sparse_weight: float = 0.3,
114
+ enable_term_expansion: bool = True,
115
+ ):
116
+ """
117
+ Initialize hybrid retriever.
118
+
119
+ Args:
120
+ qdrant_manager: QdrantStoreManager for vector search
121
+ embedding_client: SentenceTransformerClient for query embeddings
122
+ dense_weight: Weight for dense (semantic) search (0-1)
123
+ sparse_weight: Weight for sparse (BM25) search (0-1)
124
+ enable_term_expansion: Enable medical term expansion
125
+ """
126
+ self.qdrant_manager = qdrant_manager
127
+ self.embedding_client = embedding_client
128
+ self.dense_weight = dense_weight
129
+ self.sparse_weight = sparse_weight
130
+ self.enable_term_expansion = enable_term_expansion
131
+
132
+ self.console = Console()
133
+
134
+ # Validate weights
135
+ total_weight = dense_weight + sparse_weight
136
+ if not (0.99 <= total_weight <= 1.01): # Allow small floating point error
137
+ logger.warning(
138
+ f"Weights sum to {total_weight:.2f}, not 1.0. "
139
+ "Normalizing weights."
140
+ )
141
+ self.dense_weight = dense_weight / total_weight
142
+ self.sparse_weight = sparse_weight / total_weight
143
+
144
+ logger.info(
145
+ f"Initialized HybridRetriever (dense: {self.dense_weight:.2f}, "
146
+ f"sparse: {self.sparse_weight:.2f})"
147
+ )
148
+
149
+ def _preprocess_query(self, query: str) -> str:
150
+ """
151
+ Preprocess query text.
152
+
153
+ - Convert to lowercase
154
+ - Remove excessive whitespace
155
+ - Normalize punctuation
156
+
157
+ Args:
158
+ query: Raw query string
159
+
160
+ Returns:
161
+ Preprocessed query
162
+ """
163
+ # Convert to lowercase
164
+ query = query.lower()
165
+
166
+ # Remove excessive whitespace
167
+ query = re.sub(r'\s+', ' ', query)
168
+
169
+ # Strip leading/trailing whitespace
170
+ query = query.strip()
171
+
172
+ return query
173
+
174
+ def _expand_medical_terms(self, query: str) -> str:
175
+ """
176
+ Expand medical abbreviations and add synonyms.
177
+
178
+ Args:
179
+ query: Preprocessed query
180
+
181
+ Returns:
182
+ Expanded query with synonyms
183
+ """
184
+ if not self.enable_term_expansion:
185
+ return query
186
+
187
+ expanded_terms = []
188
+ words = query.split()
189
+
190
+ for word in words:
191
+ # Check if word matches any abbreviation or term
192
+ if word in self.MEDICAL_TERM_EXPANSIONS:
193
+ # Add all expansions
194
+ expansions = self.MEDICAL_TERM_EXPANSIONS[word]
195
+ expanded_terms.extend(expansions)
196
+ else:
197
+ # Keep original word
198
+ expanded_terms.append(word)
199
+
200
+ # Join and deduplicate
201
+ expanded_query = " ".join(expanded_terms)
202
+
203
+ logger.debug(f"Query expansion: '{query}' � '{expanded_query}'")
204
+
205
+ return expanded_query
206
+
207
+ def _generate_query_embedding(self, query: str) -> List[float]:
208
+ """
209
+ Generate embedding for query.
210
+
211
+ Args:
212
+ query: Query text
213
+
214
+ Returns:
215
+ Query embedding vector
216
+ """
217
+ try:
218
+ embedding = self.embedding_client.embed_text(query)
219
+ return embedding
220
+ except Exception as e:
221
+ logger.error(f"Failed to generate query embedding: {e}")
222
+ raise
223
+
224
+ def _merge_results(
225
+ self,
226
+ dense_results: List[SearchResult],
227
+ sparse_results: Optional[List[SearchResult]] = None,
228
+ ) -> List[Tuple[RetrievalResult, float]]:
229
+ """
230
+ Merge dense and sparse results using weighted fusion.
231
+
232
+ Uses Reciprocal Rank Fusion (RRF) for score combination.
233
+
234
+ Args:
235
+ dense_results: Results from dense search
236
+ sparse_results: Results from sparse search (if available)
237
+
238
+ Returns:
239
+ List of (RetrievalResult, combined_score) tuples
240
+ """
241
+ # If no sparse results, just use dense results
242
+ if not sparse_results:
243
+ results = []
244
+ for result in dense_results:
245
+ retrieval_result = RetrievalResult.from_search_result(result)
246
+ # Apply dense weight to score
247
+ weighted_score = result.score * self.dense_weight
248
+ results.append((retrieval_result, weighted_score))
249
+ return results
250
+
251
+ # Create score dictionaries keyed by chunk_id
252
+ dense_scores = {r.chunk_id: r.score for r in dense_results}
253
+ sparse_scores = {r.chunk_id: r.score for r in sparse_results}
254
+
255
+ # Get all unique chunk_ids
256
+ all_chunk_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
257
+
258
+ # Create lookup for full result objects
259
+ result_lookup = {}
260
+ for result in dense_results:
261
+ result_lookup[result.chunk_id] = result
262
+ for result in sparse_results:
263
+ if result.chunk_id not in result_lookup:
264
+ result_lookup[result.chunk_id] = result
265
+
266
+ # Calculate weighted combined scores
267
+ combined_results = []
268
+ for chunk_id in all_chunk_ids:
269
+ dense_score = dense_scores.get(chunk_id, 0.0)
270
+ sparse_score = sparse_scores.get(chunk_id, 0.0)
271
+
272
+ # Weighted combination
273
+ combined_score = (
274
+ dense_score * self.dense_weight + sparse_score * self.sparse_weight
275
+ )
276
+
277
+ result = result_lookup[chunk_id]
278
+ retrieval_result = RetrievalResult.from_search_result(result)
279
+ combined_results.append((retrieval_result, combined_score))
280
+
281
+ # Sort by combined score (descending)
282
+ combined_results.sort(key=lambda x: x[1], reverse=True)
283
+
284
+ return combined_results
285
+
286
+ def retrieve_with_scores(
287
+ self,
288
+ query: str,
289
+ top_k: int = 10,
290
+ filters: Optional[Dict] = None,
291
+ ) -> List[Tuple[RetrievalResult, float]]:
292
+ """
293
+ Retrieve documents with scores.
294
+
295
+ Args:
296
+ query: Search query
297
+ top_k: Number of results to return
298
+ filters: Optional metadata filters
299
+
300
+ Returns:
301
+ List of (RetrievalResult, score) tuples
302
+ """
303
+ # Preprocess query
304
+ processed_query = self._preprocess_query(query)
305
+
306
+ # Expand medical terms
307
+ expanded_query = self._expand_medical_terms(processed_query)
308
+
309
+ logger.info(f"Retrieving for query: '{query}'")
310
+ logger.debug(f"Processed query: '{expanded_query}'")
311
+
312
+ # Generate query embedding
313
+ query_embedding = self._generate_query_embedding(expanded_query)
314
+
315
+ # Perform dense search
316
+ dense_results = self.qdrant_manager.search(
317
+ query_embedding=query_embedding,
318
+ top_k=top_k * 2, # Get more for fusion
319
+ filters=filters,
320
+ )
321
+
322
+ logger.info(f"Dense search returned {len(dense_results)} results")
323
+
324
+ # Note: For true hybrid search with sparse vectors, you would also:
325
+ # 1. Generate sparse vector for query (BM25)
326
+ # 2. Perform sparse search via qdrant_manager.hybrid_search()
327
+ # 3. Merge results using RRF
328
+ #
329
+ # For now, we'll use dense search only
330
+ # In production, implement proper BM25 sparse vector generation
331
+
332
+ sparse_results = None # Placeholder for sparse search
333
+
334
+ # Merge results
335
+ combined_results = self._merge_results(dense_results, sparse_results)
336
+
337
+ # Return top_k
338
+ return combined_results[:top_k]
339
+
340
+ def retrieve(
341
+ self,
342
+ query: str,
343
+ top_k: int = 10,
344
+ filters: Optional[Dict] = None,
345
+ ) -> List[RetrievalResult]:
346
+ """
347
+ Retrieve documents (without scores).
348
+
349
+ Args:
350
+ query: Search query
351
+ top_k: Number of results to return
352
+ filters: Optional metadata filters
353
+
354
+ Returns:
355
+ List of RetrievalResult objects
356
+ """
357
+ results_with_scores = self.retrieve_with_scores(query, top_k, filters)
358
+
359
+ # Extract just the results, drop scores
360
+ results = [result for result, score in results_with_scores]
361
+
362
+ return results
363
+
364
+ def retrieve_by_disease(
365
+ self,
366
+ query: str,
367
+ disease_name: str,
368
+ top_k: int = 10,
369
+ ) -> List[RetrievalResult]:
370
+ """
371
+ Retrieve documents filtered by disease name.
372
+
373
+ Args:
374
+ query: Search query
375
+ disease_name: Disease name to filter by
376
+ top_k: Number of results to return
377
+
378
+ Returns:
379
+ List of RetrievalResult objects
380
+ """
381
+ filters = {"disease_name": disease_name}
382
+ return self.retrieve(query, top_k, filters)
383
+
384
+ def retrieve_by_icd_code(
385
+ self,
386
+ query: str,
387
+ icd_codes: List[str],
388
+ top_k: int = 10,
389
+ ) -> List[RetrievalResult]:
390
+ """
391
+ Retrieve documents filtered by ICD codes.
392
+
393
+ Args:
394
+ query: Search query
395
+ icd_codes: List of ICD codes to filter by
396
+ top_k: Number of results to return
397
+
398
+ Returns:
399
+ List of RetrievalResult objects
400
+ """
401
+ filters = {"icd_codes": icd_codes}
402
+ return self.retrieve(query, top_k, filters)
403
+
404
+ def retrieve_by_anatomy(
405
+ self,
406
+ query: str,
407
+ anatomical_structures: List[str],
408
+ top_k: int = 10,
409
+ ) -> List[RetrievalResult]:
410
+ """
411
+ Retrieve documents filtered by anatomical structures.
412
+
413
+ Args:
414
+ query: Search query
415
+ anatomical_structures: List of anatomical terms
416
+ top_k: Number of results to return
417
+
418
+ Returns:
419
+ List of RetrievalResult objects
420
+ """
421
+ filters = {"anatomical_structures": anatomical_structures}
422
+ return self.retrieve(query, top_k, filters)
423
+
424
+ def get_similar_sections(
425
+ self,
426
+ section_content: str,
427
+ top_k: int = 5,
428
+ filters: Optional[Dict] = None,
429
+ ) -> List[RetrievalResult]:
430
+ """
431
+ Find similar sections based on content.
432
+
433
+ Useful for "related sections" or "see also" features.
434
+
435
+ Args:
436
+ section_content: Content to find similar sections for
437
+ top_k: Number of results to return
438
+ filters: Optional metadata filters
439
+
440
+ Returns:
441
+ List of RetrievalResult objects
442
+ """
443
+ # Use the section content itself as the query
444
+ return self.retrieve(section_content, top_k, filters)
445
+
446
+ def multi_query_retrieve(
447
+ self,
448
+ queries: List[str],
449
+ top_k: int = 10,
450
+ filters: Optional[Dict] = None,
451
+ deduplicate: bool = True,
452
+ ) -> List[RetrievalResult]:
453
+ """
454
+ Retrieve using multiple queries and combine results.
455
+
456
+ Useful for query decomposition or multi-faceted questions.
457
+
458
+ Args:
459
+ queries: List of query strings
460
+ top_k: Total number of results to return
461
+ filters: Optional metadata filters
462
+ deduplicate: Remove duplicate results
463
+
464
+ Returns:
465
+ List of RetrievalResult objects
466
+ """
467
+ all_results = []
468
+ seen_chunk_ids = set()
469
+
470
+ # Retrieve for each query
471
+ for query in queries:
472
+ results = self.retrieve(query, top_k=top_k, filters=filters)
473
+
474
+ for result in results:
475
+ if deduplicate:
476
+ if result.chunk_id not in seen_chunk_ids:
477
+ all_results.append(result)
478
+ seen_chunk_ids.add(result.chunk_id)
479
+ else:
480
+ all_results.append(result)
481
+
482
+ # Return top_k overall
483
+ return all_results[:top_k]
src/scraper/__init__.py ADDED
File without changes
src/scraper/eyewiki_crawler.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EyeWiki crawler for medical article scraping using crawl4ai."""
2
+
3
+ import asyncio
4
+ import json
5
+ import re
6
+ from collections import deque
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, Set
10
+ from urllib.parse import urljoin, urlparse, parse_qs
11
+ from urllib.robotparser import RobotFileParser
12
+
13
+ import aiohttp
14
+ from bs4 import BeautifulSoup
15
+ from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
16
+ from rich.console import Console
17
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
18
+
19
+ from config.settings import settings
20
+
21
+
22
+ class EyeWikiCrawler:
23
+ """
24
+ Asynchronous crawler for EyeWiki medical articles.
25
+
26
+ Features:
27
+ - Asynchronous crawling with crawl4ai
28
+ - Respects robots.txt
29
+ - Polite crawling with configurable delays
30
+ - Markdown content extraction
31
+ - Checkpointing for resume capability
32
+ - Progress tracking with rich console
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ base_url: str = "https://eyewiki.org",
38
+ output_dir: Optional[Path] = None,
39
+ checkpoint_file: Optional[Path] = None,
40
+ delay: float = 1.5,
41
+ timeout: int = 30,
42
+ ):
43
+ """
44
+ Initialize the EyeWiki crawler.
45
+
46
+ Args:
47
+ base_url: Base URL for EyeWiki
48
+ output_dir: Directory to save scraped articles
49
+ checkpoint_file: Path to checkpoint file
50
+ delay: Delay between requests in seconds
51
+ timeout: Request timeout in seconds
52
+ """
53
+ self.base_url = base_url
54
+ self.domain = urlparse(base_url).netloc
55
+ self.output_dir = output_dir or Path(settings.data_raw_path)
56
+ self.checkpoint_file = checkpoint_file or (self.output_dir / "crawler_checkpoint.json")
57
+ self.delay = delay
58
+ self.timeout = timeout
59
+
60
+ # Ensure output directory exists
61
+ self.output_dir.mkdir(parents=True, exist_ok=True)
62
+
63
+ # Crawl state
64
+ self.visited_urls: Set[str] = set()
65
+ self.to_crawl: deque = deque()
66
+ self.failed_urls: Dict[str, str] = {}
67
+ self.articles_saved: int = 0
68
+
69
+ # Rich console for logging
70
+ self.console = Console()
71
+
72
+ # Robot parser
73
+ self.robot_parser = RobotFileParser()
74
+ self.robot_parser.set_url(urljoin(base_url, "/robots.txt"))
75
+
76
+ # Patterns to skip
77
+ self.skip_patterns = [
78
+ r"/index\.php\?title=.*&action=", # Edit, history, etc.
79
+ r"/index\.php\?title=.*&diff=", # Page diffs
80
+ r"/index\.php\?title=.*&oldid=", # Page history/revisions
81
+ r"/index\.php\?title=.*&direction=", # Page navigation
82
+ r"/index\.php\?title=Special:", # Special pages (login, create account, etc.)
83
+ r"/Special:", # Special pages
84
+ r"/User:", # User pages
85
+ r"/User_talk:", # User talk pages
86
+ r"/Talk:", # Talk pages
87
+ r"/File:", # File pages
88
+ r"/Template:", # Template pages
89
+ r"/Help:", # Help pages
90
+ r"/MediaWiki:", # MediaWiki pages
91
+ r"#", # Anchor links
92
+ ]
93
+
94
+ def _is_valid_article_url(self, url: str) -> bool:
95
+ """
96
+ Check if URL is a valid medical article.
97
+
98
+ Args:
99
+ url: URL to check
100
+
101
+ Returns:
102
+ True if valid article URL
103
+ """
104
+ # Must be from eyewiki.org domain
105
+ if self.domain not in url:
106
+ return False
107
+
108
+ # Skip patterns (these take precedence)
109
+ for pattern in self.skip_patterns:
110
+ if re.search(pattern, url):
111
+ return False
112
+
113
+ # Parse URL to check path
114
+ parsed = urlparse(url)
115
+ path = parsed.path.strip("/")
116
+
117
+ # Must be article-like URL
118
+ # EyeWiki articles can be:
119
+ # 1. Direct: /Article_Name (e.g., /Cataract)
120
+ # 2. Wiki-style: /wiki/Article_Name
121
+ # 3. Query-based: /w/index.php?title=Article_Name
122
+
123
+ # For query-based URLs, check if title parameter exists and is not a special page
124
+ if parsed.query and "title=" in parsed.query:
125
+ return True
126
+
127
+ # For direct URLs, check if path is non-empty and looks like an article
128
+ # (starts with capital letter, no file extension)
129
+ if path and not path.startswith("w/") and not "." in path:
130
+ # Path should look like an article name (capitalized, underscores/spaces)
131
+ if path[0].isupper() or path.startswith("wiki/"):
132
+ return True
133
+
134
+ return False
135
+
136
+ def _normalize_url(self, url: str) -> str:
137
+ """
138
+ Normalize URL for consistent comparison.
139
+
140
+ Args:
141
+ url: URL to normalize
142
+
143
+ Returns:
144
+ Normalized URL
145
+ """
146
+ # Remove fragment
147
+ url = url.split("#")[0]
148
+ # Remove trailing slash
149
+ url = url.rstrip("/")
150
+ return url
151
+
152
+ def _can_fetch(self, url: str) -> bool:
153
+ """
154
+ Check if URL can be fetched according to robots.txt.
155
+
156
+ Args:
157
+ url: URL to check
158
+
159
+ Returns:
160
+ True if allowed to fetch
161
+ """
162
+ try:
163
+ return self.robot_parser.can_fetch("*", url)
164
+ except Exception as e:
165
+ self.console.print(f"[yellow]Warning: Could not check robots.txt: {e}[/yellow]")
166
+ return True # Be permissive if robots.txt check fails
167
+
168
+ def _extract_links(self, html: str, current_url: str) -> Set[str]:
169
+ """
170
+ Extract valid article links from HTML.
171
+
172
+ Args:
173
+ html: HTML content
174
+ current_url: Current page URL for resolving relative links
175
+
176
+ Returns:
177
+ Set of valid article URLs
178
+ """
179
+ soup = BeautifulSoup(html, "html.parser")
180
+ links = set()
181
+
182
+ for a_tag in soup.find_all("a", href=True):
183
+ href = a_tag["href"]
184
+ # Resolve relative URLs
185
+ absolute_url = urljoin(current_url, href)
186
+ normalized_url = self._normalize_url(absolute_url)
187
+
188
+ if self._is_valid_article_url(normalized_url):
189
+ links.add(normalized_url)
190
+
191
+ return links
192
+
193
+ def _extract_metadata(self, soup: BeautifulSoup, url: str) -> Dict:
194
+ """
195
+ Extract metadata from article page.
196
+
197
+ Args:
198
+ soup: BeautifulSoup object
199
+ url: Article URL
200
+
201
+ Returns:
202
+ Dictionary of metadata
203
+ """
204
+ metadata = {
205
+ "url": url,
206
+ "title": "",
207
+ "last_updated": None,
208
+ "categories": [],
209
+ "scraped_at": datetime.utcnow().isoformat(),
210
+ }
211
+
212
+ # Extract title
213
+ title_tag = soup.find("h1", {"id": "firstHeading"}) or soup.find("h1")
214
+ if title_tag:
215
+ metadata["title"] = title_tag.get_text(strip=True)
216
+
217
+ # Extract categories
218
+ category_links = soup.find_all("a", href=re.compile(r"/Category:"))
219
+ metadata["categories"] = [link.get_text(strip=True) for link in category_links]
220
+
221
+ # Extract last modified date (if available)
222
+ last_modified = soup.find("li", {"id": "footer-info-lastmod"})
223
+ if last_modified:
224
+ metadata["last_updated"] = last_modified.get_text(strip=True)
225
+
226
+ return metadata
227
+
228
+ def save_article(self, content: Dict, filepath: Path) -> None:
229
+ """
230
+ Save article content and metadata to files.
231
+
232
+ Args:
233
+ content: Dictionary with 'markdown' and 'metadata' keys
234
+ filepath: Base filepath (without extension)
235
+ """
236
+ # Save markdown content
237
+ md_file = filepath.with_suffix(".md")
238
+ with open(md_file, "w", encoding="utf-8") as f:
239
+ f.write(content["markdown"])
240
+
241
+ # Save metadata as JSON sidecar
242
+ json_file = filepath.with_suffix(".json")
243
+ with open(json_file, "w", encoding="utf-8") as f:
244
+ json.dump(content["metadata"], f, indent=2, ensure_ascii=False)
245
+
246
+ self.articles_saved += 1
247
+ self.console.print(f"[green][/green] Saved: {content['metadata'].get('title', 'Untitled')}")
248
+
249
+ def load_checkpoint(self) -> bool:
250
+ """
251
+ Load checkpoint data to resume crawling.
252
+
253
+ Returns:
254
+ True if checkpoint was loaded successfully
255
+ """
256
+ if not self.checkpoint_file.exists():
257
+ return False
258
+
259
+ try:
260
+ with open(self.checkpoint_file, "r") as f:
261
+ data = json.load(f)
262
+
263
+ self.visited_urls = set(data.get("visited_urls", []))
264
+ self.to_crawl = deque(data.get("to_crawl", []))
265
+ self.failed_urls = data.get("failed_urls", {})
266
+ self.articles_saved = data.get("articles_saved", 0)
267
+
268
+ self.console.print(f"[blue]Loaded checkpoint:[/blue] {len(self.visited_urls)} visited, "
269
+ f"{len(self.to_crawl)} queued, {self.articles_saved} saved")
270
+ return True
271
+ except Exception as e:
272
+ self.console.print(f"[red]Error loading checkpoint: {e}[/red]")
273
+ return False
274
+
275
+ def save_checkpoint(self) -> None:
276
+ """Save current crawl state to checkpoint file."""
277
+ data = {
278
+ "visited_urls": list(self.visited_urls),
279
+ "to_crawl": list(self.to_crawl),
280
+ "failed_urls": self.failed_urls,
281
+ "articles_saved": self.articles_saved,
282
+ "last_checkpoint": datetime.utcnow().isoformat(),
283
+ }
284
+
285
+ try:
286
+ with open(self.checkpoint_file, "w") as f:
287
+ json.dump(data, f, indent=2)
288
+ except Exception as e:
289
+ self.console.print(f"[red]Error saving checkpoint: {e}[/red]")
290
+
291
+ async def crawl_single_page(self, url: str) -> Optional[Dict]:
292
+ """
293
+ Crawl a single page and extract content.
294
+
295
+ Args:
296
+ url: URL to crawl
297
+
298
+ Returns:
299
+ Dictionary with markdown content and metadata, or None if failed
300
+ """
301
+ if not self._can_fetch(url):
302
+ self.console.print(f"[yellow]Blocked by robots.txt:[/yellow] {url}")
303
+ return None
304
+
305
+ try:
306
+ # Configure browser settings
307
+ browser_config = BrowserConfig(
308
+ headless=True,
309
+ verbose=False,
310
+ )
311
+
312
+ # Configure crawler settings
313
+ crawler_config = CrawlerRunConfig(
314
+ cache_mode=CacheMode.BYPASS,
315
+ page_timeout=self.timeout * 1000, # Convert to milliseconds
316
+ wait_for="body",
317
+ )
318
+
319
+ # Create crawler and run
320
+ async with AsyncWebCrawler(config=browser_config) as crawler:
321
+ result = await crawler.arun(
322
+ url=url,
323
+ config=crawler_config,
324
+ )
325
+
326
+ if not result.success:
327
+ self.console.print(f"[red]Failed to crawl:[/red] {url}")
328
+ return None
329
+
330
+ # Parse HTML for metadata
331
+ soup = BeautifulSoup(result.html, "html.parser")
332
+ metadata = self._extract_metadata(soup, url)
333
+
334
+ # Get markdown content
335
+ markdown = result.markdown
336
+
337
+ return {
338
+ "markdown": markdown,
339
+ "metadata": metadata,
340
+ "html": result.html,
341
+ "links": self._extract_links(result.html, url),
342
+ }
343
+
344
+ except Exception as e:
345
+ self.console.print(f"[red]Error crawling {url}:[/red] {e}")
346
+ self.failed_urls[url] = str(e)
347
+ return None
348
+
349
+ async def crawl(
350
+ self,
351
+ max_pages: Optional[int] = None,
352
+ depth: int = 2,
353
+ start_urls: Optional[list] = None,
354
+ ) -> None:
355
+ """
356
+ Crawl EyeWiki starting from the main page.
357
+
358
+ Args:
359
+ max_pages: Maximum number of pages to crawl (None for unlimited)
360
+ depth: Maximum depth to crawl
361
+ start_urls: Optional list of starting URLs (defaults to base_url)
362
+ """
363
+ # Try to load checkpoint
364
+ checkpoint_loaded = self.load_checkpoint()
365
+
366
+ # Initialize robot parser
367
+ try:
368
+ self.robot_parser.read()
369
+ self.console.print("[green][/green] Loaded robots.txt")
370
+ except Exception as e:
371
+ self.console.print(f"[yellow]Warning: Could not load robots.txt: {e}[/yellow]")
372
+
373
+ # Initialize queue if not loaded from checkpoint
374
+ if not checkpoint_loaded:
375
+ if start_urls:
376
+ self.to_crawl.extend([(url, 0) for url in start_urls])
377
+ else:
378
+ self.to_crawl.append((self.base_url, 0))
379
+
380
+ self.console.print(f"\n[bold cyan]Starting EyeWiki Crawl[/bold cyan]")
381
+ self.console.print(f"Max pages: {max_pages or 'unlimited'}")
382
+ self.console.print(f"Max depth: {depth}")
383
+ self.console.print(f"Delay: {self.delay}s\n")
384
+
385
+ with Progress(
386
+ SpinnerColumn(),
387
+ TextColumn("[progress.description]{task.description}"),
388
+ BarColumn(),
389
+ TaskProgressColumn(),
390
+ console=self.console,
391
+ ) as progress:
392
+
393
+ task = progress.add_task(
394
+ "[cyan]Crawling...",
395
+ total=max_pages if max_pages else 100,
396
+ )
397
+
398
+ try:
399
+ while self.to_crawl:
400
+ # Check max_pages limit
401
+ if max_pages and self.articles_saved >= max_pages:
402
+ self.console.print(f"\n[yellow]Reached max_pages limit: {max_pages}[/yellow]")
403
+ break
404
+
405
+ # Get next URL
406
+ current_url, current_depth = self.to_crawl.popleft()
407
+
408
+ # Skip if already visited
409
+ if current_url in self.visited_urls:
410
+ continue
411
+
412
+ # Check depth limit
413
+ if current_depth > depth:
414
+ continue
415
+
416
+ # Mark as visited
417
+ self.visited_urls.add(current_url)
418
+
419
+ # Update progress
420
+ progress.update(
421
+ task,
422
+ completed=self.articles_saved,
423
+ description=f"[cyan]Crawling ({self.articles_saved} saved, {len(self.to_crawl)} queued): {current_url[:60]}...",
424
+ )
425
+
426
+ # Crawl the page
427
+ result = await self.crawl_single_page(current_url)
428
+
429
+ if result:
430
+ # Create filename from URL
431
+ parsed = urlparse(current_url)
432
+
433
+ # For URLs with query parameters (like index.php?title=Article_Name),
434
+ # extract the title parameter
435
+ if parsed.query:
436
+ query_params = parse_qs(parsed.query)
437
+ if 'title' in query_params:
438
+ # Use the title parameter as filename
439
+ filename = query_params['title'][0]
440
+ else:
441
+ # Fallback: use the entire query string
442
+ filename = parsed.query
443
+ else:
444
+ # Use path-based filename for clean URLs like /wiki/Article_Name
445
+ path_parts = parsed.path.strip("/").split("/")
446
+ filename = "_".join(path_parts[-2:]) if len(path_parts) > 1 else path_parts[-1]
447
+
448
+ # Clean filename
449
+ filename = re.sub(r"[^\w\s-]", "_", filename)
450
+ filename = re.sub(r"[-\s]+", "_", filename)
451
+ filename = filename[:200] # Limit length
452
+
453
+ # Save article
454
+ filepath = self.output_dir / filename
455
+ self.save_article(result, filepath)
456
+
457
+ # Add discovered links to queue
458
+ for link in result["links"]:
459
+ if link not in self.visited_urls:
460
+ self.to_crawl.append((link, current_depth + 1))
461
+
462
+ # Polite delay
463
+ await asyncio.sleep(self.delay)
464
+
465
+ # Periodic checkpoint save (every 10 articles)
466
+ if self.articles_saved % 10 == 0:
467
+ self.save_checkpoint()
468
+
469
+ except KeyboardInterrupt:
470
+ self.console.print("\n[yellow]Crawl interrupted by user[/yellow]")
471
+ except Exception as e:
472
+ self.console.print(f"\n[red]Error during crawl: {e}[/red]")
473
+ finally:
474
+ # Final checkpoint save
475
+ self.save_checkpoint()
476
+
477
+ # Print summary
478
+ self.console.print("\n[bold cyan]Crawl Summary[/bold cyan]")
479
+ self.console.print(f"Articles saved: {self.articles_saved}")
480
+ self.console.print(f"URLs visited: {len(self.visited_urls)}")
481
+ self.console.print(f"URLs failed: {len(self.failed_urls)}")
482
+ self.console.print(f"URLs remaining: {len(self.to_crawl)}")
483
+
484
+ if self.failed_urls:
485
+ self.console.print("\n[yellow]Failed URLs:[/yellow]")
486
+ for url, error in list(self.failed_urls.items())[:10]:
487
+ self.console.print(f" - {url}: {error}")
488
+ if len(self.failed_urls) > 10:
489
+ self.console.print(f" ... and {len(self.failed_urls) - 10} more")
src/vectorstore/__init__.py ADDED
File without changes
src/vectorstore/qdrant_store.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qdrant vector store manager for EyeWiki RAG system."""
2
+
3
+ import uuid
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional
7
+
8
+ from pydantic import BaseModel, Field
9
+ from qdrant_client import QdrantClient
10
+ from qdrant_client.models import (
11
+ Distance,
12
+ VectorParams,
13
+ SparseVectorParams,
14
+ SparseIndexParams,
15
+ PointStruct,
16
+ Filter,
17
+ FieldCondition,
18
+ MatchValue,
19
+ MatchAny,
20
+ Range,
21
+ ScoredPoint,
22
+ )
23
+ from rich.console import Console
24
+
25
+ from config.settings import settings
26
+ from src.processing.chunker import ChunkNode
27
+
28
+
29
+ # Configure logging
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class SearchResult(BaseModel):
35
+ """
36
+ Pydantic model for search results.
37
+
38
+ Attributes:
39
+ id: Unique identifier of the result
40
+ score: Relevance score
41
+ chunk_id: Chunk identifier
42
+ content: Text content
43
+ parent_section: Section header
44
+ document_title: Article title
45
+ source_url: EyeWiki URL
46
+ metadata: Additional metadata
47
+ """
48
+
49
+ id: str = Field(..., description="Unique result identifier")
50
+ score: float = Field(..., ge=0.0, description="Relevance score")
51
+ chunk_id: str = Field(..., description="Chunk identifier")
52
+ content: str = Field(..., description="Text content")
53
+ parent_section: str = Field(default="", description="Parent section header")
54
+ document_title: str = Field(default="", description="Document title")
55
+ source_url: str = Field(default="", description="Source URL")
56
+ metadata: Dict = Field(default_factory=dict, description="Additional metadata")
57
+
58
+ @classmethod
59
+ def from_scored_point(cls, point: ScoredPoint) -> "SearchResult":
60
+ """
61
+ Create SearchResult from Qdrant ScoredPoint.
62
+
63
+ Args:
64
+ point: Qdrant scored point
65
+
66
+ Returns:
67
+ SearchResult instance
68
+ """
69
+ payload = point.payload or {}
70
+
71
+ return cls(
72
+ id=str(point.id),
73
+ score=point.score,
74
+ chunk_id=payload.get("chunk_id", ""),
75
+ content=payload.get("content", ""),
76
+ parent_section=payload.get("parent_section", ""),
77
+ document_title=payload.get("document_title", ""),
78
+ source_url=payload.get("source_url", ""),
79
+ metadata=payload.get("metadata", {}),
80
+ )
81
+
82
+
83
+ class QdrantStoreManager:
84
+ """
85
+ Qdrant vector store manager for EyeWiki documents.
86
+
87
+ Features:
88
+ - Local/persistent Qdrant storage
89
+ - Dense vector search (semantic)
90
+ - Sparse vector search (BM25)
91
+ - Hybrid search combining both
92
+ - Metadata filtering
93
+ - Batched operations for efficiency
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ collection_name: Optional[str] = None,
99
+ path: Optional[str] = None,
100
+ embedding_dim: int = 768, # Default for nomic-embed-text
101
+ batch_size: int = 100,
102
+ ):
103
+ """
104
+ Initialize Qdrant store manager.
105
+
106
+ Args:
107
+ collection_name: Name of the collection (default: from settings)
108
+ path: Path to Qdrant storage (default: from settings)
109
+ embedding_dim: Dimension of dense embeddings
110
+ batch_size: Batch size for bulk operations
111
+ """
112
+ self.collection_name = collection_name or settings.qdrant_collection_name
113
+ self.path = Path(path or settings.qdrant_path)
114
+ self.embedding_dim = embedding_dim
115
+ self.batch_size = batch_size
116
+
117
+ # Create storage directory
118
+ self.path.mkdir(parents=True, exist_ok=True)
119
+
120
+ # Initialize Qdrant client (local/persistent mode)
121
+ try:
122
+ self.client = QdrantClient(path=str(self.path))
123
+ logger.info(f"Initialized Qdrant client at {self.path}")
124
+ except Exception as e:
125
+ logger.error(f"Failed to initialize Qdrant client: {e}")
126
+ raise
127
+
128
+ self.console = Console()
129
+
130
+ def initialize_collection(self, recreate: bool = False) -> None:
131
+ """
132
+ Initialize the Qdrant collection with vector configurations.
133
+
134
+ Creates collection with:
135
+ - Dense vectors for semantic search (cosine similarity)
136
+ - Sparse vectors for BM25/keyword search
137
+ - Payload indexing for metadata filtering
138
+
139
+ Args:
140
+ recreate: If True, delete existing collection and recreate
141
+ """
142
+ try:
143
+ # Check if collection exists
144
+ collections = self.client.get_collections().collections
145
+ collection_exists = any(c.name == self.collection_name for c in collections)
146
+
147
+ if collection_exists:
148
+ if recreate:
149
+ self.console.print(
150
+ f"[yellow]Deleting existing collection: {self.collection_name}[/yellow]"
151
+ )
152
+ self.client.delete_collection(self.collection_name)
153
+ else:
154
+ self.console.print(
155
+ f"[blue]Collection already exists: {self.collection_name}[/blue]"
156
+ )
157
+ return
158
+
159
+ # Create collection with dense and sparse vector configurations
160
+ self.console.print(f"[cyan]Creating collection: {self.collection_name}[/cyan]")
161
+
162
+ self.client.create_collection(
163
+ collection_name=self.collection_name,
164
+ vectors_config={
165
+ # Dense vector for semantic search
166
+ "dense": VectorParams(
167
+ size=self.embedding_dim,
168
+ distance=Distance.COSINE,
169
+ ),
170
+ },
171
+ sparse_vectors_config={
172
+ # Sparse vector for BM25/keyword search
173
+ "sparse": SparseVectorParams(
174
+ index=SparseIndexParams(
175
+ on_disk=False, # Keep in memory for speed
176
+ ),
177
+ ),
178
+ },
179
+ )
180
+
181
+ # Create payload indexes for efficient filtering
182
+ # Index on key metadata fields
183
+ self.client.create_payload_index(
184
+ collection_name=self.collection_name,
185
+ field_name="document_title",
186
+ field_schema="keyword",
187
+ )
188
+
189
+ self.client.create_payload_index(
190
+ collection_name=self.collection_name,
191
+ field_name="parent_section",
192
+ field_schema="keyword",
193
+ )
194
+
195
+ self.client.create_payload_index(
196
+ collection_name=self.collection_name,
197
+ field_name="metadata.disease_name",
198
+ field_schema="keyword",
199
+ )
200
+
201
+ self.client.create_payload_index(
202
+ collection_name=self.collection_name,
203
+ field_name="metadata.icd_codes",
204
+ field_schema="keyword",
205
+ )
206
+
207
+ self.console.print(
208
+ f"[green][/green] Collection created: {self.collection_name}"
209
+ )
210
+ logger.info(f"Created collection: {self.collection_name}")
211
+
212
+ except Exception as e:
213
+ logger.error(f"Failed to initialize collection: {e}")
214
+ raise
215
+
216
+ def add_documents(
217
+ self,
218
+ chunks: List[ChunkNode],
219
+ dense_embeddings: List[List[float]],
220
+ sparse_embeddings: Optional[List[Dict]] = None,
221
+ ) -> int:
222
+ """
223
+ Add documents to the vector store with batched upserts.
224
+
225
+ Args:
226
+ chunks: List of ChunkNode objects
227
+ dense_embeddings: List of dense embedding vectors
228
+ sparse_embeddings: Optional list of sparse vectors (for BM25)
229
+
230
+ Returns:
231
+ Number of documents successfully added
232
+
233
+ Raises:
234
+ ValueError: If chunks and embeddings length mismatch
235
+ """
236
+ if len(chunks) != len(dense_embeddings):
237
+ raise ValueError(
238
+ f"Chunks ({len(chunks)}) and embeddings ({len(dense_embeddings)}) "
239
+ "must have same length"
240
+ )
241
+
242
+ if sparse_embeddings and len(sparse_embeddings) != len(chunks):
243
+ raise ValueError(
244
+ f"Chunks ({len(chunks)}) and sparse embeddings ({len(sparse_embeddings)}) "
245
+ "must have same length"
246
+ )
247
+
248
+ total_added = 0
249
+
250
+ try:
251
+ # Process in batches
252
+ for i in range(0, len(chunks), self.batch_size):
253
+ batch_chunks = chunks[i : i + self.batch_size]
254
+ batch_dense = dense_embeddings[i : i + self.batch_size]
255
+ batch_sparse = (
256
+ sparse_embeddings[i : i + self.batch_size]
257
+ if sparse_embeddings
258
+ else None
259
+ )
260
+
261
+ # Create points for batch
262
+ points = []
263
+ for j, chunk in enumerate(batch_chunks):
264
+ # Prepare vector dict
265
+ vectors = {"dense": batch_dense[j]}
266
+
267
+ # Add sparse vector if available
268
+ if batch_sparse:
269
+ vectors["sparse"] = batch_sparse[j]
270
+
271
+ # Create point
272
+ point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, chunk.chunk_id))
273
+ point = PointStruct(
274
+ id=point_id,
275
+ vector=vectors,
276
+ payload={
277
+ "chunk_id": chunk.chunk_id,
278
+ "content": chunk.content,
279
+ "parent_section": chunk.parent_section,
280
+ "document_title": chunk.document_title,
281
+ "source_url": chunk.source_url,
282
+ "chunk_index": chunk.chunk_index,
283
+ "token_count": chunk.token_count,
284
+ "metadata": chunk.metadata,
285
+ },
286
+ )
287
+ points.append(point)
288
+
289
+ # Upsert batch
290
+ self.client.upsert(
291
+ collection_name=self.collection_name,
292
+ points=points,
293
+ )
294
+
295
+ total_added += len(points)
296
+
297
+ logger.info(
298
+ f"Uploaded batch {i // self.batch_size + 1}: "
299
+ f"{len(points)} points (total: {total_added})"
300
+ )
301
+
302
+ self.console.print(
303
+ f"[green][/green] Added {total_added} documents to {self.collection_name}"
304
+ )
305
+ return total_added
306
+
307
+ except Exception as e:
308
+ logger.error(f"Failed to add documents: {e}")
309
+ raise
310
+
311
+ def search(
312
+ self,
313
+ query_embedding: List[float],
314
+ top_k: int = 10,
315
+ filters: Optional[Dict] = None,
316
+ score_threshold: Optional[float] = None,
317
+ ) -> List[SearchResult]:
318
+ """
319
+ Search using dense vector (semantic search).
320
+
321
+ Args:
322
+ query_embedding: Dense query vector
323
+ top_k: Number of results to return
324
+ filters: Optional metadata filters (e.g., {"disease_name": "Glaucoma"})
325
+ score_threshold: Minimum score threshold
326
+
327
+ Returns:
328
+ List of SearchResult objects
329
+ """
330
+ try:
331
+ # Build filter conditions
332
+ query_filter = self._build_filter(filters) if filters else None
333
+
334
+ # Perform search
335
+ results = self.client.query_points(
336
+ collection_name=self.collection_name,
337
+ query=query_embedding,
338
+ using="dense", # Specify which named vector to use
339
+ limit=top_k,
340
+ query_filter=query_filter,
341
+ score_threshold=score_threshold,
342
+ ).points
343
+
344
+ # Convert to SearchResult objects
345
+ search_results = [SearchResult.from_scored_point(r) for r in results]
346
+
347
+ logger.info(f"Dense search returned {len(search_results)} results")
348
+ return search_results
349
+
350
+ except Exception as e:
351
+ logger.error(f"Search failed: {e}")
352
+ raise
353
+
354
+ def hybrid_search(
355
+ self,
356
+ query_embedding: List[float],
357
+ query_sparse: Optional[Dict] = None,
358
+ top_k: int = 10,
359
+ filters: Optional[Dict] = None,
360
+ ) -> List[SearchResult]:
361
+ """
362
+ Hybrid search combining dense (semantic) and sparse (BM25) vectors.
363
+
364
+ Args:
365
+ query_embedding: Dense query vector
366
+ query_sparse: Sparse query vector for BM25
367
+ top_k: Number of results to return
368
+ filters: Optional metadata filters
369
+
370
+ Returns:
371
+ List of SearchResult objects with combined scores
372
+ """
373
+ try:
374
+ # If no sparse vector provided, fall back to dense search
375
+ if query_sparse is None:
376
+ logger.warning("No sparse vector provided, using dense search only")
377
+ return self.search(query_embedding, top_k, filters)
378
+
379
+ # Build filter conditions
380
+ query_filter = self._build_filter(filters) if filters else None
381
+
382
+ # Perform hybrid search
383
+ # Note: Qdrant supports multiple vectors in search, but for true hybrid
384
+ # we'd need to do two separate searches and merge results
385
+ # For simplicity, we'll use the query API with dense vector
386
+ # In production, you'd want to implement proper RRF (Reciprocal Rank Fusion)
387
+
388
+ results = self.client.query_points(
389
+ collection_name=self.collection_name,
390
+ query=query_embedding,
391
+ using="dense", # Specify which named vector to use
392
+ limit=top_k * 2, # Get more results for reranking
393
+ query_filter=query_filter,
394
+ ).points
395
+
396
+ # Convert to SearchResult objects
397
+ search_results = [SearchResult.from_scored_point(r) for r in results]
398
+
399
+ # For now, return top_k results
400
+ # In production, implement RRF combining dense and sparse results
401
+ logger.info(f"Hybrid search returned {len(search_results[:top_k])} results")
402
+ return search_results[:top_k]
403
+
404
+ except Exception as e:
405
+ logger.error(f"Hybrid search failed: {e}")
406
+ raise
407
+
408
+ def _build_filter(self, filters: Dict) -> Filter:
409
+ """
410
+ Build Qdrant filter from dictionary.
411
+
412
+ Supports:
413
+ - disease_name: str
414
+ - icd_codes: List[str]
415
+ - anatomical_structures: List[str]
416
+ - document_title: str
417
+ - parent_section: str
418
+
419
+ Args:
420
+ filters: Dictionary of filter conditions
421
+
422
+ Returns:
423
+ Qdrant Filter object
424
+ """
425
+ conditions = []
426
+
427
+ # Disease name filter
428
+ if "disease_name" in filters:
429
+ conditions.append(
430
+ FieldCondition(
431
+ key="metadata.disease_name",
432
+ match=MatchValue(value=filters["disease_name"]),
433
+ )
434
+ )
435
+
436
+ # ICD codes filter (match any)
437
+ if "icd_codes" in filters:
438
+ icd_list = filters["icd_codes"]
439
+ if isinstance(icd_list, str):
440
+ icd_list = [icd_list]
441
+ conditions.append(
442
+ FieldCondition(
443
+ key="metadata.icd_codes",
444
+ match=MatchAny(any=icd_list),
445
+ )
446
+ )
447
+
448
+ # Anatomical structures filter
449
+ if "anatomical_structures" in filters:
450
+ structures = filters["anatomical_structures"]
451
+ if isinstance(structures, str):
452
+ structures = [structures]
453
+ conditions.append(
454
+ FieldCondition(
455
+ key="metadata.anatomical_structures",
456
+ match=MatchAny(any=structures),
457
+ )
458
+ )
459
+
460
+ # Document title filter
461
+ if "document_title" in filters:
462
+ conditions.append(
463
+ FieldCondition(
464
+ key="document_title",
465
+ match=MatchValue(value=filters["document_title"]),
466
+ )
467
+ )
468
+
469
+ # Parent section filter
470
+ if "parent_section" in filters:
471
+ conditions.append(
472
+ FieldCondition(
473
+ key="parent_section",
474
+ match=MatchValue(value=filters["parent_section"]),
475
+ )
476
+ )
477
+
478
+ # Token count range filter
479
+ if "min_tokens" in filters or "max_tokens" in filters:
480
+ range_filter = {}
481
+ if "min_tokens" in filters:
482
+ range_filter["gte"] = filters["min_tokens"]
483
+ if "max_tokens" in filters:
484
+ range_filter["lte"] = filters["max_tokens"]
485
+
486
+ conditions.append(
487
+ FieldCondition(
488
+ key="token_count",
489
+ range=Range(**range_filter),
490
+ )
491
+ )
492
+
493
+ return Filter(must=conditions) if conditions else None
494
+
495
+ def get_collection_info(self) -> Dict:
496
+ """
497
+ Get information about the collection.
498
+
499
+ Returns:
500
+ Dictionary with collection statistics
501
+ """
502
+ try:
503
+ info = self.client.get_collection(self.collection_name)
504
+
505
+ return {
506
+ "name": self.collection_name,
507
+ "vectors_count": getattr(info, "vectors_count", 0),
508
+ "points_count": info.points_count,
509
+ "status": info.status,
510
+ "optimizer_status": info.optimizer_status,
511
+ "indexed_vectors_count": getattr(info, "indexed_vectors_count", 0),
512
+ }
513
+
514
+ except Exception as e:
515
+ logger.error(f"Failed to get collection info: {e}")
516
+ raise
517
+
518
+ def delete_collection(self) -> bool:
519
+ """
520
+ Delete the collection.
521
+
522
+ Returns:
523
+ True if successful
524
+ """
525
+ try:
526
+ result = self.client.delete_collection(self.collection_name)
527
+ self.console.print(
528
+ f"[yellow]Deleted collection: {self.collection_name}[/yellow]"
529
+ )
530
+ logger.info(f"Deleted collection: {self.collection_name}")
531
+ return result
532
+
533
+ except Exception as e:
534
+ logger.error(f"Failed to delete collection: {e}")
535
+ raise
536
+
537
+ def count_documents(self) -> int:
538
+ """
539
+ Count total documents in collection.
540
+
541
+ Returns:
542
+ Number of documents
543
+ """
544
+ try:
545
+ info = self.client.get_collection(self.collection_name)
546
+ return info.points_count or 0
547
+
548
+ except Exception as e:
549
+ logger.error(f"Failed to count documents: {e}")
550
+ return 0
551
+
552
+ def get_document_by_id(self, doc_id: str) -> Optional[SearchResult]:
553
+ """
554
+ Retrieve a specific document by ID.
555
+
556
+ Args:
557
+ doc_id: Document ID (chunk_id)
558
+
559
+ Returns:
560
+ SearchResult if found, None otherwise
561
+ """
562
+ try:
563
+ points = self.client.retrieve(
564
+ collection_name=self.collection_name,
565
+ ids=[doc_id],
566
+ )
567
+
568
+ if not points:
569
+ return None
570
+
571
+ point = points[0]
572
+ payload = point.payload or {}
573
+
574
+ return SearchResult(
575
+ id=str(point.id),
576
+ score=1.0, # No score for direct retrieval
577
+ chunk_id=payload.get("chunk_id", ""),
578
+ content=payload.get("content", ""),
579
+ parent_section=payload.get("parent_section", ""),
580
+ document_title=payload.get("document_title", ""),
581
+ source_url=payload.get("source_url", ""),
582
+ metadata=payload.get("metadata", {}),
583
+ )
584
+
585
+ except Exception as e:
586
+ logger.error(f"Failed to get document by ID: {e}")
587
+ return None
tests/README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tests
2
+
3
+ Comprehensive test suite for the EyeWiki RAG system.
4
+
5
+ ## Installation
6
+
7
+ Install test dependencies:
8
+
9
+ ```bash
10
+ pip install pytest pytest-cov pytest-mock requests
11
+ ```
12
+
13
+ ## Running Tests
14
+
15
+ ### Run all tests:
16
+ ```bash
17
+ pytest
18
+ ```
19
+
20
+ ### Run with verbose output:
21
+ ```bash
22
+ pytest -v
23
+ ```
24
+
25
+ ### Run specific test file:
26
+ ```bash
27
+ pytest tests/test_components.py -v
28
+ ```
29
+
30
+ ### Run specific test:
31
+ ```bash
32
+ pytest tests/test_components.py::test_chunk_respects_headers -v
33
+ ```
34
+
35
+ ### Run tests by marker:
36
+ ```bash
37
+ # Run only unit tests
38
+ pytest -m unit
39
+
40
+ # Run only integration tests
41
+ pytest -m integration
42
+
43
+ # Run only API tests
44
+ pytest -m api
45
+ ```
46
+
47
+ ### Run with coverage:
48
+ ```bash
49
+ pytest --cov=src --cov-report=html
50
+ ```
51
+
52
+ This will generate a coverage report in `htmlcov/index.html`.
53
+
54
+ ## Test Categories
55
+
56
+ ### Unit Tests (`@pytest.mark.unit`)
57
+ - Fast, isolated tests
58
+ - Mock external dependencies
59
+ - Test individual components
60
+
61
+ ### Integration Tests (`@pytest.mark.integration`)
62
+ - Test multiple components together
63
+ - May be slower
64
+ - May require real dependencies
65
+
66
+ ### API Tests (`@pytest.mark.api`)
67
+ - Test FastAPI endpoints
68
+ - Require server components
69
+ - Use TestClient
70
+
71
+ ## Test Structure
72
+
73
+ ### Chunker Tests
74
+ - `test_chunk_respects_headers()` - Verifies markdown header handling
75
+ - `test_chunk_size_limits()` - Checks chunk size constraints
76
+ - `test_metadata_preserved()` - Ensures metadata propagation
77
+
78
+ ### Retriever Tests
79
+ - `test_retrieval_returns_results()` - Basic retrieval functionality
80
+ - `test_hybrid_search_combines_scores()` - Score combination logic
81
+ - `test_filters_work()` - Metadata filtering
82
+
83
+ ### Reranker Tests
84
+ - `test_reranking_changes_order()` - Verifies reranking effect
85
+ - `test_top_k_respected()` - Checks top_k parameter
86
+
87
+ ### Query Engine Tests
88
+ - `test_full_query_pipeline()` - End-to-end query flow
89
+ - `test_sources_included()` - Source citation functionality
90
+ - `test_disclaimer_present()` - Medical disclaimer inclusion
91
+ - `test_streaming_query()` - Streaming response
92
+
93
+ ### API Tests
94
+ - `test_health_endpoint()` - Health check endpoint
95
+ - `test_query_endpoint()` - Main query endpoint
96
+ - `test_query_endpoint_validation()` - Input validation
97
+
98
+ ### Metadata Tests
99
+ - `test_icd_code_extraction()` - ICD-10 code extraction
100
+ - `test_anatomical_term_extraction()` - Anatomical term detection
101
+ - `test_medication_extraction()` - Medication identification
102
+
103
+ ## Fixtures
104
+
105
+ Reusable test fixtures are defined in `test_components.py`:
106
+
107
+ - `semantic_chunker` - ChunkerSemanticChunker instance
108
+ - `metadata_extractor` - MetadataExtractor instance
109
+ - `sample_chunks` - Sample ChunkNode objects
110
+ - `mock_retriever` - Mocked HybridRetriever
111
+ - `mock_reranker` - Mocked CrossEncoderReranker
112
+ - `mock_ollama_client` - Mocked OllamaClient
113
+ - `query_engine` - Fully configured QueryEngine with mocks
114
+ - `test_client` - FastAPI TestClient
115
+
116
+ ## Writing New Tests
117
+
118
+ ### Example unit test:
119
+ ```python
120
+ @pytest.mark.unit
121
+ def test_my_component(my_fixture):
122
+ """Test description."""
123
+ result = my_fixture.some_method()
124
+ assert result == expected_value
125
+ ```
126
+
127
+ ### Example integration test:
128
+ ```python
129
+ @pytest.mark.integration
130
+ def test_component_interaction():
131
+ """Test multiple components together."""
132
+ # Setup
133
+ component_a = ComponentA()
134
+ component_b = ComponentB(component_a)
135
+
136
+ # Test
137
+ result = component_b.process()
138
+
139
+ # Assert
140
+ assert result.is_valid()
141
+ ```
142
+
143
+ ### Example API test:
144
+ ```python
145
+ @pytest.mark.api
146
+ def test_my_endpoint(test_client):
147
+ """Test API endpoint."""
148
+ response = test_client.get("/my-endpoint")
149
+ assert response.status_code == 200
150
+ assert "expected_field" in response.json()
151
+ ```
152
+
153
+ ## Continuous Integration
154
+
155
+ These tests are designed to run in CI/CD pipelines. Mock external dependencies (Ollama, Qdrant) to ensure tests run in any environment.
156
+
157
+ ## Troubleshooting
158
+
159
+ ### Import Errors
160
+ Make sure the project root is in PYTHONPATH:
161
+ ```bash
162
+ export PYTHONPATH=/path/to/eyewiki-rag:$PYTHONPATH
163
+ ```
164
+
165
+ ### Mock Issues
166
+ If mocks aren't working properly, check that you're using the correct spec:
167
+ ```python
168
+ mock = Mock(spec=RealClass)
169
+ ```
170
+
171
+ ### API Tests Failing
172
+ API tests may fail if the application isn't properly initialized. Use mocking to isolate components.
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pytest configuration and shared fixtures.
3
+ """
4
+
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ # Add project root to Python path
9
+ project_root = Path(__file__).parent.parent
10
+ sys.path.insert(0, str(project_root))
11
+
12
+
13
+ def pytest_configure(config):
14
+ """Configure pytest."""
15
+ # Add custom markers
16
+ config.addinivalue_line(
17
+ "markers", "integration: mark test as integration test (may be slow)"
18
+ )
19
+ config.addinivalue_line(
20
+ "markers", "api: mark test as API test (requires server components)"
21
+ )
22
+ config.addinivalue_line(
23
+ "markers", "unit: mark test as unit test (fast, isolated)"
24
+ )
tests/test_components.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive tests for EyeWiki RAG components.
3
+
4
+ Run with:
5
+ pytest tests/test_components.py -v
6
+ pytest tests/test_components.py::test_chunk_respects_headers -v
7
+ """
8
+
9
+ import pytest
10
+ from pathlib import Path
11
+ from unittest.mock import Mock, patch, MagicMock
12
+ from typing import List
13
+
14
+ from src.processing.chunker import ChunkNode, SemanticChunker
15
+ from src.processing.metadata_extractor import MetadataExtractor
16
+ from src.rag.retriever import HybridRetriever, RetrievalResult
17
+ from src.rag.reranker import CrossEncoderReranker
18
+ from src.rag.query_engine import EyeWikiQueryEngine, QueryResponse, SourceInfo
19
+
20
+
21
+ # ============================================================================
22
+ # Test Data
23
+ # ============================================================================
24
+
25
+ SAMPLE_MARKDOWN = """# Glaucoma
26
+
27
+ ## Overview
28
+
29
+ Glaucoma is a group of eye conditions that damage the optic nerve.
30
+
31
+ ## Symptoms
32
+
33
+ Common symptoms include:
34
+ - Vision loss
35
+ - Eye pain
36
+ - Halos around lights
37
+
38
+ ## Treatment
39
+
40
+ Treatment options include:
41
+ - Medications (IOP-lowering drops)
42
+ - Laser procedures
43
+ - Surgery
44
+
45
+ ### Medications
46
+
47
+ Beta-blockers and prostaglandin analogs are commonly used.
48
+
49
+ ### Surgery
50
+
51
+ Trabeculectomy is a common surgical procedure.
52
+ """
53
+
54
+ SAMPLE_METADATA = {
55
+ "title": "Glaucoma",
56
+ "url": "https://eyewiki.aao.org/Glaucoma",
57
+ "source": "eyewiki",
58
+ }
59
+
60
+
61
+ # ============================================================================
62
+ # Fixtures
63
+ # ============================================================================
64
+
65
+ @pytest.fixture
66
+ def semantic_chunker():
67
+ """Create a SemanticChunker instance."""
68
+ return SemanticChunker(
69
+ chunk_size=200,
70
+ chunk_overlap=20,
71
+ min_chunk_size=50,
72
+ )
73
+
74
+
75
+ @pytest.fixture
76
+ def metadata_extractor():
77
+ """Create a MetadataExtractor instance."""
78
+ return MetadataExtractor()
79
+
80
+
81
+ @pytest.fixture
82
+ def sample_chunks():
83
+ """Create sample retrieval results for testing."""
84
+ return [
85
+ ChunkNode(
86
+ id="chunk_1",
87
+ content="Glaucoma is characterized by elevated intraocular pressure (IOP).",
88
+ document_title="Glaucoma",
89
+ source_url="https://eyewiki.aao.org/Glaucoma",
90
+ parent_section="Overview",
91
+ metadata={"icd_codes": ["H40.1"], "anatomical_terms": ["optic nerve"]},
92
+ chunk_index=0,
93
+ total_chunks=5,
94
+ ),
95
+ ChunkNode(
96
+ id="chunk_2",
97
+ content="Treatment includes beta-blockers and prostaglandin analogs.",
98
+ document_title="Glaucoma",
99
+ source_url="https://eyewiki.aao.org/Glaucoma",
100
+ parent_section="Treatment",
101
+ metadata={"medications": ["beta-blockers", "prostaglandin analogs"]},
102
+ chunk_index=1,
103
+ total_chunks=5,
104
+ ),
105
+ ChunkNode(
106
+ id="chunk_3",
107
+ content="Diabetic retinopathy affects the retinal blood vessels.",
108
+ document_title="Diabetic Retinopathy",
109
+ source_url="https://eyewiki.aao.org/Diabetic_Retinopathy",
110
+ parent_section="Overview",
111
+ metadata={"icd_codes": ["E11.3"], "anatomical_terms": ["retina"]},
112
+ chunk_index=0,
113
+ total_chunks=3,
114
+ ),
115
+ ]
116
+
117
+
118
+ @pytest.fixture
119
+ def mock_retriever(sample_chunks):
120
+ """Create a mock HybridRetriever."""
121
+ retriever = Mock(spec=HybridRetriever)
122
+
123
+ # Convert ChunkNodes to RetrievalResults
124
+ retrieval_results = [
125
+ RetrievalResult(
126
+ id=chunk.id,
127
+ content=chunk.content,
128
+ document_title=chunk.document_title,
129
+ source_url=chunk.source_url,
130
+ section=chunk.parent_section,
131
+ metadata=chunk.metadata,
132
+ score=0.9 - (i * 0.1), # Decreasing scores
133
+ )
134
+ for i, chunk in enumerate(sample_chunks)
135
+ ]
136
+
137
+ retriever.retrieve.return_value = retrieval_results
138
+ return retriever
139
+
140
+
141
+ @pytest.fixture
142
+ def mock_reranker():
143
+ """Create a mock CrossEncoderReranker."""
144
+ reranker = Mock(spec=CrossEncoderReranker)
145
+
146
+ def rerank_func(query: str, documents: List[RetrievalResult], top_k: int):
147
+ # Reverse order to simulate reranking
148
+ reranked = list(reversed(documents[:top_k]))
149
+ # Update scores
150
+ for i, doc in enumerate(reranked):
151
+ doc.score = 0.95 - (i * 0.05)
152
+ return reranked
153
+
154
+ reranker.rerank.side_effect = rerank_func
155
+ return reranker
156
+
157
+
158
+ @pytest.fixture
159
+ def mock_ollama_client():
160
+ """Create a mock OllamaClient."""
161
+ client = Mock()
162
+ client.generate.return_value = (
163
+ "Glaucoma is a group of eye diseases that damage the optic nerve. "
164
+ "It is often associated with elevated intraocular pressure (IOP). "
165
+ "[Source: Glaucoma]"
166
+ )
167
+ client.stream_generate.return_value = iter(["Glaucoma ", "is ", "a disease."])
168
+ client.embed_text.return_value = [0.1] * 768
169
+ return client
170
+
171
+
172
+ @pytest.fixture
173
+ def query_engine(mock_retriever, mock_reranker, mock_ollama_client, tmp_path):
174
+ """Create a QueryEngine instance with mocked dependencies."""
175
+ # Create temporary prompt files
176
+ system_prompt = tmp_path / "system_prompt.txt"
177
+ system_prompt.write_text("You are an expert ophthalmology assistant.")
178
+
179
+ query_prompt = tmp_path / "query_prompt.txt"
180
+ query_prompt.write_text("Context: {context}\n\nQuestion: {question}\n\nAnswer:")
181
+
182
+ disclaimer = tmp_path / "disclaimer.txt"
183
+ disclaimer.write_text("Medical disclaimer text.")
184
+
185
+ return EyeWikiQueryEngine(
186
+ retriever=mock_retriever,
187
+ reranker=mock_reranker,
188
+ llm_client=mock_ollama_client,
189
+ system_prompt_path=system_prompt,
190
+ query_prompt_path=query_prompt,
191
+ disclaimer_path=disclaimer,
192
+ max_context_tokens=4000,
193
+ retrieval_k=20,
194
+ rerank_k=5,
195
+ )
196
+
197
+
198
+ # ============================================================================
199
+ # Chunker Tests
200
+ # ============================================================================
201
+
202
+ def test_chunk_respects_headers(semantic_chunker):
203
+ """Test that chunker respects markdown headers."""
204
+ chunks = semantic_chunker.chunk_document(
205
+ markdown_content=SAMPLE_MARKDOWN,
206
+ metadata=SAMPLE_METADATA,
207
+ )
208
+
209
+ # Should have multiple chunks based on headers
210
+ assert len(chunks) > 0
211
+
212
+ # Check that parent sections are correctly identified
213
+ sections = {chunk.parent_section for chunk in chunks}
214
+ assert "Overview" in sections or "Symptoms" in sections or "Treatment" in sections
215
+
216
+ # Verify each chunk has required fields
217
+ for chunk in chunks:
218
+ assert chunk.content
219
+ assert chunk.document_title == "Glaucoma"
220
+ assert chunk.source_url == SAMPLE_METADATA["url"]
221
+ assert chunk.id
222
+ assert isinstance(chunk.chunk_index, int)
223
+ assert isinstance(chunk.total_chunks, int)
224
+
225
+
226
+ def test_chunk_size_limits(semantic_chunker):
227
+ """Test that chunks respect size limits."""
228
+ # Create a very long section
229
+ long_text = "This is a test sentence. " * 200 # Very long text
230
+ long_markdown = f"# Test\n\n## Section\n\n{long_text}"
231
+
232
+ chunks = semantic_chunker.chunk_document(
233
+ markdown_content=long_markdown,
234
+ metadata=SAMPLE_METADATA,
235
+ )
236
+
237
+ # All chunks should respect min size
238
+ for chunk in chunks:
239
+ # Token estimation: len(text) // 4
240
+ estimated_tokens = len(chunk.content) // 4
241
+ # Should not be too small (unless it's the last chunk)
242
+ if chunk.chunk_index < chunk.total_chunks - 1:
243
+ assert estimated_tokens >= semantic_chunker.min_chunk_size
244
+
245
+ # Should have created multiple chunks for long text
246
+ assert len(chunks) > 1
247
+
248
+
249
+ def test_metadata_preserved(semantic_chunker):
250
+ """Test that metadata is preserved in chunks."""
251
+ custom_metadata = {
252
+ "title": "Test Document",
253
+ "url": "https://example.com/test",
254
+ "custom_field": "custom_value",
255
+ }
256
+
257
+ chunks = semantic_chunker.chunk_document(
258
+ markdown_content=SAMPLE_MARKDOWN,
259
+ metadata=custom_metadata,
260
+ )
261
+
262
+ # All chunks should have the same base metadata
263
+ for chunk in chunks:
264
+ assert chunk.document_title == custom_metadata["title"]
265
+ assert chunk.source_url == custom_metadata["url"]
266
+
267
+
268
+ # ============================================================================
269
+ # Retriever Tests
270
+ # ============================================================================
271
+
272
+ def test_retrieval_returns_results(mock_retriever):
273
+ """Test that retriever returns results."""
274
+ query = "What is glaucoma?"
275
+ results = mock_retriever.retrieve(query=query, top_k=10)
276
+
277
+ assert len(results) > 0
278
+ assert all(isinstance(r, RetrievalResult) for r in results)
279
+
280
+ # Verify result structure
281
+ for result in results:
282
+ assert result.id
283
+ assert result.content
284
+ assert result.document_title
285
+ assert result.source_url
286
+ assert 0 <= result.score <= 1
287
+
288
+
289
+ def test_hybrid_search_combines_scores(mock_retriever):
290
+ """Test that hybrid search returns combined scores."""
291
+ query = "glaucoma treatment"
292
+ results = mock_retriever.retrieve(query=query, top_k=5)
293
+
294
+ # Scores should be in descending order
295
+ scores = [r.score for r in results]
296
+ assert scores == sorted(scores, reverse=True)
297
+
298
+ # All scores should be valid
299
+ assert all(0 <= score <= 1 for score in scores)
300
+
301
+
302
+ def test_filters_work(mock_retriever):
303
+ """Test that metadata filters work."""
304
+ # Add filter functionality to mock
305
+ def retrieve_with_filter(query: str, top_k: int, filters: dict = None):
306
+ results = mock_retriever.retrieve(query=query, top_k=top_k)
307
+
308
+ if filters:
309
+ # Simple filter implementation for testing
310
+ filtered = []
311
+ for r in results:
312
+ if "disease_name" in filters:
313
+ if filters["disease_name"] in r.document_title:
314
+ filtered.append(r)
315
+ else:
316
+ filtered.append(r)
317
+ return filtered
318
+ return results
319
+
320
+ mock_retriever.retrieve.side_effect = retrieve_with_filter
321
+
322
+ # Test with filter
323
+ results = mock_retriever.retrieve(
324
+ query="treatment",
325
+ top_k=10,
326
+ filters={"disease_name": "Glaucoma"}
327
+ )
328
+
329
+ # All results should match filter
330
+ assert all("Glaucoma" in r.document_title for r in results)
331
+
332
+
333
+ # ============================================================================
334
+ # Reranker Tests
335
+ # ============================================================================
336
+
337
+ def test_reranking_changes_order(mock_reranker, sample_chunks):
338
+ """Test that reranking changes result order."""
339
+ # Convert to RetrievalResults
340
+ results = [
341
+ RetrievalResult(
342
+ id=chunk.id,
343
+ content=chunk.content,
344
+ document_title=chunk.document_title,
345
+ source_url=chunk.source_url,
346
+ section=chunk.parent_section,
347
+ metadata=chunk.metadata,
348
+ score=0.5, # All same initial score
349
+ )
350
+ for chunk in sample_chunks
351
+ ]
352
+
353
+ original_order = [r.id for r in results]
354
+
355
+ reranked = mock_reranker.rerank(
356
+ query="What is glaucoma?",
357
+ documents=results,
358
+ top_k=3,
359
+ )
360
+
361
+ reranked_order = [r.id for r in reranked]
362
+
363
+ # Order should change (due to our mock reversing the order)
364
+ assert reranked_order != original_order
365
+
366
+
367
+ def test_top_k_respected(mock_reranker, sample_chunks):
368
+ """Test that reranker respects top_k parameter."""
369
+ results = [
370
+ RetrievalResult(
371
+ id=chunk.id,
372
+ content=chunk.content,
373
+ document_title=chunk.document_title,
374
+ source_url=chunk.source_url,
375
+ section=chunk.parent_section,
376
+ metadata=chunk.metadata,
377
+ score=0.5,
378
+ )
379
+ for chunk in sample_chunks
380
+ ]
381
+
382
+ top_k = 2
383
+ reranked = mock_reranker.rerank(
384
+ query="treatment options",
385
+ documents=results,
386
+ top_k=top_k,
387
+ )
388
+
389
+ # Should return exactly top_k results
390
+ assert len(reranked) == top_k
391
+
392
+
393
+ # ============================================================================
394
+ # Query Engine Tests
395
+ # ============================================================================
396
+
397
+ def test_full_query_pipeline(query_engine):
398
+ """Test the full query pipeline."""
399
+ query = "What is glaucoma?"
400
+
401
+ response = query_engine.query(
402
+ question=query,
403
+ include_sources=True,
404
+ )
405
+
406
+ # Verify response structure
407
+ assert isinstance(response, QueryResponse)
408
+ assert response.answer
409
+ assert response.query == query
410
+ assert 0 <= response.confidence <= 1
411
+ assert response.disclaimer
412
+
413
+
414
+ def test_sources_included(query_engine):
415
+ """Test that sources are included in response."""
416
+ response = query_engine.query(
417
+ question="What is glaucoma?",
418
+ include_sources=True,
419
+ )
420
+
421
+ # Should have sources
422
+ assert len(response.sources) > 0
423
+
424
+ # Verify source structure
425
+ for source in response.sources:
426
+ assert isinstance(source, SourceInfo)
427
+ assert source.title
428
+ assert source.url
429
+ assert 0 <= source.relevance_score <= 1
430
+
431
+
432
+ def test_disclaimer_present(query_engine):
433
+ """Test that medical disclaimer is present."""
434
+ response = query_engine.query(
435
+ question="How is glaucoma treated?",
436
+ include_sources=True,
437
+ )
438
+
439
+ # Disclaimer should be present
440
+ assert response.disclaimer
441
+ assert len(response.disclaimer) > 0
442
+
443
+
444
+ def test_query_without_sources(query_engine):
445
+ """Test query with sources disabled."""
446
+ response = query_engine.query(
447
+ question="What is glaucoma?",
448
+ include_sources=False,
449
+ )
450
+
451
+ # Should still have answer
452
+ assert response.answer
453
+
454
+ # Sources should be empty
455
+ assert len(response.sources) == 0
456
+
457
+
458
+ def test_streaming_query(query_engine):
459
+ """Test streaming query functionality."""
460
+ chunks = list(query_engine.stream_query(
461
+ question="What is glaucoma?",
462
+ ))
463
+
464
+ # Should have received chunks
465
+ assert len(chunks) > 0
466
+
467
+ # All chunks should be strings
468
+ assert all(isinstance(chunk, str) for chunk in chunks)
469
+
470
+
471
+ def test_confidence_calculation(query_engine):
472
+ """Test confidence score calculation."""
473
+ response = query_engine.query(
474
+ question="What is glaucoma?",
475
+ include_sources=True,
476
+ )
477
+
478
+ # Confidence should be calculated
479
+ assert response.confidence is not None
480
+ assert 0 <= response.confidence <= 1
481
+
482
+ # With high-scoring retrieval results, confidence should be high
483
+ # (Our mock returns scores like 0.9, 0.8, 0.7)
484
+ assert response.confidence > 0.5
485
+
486
+
487
+ def test_empty_retrieval_results(query_engine, mock_retriever):
488
+ """Test handling of empty retrieval results."""
489
+ # Mock retriever to return empty list
490
+ mock_retriever.retrieve.return_value = []
491
+
492
+ response = query_engine.query(
493
+ question="What is xyzabc?", # Non-existent topic
494
+ include_sources=True,
495
+ )
496
+
497
+ # Should still return a response
498
+ assert response.answer
499
+ assert "couldn't find" in response.answer.lower() or "no results" in response.answer.lower()
500
+ assert len(response.sources) == 0
501
+ assert response.confidence == 0.0
502
+
503
+
504
+ # ============================================================================
505
+ # API Tests
506
+ # ============================================================================
507
+
508
+ @pytest.fixture
509
+ def test_client():
510
+ """Create a test client for FastAPI."""
511
+ from fastapi.testclient import TestClient
512
+ from src.api.main import app
513
+
514
+ return TestClient(app)
515
+
516
+
517
+ def test_health_endpoint(test_client):
518
+ """Test the health check endpoint."""
519
+ response = test_client.get("/health")
520
+
521
+ # Should return 200 or 503 depending on initialization
522
+ assert response.status_code in [200, 503]
523
+
524
+ # Should have JSON response
525
+ data = response.json()
526
+ assert "status" in data
527
+ assert "timestamp" in data
528
+
529
+
530
+ def test_root_endpoint(test_client):
531
+ """Test the root endpoint."""
532
+ response = test_client.get("/")
533
+
534
+ assert response.status_code == 200
535
+ data = response.json()
536
+
537
+ assert "name" in data
538
+ assert "version" in data
539
+ assert "endpoints" in data
540
+
541
+
542
+ def test_query_endpoint(test_client):
543
+ """Test the query endpoint."""
544
+ # Note: This will likely fail if system is not fully initialized
545
+ # In real testing, you'd mock the app_state
546
+
547
+ response = test_client.post(
548
+ "/query",
549
+ json={
550
+ "question": "What is glaucoma?",
551
+ "include_sources": True,
552
+ }
553
+ )
554
+
555
+ # Should return 200 if initialized, 503 if not
556
+ assert response.status_code in [200, 503]
557
+
558
+ if response.status_code == 200:
559
+ data = response.json()
560
+ assert "answer" in data
561
+ assert "query" in data
562
+ assert "confidence" in data
563
+ assert "disclaimer" in data
564
+
565
+
566
+ def test_query_endpoint_validation(test_client):
567
+ """Test query endpoint input validation."""
568
+ # Test with invalid input
569
+ response = test_client.post(
570
+ "/query",
571
+ json={
572
+ "question": "", # Empty question
573
+ }
574
+ )
575
+
576
+ # Should return validation error
577
+ assert response.status_code == 422 # Unprocessable Entity
578
+
579
+
580
+ def test_stats_endpoint(test_client):
581
+ """Test the stats endpoint."""
582
+ response = test_client.get("/stats")
583
+
584
+ # Should return 200 if initialized, 503 if not
585
+ assert response.status_code in [200, 503, 404]
586
+
587
+ if response.status_code == 200:
588
+ data = response.json()
589
+ assert "collection_info" in data
590
+ assert "pipeline_config" in data
591
+ assert "documents_indexed" in data
592
+
593
+
594
+ # ============================================================================
595
+ # Metadata Extractor Tests
596
+ # ============================================================================
597
+
598
+ def test_icd_code_extraction(metadata_extractor):
599
+ """Test ICD-10 code extraction."""
600
+ text = "Patient diagnosed with H40.1 (Primary open-angle glaucoma) and E11.3 (Type 2 diabetes with ophthalmic complications)."
601
+
602
+ icd_codes = metadata_extractor.extract_icd_codes(text)
603
+
604
+ assert "H40.1" in icd_codes
605
+ assert "E11.3" in icd_codes
606
+
607
+
608
+ def test_anatomical_term_extraction(metadata_extractor):
609
+ """Test anatomical term extraction."""
610
+ text = "The optic nerve and retina are affected. The cornea appears normal."
611
+
612
+ terms = metadata_extractor.extract_anatomical_terms(text)
613
+
614
+ assert "optic nerve" in terms
615
+ assert "retina" in terms
616
+ assert "cornea" in terms
617
+
618
+
619
+ def test_medication_extraction(metadata_extractor):
620
+ """Test medication extraction."""
621
+ text = "Prescribed latanoprost and timolol for IOP reduction."
622
+
623
+ medications = metadata_extractor.extract_medications(text)
624
+
625
+ assert "latanoprost" in medications or "timolol" in medications
626
+
627
+
628
+ def test_full_metadata_extraction(metadata_extractor):
629
+ """Test full metadata extraction."""
630
+ text = """
631
+ Patient with H40.1 primary open-angle glaucoma affecting the optic nerve.
632
+ Prescribed latanoprost drops. Vision loss and eye pain reported.
633
+ """
634
+
635
+ metadata = metadata_extractor.extract(text, existing_metadata={})
636
+
637
+ # Should extract various metadata
638
+ assert "icd_codes" in metadata
639
+ assert "anatomical_terms" in metadata
640
+ assert "medications" in metadata
641
+ assert "symptoms" in metadata
642
+
643
+
644
+ # ============================================================================
645
+ # Integration Tests
646
+ # ============================================================================
647
+
648
+ def test_end_to_end_chunk_to_query():
649
+ """Test end-to-end flow from chunking to query (with mocks)."""
650
+ # 1. Chunk document
651
+ chunker = SemanticChunker(chunk_size=200, chunk_overlap=20)
652
+ chunks = chunker.chunk_document(
653
+ markdown_content=SAMPLE_MARKDOWN,
654
+ metadata=SAMPLE_METADATA,
655
+ )
656
+
657
+ assert len(chunks) > 0
658
+
659
+ # 2. Convert to retrieval results
660
+ results = [
661
+ RetrievalResult(
662
+ id=chunk.id,
663
+ content=chunk.content,
664
+ document_title=chunk.document_title,
665
+ source_url=chunk.source_url,
666
+ section=chunk.parent_section,
667
+ metadata=chunk.metadata,
668
+ score=0.8,
669
+ )
670
+ for chunk in chunks[:3]
671
+ ]
672
+
673
+ # 3. Mock reranker
674
+ reranker = Mock(spec=CrossEncoderReranker)
675
+ reranker.rerank.return_value = results[:2]
676
+
677
+ # 4. Mock LLM
678
+ llm = Mock()
679
+ llm.generate.return_value = "Glaucoma is an eye disease."
680
+
681
+ # 5. Mock retriever
682
+ retriever = Mock(spec=HybridRetriever)
683
+ retriever.retrieve.return_value = results
684
+
685
+ # 6. Create query engine
686
+ engine = EyeWikiQueryEngine(
687
+ retriever=retriever,
688
+ reranker=reranker,
689
+ llm_client=llm,
690
+ max_context_tokens=4000,
691
+ retrieval_k=20,
692
+ rerank_k=5,
693
+ )
694
+
695
+ # 7. Query
696
+ response = engine.query("What is glaucoma?")
697
+
698
+ assert response.answer
699
+ assert response.confidence > 0
tests/test_questions.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "q1",
4
+ "question": "What are the main symptoms of glaucoma?",
5
+ "expected_topics": [
6
+ "vision loss",
7
+ "peripheral vision",
8
+ "eye pressure",
9
+ "optic nerve damage",
10
+ "blind spots"
11
+ ],
12
+ "expected_sources": [
13
+ "Glaucoma",
14
+ "Primary Open-Angle Glaucoma"
15
+ ],
16
+ "category": "symptoms"
17
+ },
18
+ {
19
+ "id": "q2",
20
+ "question": "How is diabetic retinopathy treated?",
21
+ "expected_topics": [
22
+ "laser treatment",
23
+ "anti-VEGF",
24
+ "photocoagulation",
25
+ "vitrectomy",
26
+ "blood sugar control"
27
+ ],
28
+ "expected_sources": [
29
+ "Diabetic Retinopathy",
30
+ "Proliferative Diabetic Retinopathy"
31
+ ],
32
+ "category": "treatment"
33
+ },
34
+ {
35
+ "id": "q3",
36
+ "question": "What causes age-related macular degeneration?",
37
+ "expected_topics": [
38
+ "aging",
39
+ "macula",
40
+ "drusen",
41
+ "photoreceptor",
42
+ "central vision"
43
+ ],
44
+ "expected_sources": [
45
+ "Age-Related Macular Degeneration",
46
+ "AMD",
47
+ "Macular Degeneration"
48
+ ],
49
+ "category": "etiology"
50
+ },
51
+ {
52
+ "id": "q4",
53
+ "question": "What is the difference between open-angle and angle-closure glaucoma?",
54
+ "expected_topics": [
55
+ "drainage angle",
56
+ "trabecular meshwork",
57
+ "acute",
58
+ "chronic",
59
+ "iridotomy"
60
+ ],
61
+ "expected_sources": [
62
+ "Glaucoma",
63
+ "Primary Open-Angle Glaucoma",
64
+ "Angle-Closure Glaucoma"
65
+ ],
66
+ "category": "classification"
67
+ },
68
+ {
69
+ "id": "q5",
70
+ "question": "What are the risk factors for cataracts?",
71
+ "expected_topics": [
72
+ "age",
73
+ "diabetes",
74
+ "UV exposure",
75
+ "smoking",
76
+ "steroid"
77
+ ],
78
+ "expected_sources": [
79
+ "Cataract",
80
+ "Age-Related Cataract"
81
+ ],
82
+ "category": "risk_factors"
83
+ },
84
+ {
85
+ "id": "q6",
86
+ "question": "How is retinal detachment diagnosed?",
87
+ "expected_topics": [
88
+ "dilated eye exam",
89
+ "ophthalmoscopy",
90
+ "ultrasound",
91
+ "floaters",
92
+ "flashes"
93
+ ],
94
+ "expected_sources": [
95
+ "Retinal Detachment",
96
+ "Rhegmatogenous Retinal Detachment"
97
+ ],
98
+ "category": "diagnosis"
99
+ },
100
+ {
101
+ "id": "q7",
102
+ "question": "What medications are used to lower intraocular pressure?",
103
+ "expected_topics": [
104
+ "prostaglandin analogs",
105
+ "beta-blockers",
106
+ "alpha agonists",
107
+ "carbonic anhydrase inhibitors",
108
+ "latanoprost",
109
+ "timolol"
110
+ ],
111
+ "expected_sources": [
112
+ "Glaucoma",
113
+ "Medical Therapy for Glaucoma"
114
+ ],
115
+ "category": "pharmacology"
116
+ },
117
+ {
118
+ "id": "q8",
119
+ "question": "What is keratoconus and how is it managed?",
120
+ "expected_topics": [
121
+ "cornea",
122
+ "thinning",
123
+ "cone-shaped",
124
+ "corneal crosslinking",
125
+ "contact lenses"
126
+ ],
127
+ "expected_sources": [
128
+ "Keratoconus"
129
+ ],
130
+ "category": "corneal_disease"
131
+ },
132
+ {
133
+ "id": "q9",
134
+ "question": "What are the complications of cataract surgery?",
135
+ "expected_topics": [
136
+ "posterior capsule opacification",
137
+ "infection",
138
+ "endophthalmitis",
139
+ "cystoid macular edema",
140
+ "retinal detachment"
141
+ ],
142
+ "expected_sources": [
143
+ "Cataract Surgery",
144
+ "Phacoemulsification"
145
+ ],
146
+ "category": "complications"
147
+ },
148
+ {
149
+ "id": "q10",
150
+ "question": "How does dry eye syndrome present?",
151
+ "expected_topics": [
152
+ "burning",
153
+ "irritation",
154
+ "tear film",
155
+ "meibomian gland",
156
+ "artificial tears"
157
+ ],
158
+ "expected_sources": [
159
+ "Dry Eye",
160
+ "Dry Eye Syndrome"
161
+ ],
162
+ "category": "symptoms"
163
+ },
164
+ {
165
+ "id": "q11",
166
+ "question": "What is the pathophysiology of uveitis?",
167
+ "expected_topics": [
168
+ "inflammation",
169
+ "uvea",
170
+ "anterior",
171
+ "posterior",
172
+ "immune-mediated"
173
+ ],
174
+ "expected_sources": [
175
+ "Uveitis",
176
+ "Anterior Uveitis"
177
+ ],
178
+ "category": "pathophysiology"
179
+ },
180
+ {
181
+ "id": "q12",
182
+ "question": "What imaging modalities are used for macular disease?",
183
+ "expected_topics": [
184
+ "OCT",
185
+ "optical coherence tomography",
186
+ "fluorescein angiography",
187
+ "fundus photography",
188
+ "angiography"
189
+ ],
190
+ "expected_sources": [
191
+ "Macular Degeneration",
192
+ "OCT",
193
+ "Optical Coherence Tomography"
194
+ ],
195
+ "category": "imaging"
196
+ },
197
+ {
198
+ "id": "q13",
199
+ "question": "What is optic neuritis and what are its causes?",
200
+ "expected_topics": [
201
+ "optic nerve inflammation",
202
+ "vision loss",
203
+ "pain with eye movement",
204
+ "multiple sclerosis",
205
+ "demyelination"
206
+ ],
207
+ "expected_sources": [
208
+ "Optic Neuritis"
209
+ ],
210
+ "category": "neuro_ophthalmology"
211
+ },
212
+ {
213
+ "id": "q14",
214
+ "question": "How is proliferative diabetic retinopathy different from non-proliferative?",
215
+ "expected_topics": [
216
+ "neovascularization",
217
+ "microaneurysms",
218
+ "hemorrhages",
219
+ "vitreous hemorrhage",
220
+ "retinal ischemia"
221
+ ],
222
+ "expected_sources": [
223
+ "Diabetic Retinopathy",
224
+ "Proliferative Diabetic Retinopathy",
225
+ "Non-Proliferative Diabetic Retinopathy"
226
+ ],
227
+ "category": "classification"
228
+ },
229
+ {
230
+ "id": "q15",
231
+ "question": "What are the signs of papilledema?",
232
+ "expected_topics": [
233
+ "optic disc swelling",
234
+ "increased intracranial pressure",
235
+ "headache",
236
+ "blurred vision",
237
+ "nausea"
238
+ ],
239
+ "expected_sources": [
240
+ "Papilledema",
241
+ "Optic Disc Edema"
242
+ ],
243
+ "category": "diagnosis"
244
+ }
245
+ ]