Commit ·
702ea87
1
Parent(s): 1c61c6e
Initial deployment
Browse files- .gitignore +162 -0
- CHANGELOG.md +95 -0
- DOCKER.md +574 -0
- Makefile +94 -0
- README.md +1024 -9
- config/__init__.py +1 -0
- config/settings.py +229 -0
- data/processed/.gitkeep +0 -0
- data/raw/.gitkeep +0 -0
- data/vectorstore/.gitkeep +0 -0
- deployment_readme.md +151 -0
- docker-compose.yml +80 -0
- plan/implementation_plan.md +62 -0
- prompts/medical_disclaimer.txt +1 -0
- prompts/query_prompt.txt +21 -0
- prompts/system_prompt.txt +87 -0
- pytest.ini +26 -0
- requirements.txt +47 -0
- scripts/build_index.py +683 -0
- scripts/evaluate.py +696 -0
- scripts/run_server.py +479 -0
- scripts/scrape_eyewiki.py +278 -0
- src/__init__.py +0 -0
- src/api/__init__.py +5 -0
- src/api/gradio_ui.py +548 -0
- src/api/main.py +627 -0
- src/llm/__init__.py +0 -0
- src/llm/llm_client.py +66 -0
- src/llm/ollama_client.py +512 -0
- src/llm/openai_client.py +187 -0
- src/llm/sentence_transformer_client.py +161 -0
- src/processing/__init__.py +0 -0
- src/processing/chunker.py +423 -0
- src/processing/metadata_extractor.py +433 -0
- src/rag/__init__.py +0 -0
- src/rag/query_engine.py +537 -0
- src/rag/reranker.py +293 -0
- src/rag/retriever.py +483 -0
- src/scraper/__init__.py +0 -0
- src/scraper/eyewiki_crawler.py +489 -0
- src/vectorstore/__init__.py +0 -0
- src/vectorstore/qdrant_store.py +587 -0
- tests/README.md +172 -0
- tests/__init__.py +0 -0
- tests/conftest.py +24 -0
- tests/test_components.py +699 -0
- tests/test_questions.json +245 -0
.gitignore
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
PIPFILE.lock
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
*.manifest
|
| 31 |
+
*.spec
|
| 32 |
+
|
| 33 |
+
# Installer logs
|
| 34 |
+
pip-log.txt
|
| 35 |
+
pip-delete-this-directory.txt
|
| 36 |
+
|
| 37 |
+
# Unit test / coverage reports
|
| 38 |
+
htmlcov/
|
| 39 |
+
.tox/
|
| 40 |
+
.nox/
|
| 41 |
+
.coverage
|
| 42 |
+
.coverage.*
|
| 43 |
+
.cache
|
| 44 |
+
nosetests.xml
|
| 45 |
+
coverage.xml
|
| 46 |
+
*.cover
|
| 47 |
+
*.py,cover
|
| 48 |
+
.hypothesis/
|
| 49 |
+
.pytest_cache/
|
| 50 |
+
cover/
|
| 51 |
+
|
| 52 |
+
# Translations
|
| 53 |
+
*.mo
|
| 54 |
+
*.pot
|
| 55 |
+
|
| 56 |
+
# Django stuff:
|
| 57 |
+
*.log
|
| 58 |
+
local_settings.py
|
| 59 |
+
db.sqlite3
|
| 60 |
+
db.sqlite3-journal
|
| 61 |
+
|
| 62 |
+
# Flask stuff:
|
| 63 |
+
instance/
|
| 64 |
+
.webassets-cache
|
| 65 |
+
|
| 66 |
+
# Scrapy stuff:
|
| 67 |
+
.scrapy
|
| 68 |
+
|
| 69 |
+
# Sphinx documentation
|
| 70 |
+
docs/_build/
|
| 71 |
+
|
| 72 |
+
# PyBuilder
|
| 73 |
+
.pybuilder/
|
| 74 |
+
target/
|
| 75 |
+
|
| 76 |
+
# Jupyter Notebook
|
| 77 |
+
.ipynb_checkpoints
|
| 78 |
+
|
| 79 |
+
# IPython
|
| 80 |
+
profile_default/
|
| 81 |
+
ipython_config.py
|
| 82 |
+
|
| 83 |
+
# pyenv
|
| 84 |
+
.python-version
|
| 85 |
+
|
| 86 |
+
# pipenv
|
| 87 |
+
Pipfile.lock
|
| 88 |
+
|
| 89 |
+
# poetry
|
| 90 |
+
poetry.lock
|
| 91 |
+
|
| 92 |
+
# pdm
|
| 93 |
+
.pdm.toml
|
| 94 |
+
|
| 95 |
+
# PEP 582
|
| 96 |
+
__pypackages__/
|
| 97 |
+
|
| 98 |
+
# Celery stuff
|
| 99 |
+
celerybeat-schedule
|
| 100 |
+
celerybeat.pid
|
| 101 |
+
|
| 102 |
+
# SageMath parsed files
|
| 103 |
+
*.sage.py
|
| 104 |
+
|
| 105 |
+
# Environments
|
| 106 |
+
.env
|
| 107 |
+
.venv
|
| 108 |
+
env/
|
| 109 |
+
venv/
|
| 110 |
+
ENV/
|
| 111 |
+
env.bak/
|
| 112 |
+
venv.bak/
|
| 113 |
+
|
| 114 |
+
# Spyder project settings
|
| 115 |
+
.spyderproject
|
| 116 |
+
.spyproject
|
| 117 |
+
|
| 118 |
+
# Rope project settings
|
| 119 |
+
.ropeproject
|
| 120 |
+
|
| 121 |
+
# mkdocs documentation
|
| 122 |
+
/site
|
| 123 |
+
|
| 124 |
+
# mypy
|
| 125 |
+
.mypy_cache/
|
| 126 |
+
.dmypy.json
|
| 127 |
+
dmypy.json
|
| 128 |
+
|
| 129 |
+
# Pyre type checker
|
| 130 |
+
.pyre/
|
| 131 |
+
|
| 132 |
+
# pytype static type analyzer
|
| 133 |
+
.pytype/
|
| 134 |
+
|
| 135 |
+
# Cython debug symbols
|
| 136 |
+
cython_debug/
|
| 137 |
+
|
| 138 |
+
# IDE
|
| 139 |
+
.vscode/
|
| 140 |
+
.idea/
|
| 141 |
+
*.swp
|
| 142 |
+
*.swo
|
| 143 |
+
*~
|
| 144 |
+
.DS_Store
|
| 145 |
+
|
| 146 |
+
# Project specific
|
| 147 |
+
data/raw/*
|
| 148 |
+
!data/raw/.gitkeep
|
| 149 |
+
data/processed/*
|
| 150 |
+
!data/processed/.gitkeep
|
| 151 |
+
data/vectorstore/*
|
| 152 |
+
!data/vectorstore/.gitkeep
|
| 153 |
+
|
| 154 |
+
# Model files
|
| 155 |
+
*.bin
|
| 156 |
+
*.onnx
|
| 157 |
+
*.pt
|
| 158 |
+
*.pth
|
| 159 |
+
|
| 160 |
+
# Logs
|
| 161 |
+
logs/
|
| 162 |
+
*.log
|
CHANGELOG.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Changelog
|
| 2 |
+
|
| 3 |
+
## [2.0.0] - 2026-01-05
|
| 4 |
+
|
| 5 |
+
### Major Improvements
|
| 6 |
+
|
| 7 |
+
#### Gradio UI Enhancements
|
| 8 |
+
- **Fixed HTML rendering issue**: Changed from HTML badges to clean emoji-based confidence indicators
|
| 9 |
+
- High Confidence: ✅ (≥70%)
|
| 10 |
+
- Medium Confidence: ⚠️ (50-69%)
|
| 11 |
+
- Low Confidence: ⚡ (<50%)
|
| 12 |
+
- **Improved message formatting**: Removed raw HTML display in chat interface
|
| 13 |
+
- **Cleaner disclaimers**: Updated medical disclaimer to be more concise
|
| 14 |
+
|
| 15 |
+
#### Content Updates
|
| 16 |
+
- **Removed "educational purposes" language** across all files:
|
| 17 |
+
- Updated system prompts
|
| 18 |
+
- Updated medical disclaimers
|
| 19 |
+
- Updated README
|
| 20 |
+
- Updated UI text
|
| 21 |
+
- **Streamlined medical disclaimers**: More professional, less verbose
|
| 22 |
+
|
| 23 |
+
#### Bug Fixes
|
| 24 |
+
- **Fixed Ollama GPU support**: Configured Ollama to use RTX 5090 GPU instead of CPU
|
| 25 |
+
- Added GPU initialization script
|
| 26 |
+
- Set proper CUDA environment variables
|
| 27 |
+
- Verified VRAM usage (4.79 GB on GPU)
|
| 28 |
+
- Performance improvement: ~10-50x faster inference
|
| 29 |
+
|
| 30 |
+
- **Fixed Qdrant API compatibility**: Updated to qdrant-client v1.16.1 API
|
| 31 |
+
- Changed from `client.search()` to `client.query_points()`
|
| 32 |
+
- Added `using="dense"` parameter for named vectors
|
| 33 |
+
- Fixed both search and hybrid_search methods
|
| 34 |
+
|
| 35 |
+
- **Fixed Pydantic validation errors**:
|
| 36 |
+
- Removed `ge=0.0` constraint from `RetrievalResult.score` (cross-encoder scores can be negative)
|
| 37 |
+
- Removed `ge=0.0, le=1.0` constraints from `SourceInfo.relevance_score`
|
| 38 |
+
|
| 39 |
+
- **Fixed QdrantStoreManager initialization**:
|
| 40 |
+
- Changed `vector_size` → `embedding_dim`
|
| 41 |
+
- Changed `qdrant_path` → `path`
|
| 42 |
+
- Use `embedding_client.embedding_dim` instead of non-existent settings attribute
|
| 43 |
+
|
| 44 |
+
- **Added missing Settings attributes**:
|
| 45 |
+
- `ollama_timeout` (default: 30)
|
| 46 |
+
- `reranker_model` (default: "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
| 47 |
+
- `max_context_tokens` (default: 4096)
|
| 48 |
+
|
| 49 |
+
- **Fixed OllamaClient embedding model verification**:
|
| 50 |
+
- Skip embedding model verification when `embedding_model=None`
|
| 51 |
+
- Prevents false errors when using SentenceTransformerClient for embeddings
|
| 52 |
+
|
| 53 |
+
#### Code Cleanup
|
| 54 |
+
- Removed unnecessary comments and annotations
|
| 55 |
+
- Cleaned up fix-related comments
|
| 56 |
+
- Improved code documentation
|
| 57 |
+
- Removed redundant validation constraints
|
| 58 |
+
|
| 59 |
+
### Technical Details
|
| 60 |
+
|
| 61 |
+
#### Performance
|
| 62 |
+
- **GPU Acceleration**: Full GPU support for Ollama (RTX 5090)
|
| 63 |
+
- **Model Loading**: 4.79 GB VRAM usage confirmed
|
| 64 |
+
- **Faster Inference**: Significant speedup from CPU to GPU
|
| 65 |
+
|
| 66 |
+
#### API Changes
|
| 67 |
+
- Qdrant API updated to v1.16.1 syntax
|
| 68 |
+
- Improved error handling for cross-encoder scores
|
| 69 |
+
- Better validation for unbounded reranker scores
|
| 70 |
+
|
| 71 |
+
#### Configuration
|
| 72 |
+
- New environment variables for Ollama GPU support:
|
| 73 |
+
- `CUDA_VISIBLE_DEVICES=0`
|
| 74 |
+
- `OLLAMA_NUM_PARALLEL=1`
|
| 75 |
+
- `OLLAMA_MAX_LOADED_MODELS=1`
|
| 76 |
+
|
| 77 |
+
### Files Modified
|
| 78 |
+
- `src/api/gradio_ui.py` - UI improvements and HTML rendering fix
|
| 79 |
+
- `src/api/main.py` - Fixed initialization parameters
|
| 80 |
+
- `src/rag/query_engine.py` - Updated disclaimers and validation
|
| 81 |
+
- `src/rag/retriever.py` - Removed score constraints
|
| 82 |
+
- `src/vectorstore/qdrant_store.py` - Updated Qdrant API calls
|
| 83 |
+
- `src/llm/ollama_client.py` - Fixed embedding model handling
|
| 84 |
+
- `config/settings.py` - Added missing configuration fields
|
| 85 |
+
- `prompts/medical_disclaimer.txt` - Removed educational language
|
| 86 |
+
- `prompts/system_prompt.txt` - Streamlined instructions
|
| 87 |
+
- `README.md` - Updated disclaimers and documentation
|
| 88 |
+
|
| 89 |
+
### Breaking Changes
|
| 90 |
+
- None - all changes are backward compatible
|
| 91 |
+
|
| 92 |
+
### Upgrade Notes
|
| 93 |
+
1. Restart Ollama with GPU support using provided script
|
| 94 |
+
2. Clear Python cache if experiencing import issues
|
| 95 |
+
3. Verify GPU usage with `curl -s http://localhost:11434/api/ps | python3 -m json.tool`
|
DOCKER.md
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker Deployment Guide
|
| 2 |
+
|
| 3 |
+
Complete guide for deploying EyeWiki RAG using Docker.
|
| 4 |
+
|
| 5 |
+
## 📋 Table of Contents
|
| 6 |
+
- [Prerequisites](#prerequisites)
|
| 7 |
+
- [Quick Start](#quick-start)
|
| 8 |
+
- [Architecture](#architecture)
|
| 9 |
+
- [Configuration](#configuration)
|
| 10 |
+
- [Operations](#operations)
|
| 11 |
+
- [Troubleshooting](#troubleshooting)
|
| 12 |
+
- [Production](#production)
|
| 13 |
+
|
| 14 |
+
## Prerequisites
|
| 15 |
+
|
| 16 |
+
### Required Software
|
| 17 |
+
- **Docker** 20.10+ ([Install Docker](https://docs.docker.com/get-docker/))
|
| 18 |
+
- **Docker Compose** 2.0+ ([Install Compose](https://docs.docker.com/compose/install/))
|
| 19 |
+
- **Ollama** running on host ([Install Ollama](https://ollama.ai/download))
|
| 20 |
+
|
| 21 |
+
### System Requirements
|
| 22 |
+
- 8GB+ RAM allocated to Docker
|
| 23 |
+
- 20GB+ disk space
|
| 24 |
+
- CPU: 4+ cores recommended
|
| 25 |
+
- GPU: Optional, for faster processing
|
| 26 |
+
|
| 27 |
+
### Verify Installation
|
| 28 |
+
```bash
|
| 29 |
+
docker --version
|
| 30 |
+
docker-compose --version
|
| 31 |
+
ollama --version
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Quick Start
|
| 35 |
+
|
| 36 |
+
### 1. Prepare Ollama (Host Machine)
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# Start Ollama service
|
| 40 |
+
ollama serve
|
| 41 |
+
|
| 42 |
+
# Pull required models
|
| 43 |
+
ollama pull nomic-embed-text # ~270MB
|
| 44 |
+
ollama pull mistral # ~4.1GB
|
| 45 |
+
|
| 46 |
+
# Verify models
|
| 47 |
+
ollama list
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### 2. Build and Start Services
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
# Clone repository
|
| 54 |
+
git clone <repo-url>
|
| 55 |
+
cd eyewiki-rag
|
| 56 |
+
|
| 57 |
+
# Build images
|
| 58 |
+
docker-compose build
|
| 59 |
+
|
| 60 |
+
# Start services
|
| 61 |
+
docker-compose up -d
|
| 62 |
+
|
| 63 |
+
# Check status
|
| 64 |
+
docker-compose ps
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### 3. Verify Services
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# Check API health
|
| 71 |
+
curl http://localhost:8000/health
|
| 72 |
+
|
| 73 |
+
# Check Qdrant
|
| 74 |
+
curl http://localhost:6333/
|
| 75 |
+
|
| 76 |
+
# View logs
|
| 77 |
+
docker-compose logs -f
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### 4. Access Services
|
| 81 |
+
|
| 82 |
+
- **API**: http://localhost:8000
|
| 83 |
+
- **Gradio UI**: http://localhost:8000/ui
|
| 84 |
+
- **API Docs**: http://localhost:8000/docs
|
| 85 |
+
- **Qdrant Dashboard**: http://localhost:6333/dashboard
|
| 86 |
+
|
| 87 |
+
## Architecture
|
| 88 |
+
|
| 89 |
+
### Container Network
|
| 90 |
+
|
| 91 |
+
```
|
| 92 |
+
┌─────────────────────────────────────────────┐
|
| 93 |
+
│ Host Machine │
|
| 94 |
+
│ ┌──────────────────────────────────────┐ │
|
| 95 |
+
│ │ Ollama (GPU Access) │ │
|
| 96 |
+
│ │ - Port: 11434 │ │
|
| 97 |
+
│ │ - Models: mistral, nomic-embed │ │
|
| 98 |
+
│ └────────────┬─────────────────────────┘ │
|
| 99 |
+
│ │ │
|
| 100 |
+
│ ┌────────────▼─────────────────────────┐ │
|
| 101 |
+
│ │ Docker Network │ │
|
| 102 |
+
│ │ ┌─────────────────────────────────┐ │ │
|
| 103 |
+
│ │ │ eyewiki-rag (API Server) │ │ │
|
| 104 |
+
│ │ │ - Port: 8000 │ │ │
|
| 105 |
+
│ │ │ - Connects to Ollama via │ │ │
|
| 106 |
+
│ │ │ host.docker.internal │ │ │
|
| 107 |
+
│ │ └─────────────┬───────────────────┘ │ │
|
| 108 |
+
│ │ │ │ │
|
| 109 |
+
│ │ ┌─────────────▼───────────────────┐ │ │
|
| 110 |
+
│ │ │ qdrant (Vector DB) │ │ │
|
| 111 |
+
│ │ │ - Ports: 6333, 6334 │ │ │
|
| 112 |
+
│ │ │ - Persistent volume │ │ │
|
| 113 |
+
│ │ └─────────────────────────────────┘ │ │
|
| 114 |
+
│ └──────────────────────────────────────┘ │
|
| 115 |
+
└─────────────────────────────────────────────┘
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Data Flow
|
| 119 |
+
|
| 120 |
+
1. **User Request** → API Container (port 8000)
|
| 121 |
+
2. **Query Engine** → Qdrant Container (vector search)
|
| 122 |
+
3. **Embedding** → Ollama on Host (via host.docker.internal)
|
| 123 |
+
4. **LLM Generation** → Ollama on Host
|
| 124 |
+
5. **Response** → User
|
| 125 |
+
|
| 126 |
+
### Volumes
|
| 127 |
+
|
| 128 |
+
| Volume | Path | Purpose |
|
| 129 |
+
|--------|------|---------|
|
| 130 |
+
| `./data/raw` | `/app/data/raw` | Scraped content |
|
| 131 |
+
| `./data/processed` | `/app/data/processed` | Chunked documents |
|
| 132 |
+
| `qdrant_data` | `/app/data/qdrant` | Vector database |
|
| 133 |
+
| `./prompts` | `/app/prompts` | Customizable prompts |
|
| 134 |
+
|
| 135 |
+
## Configuration
|
| 136 |
+
|
| 137 |
+
### Environment Variables
|
| 138 |
+
|
| 139 |
+
Edit `docker-compose.yml`:
|
| 140 |
+
|
| 141 |
+
```yaml
|
| 142 |
+
environment:
|
| 143 |
+
# Ollama Configuration
|
| 144 |
+
- OLLAMA_BASE_URL=http://host.docker.internal:11434
|
| 145 |
+
- LLM_MODEL=mistral
|
| 146 |
+
- EMBEDDING_MODEL=nomic-embed-text
|
| 147 |
+
- OLLAMA_TIMEOUT=120
|
| 148 |
+
|
| 149 |
+
# Qdrant Configuration
|
| 150 |
+
- QDRANT_HOST=qdrant
|
| 151 |
+
- QDRANT_PORT=6333
|
| 152 |
+
- QDRANT_COLLECTION_NAME=eyewiki_rag
|
| 153 |
+
- QDRANT_PATH=/app/data/qdrant
|
| 154 |
+
|
| 155 |
+
# Processing Configuration
|
| 156 |
+
- CHUNK_SIZE=512
|
| 157 |
+
- CHUNK_OVERLAP=50
|
| 158 |
+
- MIN_CHUNK_SIZE=100
|
| 159 |
+
- MAX_CONTEXT_TOKENS=4000
|
| 160 |
+
|
| 161 |
+
# Retrieval Configuration
|
| 162 |
+
- RETRIEVAL_K=20
|
| 163 |
+
- RERANK_K=5
|
| 164 |
+
- RERANKER_MODEL=ms-marco-MiniLM-L-6-v2
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
### Custom Prompts
|
| 168 |
+
|
| 169 |
+
Edit files in `./prompts/` directory (mounted into container):
|
| 170 |
+
- `system_prompt.txt`
|
| 171 |
+
- `query_prompt.txt`
|
| 172 |
+
- `medical_disclaimer.txt`
|
| 173 |
+
|
| 174 |
+
Changes take effect on container restart.
|
| 175 |
+
|
| 176 |
+
### Resource Limits
|
| 177 |
+
|
| 178 |
+
Add to service in `docker-compose.yml`:
|
| 179 |
+
|
| 180 |
+
```yaml
|
| 181 |
+
deploy:
|
| 182 |
+
resources:
|
| 183 |
+
limits:
|
| 184 |
+
cpus: '4'
|
| 185 |
+
memory: 8G
|
| 186 |
+
reservations:
|
| 187 |
+
cpus: '2'
|
| 188 |
+
memory: 4G
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
## Operations
|
| 192 |
+
|
| 193 |
+
### Makefile Commands
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
# Service Management
|
| 197 |
+
make up # Start services
|
| 198 |
+
make down # Stop services
|
| 199 |
+
make restart # Restart services
|
| 200 |
+
make ps # Show status
|
| 201 |
+
make logs # View all logs
|
| 202 |
+
make logs-api # View API logs only
|
| 203 |
+
make logs-qdrant # View Qdrant logs only
|
| 204 |
+
|
| 205 |
+
# Health & Monitoring
|
| 206 |
+
make health # Check service health
|
| 207 |
+
make stats # Show resource usage
|
| 208 |
+
|
| 209 |
+
# Data Operations
|
| 210 |
+
make scrape # Run scraper
|
| 211 |
+
make build-index # Build vector index
|
| 212 |
+
make evaluate # Run evaluation
|
| 213 |
+
make test # Run tests
|
| 214 |
+
|
| 215 |
+
# Maintenance
|
| 216 |
+
make clean # Remove containers & volumes
|
| 217 |
+
make rebuild # Clean rebuild
|
| 218 |
+
make backup-qdrant # Backup vector DB
|
| 219 |
+
make restore-qdrant # Restore from backup
|
| 220 |
+
|
| 221 |
+
# Development
|
| 222 |
+
make exec-api # Bash into API container
|
| 223 |
+
make exec-qdrant # Shell into Qdrant container
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
### Manual Commands
|
| 227 |
+
|
| 228 |
+
#### Start Services
|
| 229 |
+
```bash
|
| 230 |
+
docker-compose up -d
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
#### Stop Services
|
| 234 |
+
```bash
|
| 235 |
+
docker-compose down
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
#### View Logs
|
| 239 |
+
```bash
|
| 240 |
+
# All services
|
| 241 |
+
docker-compose logs -f
|
| 242 |
+
|
| 243 |
+
# Specific service
|
| 244 |
+
docker-compose logs -f eyewiki-rag
|
| 245 |
+
docker-compose logs -f qdrant
|
| 246 |
+
|
| 247 |
+
# Last N lines
|
| 248 |
+
docker-compose logs --tail=100 -f
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
#### Execute Commands in Container
|
| 252 |
+
```bash
|
| 253 |
+
# Run scraper
|
| 254 |
+
docker-compose exec eyewiki-rag \
|
| 255 |
+
python scripts/scrape_eyewiki.py --max-pages 100
|
| 256 |
+
|
| 257 |
+
# Build index
|
| 258 |
+
docker-compose exec eyewiki-rag \
|
| 259 |
+
python scripts/build_index.py --index-vectors
|
| 260 |
+
|
| 261 |
+
# Run evaluation
|
| 262 |
+
docker-compose exec eyewiki-rag \
|
| 263 |
+
python scripts/evaluate.py -v
|
| 264 |
+
|
| 265 |
+
# Run tests
|
| 266 |
+
docker-compose exec eyewiki-rag pytest tests/ -v
|
| 267 |
+
|
| 268 |
+
# Interactive shell
|
| 269 |
+
docker-compose exec eyewiki-rag bash
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
#### Inspect Services
|
| 273 |
+
```bash
|
| 274 |
+
# Container status
|
| 275 |
+
docker-compose ps
|
| 276 |
+
|
| 277 |
+
# Resource usage
|
| 278 |
+
docker stats eyewiki-rag-api eyewiki-qdrant
|
| 279 |
+
|
| 280 |
+
# Network info
|
| 281 |
+
docker network inspect eyewiki-network
|
| 282 |
+
|
| 283 |
+
# Volume info
|
| 284 |
+
docker volume ls
|
| 285 |
+
docker volume inspect eyewiki-rag_qdrant_data
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
### Data Management
|
| 289 |
+
|
| 290 |
+
#### Backup Qdrant
|
| 291 |
+
```bash
|
| 292 |
+
# Using Makefile
|
| 293 |
+
make backup-qdrant
|
| 294 |
+
|
| 295 |
+
# Manual
|
| 296 |
+
docker-compose exec qdrant tar -czf /tmp/backup.tar.gz /qdrant/storage
|
| 297 |
+
docker cp eyewiki-qdrant:/tmp/backup.tar.gz ./backups/qdrant-$(date +%Y%m%d).tar.gz
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
#### Restore Qdrant
|
| 301 |
+
```bash
|
| 302 |
+
# Stop services
|
| 303 |
+
docker-compose down
|
| 304 |
+
|
| 305 |
+
# Restore backup
|
| 306 |
+
docker-compose up -d qdrant
|
| 307 |
+
docker cp ./backups/qdrant-20241209.tar.gz eyewiki-qdrant:/tmp/backup.tar.gz
|
| 308 |
+
docker-compose exec qdrant tar -xzf /tmp/backup.tar.gz -C /
|
| 309 |
+
|
| 310 |
+
# Restart all services
|
| 311 |
+
docker-compose up -d
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
#### Clear Data
|
| 315 |
+
```bash
|
| 316 |
+
# Remove all data and volumes
|
| 317 |
+
docker-compose down -v
|
| 318 |
+
|
| 319 |
+
# Remove only processed data
|
| 320 |
+
rm -rf data/processed/*
|
| 321 |
+
rm -rf data/qdrant/*
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
## Troubleshooting
|
| 325 |
+
|
| 326 |
+
### Cannot Connect to Ollama
|
| 327 |
+
|
| 328 |
+
**Symptoms:**
|
| 329 |
+
- `ConnectionError: Failed to connect to Ollama`
|
| 330 |
+
- 503 errors on API startup
|
| 331 |
+
|
| 332 |
+
**Solutions:**
|
| 333 |
+
|
| 334 |
+
1. **Verify Ollama is running:**
|
| 335 |
+
```bash
|
| 336 |
+
curl http://localhost:11434/api/tags
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
2. **On Linux, add to docker-compose.yml:**
|
| 340 |
+
```yaml
|
| 341 |
+
extra_hosts:
|
| 342 |
+
- "host.docker.internal:host-gateway"
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
3. **Use host IP instead:**
|
| 346 |
+
```bash
|
| 347 |
+
# Get host IP
|
| 348 |
+
ip addr show docker0 | grep inet
|
| 349 |
+
|
| 350 |
+
# Update OLLAMA_BASE_URL
|
| 351 |
+
OLLAMA_BASE_URL=http://172.17.0.1:11434
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
+
### Qdrant Permission Errors
|
| 355 |
+
|
| 356 |
+
**Symptoms:**
|
| 357 |
+
- Permission denied errors in Qdrant logs
|
| 358 |
+
- Cannot write to volume
|
| 359 |
+
|
| 360 |
+
**Solution:**
|
| 361 |
+
```bash
|
| 362 |
+
# Fix permissions
|
| 363 |
+
sudo chown -R 1000:1000 data/qdrant/
|
| 364 |
+
|
| 365 |
+
# Or recreate volume
|
| 366 |
+
docker-compose down -v
|
| 367 |
+
docker-compose up -d
|
| 368 |
+
```
|
| 369 |
+
|
| 370 |
+
### Out of Memory
|
| 371 |
+
|
| 372 |
+
**Symptoms:**
|
| 373 |
+
- Container killed (exit code 137)
|
| 374 |
+
- Slow performance
|
| 375 |
+
|
| 376 |
+
**Solutions:**
|
| 377 |
+
|
| 378 |
+
1. **Increase Docker memory:**
|
| 379 |
+
- Docker Desktop: Settings → Resources → Memory → 8GB+
|
| 380 |
+
|
| 381 |
+
2. **Add resource limits:**
|
| 382 |
+
```yaml
|
| 383 |
+
deploy:
|
| 384 |
+
resources:
|
| 385 |
+
limits:
|
| 386 |
+
memory: 8G
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
3. **Use smaller models:**
|
| 390 |
+
```bash
|
| 391 |
+
ollama pull llama3.2:3b # Instead of mistral
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
### Port Already in Use
|
| 395 |
+
|
| 396 |
+
**Symptoms:**
|
| 397 |
+
- `Bind for 0.0.0.0:8000 failed: port is already allocated`
|
| 398 |
+
|
| 399 |
+
**Solutions:**
|
| 400 |
+
|
| 401 |
+
1. **Find and kill process:**
|
| 402 |
+
```bash
|
| 403 |
+
lsof -i :8000
|
| 404 |
+
kill <PID>
|
| 405 |
+
```
|
| 406 |
+
|
| 407 |
+
2. **Change port in docker-compose.yml:**
|
| 408 |
+
```yaml
|
| 409 |
+
ports:
|
| 410 |
+
- "8080:8000" # Use 8080 instead
|
| 411 |
+
```
|
| 412 |
+
|
| 413 |
+
### Slow Performance
|
| 414 |
+
|
| 415 |
+
**Solutions:**
|
| 416 |
+
|
| 417 |
+
1. **Reduce batch sizes:**
|
| 418 |
+
```yaml
|
| 419 |
+
environment:
|
| 420 |
+
- RETRIEVAL_K=10 # Instead of 20
|
| 421 |
+
- RERANK_K=3 # Instead of 5
|
| 422 |
+
```
|
| 423 |
+
|
| 424 |
+
2. **Allocate more resources:**
|
| 425 |
+
```yaml
|
| 426 |
+
deploy:
|
| 427 |
+
resources:
|
| 428 |
+
limits:
|
| 429 |
+
cpus: '4'
|
| 430 |
+
memory: 8G
|
| 431 |
+
```
|
| 432 |
+
|
| 433 |
+
3. **Use GPU for Ollama** (on host)
|
| 434 |
+
|
| 435 |
+
## Production
|
| 436 |
+
|
| 437 |
+
### Production Configuration
|
| 438 |
+
|
| 439 |
+
Create `docker-compose.prod.yml`:
|
| 440 |
+
|
| 441 |
+
```yaml
|
| 442 |
+
version: '3.8'
|
| 443 |
+
|
| 444 |
+
services:
|
| 445 |
+
eyewiki-rag:
|
| 446 |
+
restart: always
|
| 447 |
+
deploy:
|
| 448 |
+
resources:
|
| 449 |
+
limits:
|
| 450 |
+
cpus: '4'
|
| 451 |
+
memory: 8G
|
| 452 |
+
reservations:
|
| 453 |
+
cpus: '2'
|
| 454 |
+
memory: 4G
|
| 455 |
+
logging:
|
| 456 |
+
driver: "json-file"
|
| 457 |
+
options:
|
| 458 |
+
max-size: "100m"
|
| 459 |
+
max-file: "5"
|
| 460 |
+
environment:
|
| 461 |
+
- LOG_LEVEL=WARNING
|
| 462 |
+
healthcheck:
|
| 463 |
+
interval: 30s
|
| 464 |
+
timeout: 10s
|
| 465 |
+
retries: 3
|
| 466 |
+
start_period: 60s
|
| 467 |
+
|
| 468 |
+
qdrant:
|
| 469 |
+
restart: always
|
| 470 |
+
deploy:
|
| 471 |
+
resources:
|
| 472 |
+
limits:
|
| 473 |
+
cpus: '2'
|
| 474 |
+
memory: 4G
|
| 475 |
+
logging:
|
| 476 |
+
driver: "json-file"
|
| 477 |
+
options:
|
| 478 |
+
max-size: "50m"
|
| 479 |
+
max-file: "3"
|
| 480 |
+
```
|
| 481 |
+
|
| 482 |
+
### Start Production
|
| 483 |
+
|
| 484 |
+
```bash
|
| 485 |
+
# Use production config
|
| 486 |
+
docker-compose -f docker-compose.yml -f docker-compose.prod.yml up -d
|
| 487 |
+
|
| 488 |
+
# Or use Makefile
|
| 489 |
+
make prod
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
### Monitoring
|
| 493 |
+
|
| 494 |
+
```bash
|
| 495 |
+
# Watch container status
|
| 496 |
+
watch docker-compose ps
|
| 497 |
+
|
| 498 |
+
# Monitor resources
|
| 499 |
+
docker stats --no-stream eyewiki-rag-api eyewiki-qdrant
|
| 500 |
+
|
| 501 |
+
# Check logs
|
| 502 |
+
docker-compose logs --tail=100 -f
|
| 503 |
+
|
| 504 |
+
# Test health endpoints
|
| 505 |
+
watch curl -s http://localhost:8000/health
|
| 506 |
+
```
|
| 507 |
+
|
| 508 |
+
### Backup Strategy
|
| 509 |
+
|
| 510 |
+
```bash
|
| 511 |
+
# Daily backup script (add to cron)
|
| 512 |
+
#!/bin/bash
|
| 513 |
+
BACKUP_DIR="/backups/eyewiki-rag"
|
| 514 |
+
DATE=$(date +%Y%m%d)
|
| 515 |
+
|
| 516 |
+
# Backup Qdrant
|
| 517 |
+
make backup-qdrant
|
| 518 |
+
|
| 519 |
+
# Backup configuration
|
| 520 |
+
tar -czf $BACKUP_DIR/config-$DATE.tar.gz \
|
| 521 |
+
docker-compose.yml prompts/ data/raw/
|
| 522 |
+
|
| 523 |
+
# Keep last 7 days
|
| 524 |
+
find $BACKUP_DIR -name "*.tar.gz" -mtime +7 -delete
|
| 525 |
+
```
|
| 526 |
+
|
| 527 |
+
### Update Strategy
|
| 528 |
+
|
| 529 |
+
```bash
|
| 530 |
+
# 1. Backup current state
|
| 531 |
+
make backup-qdrant
|
| 532 |
+
|
| 533 |
+
# 2. Pull latest code
|
| 534 |
+
git pull origin main
|
| 535 |
+
|
| 536 |
+
# 3. Rebuild images
|
| 537 |
+
docker-compose build --no-cache
|
| 538 |
+
|
| 539 |
+
# 4. Restart services with zero downtime
|
| 540 |
+
docker-compose up -d --no-deps --build eyewiki-rag
|
| 541 |
+
|
| 542 |
+
# 5. Verify health
|
| 543 |
+
make health
|
| 544 |
+
```
|
| 545 |
+
|
| 546 |
+
## Best Practices
|
| 547 |
+
|
| 548 |
+
### Security
|
| 549 |
+
- Use environment files for secrets
|
| 550 |
+
- Don't expose unnecessary ports
|
| 551 |
+
- Run as non-root user (add to Dockerfile)
|
| 552 |
+
- Keep base images updated
|
| 553 |
+
- Use Docker secrets for production
|
| 554 |
+
|
| 555 |
+
### Performance
|
| 556 |
+
- Allocate sufficient memory (8GB+)
|
| 557 |
+
- Use volume for Qdrant data
|
| 558 |
+
- Monitor resource usage
|
| 559 |
+
- Scale horizontally if needed
|
| 560 |
+
|
| 561 |
+
### Maintenance
|
| 562 |
+
- Regular backups
|
| 563 |
+
- Monitor logs for errors
|
| 564 |
+
- Update dependencies
|
| 565 |
+
- Prune unused images/volumes
|
| 566 |
+
|
| 567 |
+
### Development
|
| 568 |
+
- Use `docker-compose.override.yml` for local config
|
| 569 |
+
- Mount source code as volume for hot reload
|
| 570 |
+
- Keep production and development configs separate
|
| 571 |
+
|
| 572 |
+
---
|
| 573 |
+
|
| 574 |
+
For more information, see the main [README.md](README.md).
|
Makefile
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EyeWiki RAG System - Makefile for Docker operations
|
| 2 |
+
|
| 3 |
+
.PHONY: help build up down restart logs ps clean test
|
| 4 |
+
|
| 5 |
+
help: ## Show this help message
|
| 6 |
+
@echo "EyeWiki RAG System - Docker Commands"
|
| 7 |
+
@echo ""
|
| 8 |
+
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
| 9 |
+
|
| 10 |
+
build: ## Build Docker images
|
| 11 |
+
docker-compose build
|
| 12 |
+
|
| 13 |
+
up: ## Start all services
|
| 14 |
+
docker-compose up -d
|
| 15 |
+
@echo "Services starting..."
|
| 16 |
+
@echo "API: http://localhost:8000"
|
| 17 |
+
@echo "Gradio UI: http://localhost:8000/ui"
|
| 18 |
+
@echo "API Docs: http://localhost:8000/docs"
|
| 19 |
+
@echo "Qdrant: http://localhost:6333/dashboard"
|
| 20 |
+
|
| 21 |
+
down: ## Stop all services
|
| 22 |
+
docker-compose down
|
| 23 |
+
|
| 24 |
+
restart: ## Restart all services
|
| 25 |
+
docker-compose restart
|
| 26 |
+
|
| 27 |
+
logs: ## View logs from all services
|
| 28 |
+
docker-compose logs -f
|
| 29 |
+
|
| 30 |
+
logs-api: ## View API logs only
|
| 31 |
+
docker-compose logs -f eyewiki-rag
|
| 32 |
+
|
| 33 |
+
logs-qdrant: ## View Qdrant logs only
|
| 34 |
+
docker-compose logs -f qdrant
|
| 35 |
+
|
| 36 |
+
ps: ## Show running containers
|
| 37 |
+
docker-compose ps
|
| 38 |
+
|
| 39 |
+
health: ## Check health of services
|
| 40 |
+
@echo "Checking Qdrant..."
|
| 41 |
+
@curl -s http://localhost:6333/healthz || echo "Qdrant not healthy"
|
| 42 |
+
@echo "\nChecking API..."
|
| 43 |
+
@curl -s http://localhost:8000/health | python -m json.tool || echo "API not healthy"
|
| 44 |
+
|
| 45 |
+
exec-api: ## Execute bash in API container
|
| 46 |
+
docker-compose exec eyewiki-rag bash
|
| 47 |
+
|
| 48 |
+
exec-qdrant: ## Execute bash in Qdrant container
|
| 49 |
+
docker-compose exec qdrant /bin/sh
|
| 50 |
+
|
| 51 |
+
clean: ## Remove all containers, volumes, and images
|
| 52 |
+
docker-compose down -v
|
| 53 |
+
docker rmi eyewiki-rag_eyewiki-rag 2>/dev/null || true
|
| 54 |
+
|
| 55 |
+
clean-volumes: ## Remove only volumes (keeps images)
|
| 56 |
+
docker-compose down -v
|
| 57 |
+
|
| 58 |
+
rebuild: clean build up ## Clean rebuild and start
|
| 59 |
+
|
| 60 |
+
test: ## Run tests in container
|
| 61 |
+
docker-compose exec eyewiki-rag pytest tests/ -v
|
| 62 |
+
|
| 63 |
+
scrape: ## Run scraper in container (example: make scrape ARGS="--max-pages 50")
|
| 64 |
+
docker-compose exec eyewiki-rag python scripts/scrape_eyewiki.py $(ARGS)
|
| 65 |
+
|
| 66 |
+
build-index: ## Build vector index in container
|
| 67 |
+
docker-compose exec eyewiki-rag python scripts/build_index.py --index-vectors
|
| 68 |
+
|
| 69 |
+
evaluate: ## Run evaluation in container
|
| 70 |
+
docker-compose exec eyewiki-rag python scripts/evaluate.py
|
| 71 |
+
|
| 72 |
+
stats: ## Show system statistics
|
| 73 |
+
@echo "Docker stats:"
|
| 74 |
+
docker stats --no-stream eyewiki-rag-api eyewiki-qdrant
|
| 75 |
+
@echo "\nDisk usage:"
|
| 76 |
+
docker system df
|
| 77 |
+
|
| 78 |
+
backup-qdrant: ## Backup Qdrant data
|
| 79 |
+
docker-compose exec qdrant tar -czf /tmp/qdrant-backup.tar.gz /qdrant/storage
|
| 80 |
+
docker cp eyewiki-qdrant:/tmp/qdrant-backup.tar.gz ./backups/qdrant-backup-$$(date +%Y%m%d-%H%M%S).tar.gz
|
| 81 |
+
@echo "Backup saved to ./backups/"
|
| 82 |
+
|
| 83 |
+
restore-qdrant: ## Restore Qdrant data (usage: make restore-qdrant BACKUP=backups/file.tar.gz)
|
| 84 |
+
docker cp $(BACKUP) eyewiki-qdrant:/tmp/qdrant-backup.tar.gz
|
| 85 |
+
docker-compose exec qdrant tar -xzf /tmp/qdrant-backup.tar.gz -C /
|
| 86 |
+
|
| 87 |
+
prod: ## Start in production mode (detached, with restart policy)
|
| 88 |
+
docker-compose up -d --remove-orphans
|
| 89 |
+
@echo "Production services started"
|
| 90 |
+
|
| 91 |
+
dev: ## Start in development mode (with logs)
|
| 92 |
+
docker-compose up
|
| 93 |
+
|
| 94 |
+
.DEFAULT_GOAL := help
|
README.md
CHANGED
|
@@ -1,11 +1,1026 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
| 1 |
+
# 🏥 EyeWiki RAG System
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
A production-ready Retrieval-Augmented Generation (RAG) system for ophthalmology knowledge, powered by EyeWiki content and local LLMs.
|
| 7 |
+
|
| 8 |
+
## 📋 Overview
|
| 9 |
+
|
| 10 |
+
The EyeWiki RAG system provides intelligent question-answering capabilities for ophthalmology topics by combining:
|
| 11 |
+
- **Web scraping** of authoritative EyeWiki content
|
| 12 |
+
- **Semantic search** with hybrid retrieval (dense + sparse)
|
| 13 |
+
- **Cross-encoder reranking** for precision
|
| 14 |
+
- **Local LLM inference** via Ollama for privacy and control
|
| 15 |
+
- **RESTful API** with interactive web UI
|
| 16 |
+
|
| 17 |
+
Built for medical professionals, researchers, and students seeking quick, evidence-based answers to ophthalmology questions.
|
| 18 |
+
|
| 19 |
+
## ✨ Features
|
| 20 |
+
|
| 21 |
+
### Core Capabilities
|
| 22 |
+
- 🔍 **Intelligent Retrieval**: Hybrid search combining dense embeddings and sparse BM25
|
| 23 |
+
- 🎯 **Precise Reranking**: Cross-encoder models for relevance scoring
|
| 24 |
+
- 🏠 **Local Processing**: All data stays on your machine (HIPAA-friendly)
|
| 25 |
+
- 📚 **Source Citations**: Every answer includes EyeWiki article references
|
| 26 |
+
- ⚡ **Streaming Responses**: Real-time answer generation
|
| 27 |
+
- 🌐 **Web Interface**: Beautiful Gradio UI for easy interaction
|
| 28 |
+
- 🔌 **REST API**: Programmatic access with FastAPI
|
| 29 |
+
- ✅ **Comprehensive Testing**: 25+ pytest tests with mocking
|
| 30 |
+
|
| 31 |
+
### Technical Highlights
|
| 32 |
+
- **Polite Web Scraping**: Respects robots.txt and implements rate limiting
|
| 33 |
+
- **Smart Chunking**: Hierarchical markdown splitting with section awareness
|
| 34 |
+
- **Metadata Extraction**: Automatic ICD-10 codes, anatomical terms, medications
|
| 35 |
+
- **Vector Store**: Local Qdrant with payload indexing
|
| 36 |
+
- **Medical Disclaimer**: Automatic inclusion in all responses
|
| 37 |
+
|
| 38 |
+
## 🏗️ Architecture
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
┌─────────────────────────────────────────────────────────────────┐
|
| 42 |
+
│ User Interface │
|
| 43 |
+
│ ┌──────────────────┐ ┌─────────────────────────┐ │
|
| 44 |
+
│ │ Gradio Web UI │ │ REST API (FastAPI) │ │
|
| 45 |
+
│ │ - Chat interface│ │ - /query │ │
|
| 46 |
+
│ │ - Examples │ │ - /query/stream │ │
|
| 47 |
+
│ │ - Source display│ │ - /health, /stats │ │
|
| 48 |
+
│ └────────┬─────────┘ └───────────┬─────────────┘ │
|
| 49 |
+
└───────────┼────────────────────────────────────┼────────────────┘
|
| 50 |
+
│ │
|
| 51 |
+
└────────────────┬───────────────────┘
|
| 52 |
+
▼
|
| 53 |
+
┌────────────────────────────────────────┐
|
| 54 |
+
│ Query Engine (Orchestrator) │
|
| 55 |
+
│ - Context assembly │
|
| 56 |
+
│ - Prompt formatting │
|
| 57 |
+
│ - Source diversity │
|
| 58 |
+
└──┬────────────────────────┬────────────┘
|
| 59 |
+
│ │
|
| 60 |
+
┌───────▼──────┐ ┌──────▼──────────┐
|
| 61 |
+
│ Retriever │ │ Ollama Client │
|
| 62 |
+
│ (Hybrid) │ │ - LLM (Mistral)│
|
| 63 |
+
│ Dense: 0.7 │ │ │
|
| 64 |
+
│ Sparse: 0.3 │ │ Sentence- │
|
| 65 |
+
└──┬───────────┘ │ Transformers │
|
| 66 |
+
│ │ - Embeddings │
|
| 67 |
+
│ │ (all-mpnet) │
|
| 68 |
+
│ └─────────────────┘
|
| 69 |
+
│
|
| 70 |
+
┌───────▼──────────┐
|
| 71 |
+
│ Reranker │
|
| 72 |
+
│ (CrossEncoder) │
|
| 73 |
+
│ ms-marco-MiniLM │
|
| 74 |
+
└──┬───────────────┘
|
| 75 |
+
│
|
| 76 |
+
▼
|
| 77 |
+
┌────────────────────────────────────┐
|
| 78 |
+
│ Qdrant Vector Store │
|
| 79 |
+
│ - Dense vectors (768-dim) │
|
| 80 |
+
│ - Sparse vectors (BM25) │
|
| 81 |
+
│ - Metadata filtering │
|
| 82 |
+
│ - Local storage │
|
| 83 |
+
└────────────────────────────────────┘
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
**Data Flow:**
|
| 87 |
+
1. **Scraping** → EyeWiki → Raw Markdown
|
| 88 |
+
2. **Processing** → Chunking → Metadata Extraction → JSON
|
| 89 |
+
3. **Indexing** → Embeddings → Vector Store
|
| 90 |
+
4. **Query** → Retrieval → Reranking → LLM → Response
|
| 91 |
+
|
| 92 |
+
## 📁 Project Structure
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
eyewiki-rag/
|
| 96 |
+
├── src/
|
| 97 |
+
│ ├── scraper/ # Web scraping (crawl4ai)
|
| 98 |
+
│ │ └── eyewiki_crawler.py
|
| 99 |
+
│ ├── processing/ # Document processing
|
| 100 |
+
│ │ ├── chunker.py # Semantic chunking
|
| 101 |
+
│ │ └── metadata_extractor.py # Medical metadata
|
| 102 |
+
│ ├── vectorstore/ # Vector database
|
| 103 |
+
│ │ └── qdrant_store.py
|
| 104 |
+
│ ├── rag/ # RAG components
|
| 105 |
+
│ │ ├── retriever.py # Hybrid retrieval
|
| 106 |
+
│ │ ├── reranker.py # Cross-encoder reranking
|
| 107 |
+
│ │ └── query_engine.py # Main orchestrator
|
| 108 |
+
│ ├── llm/ # LLM integration
|
| 109 |
+
│ │ ├── ollama_client.py # Ollama for LLM generation
|
| 110 |
+
│ │ └── sentence_transformer_client.py # Stable embeddings
|
| 111 |
+
│ ├── api/ # FastAPI server
|
| 112 |
+
│ │ ├── main.py # API endpoints
|
| 113 |
+
│ │ └── gradio_ui.py # Web interface
|
| 114 |
+
│ └── config/ # Configuration
|
| 115 |
+
│ └── settings.py
|
| 116 |
+
├── prompts/ # Customizable prompts
|
| 117 |
+
│ ├── system_prompt.txt
|
| 118 |
+
│ ├── query_prompt.txt
|
| 119 |
+
│ └── medical_disclaimer.txt
|
| 120 |
+
├── scripts/ # Utility scripts
|
| 121 |
+
│ ├── scrape_eyewiki.py # Web scraping
|
| 122 |
+
│ ├── build_index.py # Index building
|
| 123 |
+
│ ├── run_server.py # Server startup
|
| 124 |
+
│ └── evaluate.py # System evaluation
|
| 125 |
+
├── tests/ # Comprehensive test suite
|
| 126 |
+
│ ├── test_components.py # Component tests
|
| 127 |
+
│ ├── test_questions.json # Evaluation questions
|
| 128 |
+
│ └── conftest.py
|
| 129 |
+
├── data/ # Data storage (gitignored)
|
| 130 |
+
│ ├── raw/ # Scraped content
|
| 131 |
+
│ ├── processed/ # Chunked documents
|
| 132 |
+
│ └── qdrant/ # Vector database
|
| 133 |
+
└── requirements.txt # Python dependencies
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## 📋 Prerequisites
|
| 137 |
+
|
| 138 |
+
### Required
|
| 139 |
+
- **Python 3.10+** (tested on 3.10, 3.11)
|
| 140 |
+
- **Ollama** (for local LLM text generation only)
|
| 141 |
+
- Install: https://ollama.ai/download
|
| 142 |
+
- Note: Embeddings now use sentence-transformers (more stable)
|
| 143 |
+
- **8GB+ RAM** (16GB recommended for larger datasets)
|
| 144 |
+
- **10GB+ disk space** (for models and vector store)
|
| 145 |
+
|
| 146 |
+
### Optional
|
| 147 |
+
- **CUDA-capable GPU** (for faster embedding generation with sentence-transformers)
|
| 148 |
+
- **Docker** (if running Qdrant in container)
|
| 149 |
+
|
| 150 |
+
### System Requirements by Component
|
| 151 |
+
| Component | RAM | CPU | GPU | Disk |
|
| 152 |
+
|-----------|-----|-----|-----|------|
|
| 153 |
+
| Scraping | 2GB | 2 cores | No | 500MB |
|
| 154 |
+
| Processing | 4GB | 4 cores | No | 2GB |
|
| 155 |
+
| Indexing | 8GB | 4 cores | Optional | 5GB |
|
| 156 |
+
| API Server | 4GB | 2 cores | Optional | 100MB |
|
| 157 |
+
|
| 158 |
+
## 🚀 Quick Start
|
| 159 |
+
|
| 160 |
+
### Step 1: Installation
|
| 161 |
+
|
| 162 |
+
```bash
|
| 163 |
+
# Clone repository
|
| 164 |
+
git clone <repository-url>
|
| 165 |
+
cd eyewiki-rag
|
| 166 |
+
|
| 167 |
+
# Create virtual environment
|
| 168 |
+
python -m venv venv
|
| 169 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 170 |
+
|
| 171 |
+
# Install Python dependencies
|
| 172 |
+
pip install -r requirements.txt
|
| 173 |
+
|
| 174 |
+
# Install system dependencies for Playwright (Linux/WSL only)
|
| 175 |
+
# This installs required shared libraries (libnss3, libnspr4, etc.)
|
| 176 |
+
python -m playwright install-deps
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
### Step 2: Install Ollama and LLM Model
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
# Install Ollama from https://ollama.ai/download
|
| 183 |
+
# Then pull required LLM model:
|
| 184 |
+
|
| 185 |
+
ollama pull mistral # LLM model (4.1GB)
|
| 186 |
+
# or use smaller alternative:
|
| 187 |
+
ollama pull llama3.2:3b # Smaller LLM (2GB)
|
| 188 |
+
|
| 189 |
+
# Note: Embedding model (sentence-transformers) will be auto-downloaded
|
| 190 |
+
# when you first run build_index.py (no Ollama needed for embeddings!)
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### Step 3: Scrape EyeWiki
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
# Quick test (50 pages, ~5 minutes)
|
| 197 |
+
python scripts/scrape_eyewiki.py --max-pages 50
|
| 198 |
+
|
| 199 |
+
# Full crawl (1000+ pages, ~2 hours)
|
| 200 |
+
python scripts/scrape_eyewiki.py --max-pages 1000
|
| 201 |
+
|
| 202 |
+
# Resume from checkpoint
|
| 203 |
+
python scripts/scrape_eyewiki.py --resume
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
**Output:** `data/raw/*.json` (markdown files with metadata)
|
| 207 |
+
|
| 208 |
+
### Step 4: Build Vector Index
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
# Process documents and build vector index
|
| 212 |
+
python scripts/build_index.py --index-vectors
|
| 213 |
+
|
| 214 |
+
# This will:
|
| 215 |
+
# 1. Chunk documents (data/processed/)
|
| 216 |
+
# 2. Extract metadata
|
| 217 |
+
# 3. Generate embeddings using sentence-transformers (all-mpnet-base-v2)
|
| 218 |
+
# 4. Build Qdrant index (data/qdrant/)
|
| 219 |
+
|
| 220 |
+
# Optional: Use different embedding model
|
| 221 |
+
python scripts/build_index.py --index-vectors --embedding-model "BAAI/bge-base-en-v1.5"
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
**Time:** ~10-30 minutes depending on dataset size
|
| 225 |
+
**Note:** First run will download the embedding model (~400MB for all-mpnet-base-v2)
|
| 226 |
+
|
| 227 |
+
### Step 5: Start Server
|
| 228 |
+
|
| 229 |
+
```bash
|
| 230 |
+
# Run with pre-flight checks
|
| 231 |
+
python scripts/run_server.py
|
| 232 |
+
|
| 233 |
+
# Development mode with hot reload
|
| 234 |
+
python scripts/run_server.py --reload
|
| 235 |
+
|
| 236 |
+
# Custom port
|
| 237 |
+
python scripts/run_server.py --port 8080
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
### Step 6: Access the System
|
| 241 |
+
|
| 242 |
+
**Web Interface:** http://localhost:8000/ui
|
| 243 |
+
- Beautiful chat interface
|
| 244 |
+
- Example questions
|
| 245 |
+
- Source citations
|
| 246 |
+
- Settings sidebar
|
| 247 |
+
|
| 248 |
+
**API Docs:** http://localhost:8000/docs
|
| 249 |
+
- Swagger UI
|
| 250 |
+
- Interactive testing
|
| 251 |
+
- Full API documentation
|
| 252 |
+
|
| 253 |
+
**Health Check:** http://localhost:8000/health
|
| 254 |
+
|
| 255 |
+
### Example Query
|
| 256 |
+
|
| 257 |
+
```bash
|
| 258 |
+
curl -X POST http://localhost:8000/query \
|
| 259 |
+
-H "Content-Type: application/json" \
|
| 260 |
+
-d '{
|
| 261 |
+
"question": "What are the symptoms of glaucoma?",
|
| 262 |
+
"include_sources": true
|
| 263 |
+
}'
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
## 🐳 Docker Deployment
|
| 267 |
+
|
| 268 |
+
### Prerequisites
|
| 269 |
+
- **Docker** and **Docker Compose** installed
|
| 270 |
+
- **Ollama** running on host machine (for GPU access)
|
| 271 |
+
- **8GB+ RAM** allocated to Docker
|
| 272 |
+
|
| 273 |
+
### Quick Start with Docker
|
| 274 |
+
|
| 275 |
+
```bash
|
| 276 |
+
# 1. Ensure Ollama is running on host
|
| 277 |
+
ollama serve
|
| 278 |
+
|
| 279 |
+
# 2. Pull required models (on host)
|
| 280 |
+
ollama pull nomic-embed-text
|
| 281 |
+
ollama pull mistral
|
| 282 |
+
|
| 283 |
+
# 3. Build and start services
|
| 284 |
+
docker-compose up -d
|
| 285 |
+
|
| 286 |
+
# 4. Check status
|
| 287 |
+
docker-compose ps
|
| 288 |
+
|
| 289 |
+
# 5. View logs
|
| 290 |
+
docker-compose logs -f eyewiki-rag
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
**Access:**
|
| 294 |
+
- API: http://localhost:8000
|
| 295 |
+
- Gradio UI: http://localhost:8000/ui
|
| 296 |
+
- API Docs: http://localhost:8000/docs
|
| 297 |
+
- Qdrant Dashboard: http://localhost:6333/dashboard
|
| 298 |
+
|
| 299 |
+
### Using Makefile Commands
|
| 300 |
+
|
| 301 |
+
```bash
|
| 302 |
+
# Start services
|
| 303 |
+
make up
|
| 304 |
+
|
| 305 |
+
# View logs
|
| 306 |
+
make logs
|
| 307 |
+
|
| 308 |
+
# Check health
|
| 309 |
+
make health
|
| 310 |
+
|
| 311 |
+
# Run scraper in container
|
| 312 |
+
make scrape ARGS="--max-pages 50"
|
| 313 |
+
|
| 314 |
+
# Build index
|
| 315 |
+
make build-index
|
| 316 |
+
|
| 317 |
+
# Run evaluation
|
| 318 |
+
make evaluate
|
| 319 |
+
|
| 320 |
+
# Stop services
|
| 321 |
+
make down
|
| 322 |
+
|
| 323 |
+
# Clean everything
|
| 324 |
+
make clean
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
### Docker Compose Services
|
| 328 |
+
|
| 329 |
+
**eyewiki-rag** (API Server)
|
| 330 |
+
- Built from Dockerfile
|
| 331 |
+
- Exposes port 8000
|
| 332 |
+
- Connects to Ollama on host via `host.docker.internal`
|
| 333 |
+
- Connects to Qdrant container
|
| 334 |
+
- Mounts data volumes for persistence
|
| 335 |
+
|
| 336 |
+
**qdrant** (Vector Database)
|
| 337 |
+
- Official Qdrant image
|
| 338 |
+
- Exposes ports 6333 (REST) and 6334 (gRPC)
|
| 339 |
+
- Persistent volume for vector storage
|
| 340 |
+
- Health checks enabled
|
| 341 |
+
|
| 342 |
+
### Volume Management
|
| 343 |
+
|
| 344 |
+
**Persistent volumes:**
|
| 345 |
+
- `./data/raw` - Scraped content
|
| 346 |
+
- `./data/processed` - Chunked documents
|
| 347 |
+
- `qdrant_data` - Vector database (Docker volume)
|
| 348 |
+
- `./prompts` - Customizable prompts
|
| 349 |
+
|
| 350 |
+
**Backup Qdrant data:**
|
| 351 |
+
```bash
|
| 352 |
+
make backup-qdrant
|
| 353 |
+
# Saves to ./backups/qdrant-backup-YYYYMMDD-HHMMSS.tar.gz
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
**Restore Qdrant data:**
|
| 357 |
+
```bash
|
| 358 |
+
make restore-qdrant BACKUP=backups/qdrant-backup-20241209-120000.tar.gz
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
### Configuration via Environment Variables
|
| 362 |
+
|
| 363 |
+
Edit `docker-compose.yml` to customize:
|
| 364 |
+
|
| 365 |
+
```yaml
|
| 366 |
+
environment:
|
| 367 |
+
# Ollama settings
|
| 368 |
+
- OLLAMA_BASE_URL=http://host.docker.internal:11434
|
| 369 |
+
- LLM_MODEL=mistral
|
| 370 |
+
- EMBEDDING_MODEL=nomic-embed-text
|
| 371 |
+
|
| 372 |
+
# Qdrant settings
|
| 373 |
+
- QDRANT_HOST=qdrant
|
| 374 |
+
- QDRANT_PORT=6333
|
| 375 |
+
|
| 376 |
+
# Processing settings
|
| 377 |
+
- CHUNK_SIZE=512
|
| 378 |
+
- RETRIEVAL_K=20
|
| 379 |
+
- RERANK_K=5
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
### Running Scripts in Container
|
| 383 |
+
|
| 384 |
+
```bash
|
| 385 |
+
# Scrape EyeWiki
|
| 386 |
+
docker-compose exec eyewiki-rag python scripts/scrape_eyewiki.py --max-pages 100
|
| 387 |
+
|
| 388 |
+
# Build index
|
| 389 |
+
docker-compose exec eyewiki-rag python scripts/build_index.py --index-vectors
|
| 390 |
+
|
| 391 |
+
# Run evaluation
|
| 392 |
+
docker-compose exec eyewiki-rag python scripts/evaluate.py
|
| 393 |
+
|
| 394 |
+
# Run tests
|
| 395 |
+
docker-compose exec eyewiki-rag pytest tests/ -v
|
| 396 |
+
```
|
| 397 |
+
|
| 398 |
+
### Production Deployment
|
| 399 |
+
|
| 400 |
+
```bash
|
| 401 |
+
# Start in production mode
|
| 402 |
+
make prod
|
| 403 |
+
|
| 404 |
+
# Or manually:
|
| 405 |
+
docker-compose up -d --remove-orphans
|
| 406 |
+
|
| 407 |
+
# Monitor with healthchecks
|
| 408 |
+
watch docker-compose ps
|
| 409 |
+
|
| 410 |
+
# View metrics
|
| 411 |
+
docker stats eyewiki-rag-api eyewiki-qdrant
|
| 412 |
+
```
|
| 413 |
+
|
| 414 |
+
### Troubleshooting Docker
|
| 415 |
+
|
| 416 |
+
**Problem:** Cannot connect to Ollama
|
| 417 |
+
|
| 418 |
+
**Solution:**
|
| 419 |
+
```bash
|
| 420 |
+
# Linux: Use host.docker.internal
|
| 421 |
+
# If not working, use host IP:
|
| 422 |
+
docker network inspect eyewiki-network
|
| 423 |
+
# Update OLLAMA_BASE_URL to http://<host-ip>:11434
|
| 424 |
+
|
| 425 |
+
# Or on Linux, add to docker-compose.yml:
|
| 426 |
+
extra_hosts:
|
| 427 |
+
- "host.docker.internal:host-gateway"
|
| 428 |
+
```
|
| 429 |
+
|
| 430 |
+
**Problem:** Qdrant volume permission issues
|
| 431 |
+
|
| 432 |
+
**Solution:**
|
| 433 |
+
```bash
|
| 434 |
+
# Fix permissions
|
| 435 |
+
sudo chown -R 1000:1000 data/qdrant/
|
| 436 |
+
```
|
| 437 |
+
|
| 438 |
+
**Problem:** Out of memory
|
| 439 |
+
|
| 440 |
+
**Solution:**
|
| 441 |
+
```bash
|
| 442 |
+
# Increase Docker memory limit in Docker Desktop
|
| 443 |
+
# Or in docker-compose.yml, add:
|
| 444 |
+
deploy:
|
| 445 |
+
resources:
|
| 446 |
+
limits:
|
| 447 |
+
memory: 8G
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### Docker Image Sizes
|
| 451 |
+
|
| 452 |
+
| Image | Size | Purpose |
|
| 453 |
+
|-------|------|---------|
|
| 454 |
+
| eyewiki-rag | ~2.5GB | API server with dependencies |
|
| 455 |
+
| qdrant/qdrant | ~200MB | Vector database |
|
| 456 |
+
| **Total** | ~2.7GB | Both services |
|
| 457 |
+
|
| 458 |
+
**Note:** Ollama models (~4-5GB) run on host for GPU access.
|
| 459 |
+
|
| 460 |
+
## ⚙️ Configuration
|
| 461 |
+
|
| 462 |
+
Configuration via `src/config/settings.py` (uses pydantic-settings):
|
| 463 |
+
|
| 464 |
+
| Parameter | Default | Description |
|
| 465 |
+
|-----------|---------|-------------|
|
| 466 |
+
| **LLM Settings** |
|
| 467 |
+
| `llm_model` | `mistral` | Ollama LLM model name |
|
| 468 |
+
| `ollama_base_url` | `http://localhost:11434` | Ollama API URL |
|
| 469 |
+
| `llm_temperature` | `0.7` | LLM sampling temperature |
|
| 470 |
+
| `llm_max_tokens` | `2048` | Max tokens for LLM response |
|
| 471 |
+
| **Embedding Settings** |
|
| 472 |
+
| `embedding_model` | `all-mpnet-base-v2` | Sentence-transformers model |
|
| 473 |
+
| **Vector Store** |
|
| 474 |
+
| `qdrant_collection_name` | `eyewiki_rag` | Collection name |
|
| 475 |
+
| `qdrant_path` | `./data/vectorstore` | Local storage path |
|
| 476 |
+
| `qdrant_url` | `None` | Remote Qdrant URL (optional) |
|
| 477 |
+
| **Chunking** |
|
| 478 |
+
| `chunk_size` | `512` | Max tokens per chunk |
|
| 479 |
+
| `chunk_overlap` | `50` | Overlap between chunks |
|
| 480 |
+
| `min_chunk_size` | `100` | Minimum chunk size |
|
| 481 |
+
| **Retrieval** |
|
| 482 |
+
| `top_k` | `10` | Initial retrieval count |
|
| 483 |
+
| `rerank_top_k` | `5` | After reranking |
|
| 484 |
+
| `similarity_threshold` | `0.7` | Minimum similarity score |
|
| 485 |
+
| **Scraper** |
|
| 486 |
+
| `scraper_delay` | `1.0` | Delay between requests (seconds) |
|
| 487 |
+
| `scraper_timeout` | `30` | Request timeout (seconds) |
|
| 488 |
+
|
| 489 |
+
### Environment Variables
|
| 490 |
+
|
| 491 |
+
Create `.env` file to override defaults (see `.env.example`):
|
| 492 |
+
|
| 493 |
+
```env
|
| 494 |
+
# Ollama Configuration (for LLM only)
|
| 495 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 496 |
+
LLM_MODEL=mistral
|
| 497 |
+
LLM_TEMPERATURE=0.7
|
| 498 |
+
LLM_MAX_TOKENS=2048
|
| 499 |
+
|
| 500 |
+
# Embedding Configuration (sentence-transformers)
|
| 501 |
+
EMBEDDING_MODEL=sentence-transformers/all-mpnet-base-v2
|
| 502 |
+
|
| 503 |
+
# Qdrant Vector Store
|
| 504 |
+
QDRANT_COLLECTION_NAME=eyewiki_rag
|
| 505 |
+
QDRANT_PATH=./data/vectorstore
|
| 506 |
+
# QDRANT_URL=http://localhost:6333 # For remote Qdrant
|
| 507 |
+
# QDRANT_API_KEY=your-key # For Qdrant Cloud
|
| 508 |
+
|
| 509 |
+
# Document Processing
|
| 510 |
+
CHUNK_SIZE=512
|
| 511 |
+
CHUNK_OVERLAP=50
|
| 512 |
+
MIN_CHUNK_SIZE=100
|
| 513 |
+
|
| 514 |
+
# RAG Retrieval
|
| 515 |
+
TOP_K=10
|
| 516 |
+
RERANK_TOP_K=5
|
| 517 |
+
SIMILARITY_THRESHOLD=0.7
|
| 518 |
+
|
| 519 |
+
# Web Scraper
|
| 520 |
+
SCRAPER_DELAY=1.0
|
| 521 |
+
SCRAPER_TIMEOUT=30
|
| 522 |
+
|
| 523 |
+
# API Server
|
| 524 |
+
API_HOST=0.0.0.0
|
| 525 |
+
API_PORT=8000
|
| 526 |
+
API_WORKERS=4
|
| 527 |
+
|
| 528 |
+
# Gradio UI
|
| 529 |
+
GRADIO_HOST=0.0.0.0
|
| 530 |
+
GRADIO_PORT=7860
|
| 531 |
+
GRADIO_SHARE=false
|
| 532 |
+
|
| 533 |
+
# Data Paths
|
| 534 |
+
DATA_RAW_PATH=./data/raw
|
| 535 |
+
DATA_PROCESSED_PATH=./data/processed
|
| 536 |
+
|
| 537 |
+
# Logging
|
| 538 |
+
LOG_LEVEL=INFO
|
| 539 |
+
LOG_FILE=logs/eyewiki_rag.log
|
| 540 |
+
```
|
| 541 |
+
|
| 542 |
+
### Customizing Prompts
|
| 543 |
+
|
| 544 |
+
Edit files in `prompts/` directory:
|
| 545 |
+
- `system_prompt.txt` - System instructions for LLM
|
| 546 |
+
- `query_prompt.txt` - Query template with `{context}` and `{question}` placeholders
|
| 547 |
+
- `medical_disclaimer.txt` - Medical disclaimer text
|
| 548 |
+
|
| 549 |
+
## 📡 API Documentation
|
| 550 |
+
|
| 551 |
+
### Endpoints
|
| 552 |
+
|
| 553 |
+
#### `GET /`
|
| 554 |
+
Root endpoint with API information
|
| 555 |
+
|
| 556 |
+
#### `GET /health`
|
| 557 |
+
Health check endpoint
|
| 558 |
+
|
| 559 |
+
**Response:**
|
| 560 |
+
```json
|
| 561 |
+
{
|
| 562 |
+
"status": "healthy",
|
| 563 |
+
"ollama": {"status": "healthy", "models": {...}},
|
| 564 |
+
"qdrant": {"status": "healthy", "vectors_count": 1234},
|
| 565 |
+
"query_engine": {"status": "initialized"},
|
| 566 |
+
"timestamp": 1702134567.89
|
| 567 |
+
}
|
| 568 |
+
```
|
| 569 |
+
|
| 570 |
+
#### `POST /query`
|
| 571 |
+
Main query endpoint
|
| 572 |
+
|
| 573 |
+
**Request:**
|
| 574 |
+
```json
|
| 575 |
+
{
|
| 576 |
+
"question": "What is glaucoma?",
|
| 577 |
+
"include_sources": true,
|
| 578 |
+
"filters": {"disease_name": "Glaucoma"} // optional
|
| 579 |
+
}
|
| 580 |
+
```
|
| 581 |
+
|
| 582 |
+
**Response:**
|
| 583 |
+
```json
|
| 584 |
+
{
|
| 585 |
+
"answer": "Glaucoma is a group of eye diseases...",
|
| 586 |
+
"sources": [
|
| 587 |
+
{
|
| 588 |
+
"title": "Primary Open-Angle Glaucoma",
|
| 589 |
+
"url": "https://eyewiki.aao.org/...",
|
| 590 |
+
"section": "Overview",
|
| 591 |
+
"relevance_score": 0.89
|
| 592 |
+
}
|
| 593 |
+
],
|
| 594 |
+
"confidence": 0.85,
|
| 595 |
+
"disclaimer": "Medical disclaimer text...",
|
| 596 |
+
"query": "What is glaucoma?"
|
| 597 |
+
}
|
| 598 |
+
```
|
| 599 |
+
|
| 600 |
+
#### `POST /query/stream`
|
| 601 |
+
Streaming query with Server-Sent Events
|
| 602 |
+
|
| 603 |
+
**Request:**
|
| 604 |
+
```json
|
| 605 |
+
{
|
| 606 |
+
"question": "What is glaucoma?",
|
| 607 |
+
"filters": {} // optional
|
| 608 |
+
}
|
| 609 |
+
```
|
| 610 |
+
|
| 611 |
+
**Response:** SSE stream
|
| 612 |
+
```
|
| 613 |
+
data: Glaucoma
|
| 614 |
+
data: is
|
| 615 |
+
data: a group of eye diseases...
|
| 616 |
+
```
|
| 617 |
+
|
| 618 |
+
#### `GET /stats`
|
| 619 |
+
Index and pipeline statistics
|
| 620 |
+
|
| 621 |
+
**Response:**
|
| 622 |
+
```json
|
| 623 |
+
{
|
| 624 |
+
"collection_info": {
|
| 625 |
+
"name": "eyewiki_rag",
|
| 626 |
+
"vectors_count": 1234
|
| 627 |
+
},
|
| 628 |
+
"pipeline_config": {
|
| 629 |
+
"retrieval_k": 20,
|
| 630 |
+
"rerank_k": 5,
|
| 631 |
+
"llm_model": "mistral"
|
| 632 |
+
},
|
| 633 |
+
"documents_indexed": 1234,
|
| 634 |
+
"timestamp": 1702134567.89
|
| 635 |
+
}
|
| 636 |
+
```
|
| 637 |
+
|
| 638 |
+
### Python Client Example
|
| 639 |
+
|
| 640 |
+
```python
|
| 641 |
+
import requests
|
| 642 |
+
|
| 643 |
+
# Query the API
|
| 644 |
+
response = requests.post(
|
| 645 |
+
"http://localhost:8000/query",
|
| 646 |
+
json={
|
| 647 |
+
"question": "What causes diabetic retinopathy?",
|
| 648 |
+
"include_sources": True
|
| 649 |
+
}
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
result = response.json()
|
| 653 |
+
print(f"Answer: {result['answer']}")
|
| 654 |
+
print(f"Confidence: {result['confidence']:.2%}")
|
| 655 |
+
print(f"Sources: {len(result['sources'])}")
|
| 656 |
+
```
|
| 657 |
+
|
| 658 |
+
### Streaming Example
|
| 659 |
+
|
| 660 |
+
```python
|
| 661 |
+
import requests
|
| 662 |
+
|
| 663 |
+
response = requests.post(
|
| 664 |
+
"http://localhost:8000/query/stream",
|
| 665 |
+
json={"question": "What is glaucoma?"},
|
| 666 |
+
stream=True
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
for line in response.iter_lines():
|
| 670 |
+
if line.startswith(b"data: "):
|
| 671 |
+
chunk = line[6:].decode()
|
| 672 |
+
print(chunk, end="", flush=True)
|
| 673 |
+
```
|
| 674 |
+
|
| 675 |
+
## 🧪 Development
|
| 676 |
+
|
| 677 |
+
### Running Tests
|
| 678 |
+
|
| 679 |
+
```bash
|
| 680 |
+
# Run all tests
|
| 681 |
+
pytest
|
| 682 |
+
|
| 683 |
+
# Run with coverage
|
| 684 |
+
pytest --cov=src --cov-report=html
|
| 685 |
+
|
| 686 |
+
# Run specific test file
|
| 687 |
+
pytest tests/test_components.py -v
|
| 688 |
+
|
| 689 |
+
# Run specific test
|
| 690 |
+
pytest tests/test_components.py::test_chunk_respects_headers -v
|
| 691 |
+
|
| 692 |
+
# Run by marker
|
| 693 |
+
pytest -m unit # Fast unit tests
|
| 694 |
+
pytest -m api # API tests
|
| 695 |
+
```
|
| 696 |
+
|
| 697 |
+
### Code Quality
|
| 698 |
+
|
| 699 |
+
```bash
|
| 700 |
+
# Format code
|
| 701 |
+
black src/ scripts/ tests/
|
| 702 |
+
isort src/ scripts/ tests/
|
| 703 |
+
|
| 704 |
+
# Lint
|
| 705 |
+
flake8 src/
|
| 706 |
+
pylint src/
|
| 707 |
+
|
| 708 |
+
# Type checking
|
| 709 |
+
mypy src/
|
| 710 |
+
```
|
| 711 |
+
|
| 712 |
+
### Evaluation
|
| 713 |
+
|
| 714 |
+
Run system evaluation on test questions:
|
| 715 |
+
|
| 716 |
+
```bash
|
| 717 |
+
# Run evaluation
|
| 718 |
+
python scripts/evaluate.py
|
| 719 |
+
|
| 720 |
+
# With custom questions
|
| 721 |
+
python scripts/evaluate.py --questions tests/custom_questions.json
|
| 722 |
+
|
| 723 |
+
# Save results
|
| 724 |
+
python scripts/evaluate.py --output results/eval.json
|
| 725 |
+
|
| 726 |
+
# Verbose output
|
| 727 |
+
python scripts/evaluate.py -v
|
| 728 |
+
```
|
| 729 |
+
|
| 730 |
+
**Metrics:**
|
| 731 |
+
- Retrieval Recall
|
| 732 |
+
- Answer Relevance
|
| 733 |
+
- Citation Precision/Recall/F1
|
| 734 |
+
- Performance by category
|
| 735 |
+
|
| 736 |
+
## 🔧 Troubleshooting
|
| 737 |
+
|
| 738 |
+
### Ollama Issues
|
| 739 |
+
|
| 740 |
+
**Problem:** "Connection refused" to Ollama
|
| 741 |
+
|
| 742 |
+
**Solution:**
|
| 743 |
+
```bash
|
| 744 |
+
# Check if Ollama is running
|
| 745 |
+
curl http://localhost:11434/api/tags
|
| 746 |
+
|
| 747 |
+
# Start Ollama
|
| 748 |
+
ollama serve
|
| 749 |
+
|
| 750 |
+
# Verify models are installed
|
| 751 |
+
ollama list
|
| 752 |
+
```
|
| 753 |
+
|
| 754 |
+
**Problem:** "Model not found"
|
| 755 |
+
|
| 756 |
+
**Solution:**
|
| 757 |
+
```bash
|
| 758 |
+
# Pull required models
|
| 759 |
+
ollama pull nomic-embed-text
|
| 760 |
+
ollama pull mistral
|
| 761 |
+
|
| 762 |
+
# List available models
|
| 763 |
+
ollama list
|
| 764 |
+
```
|
| 765 |
+
|
| 766 |
+
### Vector Store Issues
|
| 767 |
+
|
| 768 |
+
**Problem:** "Collection not found"
|
| 769 |
+
|
| 770 |
+
**Solution:**
|
| 771 |
+
```bash
|
| 772 |
+
# Rebuild the index
|
| 773 |
+
python scripts/build_index.py --index-vectors --recreate-collection
|
| 774 |
+
|
| 775 |
+
# Check Qdrant data directory
|
| 776 |
+
ls -la data/qdrant/
|
| 777 |
+
```
|
| 778 |
+
|
| 779 |
+
**Problem:** "Out of memory during indexing"
|
| 780 |
+
|
| 781 |
+
**Solution:**
|
| 782 |
+
```bash
|
| 783 |
+
# Use smaller batch size
|
| 784 |
+
python scripts/build_index.py --index-vectors --embedding-batch-size 16
|
| 785 |
+
|
| 786 |
+
# Or process in stages
|
| 787 |
+
python scripts/build_index.py # Process only (no indexing)
|
| 788 |
+
python scripts/build_index.py --index-only # Index separately
|
| 789 |
+
```
|
| 790 |
+
|
| 791 |
+
### Scraping Issues
|
| 792 |
+
|
| 793 |
+
**Problem:** "Rate limited by EyeWiki"
|
| 794 |
+
|
| 795 |
+
**Solution:**
|
| 796 |
+
```bash
|
| 797 |
+
# Increase delay between requests
|
| 798 |
+
python scripts/scrape_eyewiki.py --delay 5.0
|
| 799 |
+
|
| 800 |
+
# Resume from checkpoint if interrupted
|
| 801 |
+
python scripts/scrape_eyewiki.py --resume
|
| 802 |
+
```
|
| 803 |
+
|
| 804 |
+
**Problem:** "Timeout during scraping"
|
| 805 |
+
|
| 806 |
+
**Solution:**
|
| 807 |
+
```bash
|
| 808 |
+
# Increase timeout
|
| 809 |
+
python scripts/scrape_eyewiki.py --timeout 60
|
| 810 |
+
```
|
| 811 |
+
|
| 812 |
+
**Problem:** "error while loading shared libraries: libnspr4.so" or browser crashes
|
| 813 |
+
|
| 814 |
+
**Solution:**
|
| 815 |
+
```bash
|
| 816 |
+
# Install Playwright system dependencies (Linux/WSL)
|
| 817 |
+
python -m playwright install-deps
|
| 818 |
+
|
| 819 |
+
# Or manually install required libraries
|
| 820 |
+
sudo apt-get update && sudo apt-get install -y \
|
| 821 |
+
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 \
|
| 822 |
+
libcups2 libdrm2 libdbus-1-3 libxkbcommon0 \
|
| 823 |
+
libatspi2.0-0 libxcomposite1 libxdamage1 \
|
| 824 |
+
libxfixes3 libxrandr2 libgbm1 libasound2
|
| 825 |
+
```
|
| 826 |
+
|
| 827 |
+
**Problem:** "Executable doesn't exist" - Chromium browser not found
|
| 828 |
+
|
| 829 |
+
**Solution:**
|
| 830 |
+
```bash
|
| 831 |
+
# Install Playwright browsers
|
| 832 |
+
playwright install chromium
|
| 833 |
+
|
| 834 |
+
# Or install all browsers
|
| 835 |
+
playwright install
|
| 836 |
+
```
|
| 837 |
+
|
| 838 |
+
### API Server Issues
|
| 839 |
+
|
| 840 |
+
**Problem:** "Pre-flight checks failed"
|
| 841 |
+
|
| 842 |
+
**Solution:**
|
| 843 |
+
1. Check Ollama is running: `ollama serve`
|
| 844 |
+
2. Verify models: `ollama list`
|
| 845 |
+
3. Check vector store: `ls data/qdrant/`
|
| 846 |
+
4. View logs for specific error
|
| 847 |
+
|
| 848 |
+
**Problem:** "Gradio UI not loading"
|
| 849 |
+
|
| 850 |
+
**Solution:**
|
| 851 |
+
```bash
|
| 852 |
+
# Check if port is in use
|
| 853 |
+
lsof -i :8000
|
| 854 |
+
|
| 855 |
+
# Use different port
|
| 856 |
+
python scripts/run_server.py --port 8080
|
| 857 |
+
|
| 858 |
+
# Skip checks for testing
|
| 859 |
+
python scripts/run_server.py --skip-checks
|
| 860 |
+
```
|
| 861 |
+
|
| 862 |
+
### Performance Issues
|
| 863 |
+
|
| 864 |
+
**Problem:** "Slow query responses"
|
| 865 |
+
|
| 866 |
+
**Solution:**
|
| 867 |
+
1. Use GPU for embeddings (if available)
|
| 868 |
+
2. Reduce `retrieval_k` and `rerank_k` in config
|
| 869 |
+
3. Decrease `max_context_tokens`
|
| 870 |
+
4. Use smaller LLM model (llama3.2:3b instead of mistral)
|
| 871 |
+
|
| 872 |
+
**Problem:** "High memory usage"
|
| 873 |
+
|
| 874 |
+
**Solution:**
|
| 875 |
+
```bash
|
| 876 |
+
# Use smaller models
|
| 877 |
+
ollama pull llama3.2:3b # Only 2GB
|
| 878 |
+
|
| 879 |
+
# Reduce batch sizes in config
|
| 880 |
+
# Edit src/config/settings.py:
|
| 881 |
+
# chunk_size = 256 (instead of 512)
|
| 882 |
+
# retrieval_k = 10 (instead of 20)
|
| 883 |
+
```
|
| 884 |
+
|
| 885 |
+
### Common Error Messages
|
| 886 |
+
|
| 887 |
+
| Error | Cause | Solution |
|
| 888 |
+
|-------|-------|----------|
|
| 889 |
+
| `ConnectionError: Ollama` | Ollama not running | `ollama serve` |
|
| 890 |
+
| `Collection 'eyewiki_rag' not found` | Index not built | `python scripts/build_index.py --index-vectors` |
|
| 891 |
+
| `Model 'mistral' not found` | Model not pulled | `ollama pull mistral` |
|
| 892 |
+
| `503 Service Unavailable` | System not initialized | Check logs, verify dependencies |
|
| 893 |
+
| `422 Validation Error` | Invalid request format | Check API docs |
|
| 894 |
+
|
| 895 |
+
## 📊 Performance Benchmarks
|
| 896 |
+
|
| 897 |
+
Typical performance on a modern laptop (16GB RAM, M1/M2 or equivalent):
|
| 898 |
+
|
| 899 |
+
| Operation | Time | Notes |
|
| 900 |
+
|-----------|------|-------|
|
| 901 |
+
| Scraping (100 pages) | ~5-10 min | Network dependent |
|
| 902 |
+
| Processing | ~2-5 min | 100 documents |
|
| 903 |
+
| Embedding generation | ~5-10 min | 100 documents |
|
| 904 |
+
| Index building | ~3-5 min | 100 documents |
|
| 905 |
+
| Query (no streaming) | ~2-5s | Includes retrieval + LLM |
|
| 906 |
+
| Query (streaming) | ~0.5s first token | Then ~50 tokens/s |
|
| 907 |
+
|
| 908 |
+
## 📚 Additional Resources
|
| 909 |
+
|
| 910 |
+
### Documentation
|
| 911 |
+
- [EyeWiki](https://eyewiki.aao.org/) - Source of medical content
|
| 912 |
+
- [Ollama Documentation](https://github.com/ollama/ollama/blob/main/docs/README.md)
|
| 913 |
+
- [Qdrant Documentation](https://qdrant.tech/documentation/)
|
| 914 |
+
- [FastAPI Documentation](https://fastapi.tiangolo.com/)
|
| 915 |
+
|
| 916 |
+
### Related Projects
|
| 917 |
+
- [LlamaIndex](https://www.llamaindex.ai/) - Data framework for LLM applications
|
| 918 |
+
- [LangChain](https://www.langchain.com/) - Framework for developing LLM applications
|
| 919 |
+
- [Haystack](https://haystack.deepset.ai/) - End-to-end NLP framework
|
| 920 |
+
|
| 921 |
+
### Papers & Resources
|
| 922 |
+
- [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401)
|
| 923 |
+
- [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906)
|
| 924 |
+
|
| 925 |
+
## ⚠️ Medical Disclaimer
|
| 926 |
+
|
| 927 |
+
**IMPORTANT:** This system provides information from EyeWiki, a resource of the American Academy of Ophthalmology (AAO).
|
| 928 |
+
|
| 929 |
+
The information provided by this system:
|
| 930 |
+
- Is not a substitute for professional medical advice, diagnosis, or treatment
|
| 931 |
+
- May contain errors due to AI limitations
|
| 932 |
+
- Should be verified with authoritative sources before clinical use
|
| 933 |
+
|
| 934 |
+
Always consult with a qualified ophthalmologist or eye care professional for medical concerns. This system should not be used for:
|
| 935 |
+
- Clinical decision-making
|
| 936 |
+
- Patient diagnosis
|
| 937 |
+
- Treatment recommendations
|
| 938 |
+
- Emergency medical situations
|
| 939 |
+
|
| 940 |
+
## 📄 License
|
| 941 |
+
|
| 942 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 943 |
+
|
| 944 |
+
### Third-Party Licenses
|
| 945 |
+
- **EyeWiki Content**: © American Academy of Ophthalmology - Used under fair use for research purposes
|
| 946 |
+
- **Ollama**: Apache 2.0 License
|
| 947 |
+
- **Qdrant**: Apache 2.0 License
|
| 948 |
+
- **FastAPI**: MIT License
|
| 949 |
+
- **Gradio**: Apache 2.0 License
|
| 950 |
+
|
| 951 |
+
## 🙏 Attribution
|
| 952 |
+
|
| 953 |
+
### EyeWiki & AAO
|
| 954 |
+
This project uses content from [EyeWiki](https://eyewiki.aao.org/), the collaborative online encyclopedia of ophthalmology created and maintained by the [American Academy of Ophthalmology (AAO)](https://www.aao.org/).
|
| 955 |
+
|
| 956 |
+
**Citation:**
|
| 957 |
+
> American Academy of Ophthalmology. EyeWiki. Available at: https://eyewiki.aao.org/. Accessed [Date].
|
| 958 |
+
|
| 959 |
+
### Models & Libraries
|
| 960 |
+
- **nomic-embed-text**: [Nomic AI](https://www.nomic.ai/)
|
| 961 |
+
- **mistral**: [Mistral AI](https://mistral.ai/)
|
| 962 |
+
- **sentence-transformers**: [UKPLab](https://www.ukp.tu-darmstadt.de/)
|
| 963 |
+
- **crawl4ai**: [Web scraping framework](https://github.com/unclecode/crawl4ai)
|
| 964 |
+
|
| 965 |
+
## 🤝 Contributing
|
| 966 |
+
|
| 967 |
+
Contributions are welcome! Here's how you can help:
|
| 968 |
+
|
| 969 |
+
### Areas for Contribution
|
| 970 |
+
- 🐛 Bug fixes
|
| 971 |
+
- ✨ New features
|
| 972 |
+
- 📝 Documentation improvements
|
| 973 |
+
- 🧪 Test coverage
|
| 974 |
+
- 🎨 UI/UX enhancements
|
| 975 |
+
- 🌍 Internationalization
|
| 976 |
+
|
| 977 |
+
### Development Workflow
|
| 978 |
+
1. Fork the repository
|
| 979 |
+
2. Create a feature branch: `git checkout -b feature/amazing-feature`
|
| 980 |
+
3. Make your changes
|
| 981 |
+
4. Run tests: `pytest`
|
| 982 |
+
5. Commit: `git commit -m 'Add amazing feature'`
|
| 983 |
+
6. Push: `git push origin feature/amazing-feature`
|
| 984 |
+
7. Open a Pull Request
|
| 985 |
+
|
| 986 |
+
### Code Style
|
| 987 |
+
- Follow PEP 8
|
| 988 |
+
- Use Black for formatting
|
| 989 |
+
- Add type hints
|
| 990 |
+
- Write docstrings
|
| 991 |
+
- Include tests for new features
|
| 992 |
+
|
| 993 |
+
## 📞 Support
|
| 994 |
+
|
| 995 |
+
- **Issues**: [GitHub Issues](https://github.com/your-repo/issues)
|
| 996 |
+
- **Discussions**: [GitHub Discussions](https://github.com/your-repo/discussions)
|
| 997 |
+
- **Email**: your-email@example.com
|
| 998 |
+
|
| 999 |
+
## 🗺️ Roadmap
|
| 1000 |
+
|
| 1001 |
+
### Planned Features
|
| 1002 |
+
- [ ] Multi-language support
|
| 1003 |
+
- [ ] PDF document upload
|
| 1004 |
+
- [ ] Advanced filtering (date, author, etc.)
|
| 1005 |
+
- [ ] Conversation history
|
| 1006 |
+
- [ ] Feedback mechanism
|
| 1007 |
+
- [ ] Export answers to PDF
|
| 1008 |
+
- [ ] Mobile-responsive UI
|
| 1009 |
+
- [ ] Docker deployment
|
| 1010 |
+
- [ ] Cloud deployment guide (AWS, GCP, Azure)
|
| 1011 |
+
- [ ] Integration with medical record systems
|
| 1012 |
+
|
| 1013 |
+
### Future Improvements
|
| 1014 |
+
- [ ] Support for images in articles
|
| 1015 |
+
- [ ] Better handling of tables and diagrams
|
| 1016 |
+
- [ ] Citation formatting options (APA, MLA, etc.)
|
| 1017 |
+
- [ ] Multi-modal retrieval (text + images)
|
| 1018 |
+
- [ ] Custom model fine-tuning
|
| 1019 |
+
|
| 1020 |
+
## ⭐ Star History
|
| 1021 |
+
|
| 1022 |
+
If you find this project helpful, please consider giving it a star!
|
| 1023 |
+
|
| 1024 |
---
|
| 1025 |
|
| 1026 |
+
**Built with ❤️ for the ophthalmology community**
|
config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Configuration package
|
config/settings.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration settings for EyeWiki RAG system."""
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import Field
|
| 8 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LLMProvider(str, Enum):
|
| 12 |
+
"""Supported LLM providers."""
|
| 13 |
+
|
| 14 |
+
OLLAMA = "ollama"
|
| 15 |
+
OPENAI = "openai"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Settings(BaseSettings):
|
| 19 |
+
"""Application settings loaded from environment variables."""
|
| 20 |
+
|
| 21 |
+
model_config = SettingsConfigDict(
|
| 22 |
+
env_file=".env",
|
| 23 |
+
env_file_encoding="utf-8",
|
| 24 |
+
case_sensitive=False,
|
| 25 |
+
extra="ignore",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# LLM Provider Configuration
|
| 29 |
+
llm_provider: LLMProvider = Field(
|
| 30 |
+
default=LLMProvider.OLLAMA,
|
| 31 |
+
description="LLM provider to use: 'ollama' for local Ollama, 'openai' for OpenAI-compatible APIs (Groq, DeepSeek, OpenAI)",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Ollama Configuration
|
| 35 |
+
ollama_base_url: str = Field(
|
| 36 |
+
default="http://localhost:11434",
|
| 37 |
+
description="Base URL for Ollama API",
|
| 38 |
+
)
|
| 39 |
+
ollama_timeout: int = Field(
|
| 40 |
+
default=30,
|
| 41 |
+
gt=0,
|
| 42 |
+
description="Request timeout for Ollama API in seconds",
|
| 43 |
+
)
|
| 44 |
+
embedding_model: str = Field(
|
| 45 |
+
default="nomic-embed-text",
|
| 46 |
+
description="Ollama embedding model name",
|
| 47 |
+
)
|
| 48 |
+
llm_model: str = Field(
|
| 49 |
+
default="mistral",
|
| 50 |
+
description="Ollama LLM model name",
|
| 51 |
+
)
|
| 52 |
+
llm_temperature: float = Field(
|
| 53 |
+
default=0.7,
|
| 54 |
+
ge=0.0,
|
| 55 |
+
le=2.0,
|
| 56 |
+
description="LLM temperature for response generation",
|
| 57 |
+
)
|
| 58 |
+
llm_max_tokens: int = Field(
|
| 59 |
+
default=2048,
|
| 60 |
+
gt=0,
|
| 61 |
+
description="Maximum tokens for LLM response",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# OpenAI-compatible API Configuration (for Groq, DeepSeek, OpenAI, etc.)
|
| 65 |
+
openai_api_key: Optional[str] = Field(
|
| 66 |
+
default=None,
|
| 67 |
+
description="API key for OpenAI-compatible provider",
|
| 68 |
+
)
|
| 69 |
+
openai_base_url: Optional[str] = Field(
|
| 70 |
+
default=None,
|
| 71 |
+
description="Base URL for OpenAI-compatible API (e.g., https://api.groq.com/openai/v1 for Groq)",
|
| 72 |
+
)
|
| 73 |
+
openai_model: str = Field(
|
| 74 |
+
default="llama-3.3-70b-versatile",
|
| 75 |
+
description="Model name for OpenAI-compatible provider (e.g., llama-3.3-70b-versatile for Groq)",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Qdrant Configuration
|
| 79 |
+
qdrant_path: str = Field(
|
| 80 |
+
default="./data/vectorstore",
|
| 81 |
+
description="Path to Qdrant vector database",
|
| 82 |
+
)
|
| 83 |
+
qdrant_collection_name: str = Field(
|
| 84 |
+
default="eyewiki_rag",
|
| 85 |
+
description="Qdrant collection name",
|
| 86 |
+
)
|
| 87 |
+
qdrant_url: Optional[str] = Field(
|
| 88 |
+
default=None,
|
| 89 |
+
description="Qdrant server URL (for remote Qdrant)",
|
| 90 |
+
)
|
| 91 |
+
qdrant_api_key: Optional[str] = Field(
|
| 92 |
+
default=None,
|
| 93 |
+
description="Qdrant API key (for Qdrant Cloud)",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Document Processing Configuration
|
| 97 |
+
chunk_size: int = Field(
|
| 98 |
+
default=512,
|
| 99 |
+
gt=0,
|
| 100 |
+
description="Size of text chunks for processing",
|
| 101 |
+
)
|
| 102 |
+
chunk_overlap: int = Field(
|
| 103 |
+
default=50,
|
| 104 |
+
ge=0,
|
| 105 |
+
description="Overlap between consecutive chunks",
|
| 106 |
+
)
|
| 107 |
+
min_chunk_size: int = Field(
|
| 108 |
+
default=100,
|
| 109 |
+
gt=0,
|
| 110 |
+
description="Minimum chunk size in tokens (skip smaller chunks)",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# RAG Configuration
|
| 114 |
+
top_k: int = Field(
|
| 115 |
+
default=10,
|
| 116 |
+
gt=0,
|
| 117 |
+
description="Number of documents to retrieve",
|
| 118 |
+
)
|
| 119 |
+
rerank_top_k: int = Field(
|
| 120 |
+
default=5,
|
| 121 |
+
gt=0,
|
| 122 |
+
description="Number of documents after reranking",
|
| 123 |
+
)
|
| 124 |
+
similarity_threshold: float = Field(
|
| 125 |
+
default=0.7,
|
| 126 |
+
ge=0.0,
|
| 127 |
+
le=1.0,
|
| 128 |
+
description="Minimum similarity score for retrieval",
|
| 129 |
+
)
|
| 130 |
+
reranker_model: str = Field(
|
| 131 |
+
default="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 132 |
+
description="Cross-encoder model for reranking",
|
| 133 |
+
)
|
| 134 |
+
max_context_tokens: int = Field(
|
| 135 |
+
default=4096,
|
| 136 |
+
gt=0,
|
| 137 |
+
description="Maximum tokens for context in LLM prompt",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Scraper Configuration
|
| 141 |
+
scraper_delay: float = Field(
|
| 142 |
+
default=1.0,
|
| 143 |
+
ge=0.0,
|
| 144 |
+
description="Delay between scraping requests in seconds",
|
| 145 |
+
)
|
| 146 |
+
scraper_max_pages: Optional[int] = Field(
|
| 147 |
+
default=None,
|
| 148 |
+
description="Maximum number of pages to scrape (None for unlimited)",
|
| 149 |
+
)
|
| 150 |
+
scraper_timeout: int = Field(
|
| 151 |
+
default=30,
|
| 152 |
+
gt=0,
|
| 153 |
+
description="Request timeout in seconds",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# API Configuration
|
| 157 |
+
api_host: str = Field(
|
| 158 |
+
default="0.0.0.0",
|
| 159 |
+
description="API server host",
|
| 160 |
+
)
|
| 161 |
+
api_port: int = Field(
|
| 162 |
+
default=8000,
|
| 163 |
+
gt=0,
|
| 164 |
+
le=65535,
|
| 165 |
+
description="API server port",
|
| 166 |
+
)
|
| 167 |
+
api_workers: int = Field(
|
| 168 |
+
default=4,
|
| 169 |
+
gt=0,
|
| 170 |
+
description="Number of API workers",
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Gradio UI Configuration
|
| 174 |
+
gradio_host: str = Field(
|
| 175 |
+
default="0.0.0.0",
|
| 176 |
+
description="Gradio UI host",
|
| 177 |
+
)
|
| 178 |
+
gradio_port: int = Field(
|
| 179 |
+
default=7860,
|
| 180 |
+
gt=0,
|
| 181 |
+
le=65535,
|
| 182 |
+
description="Gradio UI port",
|
| 183 |
+
)
|
| 184 |
+
gradio_share: bool = Field(
|
| 185 |
+
default=False,
|
| 186 |
+
description="Create public Gradio share link",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Data Paths
|
| 190 |
+
data_raw_path: str = Field(
|
| 191 |
+
default="./data/raw",
|
| 192 |
+
description="Path to raw scraped data",
|
| 193 |
+
)
|
| 194 |
+
data_processed_path: str = Field(
|
| 195 |
+
default="./data/processed",
|
| 196 |
+
description="Path to processed documents",
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Logging
|
| 200 |
+
log_level: str = Field(
|
| 201 |
+
default="INFO",
|
| 202 |
+
description="Logging level",
|
| 203 |
+
)
|
| 204 |
+
log_file: Optional[str] = Field(
|
| 205 |
+
default="logs/eyewiki_rag.log",
|
| 206 |
+
description="Log file path",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def get_data_paths(self) -> dict[str, Path]:
|
| 210 |
+
"""Get all data paths as Path objects."""
|
| 211 |
+
return {
|
| 212 |
+
"raw": Path(self.data_raw_path),
|
| 213 |
+
"processed": Path(self.data_processed_path),
|
| 214 |
+
"vectorstore": Path(self.qdrant_path),
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
def ensure_data_directories(self) -> None:
|
| 218 |
+
"""Create data directories if they don't exist."""
|
| 219 |
+
for path in self.get_data_paths().values():
|
| 220 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 221 |
+
|
| 222 |
+
# Create logs directory if log_file is specified
|
| 223 |
+
if self.log_file:
|
| 224 |
+
log_path = Path(self.log_file)
|
| 225 |
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# Create global settings instance
|
| 229 |
+
settings = Settings()
|
data/processed/.gitkeep
ADDED
|
File without changes
|
data/raw/.gitkeep
ADDED
|
File without changes
|
data/vectorstore/.gitkeep
ADDED
|
File without changes
|
deployment_readme.md
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deployment Guide - EyeWiki RAG on Free Hosting
|
| 2 |
+
|
| 3 |
+
This guide covers deploying the EyeWiki RAG system using free/cheap cloud services:
|
| 4 |
+
|
| 5 |
+
- **App Hosting**: Hugging Face Spaces (Docker SDK)
|
| 6 |
+
- **Vector Database**: Qdrant Cloud (Free Tier)
|
| 7 |
+
- **LLM Provider**: Groq (Free Tier) or any OpenAI-compatible API
|
| 8 |
+
|
| 9 |
+
## Prerequisites
|
| 10 |
+
|
| 11 |
+
- A [Hugging Face](https://huggingface.co) account
|
| 12 |
+
- A [Qdrant Cloud](https://cloud.qdrant.io) account
|
| 13 |
+
- A [Groq](https://console.groq.com) account (or other OpenAI-compatible provider)
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## Step 1: Set Up Qdrant Cloud
|
| 18 |
+
|
| 19 |
+
1. Go to [Qdrant Cloud](https://cloud.qdrant.io) and create a free cluster.
|
| 20 |
+
2. Once created, note down:
|
| 21 |
+
- **Cluster URL** (e.g., `https://abc123-xyz.aws.cloud.qdrant.io:6333`)
|
| 22 |
+
- **API Key** (from the cluster dashboard)
|
| 23 |
+
3. You will need to index your data into the Qdrant Cloud cluster. You can do this locally:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
export QDRANT_URL="https://your-cluster-url:6333"
|
| 27 |
+
export QDRANT_API_KEY="your-qdrant-api-key"
|
| 28 |
+
python scripts/build_index.py --index-vectors
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Step 2: Get a Groq API Key
|
| 32 |
+
|
| 33 |
+
1. Go to [Groq Console](https://console.groq.com) and sign up.
|
| 34 |
+
2. Create an API key from the dashboard.
|
| 35 |
+
3. Note down the API key.
|
| 36 |
+
|
| 37 |
+
## Step 3: Deploy to Hugging Face Spaces
|
| 38 |
+
|
| 39 |
+
### Option A: Via the HF Web UI
|
| 40 |
+
|
| 41 |
+
1. Go to [Hugging Face Spaces](https://huggingface.co/spaces) and click **Create new Space**.
|
| 42 |
+
2. Choose **Docker** as the SDK.
|
| 43 |
+
3. Upload the project files (or connect a Git repo).
|
| 44 |
+
4. In the Space **Settings > Variables and secrets**, add:
|
| 45 |
+
|
| 46 |
+
| Variable | Value |
|
| 47 |
+
|-------------------|-----------------------------------------------|
|
| 48 |
+
| `LLM_PROVIDER` | `openai` |
|
| 49 |
+
| `OPENAI_API_KEY` | `gsk_your_groq_api_key` |
|
| 50 |
+
| `OPENAI_BASE_URL` | `https://api.groq.com/openai/v1` |
|
| 51 |
+
| `OPENAI_MODEL` | `llama-3.3-70b-versatile` |
|
| 52 |
+
| `QDRANT_URL` | `https://your-cluster.cloud.qdrant.io:6333` |
|
| 53 |
+
| `QDRANT_API_KEY` | `your_qdrant_api_key` |
|
| 54 |
+
|
| 55 |
+
5. The Space will build using `Dockerfile.deploy` and start automatically.
|
| 56 |
+
|
| 57 |
+
### Option B: Via the HF CLI
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Install HF CLI
|
| 61 |
+
pip install huggingface_hub
|
| 62 |
+
|
| 63 |
+
# Login
|
| 64 |
+
huggingface-cli login
|
| 65 |
+
|
| 66 |
+
# Create Space
|
| 67 |
+
huggingface-cli repo create eyewiki-rag --type space --space-sdk docker
|
| 68 |
+
|
| 69 |
+
# Clone and push
|
| 70 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/eyewiki-rag
|
| 71 |
+
cd eyewiki-rag
|
| 72 |
+
# Copy project files here, then:
|
| 73 |
+
cp /path/to/project/Dockerfile.deploy ./Dockerfile
|
| 74 |
+
git add . && git commit -m "Initial deployment" && git push
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Then add the environment variables via the web UI (Settings > Variables and secrets).
|
| 78 |
+
|
| 79 |
+
## Step 4: Verify Deployment
|
| 80 |
+
|
| 81 |
+
Once the Space is running:
|
| 82 |
+
|
| 83 |
+
1. Visit your Space URL (e.g., `https://your-username-eyewiki-rag.hf.space`)
|
| 84 |
+
2. Check the health endpoint: `https://your-username-eyewiki-rag.hf.space/health`
|
| 85 |
+
3. Try the Gradio UI: `https://your-username-eyewiki-rag.hf.space/ui`
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## Environment Variables Reference
|
| 90 |
+
|
| 91 |
+
| Variable | Required | Default | Description |
|
| 92 |
+
|---------------------------|----------|----------------------------|----------------------------------------------------|
|
| 93 |
+
| `LLM_PROVIDER` | No | `ollama` | LLM provider: `ollama` or `openai` |
|
| 94 |
+
| `OPENAI_API_KEY` | If openai| - | API key for OpenAI-compatible provider |
|
| 95 |
+
| `OPENAI_BASE_URL` | No | `https://api.openai.com/v1`| Base URL for OpenAI-compatible API |
|
| 96 |
+
| `OPENAI_MODEL` | No | `llama-3.3-70b-versatile` | Model name for the provider |
|
| 97 |
+
| `OLLAMA_BASE_URL` | No | `http://localhost:11434` | Ollama API URL (only for ollama provider) |
|
| 98 |
+
| `LLM_MODEL` | No | `mistral` | Ollama model name (only for ollama provider) |
|
| 99 |
+
| `QDRANT_URL` | No | - | Qdrant Cloud cluster URL |
|
| 100 |
+
| `QDRANT_API_KEY` | No | - | Qdrant Cloud API key |
|
| 101 |
+
| `QDRANT_PATH` | No | `./data/vectorstore` | Local Qdrant path (if not using cloud) |
|
| 102 |
+
| `QDRANT_COLLECTION_NAME` | No | `eyewiki_rag` | Qdrant collection name |
|
| 103 |
+
| `EMBEDDING_MODEL` | No | `nomic-embed-text` | Sentence-transformer embedding model |
|
| 104 |
+
| `API_PORT` | No | `8000` | API server port |
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## Provider Examples
|
| 109 |
+
|
| 110 |
+
### Groq (Free Tier)
|
| 111 |
+
|
| 112 |
+
```env
|
| 113 |
+
LLM_PROVIDER=openai
|
| 114 |
+
OPENAI_API_KEY=gsk_your_key_here
|
| 115 |
+
OPENAI_BASE_URL=https://api.groq.com/openai/v1
|
| 116 |
+
OPENAI_MODEL=llama-3.3-70b-versatile
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### OpenAI
|
| 120 |
+
|
| 121 |
+
```env
|
| 122 |
+
LLM_PROVIDER=openai
|
| 123 |
+
OPENAI_API_KEY=sk-your_key_here
|
| 124 |
+
OPENAI_MODEL=gpt-4o-mini
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### DeepSeek
|
| 128 |
+
|
| 129 |
+
```env
|
| 130 |
+
LLM_PROVIDER=openai
|
| 131 |
+
OPENAI_API_KEY=your_key_here
|
| 132 |
+
OPENAI_BASE_URL=https://api.deepseek.com/v1
|
| 133 |
+
OPENAI_MODEL=deepseek-chat
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Local Ollama (Default)
|
| 137 |
+
|
| 138 |
+
```env
|
| 139 |
+
LLM_PROVIDER=ollama
|
| 140 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 141 |
+
LLM_MODEL=mistral
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## Troubleshooting
|
| 147 |
+
|
| 148 |
+
- **Space fails to build**: Check that `Dockerfile.deploy` is renamed to `Dockerfile` in the Space repo.
|
| 149 |
+
- **Model download slow on startup**: The embedding model (`all-mpnet-base-v2`) downloads on first run. Subsequent restarts use the cached version.
|
| 150 |
+
- **Qdrant connection errors**: Verify your `QDRANT_URL` includes the port (`:6333`) and the API key is correct.
|
| 151 |
+
- **LLM errors**: Check that your API key is valid and the model name is supported by your provider.
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EyeWiki RAG System - Docker Compose Configuration
|
| 2 |
+
version: '3.8'
|
| 3 |
+
|
| 4 |
+
services:
|
| 5 |
+
# Qdrant vector database
|
| 6 |
+
qdrant:
|
| 7 |
+
image: qdrant/qdrant:latest
|
| 8 |
+
container_name: eyewiki-qdrant
|
| 9 |
+
ports:
|
| 10 |
+
- "6333:6333" # REST API
|
| 11 |
+
- "6334:6334" # gRPC (optional)
|
| 12 |
+
volumes:
|
| 13 |
+
- qdrant_data:/qdrant/storage
|
| 14 |
+
environment:
|
| 15 |
+
- QDRANT__SERVICE__GRPC_PORT=6334
|
| 16 |
+
networks:
|
| 17 |
+
- eyewiki-network
|
| 18 |
+
restart: unless-stopped
|
| 19 |
+
healthcheck:
|
| 20 |
+
test: ["CMD", "curl", "-f", "http://localhost:6333/"]
|
| 21 |
+
interval: 30s
|
| 22 |
+
timeout: 10s
|
| 23 |
+
retries: 3
|
| 24 |
+
start_period: 40s
|
| 25 |
+
|
| 26 |
+
# EyeWiki RAG API
|
| 27 |
+
eyewiki-rag:
|
| 28 |
+
build:
|
| 29 |
+
context: .
|
| 30 |
+
dockerfile: Dockerfile
|
| 31 |
+
container_name: eyewiki-rag-api
|
| 32 |
+
ports:
|
| 33 |
+
- "8000:8000"
|
| 34 |
+
volumes:
|
| 35 |
+
# Mount data directories for persistence
|
| 36 |
+
- ./data/raw:/app/data/raw
|
| 37 |
+
- ./data/processed:/app/data/processed
|
| 38 |
+
- qdrant_data:/app/data/qdrant
|
| 39 |
+
# Mount prompts for easy customization
|
| 40 |
+
- ./prompts:/app/prompts
|
| 41 |
+
environment:
|
| 42 |
+
# Ollama on host (access via host.docker.internal)
|
| 43 |
+
- OLLAMA_BASE_URL=http://host.docker.internal:11434
|
| 44 |
+
- LLM_MODEL=mistral
|
| 45 |
+
- EMBEDDING_MODEL=nomic-embed-text
|
| 46 |
+
|
| 47 |
+
# Qdrant service
|
| 48 |
+
- QDRANT_HOST=qdrant
|
| 49 |
+
- QDRANT_PORT=6333
|
| 50 |
+
- QDRANT_COLLECTION_NAME=eyewiki_rag
|
| 51 |
+
- QDRANT_PATH=/app/data/qdrant
|
| 52 |
+
|
| 53 |
+
# Processing settings
|
| 54 |
+
- CHUNK_SIZE=512
|
| 55 |
+
- CHUNK_OVERLAP=50
|
| 56 |
+
- MAX_CONTEXT_TOKENS=4000
|
| 57 |
+
|
| 58 |
+
# Retrieval settings
|
| 59 |
+
- RETRIEVAL_K=20
|
| 60 |
+
- RERANK_K=5
|
| 61 |
+
networks:
|
| 62 |
+
- eyewiki-network
|
| 63 |
+
depends_on:
|
| 64 |
+
qdrant:
|
| 65 |
+
condition: service_healthy
|
| 66 |
+
restart: unless-stopped
|
| 67 |
+
healthcheck:
|
| 68 |
+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
| 69 |
+
interval: 30s
|
| 70 |
+
timeout: 10s
|
| 71 |
+
retries: 3
|
| 72 |
+
start_period: 60s
|
| 73 |
+
|
| 74 |
+
networks:
|
| 75 |
+
eyewiki-network:
|
| 76 |
+
driver: bridge
|
| 77 |
+
|
| 78 |
+
volumes:
|
| 79 |
+
qdrant_data:
|
| 80 |
+
driver: local
|
plan/implementation_plan.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation Plan - EyeWiki RAG Deployment
|
| 2 |
+
|
| 3 |
+
This plan outlines the steps to prepare the EyeWiki RAG application for deployment on free/cheap hosting providers (specifically Hugging Face Spaces + Groq + Qdrant Cloud), by decoupling the local Ollama dependency.
|
| 4 |
+
|
| 5 |
+
## User Review Required
|
| 6 |
+
|
| 7 |
+
> [!IMPORTANT]
|
| 8 |
+
> **LLM Provider Switch**: The deployment will support switching from local Ollama to "OpenAI-compatible" APIs (like Groq, DeepSeek, or OpenAI itself). This requires an API key for the chosen provider.
|
| 9 |
+
|
| 10 |
+
> [!NOTE]
|
| 11 |
+
> **Hosting Choice**: The recommended "free" stack is **Hugging Face Spaces (Docker)** for the app, **Qdrant Cloud (Free Tier)** for the vector DB, and **Groq (Free Tier)** for the LLM.
|
| 12 |
+
|
| 13 |
+
## Proposed Changes
|
| 14 |
+
|
| 15 |
+
### Configuration
|
| 16 |
+
#### [MODIFY] [settings.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/config/settings.py)
|
| 17 |
+
- Add `llm_provider` field (enum: "ollama", "openai").
|
| 18 |
+
- Add `openai_api_key`, `openai_base_url`, `openai_model` fields.
|
| 19 |
+
|
| 20 |
+
### LLM Abstraction
|
| 21 |
+
#### [NEW] [llm_client.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/llm/llm_client.py)
|
| 22 |
+
- Define `LLMClient` abstract base class/protocol with `generate` and `stream_generate` methods.
|
| 23 |
+
|
| 24 |
+
#### [MODIFY] [ollama_client.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/llm/ollama_client.py)
|
| 25 |
+
- Implement `LLMClient` interface.
|
| 26 |
+
|
| 27 |
+
#### [NEW] [openai_client.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/llm/openai_client.py)
|
| 28 |
+
- Implement `LLMClient` using `openai` python package.
|
| 29 |
+
- Support standard OpenAI API and compatible endpoints (Groq).
|
| 30 |
+
|
| 31 |
+
### Application Logic
|
| 32 |
+
#### [MODIFY] [query_engine.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/rag/query_engine.py)
|
| 33 |
+
- Update type hints to use abstract `LLMClient`.
|
| 34 |
+
|
| 35 |
+
#### [MODIFY] [main.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/src/api/main.py)
|
| 36 |
+
- Instantiate appropriate client based on `settings.llm_provider`.
|
| 37 |
+
- Update lifecycle events.
|
| 38 |
+
|
| 39 |
+
#### [MODIFY] [run_server.py](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/scripts/run_server.py)
|
| 40 |
+
- Modify pre-flight checks to only check Ollama if `llm_provider_is_ollama`.
|
| 41 |
+
- Add checks for API keys if provider is OpenAI/Groq.
|
| 42 |
+
|
| 43 |
+
### Deployment Configuration
|
| 44 |
+
#### [NEW] [Dockerfile.deploy](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/Dockerfile.deploy)
|
| 45 |
+
- Optimized Dockerfile for Hugging Face Spaces (non-root user, specific cache directories).
|
| 46 |
+
|
| 47 |
+
#### [NEW] [deployment_readme.md](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/deployment_readme.md)
|
| 48 |
+
- Step-by-step guide for deploying to HF Spaces and setting up Qdrant Cloud.
|
| 49 |
+
|
| 50 |
+
#### [MODIFY] [requirements.txt](file:///home/obum/Projects/Care1/research-and-ml/eyewiki-rag/requirements.txt)
|
| 51 |
+
- Add `openai>=1.0.0`.
|
| 52 |
+
|
| 53 |
+
## Verification Plan
|
| 54 |
+
|
| 55 |
+
### Automated Tests
|
| 56 |
+
- Run existing tests to ensure no regression: `pytest tests/`
|
| 57 |
+
- *Note:* New client tests would require mocking OpenAI API, which might be out of scope for a "test deployment", but we will verify the code compiles and runs.
|
| 58 |
+
|
| 59 |
+
### Manual Verification
|
| 60 |
+
1. **Local Test (Ollama)**: Run server with `LLM_PROVIDER=ollama` and verify standard functionality.
|
| 61 |
+
2. **Local Test (Mock/Groq)**: Run server with `LLM_PROVIDER=openai` and a valid API key (or mock) to verify the switch works.
|
| 62 |
+
3. **Deployment Build**: Build the `Dockerfile.deploy` locally to ensure it builds correctly.
|
prompts/medical_disclaimer.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
**Medical Disclaimer:** This information is sourced from EyeWiki, a resource of the American Academy of Ophthalmology (AAO). It is not a substitute for professional medical advice, diagnosis, or treatment. AI systems can make errors. Always consult with a qualified ophthalmologist or eye care professional for medical concerns and verify any critical information with authoritative sources.
|
prompts/query_prompt.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are answering a question using information from the EyeWiki medical knowledge base.
|
| 2 |
+
|
| 3 |
+
CONTEXT FROM EYEWIKI:
|
| 4 |
+
{context}
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
QUESTION: {question}
|
| 9 |
+
|
| 10 |
+
INSTRUCTIONS:
|
| 11 |
+
1. Answer the question using ONLY the information provided in the context above
|
| 12 |
+
2. Cite sources for all claims using the format: [Source: Article Title]
|
| 13 |
+
3. If the context does not contain enough information to fully answer the question, clearly state: "The provided sources do not contain sufficient information about [specific aspect]"
|
| 14 |
+
4. Organize your answer with:
|
| 15 |
+
- A direct answer to the question (1-2 sentences)
|
| 16 |
+
- Supporting details from the sources with citations
|
| 17 |
+
- Any relevant additional context from the sources
|
| 18 |
+
5. Use clear medical terminology with explanations for technical terms
|
| 19 |
+
6. Do NOT make up or infer information beyond what is explicitly stated in the context
|
| 20 |
+
|
| 21 |
+
ANSWER:
|
prompts/system_prompt.txt
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an expert ophthalmology knowledge assistant powered by the EyeWiki medical database. Your role is to provide accurate, evidence-based information about eye diseases, conditions, treatments, and procedures.
|
| 2 |
+
|
| 3 |
+
CRITICAL GUIDELINES:
|
| 4 |
+
|
| 5 |
+
1. CONTEXT-ONLY RESPONSES
|
| 6 |
+
- Base ALL answers strictly on the provided context from EyeWiki articles
|
| 7 |
+
- NEVER make up, infer, or add information that is not explicitly in the context
|
| 8 |
+
- If the context does not contain enough information to answer a question, clearly state this
|
| 9 |
+
- Do not use general knowledge or information from other sources
|
| 10 |
+
|
| 11 |
+
2. SOURCE CITATION
|
| 12 |
+
- Always cite the specific EyeWiki article when referencing information
|
| 13 |
+
- Use the format: [Source: Article Title] or "According to [Article Title]..."
|
| 14 |
+
- When multiple sources support a point, cite all relevant sources
|
| 15 |
+
- Include section names when specific information comes from a particular section
|
| 16 |
+
|
| 17 |
+
3. RESPONSE STRUCTURE
|
| 18 |
+
Your answers should follow this format:
|
| 19 |
+
|
| 20 |
+
a) Direct Answer
|
| 21 |
+
- Begin with a clear, concise answer to the specific question
|
| 22 |
+
- Use 1-2 sentences to address the core query
|
| 23 |
+
|
| 24 |
+
b) Supporting Details
|
| 25 |
+
- Provide relevant details, definitions, and explanations from the sources
|
| 26 |
+
- Use proper medical terminology, but include clear explanations for complex terms
|
| 27 |
+
- Organize information logically (e.g., causes, symptoms, diagnosis, treatment)
|
| 28 |
+
|
| 29 |
+
c) Additional Context (when appropriate)
|
| 30 |
+
- Include related information that provides valuable context
|
| 31 |
+
- Mention important considerations, risk factors, or variations
|
| 32 |
+
- Connect concepts to help understanding
|
| 33 |
+
|
| 34 |
+
d) Limitations
|
| 35 |
+
- If the context is incomplete, specify what information is missing
|
| 36 |
+
- Acknowledge when a question requires clinical judgment or patient-specific evaluation
|
| 37 |
+
|
| 38 |
+
4. MEDICAL TERMINOLOGY
|
| 39 |
+
- Use accurate medical terminology as it appears in the sources
|
| 40 |
+
- Immediately follow technical terms with clear explanations in parentheses
|
| 41 |
+
- Example: "trabecular meshwork (the eye's drainage system)"
|
| 42 |
+
- Balance professional precision with accessibility
|
| 43 |
+
|
| 44 |
+
5. UNCERTAINTY AND LIMITATIONS
|
| 45 |
+
When you cannot fully answer a question:
|
| 46 |
+
- Explicitly state: "The provided sources do not contain sufficient information about..."
|
| 47 |
+
- Offer what partial information IS available
|
| 48 |
+
- Suggest what type of information would be needed for a complete answer
|
| 49 |
+
- NEVER guess or extrapolate beyond what the sources explicitly state
|
| 50 |
+
|
| 51 |
+
6. CLINICAL CONSULTATION REMINDER
|
| 52 |
+
- For questions about specific symptoms, diagnosis, or treatment decisions, remind users to consult a qualified eye care professional
|
| 53 |
+
- Emphasize that individual cases vary and require professional medical evaluation
|
| 54 |
+
- Do not provide specific medical advice for individual situations
|
| 55 |
+
|
| 56 |
+
7. RESPONSE QUALITY
|
| 57 |
+
- Be thorough but concise - avoid unnecessary verbosity
|
| 58 |
+
- Use clear section headers for longer responses
|
| 59 |
+
- Present information in a logical, easy-to-follow structure
|
| 60 |
+
- Use bullet points or numbered lists when appropriate for clarity
|
| 61 |
+
- Maintain a professional yet approachable tone
|
| 62 |
+
|
| 63 |
+
8. ACCURACY PRIORITIES
|
| 64 |
+
- Accuracy is more important than completeness
|
| 65 |
+
- It is better to say "I don't have enough information" than to speculate
|
| 66 |
+
- When sources conflict or present multiple perspectives, present all views and cite each
|
| 67 |
+
- Distinguish between established facts and areas of ongoing research or debate
|
| 68 |
+
|
| 69 |
+
EXAMPLE RESPONSE PATTERNS:
|
| 70 |
+
|
| 71 |
+
Good Response:
|
| 72 |
+
"Primary open-angle glaucoma (POAG) is characterized by progressive optic nerve damage and visual field loss [Source: Primary Open-Angle Glaucoma]. The primary risk factor is elevated intraocular pressure (IOP), which occurs when the eye's drainage system (trabecular meshwork) becomes less efficient at draining aqueous humor [Source: Glaucoma Pathophysiology]..."
|
| 73 |
+
|
| 74 |
+
Poor Response:
|
| 75 |
+
"Glaucoma is usually treated with eye drops, and most patients do well with treatment."
|
| 76 |
+
(No citations, no source verification, making general claims)
|
| 77 |
+
|
| 78 |
+
When Uncertain:
|
| 79 |
+
"The provided sources discuss glaucoma treatment options including medications and surgery [Source: Glaucoma Management], but do not contain specific information about the long-term success rates you're asking about. For detailed statistics on treatment outcomes, you would need additional clinical research data."
|
| 80 |
+
|
| 81 |
+
REMEMBER:
|
| 82 |
+
- You are a knowledge assistant, not a medical professional
|
| 83 |
+
- Your purpose is to provide information, not to diagnose or prescribe
|
| 84 |
+
- Every piece of information should be traceable to the provided sources
|
| 85 |
+
- Professional consultation is irreplaceable for medical care
|
| 86 |
+
|
| 87 |
+
Maintain these standards in every response to ensure users receive accurate, well-sourced, and appropriately contextualized medical information.
|
pytest.ini
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
# Pytest configuration file
|
| 3 |
+
|
| 4 |
+
# Test discovery patterns
|
| 5 |
+
python_files = test_*.py
|
| 6 |
+
python_classes = Test*
|
| 7 |
+
python_functions = test_*
|
| 8 |
+
|
| 9 |
+
# Test paths
|
| 10 |
+
testpaths = tests
|
| 11 |
+
|
| 12 |
+
# Output options
|
| 13 |
+
addopts =
|
| 14 |
+
-v
|
| 15 |
+
--strict-markers
|
| 16 |
+
--tb=short
|
| 17 |
+
--disable-warnings
|
| 18 |
+
|
| 19 |
+
# Markers
|
| 20 |
+
markers =
|
| 21 |
+
unit: Unit tests (fast, isolated)
|
| 22 |
+
integration: Integration tests (may be slow)
|
| 23 |
+
api: API tests (requires server components)
|
| 24 |
+
|
| 25 |
+
# Minimum Python version
|
| 26 |
+
minversion = 3.8
|
requirements.txt
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web Scraping
|
| 2 |
+
crawl4ai>=0.3.0
|
| 3 |
+
beautifulsoup4>=4.12.0
|
| 4 |
+
markdownify>=0.11.0
|
| 5 |
+
|
| 6 |
+
# RAG Framework
|
| 7 |
+
llama-index>=0.10.0
|
| 8 |
+
llama-index-vector-stores-qdrant>=0.2.0
|
| 9 |
+
llama-index-embeddings-ollama>=0.1.0
|
| 10 |
+
llama-index-llms-ollama>=0.1.0
|
| 11 |
+
|
| 12 |
+
# Vector Storage
|
| 13 |
+
qdrant-client>=1.7.0
|
| 14 |
+
|
| 15 |
+
# Embeddings & Reranking
|
| 16 |
+
sentence-transformers>=2.2.0 # For stable embeddings and cross-encoder reranking
|
| 17 |
+
torch>=2.0.0 # Required by sentence-transformers
|
| 18 |
+
|
| 19 |
+
# API Server
|
| 20 |
+
fastapi>=0.104.0
|
| 21 |
+
uvicorn[standard]>=0.24.0
|
| 22 |
+
|
| 23 |
+
# UI
|
| 24 |
+
gradio>=4.0.0
|
| 25 |
+
|
| 26 |
+
# Configuration
|
| 27 |
+
python-dotenv>=1.0.0
|
| 28 |
+
pydantic>=2.0.0
|
| 29 |
+
pydantic-settings>=2.0.0
|
| 30 |
+
|
| 31 |
+
# CLI Output & Progress
|
| 32 |
+
rich>=13.0.0
|
| 33 |
+
tqdm>=4.66.0
|
| 34 |
+
|
| 35 |
+
# OpenAI-compatible API
|
| 36 |
+
openai>=1.0.0
|
| 37 |
+
|
| 38 |
+
# Utilities
|
| 39 |
+
requests>=2.31.0
|
| 40 |
+
aiohttp>=3.9.0
|
| 41 |
+
|
| 42 |
+
# Development
|
| 43 |
+
pytest>=7.4.0
|
| 44 |
+
pytest-asyncio>=0.21.0
|
| 45 |
+
black>=23.11.0
|
| 46 |
+
isort>=5.12.0
|
| 47 |
+
flake8>=6.1.0
|
scripts/build_index.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Build index by processing raw markdown files into semantic chunks with metadata."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
import traceback
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
# Add parent directory to path
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
|
| 16 |
+
from src.processing.chunker import SemanticChunker, ChunkNode
|
| 17 |
+
from src.processing.metadata_extractor import MetadataExtractor
|
| 18 |
+
from src.vectorstore.qdrant_store import QdrantStoreManager
|
| 19 |
+
from src.llm.sentence_transformer_client import SentenceTransformerClient
|
| 20 |
+
from config.settings import settings
|
| 21 |
+
from rich.console import Console
|
| 22 |
+
from rich.panel import Panel
|
| 23 |
+
from rich.table import Table
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_args():
|
| 27 |
+
"""Parse command line arguments."""
|
| 28 |
+
parser = argparse.ArgumentParser(
|
| 29 |
+
description="Process raw EyeWiki markdown into semantic chunks with medical metadata",
|
| 30 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 31 |
+
epilog="""
|
| 32 |
+
Examples:
|
| 33 |
+
# Just process files (no vector indexing)
|
| 34 |
+
python scripts/build_index.py
|
| 35 |
+
|
| 36 |
+
# Process AND build vector index
|
| 37 |
+
python scripts/build_index.py --index-vectors
|
| 38 |
+
|
| 39 |
+
# Only build vector index from existing processed files
|
| 40 |
+
python scripts/build_index.py --index-only
|
| 41 |
+
|
| 42 |
+
# Process with custom directories
|
| 43 |
+
python scripts/build_index.py --input-dir ./my_raw --output-dir ./my_processed
|
| 44 |
+
|
| 45 |
+
# Force rebuild with fresh Qdrant collection
|
| 46 |
+
python scripts/build_index.py --rebuild --index-vectors --recreate-collection
|
| 47 |
+
|
| 48 |
+
# Process only files matching pattern
|
| 49 |
+
python scripts/build_index.py --pattern "Glaucoma*.md" --index-vectors
|
| 50 |
+
|
| 51 |
+
# Custom chunking and embedding parameters
|
| 52 |
+
python scripts/build_index.py --chunk-size 1024 --embedding-batch-size 64 --index-vectors
|
| 53 |
+
""",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--input-dir",
|
| 58 |
+
type=str,
|
| 59 |
+
default=None,
|
| 60 |
+
help=f"Input directory with raw markdown files (default: {settings.data_raw_path})",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--output-dir",
|
| 65 |
+
type=str,
|
| 66 |
+
default=None,
|
| 67 |
+
help=f"Output directory for processed chunks (default: {settings.data_processed_path})",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--rebuild",
|
| 72 |
+
action="store_true",
|
| 73 |
+
help="Force rebuild even if output files exist",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--pattern",
|
| 78 |
+
type=str,
|
| 79 |
+
default="*.md",
|
| 80 |
+
help="Glob pattern for files to process (default: *.md)",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--chunk-size",
|
| 85 |
+
type=int,
|
| 86 |
+
default=None,
|
| 87 |
+
help=f"Chunk size in tokens (default: {settings.chunk_size})",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--chunk-overlap",
|
| 92 |
+
type=int,
|
| 93 |
+
default=None,
|
| 94 |
+
help=f"Chunk overlap in tokens (default: {settings.chunk_overlap})",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--min-chunk-size",
|
| 99 |
+
type=int,
|
| 100 |
+
default=None,
|
| 101 |
+
help=f"Minimum chunk size in tokens (default: {settings.min_chunk_size})",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--verbose",
|
| 106 |
+
"-v",
|
| 107 |
+
action="store_true",
|
| 108 |
+
help="Enable verbose output with detailed error messages",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--index-vectors",
|
| 113 |
+
action="store_true",
|
| 114 |
+
help="Build vector index in Qdrant after processing",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--index-only",
|
| 119 |
+
action="store_true",
|
| 120 |
+
help="Skip processing, only build vector index from existing processed files",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--recreate-collection",
|
| 125 |
+
action="store_true",
|
| 126 |
+
help="Recreate Qdrant collection (deletes existing data)",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--embedding-batch-size",
|
| 131 |
+
type=int,
|
| 132 |
+
default=32,
|
| 133 |
+
help="Batch size for embedding generation (default: 32)",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--embedding-model",
|
| 138 |
+
type=str,
|
| 139 |
+
default="sentence-transformers/all-mpnet-base-v2",
|
| 140 |
+
help="Sentence transformer model name (default: all-mpnet-base-v2)",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return parser.parse_args()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def print_banner(console: Console):
|
| 147 |
+
"""Print welcome banner."""
|
| 148 |
+
banner = """
|
| 149 |
+
[bold cyan]EyeWiki Index Builder[/bold cyan]
|
| 150 |
+
[dim]Processing pipeline: Markdown � Metadata Extraction � Semantic Chunking � JSON[/dim]
|
| 151 |
+
"""
|
| 152 |
+
console.print(Panel(banner, border_style="cyan"))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def load_markdown_file(md_file: Path) -> tuple[str, Dict]:
|
| 156 |
+
"""
|
| 157 |
+
Load markdown content and corresponding JSON metadata.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
md_file: Path to markdown file
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Tuple of (content, metadata)
|
| 164 |
+
|
| 165 |
+
Raises:
|
| 166 |
+
FileNotFoundError: If JSON metadata file not found
|
| 167 |
+
ValueError: If content is empty or metadata is invalid
|
| 168 |
+
"""
|
| 169 |
+
# Read markdown content
|
| 170 |
+
with open(md_file, "r", encoding="utf-8") as f:
|
| 171 |
+
content = f.read()
|
| 172 |
+
|
| 173 |
+
if not content.strip():
|
| 174 |
+
raise ValueError("Empty markdown content")
|
| 175 |
+
|
| 176 |
+
# Look for corresponding JSON metadata
|
| 177 |
+
json_file = md_file.with_suffix(".json")
|
| 178 |
+
if not json_file.exists():
|
| 179 |
+
raise FileNotFoundError(f"Metadata file not found: {json_file}")
|
| 180 |
+
|
| 181 |
+
# Read metadata
|
| 182 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 183 |
+
metadata = json.load(f)
|
| 184 |
+
|
| 185 |
+
if not isinstance(metadata, dict):
|
| 186 |
+
raise ValueError("Invalid metadata format (must be dict)")
|
| 187 |
+
|
| 188 |
+
return content, metadata
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def process_file(
|
| 192 |
+
md_file: Path,
|
| 193 |
+
output_dir: Path,
|
| 194 |
+
chunker: SemanticChunker,
|
| 195 |
+
extractor: MetadataExtractor,
|
| 196 |
+
rebuild: bool = False,
|
| 197 |
+
verbose: bool = False,
|
| 198 |
+
) -> Dict:
|
| 199 |
+
"""
|
| 200 |
+
Process a single markdown file through the pipeline.
|
| 201 |
+
|
| 202 |
+
Pipeline:
|
| 203 |
+
1. Load markdown and metadata
|
| 204 |
+
2. Extract medical metadata
|
| 205 |
+
3. Chunk document
|
| 206 |
+
4. Save chunks to JSON
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
md_file: Path to markdown file
|
| 210 |
+
output_dir: Output directory for chunks
|
| 211 |
+
chunker: SemanticChunker instance
|
| 212 |
+
extractor: MetadataExtractor instance
|
| 213 |
+
rebuild: Force rebuild even if output exists
|
| 214 |
+
verbose: Enable verbose error output
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Dictionary with processing results and statistics
|
| 218 |
+
"""
|
| 219 |
+
result = {
|
| 220 |
+
"file": md_file.name,
|
| 221 |
+
"status": "pending",
|
| 222 |
+
"chunks_created": 0,
|
| 223 |
+
"total_tokens": 0,
|
| 224 |
+
"error": None,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
output_file = output_dir / f"{md_file.stem}_chunks.json"
|
| 228 |
+
|
| 229 |
+
# Check if output already exists
|
| 230 |
+
if output_file.exists() and not rebuild:
|
| 231 |
+
result["status"] = "skipped"
|
| 232 |
+
result["error"] = "Output already exists (use --rebuild to force)"
|
| 233 |
+
return result
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
# Step 1: Load file
|
| 237 |
+
content, metadata = load_markdown_file(md_file)
|
| 238 |
+
|
| 239 |
+
# Step 2: Extract medical metadata
|
| 240 |
+
enhanced_metadata = extractor.extract(content, metadata)
|
| 241 |
+
|
| 242 |
+
# Step 3: Chunk document
|
| 243 |
+
chunks = chunker.chunk_document(content, enhanced_metadata)
|
| 244 |
+
|
| 245 |
+
if not chunks:
|
| 246 |
+
result["status"] = "skipped"
|
| 247 |
+
result["error"] = "No chunks created (content too small or filtered)"
|
| 248 |
+
return result
|
| 249 |
+
|
| 250 |
+
# Step 4: Save chunks to JSON
|
| 251 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 252 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 253 |
+
chunk_dicts = [chunk.to_dict() for chunk in chunks]
|
| 254 |
+
json.dump(chunk_dicts, f, indent=2, ensure_ascii=False)
|
| 255 |
+
|
| 256 |
+
# Update result
|
| 257 |
+
result["status"] = "success"
|
| 258 |
+
result["chunks_created"] = len(chunks)
|
| 259 |
+
result["total_tokens"] = sum(chunk.token_count for chunk in chunks)
|
| 260 |
+
|
| 261 |
+
except FileNotFoundError as e:
|
| 262 |
+
result["status"] = "error"
|
| 263 |
+
result["error"] = f"File not found: {e}"
|
| 264 |
+
if verbose:
|
| 265 |
+
result["traceback"] = traceback.format_exc()
|
| 266 |
+
|
| 267 |
+
except ValueError as e:
|
| 268 |
+
result["status"] = "error"
|
| 269 |
+
result["error"] = f"Invalid data: {e}"
|
| 270 |
+
if verbose:
|
| 271 |
+
result["traceback"] = traceback.format_exc()
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
result["status"] = "error"
|
| 275 |
+
result["error"] = f"Unexpected error: {e}"
|
| 276 |
+
if verbose:
|
| 277 |
+
result["traceback"] = traceback.format_exc()
|
| 278 |
+
|
| 279 |
+
return result
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def print_statistics(results: List[Dict], console: Console):
|
| 283 |
+
"""
|
| 284 |
+
Print processing statistics.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
results: List of processing results
|
| 288 |
+
console: Rich console for output
|
| 289 |
+
"""
|
| 290 |
+
# Calculate statistics
|
| 291 |
+
total_files = len(results)
|
| 292 |
+
successful = sum(1 for r in results if r["status"] == "success")
|
| 293 |
+
skipped = sum(1 for r in results if r["status"] == "skipped")
|
| 294 |
+
errors = sum(1 for r in results if r["status"] == "error")
|
| 295 |
+
|
| 296 |
+
total_chunks = sum(r["chunks_created"] for r in results)
|
| 297 |
+
total_tokens = sum(r["total_tokens"] for r in results)
|
| 298 |
+
|
| 299 |
+
avg_chunks = total_chunks / successful if successful > 0 else 0
|
| 300 |
+
avg_tokens_per_chunk = total_tokens / total_chunks if total_chunks > 0 else 0
|
| 301 |
+
avg_tokens_per_doc = total_tokens / successful if successful > 0 else 0
|
| 302 |
+
|
| 303 |
+
# Create statistics table
|
| 304 |
+
table = Table(title="Processing Statistics", border_style="green")
|
| 305 |
+
table.add_column("Metric", style="cyan", justify="left")
|
| 306 |
+
table.add_column("Value", style="white", justify="right")
|
| 307 |
+
|
| 308 |
+
table.add_row("Total Files", f"{total_files:,}")
|
| 309 |
+
table.add_row("Successfully Processed", f"{successful:,}")
|
| 310 |
+
table.add_row("Skipped", f"{skipped:,}")
|
| 311 |
+
table.add_row("Errors", f"{errors:,}")
|
| 312 |
+
table.add_row("", "") # Separator
|
| 313 |
+
table.add_row("Total Chunks Created", f"{total_chunks:,}")
|
| 314 |
+
table.add_row("Total Tokens", f"{total_tokens:,}")
|
| 315 |
+
table.add_row("", "") # Separator
|
| 316 |
+
table.add_row("Avg Chunks per Document", f"{avg_chunks:.1f}")
|
| 317 |
+
table.add_row("Avg Tokens per Chunk", f"{avg_tokens_per_chunk:.1f}")
|
| 318 |
+
table.add_row("Avg Tokens per Document", f"{avg_tokens_per_doc:.1f}")
|
| 319 |
+
|
| 320 |
+
console.print("\n")
|
| 321 |
+
console.print(table)
|
| 322 |
+
|
| 323 |
+
# Show error details if any
|
| 324 |
+
error_results = [r for r in results if r["status"] == "error"]
|
| 325 |
+
if error_results:
|
| 326 |
+
console.print("\n[yellow]Error Details:[/yellow]")
|
| 327 |
+
for i, result in enumerate(error_results[:10], 1):
|
| 328 |
+
console.print(f" {i}. [red]{result['file']}[/red]")
|
| 329 |
+
console.print(f" [dim]{result['error']}[/dim]")
|
| 330 |
+
if "traceback" in result:
|
| 331 |
+
console.print(f" [dim]{result['traceback']}[/dim]")
|
| 332 |
+
|
| 333 |
+
if len(error_results) > 10:
|
| 334 |
+
console.print(f" [dim]... and {len(error_results) - 10} more errors[/dim]")
|
| 335 |
+
|
| 336 |
+
# Show skipped details if any
|
| 337 |
+
skip_results = [r for r in results if r["status"] == "skipped"]
|
| 338 |
+
if skip_results and len(skip_results) <= 5:
|
| 339 |
+
console.print("\n[yellow]Skipped Files:[/yellow]")
|
| 340 |
+
for i, result in enumerate(skip_results, 1):
|
| 341 |
+
console.print(f" {i}. {result['file']}: {result['error']}")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def load_processed_chunks(processed_dir: Path, console: Console) -> List[ChunkNode]:
|
| 345 |
+
"""
|
| 346 |
+
Load all processed chunks from JSON files.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
processed_dir: Directory containing processed chunk JSON files
|
| 350 |
+
console: Rich console for output
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
List of ChunkNode objects
|
| 354 |
+
"""
|
| 355 |
+
chunk_files = list(processed_dir.glob("*_chunks.json"))
|
| 356 |
+
|
| 357 |
+
if not chunk_files:
|
| 358 |
+
console.print(f"[yellow]No processed chunk files found in {processed_dir}[/yellow]")
|
| 359 |
+
return []
|
| 360 |
+
|
| 361 |
+
all_chunks = []
|
| 362 |
+
|
| 363 |
+
console.print(f"\n[cyan]Loading processed chunks from {len(chunk_files)} files...[/cyan]")
|
| 364 |
+
|
| 365 |
+
with tqdm(chunk_files, desc="Loading chunks", unit="file") as pbar:
|
| 366 |
+
for chunk_file in pbar:
|
| 367 |
+
try:
|
| 368 |
+
with open(chunk_file, "r", encoding="utf-8") as f:
|
| 369 |
+
chunk_dicts = json.load(f)
|
| 370 |
+
|
| 371 |
+
# Convert dicts to ChunkNode objects
|
| 372 |
+
for chunk_dict in chunk_dicts:
|
| 373 |
+
chunk = ChunkNode.from_dict(chunk_dict)
|
| 374 |
+
all_chunks.append(chunk)
|
| 375 |
+
|
| 376 |
+
pbar.set_postfix({"total_chunks": len(all_chunks)})
|
| 377 |
+
|
| 378 |
+
except Exception as e:
|
| 379 |
+
console.print(f"[red]Error loading {chunk_file.name}: {e}[/red]")
|
| 380 |
+
|
| 381 |
+
console.print(f"[green]✓[/green] Loaded {len(all_chunks):,} chunks")
|
| 382 |
+
return all_chunks
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def build_vector_index(
|
| 386 |
+
chunks: List[ChunkNode],
|
| 387 |
+
embedding_client: SentenceTransformerClient,
|
| 388 |
+
qdrant_manager: QdrantStoreManager,
|
| 389 |
+
batch_size: int,
|
| 390 |
+
console: Console,
|
| 391 |
+
) -> Dict:
|
| 392 |
+
"""
|
| 393 |
+
Build vector index by generating embeddings and inserting into Qdrant.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
chunks: List of ChunkNode objects
|
| 397 |
+
embedding_client: SentenceTransformerClient for stable embeddings
|
| 398 |
+
qdrant_manager: QdrantStoreManager for vector storage
|
| 399 |
+
batch_size: Batch size for embedding generation
|
| 400 |
+
console: Rich console for output
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
Dictionary with indexing statistics
|
| 404 |
+
"""
|
| 405 |
+
if not chunks:
|
| 406 |
+
console.print("[yellow]No chunks to index[/yellow]")
|
| 407 |
+
return {"chunks_indexed": 0, "time_taken": 0}
|
| 408 |
+
|
| 409 |
+
console.print(f"\n[bold cyan]Building Vector Index[/bold cyan]")
|
| 410 |
+
console.print(f"Chunks to index: {len(chunks):,}")
|
| 411 |
+
console.print(f"Embedding batch size: {batch_size}")
|
| 412 |
+
|
| 413 |
+
import time
|
| 414 |
+
start_time = time.time()
|
| 415 |
+
|
| 416 |
+
# Extract text content for embedding
|
| 417 |
+
texts = [chunk.content for chunk in chunks]
|
| 418 |
+
|
| 419 |
+
# Generate embeddings with progress bar
|
| 420 |
+
console.print("\n[cyan]Generating embeddings...[/cyan]")
|
| 421 |
+
try:
|
| 422 |
+
embeddings = embedding_client.embed_batch(
|
| 423 |
+
texts=texts,
|
| 424 |
+
batch_size=batch_size,
|
| 425 |
+
show_progress=True,
|
| 426 |
+
)
|
| 427 |
+
except Exception as e:
|
| 428 |
+
console.print(f"[red]Failed to generate embeddings: {e}[/red]")
|
| 429 |
+
raise
|
| 430 |
+
|
| 431 |
+
# Insert into Qdrant
|
| 432 |
+
console.print("\n[cyan]Inserting into Qdrant...[/cyan]")
|
| 433 |
+
try:
|
| 434 |
+
num_added = qdrant_manager.add_documents(
|
| 435 |
+
chunks=chunks,
|
| 436 |
+
dense_embeddings=embeddings,
|
| 437 |
+
)
|
| 438 |
+
except Exception as e:
|
| 439 |
+
console.print(f"[red]Failed to insert into Qdrant: {e}[/red]")
|
| 440 |
+
raise
|
| 441 |
+
|
| 442 |
+
elapsed_time = time.time() - start_time
|
| 443 |
+
|
| 444 |
+
# Get collection info
|
| 445 |
+
try:
|
| 446 |
+
collection_info = qdrant_manager.get_collection_info()
|
| 447 |
+
except Exception as e:
|
| 448 |
+
console.print(f"[yellow]Could not get collection info: {e}[/yellow]")
|
| 449 |
+
collection_info = {}
|
| 450 |
+
|
| 451 |
+
stats = {
|
| 452 |
+
"chunks_indexed": num_added,
|
| 453 |
+
"time_taken": elapsed_time,
|
| 454 |
+
"chunks_per_second": num_added / elapsed_time if elapsed_time > 0 else 0,
|
| 455 |
+
"collection_info": collection_info,
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
return stats
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def print_index_statistics(stats: Dict, console: Console):
|
| 462 |
+
"""
|
| 463 |
+
Print vector indexing statistics.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
stats: Statistics dictionary
|
| 467 |
+
console: Rich console for output
|
| 468 |
+
"""
|
| 469 |
+
table = Table(title="Vector Index Statistics", border_style="green")
|
| 470 |
+
table.add_column("Metric", style="cyan", justify="left")
|
| 471 |
+
table.add_column("Value", style="white", justify="right")
|
| 472 |
+
|
| 473 |
+
table.add_row("Chunks Indexed", f"{stats['chunks_indexed']:,}")
|
| 474 |
+
table.add_row("Time Taken", f"{stats['time_taken']:.1f}s")
|
| 475 |
+
table.add_row("Chunks/Second", f"{stats['chunks_per_second']:.1f}")
|
| 476 |
+
|
| 477 |
+
if "collection_info" in stats and stats["collection_info"]:
|
| 478 |
+
info = stats["collection_info"]
|
| 479 |
+
table.add_row("", "") # Separator
|
| 480 |
+
table.add_row("Collection Name", info.get("name", "N/A"))
|
| 481 |
+
table.add_row("Total Vectors", f"{info.get('vectors_count', 0):,}")
|
| 482 |
+
table.add_row("Total Points", f"{info.get('points_count', 0):,}")
|
| 483 |
+
table.add_row("Status", info.get("status", "N/A"))
|
| 484 |
+
|
| 485 |
+
console.print("\n")
|
| 486 |
+
console.print(table)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def main():
|
| 490 |
+
"""Main entry point for index building."""
|
| 491 |
+
args = parse_args()
|
| 492 |
+
console = Console()
|
| 493 |
+
|
| 494 |
+
# Print banner
|
| 495 |
+
print_banner(console)
|
| 496 |
+
|
| 497 |
+
# Prepare directories
|
| 498 |
+
input_dir = Path(args.input_dir) if args.input_dir else Path(settings.data_raw_path)
|
| 499 |
+
output_dir = Path(args.output_dir) if args.output_dir else Path(settings.data_processed_path)
|
| 500 |
+
|
| 501 |
+
# Check mode
|
| 502 |
+
index_only = args.index_only
|
| 503 |
+
should_index = args.index_vectors or args.index_only
|
| 504 |
+
|
| 505 |
+
# Print mode
|
| 506 |
+
if index_only:
|
| 507 |
+
console.print("[cyan]Mode:[/cyan] Index only (skip processing)")
|
| 508 |
+
elif should_index:
|
| 509 |
+
console.print("[cyan]Mode:[/cyan] Process and build vector index")
|
| 510 |
+
else:
|
| 511 |
+
console.print("[cyan]Mode:[/cyan] Process only (no vector indexing)")
|
| 512 |
+
|
| 513 |
+
# Validate input directory (only needed if not index-only)
|
| 514 |
+
if not index_only and not input_dir.exists():
|
| 515 |
+
console.print(f"[bold red]Error: Input directory does not exist: {input_dir}[/bold red]")
|
| 516 |
+
return 1
|
| 517 |
+
|
| 518 |
+
# Validate output directory exists (needed for index-only)
|
| 519 |
+
if index_only and not output_dir.exists():
|
| 520 |
+
console.print(f"[bold red]Error: Output directory does not exist: {output_dir}[/bold red]")
|
| 521 |
+
console.print("[yellow]Please run processing first without --index-only[/yellow]")
|
| 522 |
+
return 1
|
| 523 |
+
|
| 524 |
+
# Print configuration
|
| 525 |
+
if not index_only:
|
| 526 |
+
# Find all markdown files
|
| 527 |
+
md_files = list(input_dir.glob(args.pattern))
|
| 528 |
+
|
| 529 |
+
if not md_files:
|
| 530 |
+
console.print(f"[yellow]No files matching pattern '{args.pattern}' found in {input_dir}[/yellow]")
|
| 531 |
+
return 0
|
| 532 |
+
|
| 533 |
+
console.print(f"[cyan]Input directory:[/cyan] {input_dir}")
|
| 534 |
+
console.print(f"[cyan]Output directory:[/cyan] {output_dir}")
|
| 535 |
+
console.print(f"[cyan]Files found:[/cyan] {len(md_files)}")
|
| 536 |
+
console.print(f"[cyan]Pattern:[/cyan] {args.pattern}")
|
| 537 |
+
console.print(f"[cyan]Rebuild mode:[/cyan] {'Yes' if args.rebuild else 'No'}")
|
| 538 |
+
else:
|
| 539 |
+
console.print(f"[cyan]Processed directory:[/cyan] {output_dir}")
|
| 540 |
+
|
| 541 |
+
# Initialize components (only if processing)
|
| 542 |
+
results = []
|
| 543 |
+
|
| 544 |
+
if not index_only:
|
| 545 |
+
chunker = SemanticChunker(
|
| 546 |
+
chunk_size=args.chunk_size if args.chunk_size is not None else settings.chunk_size,
|
| 547 |
+
chunk_overlap=args.chunk_overlap if args.chunk_overlap is not None else settings.chunk_overlap,
|
| 548 |
+
min_chunk_size=args.min_chunk_size if args.min_chunk_size is not None else settings.min_chunk_size,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
extractor = MetadataExtractor()
|
| 552 |
+
|
| 553 |
+
console.print(f"[cyan]Chunk size:[/cyan] {chunker.chunk_size} tokens")
|
| 554 |
+
console.print(f"[cyan]Chunk overlap:[/cyan] {chunker.chunk_overlap} tokens")
|
| 555 |
+
console.print(f"[cyan]Min chunk size:[/cyan] {chunker.min_chunk_size} tokens")
|
| 556 |
+
console.print()
|
| 557 |
+
|
| 558 |
+
# Process files with progress bar
|
| 559 |
+
console.print("[bold cyan]Processing Files...[/bold cyan]\n")
|
| 560 |
+
|
| 561 |
+
with tqdm(
|
| 562 |
+
total=len(md_files),
|
| 563 |
+
desc="Processing",
|
| 564 |
+
unit="file",
|
| 565 |
+
ncols=100,
|
| 566 |
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
|
| 567 |
+
) as pbar:
|
| 568 |
+
|
| 569 |
+
for md_file in md_files:
|
| 570 |
+
# Update progress bar description
|
| 571 |
+
pbar.set_description(f"Processing {md_file.name[:30]:30}")
|
| 572 |
+
|
| 573 |
+
# Process file
|
| 574 |
+
result = process_file(
|
| 575 |
+
md_file=md_file,
|
| 576 |
+
output_dir=output_dir,
|
| 577 |
+
chunker=chunker,
|
| 578 |
+
extractor=extractor,
|
| 579 |
+
rebuild=args.rebuild,
|
| 580 |
+
verbose=args.verbose,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
results.append(result)
|
| 584 |
+
|
| 585 |
+
# Update progress bar postfix with running stats
|
| 586 |
+
successful = sum(1 for r in results if r["status"] == "success")
|
| 587 |
+
chunks = sum(r["chunks_created"] for r in results)
|
| 588 |
+
pbar.set_postfix({"success": successful, "chunks": chunks})
|
| 589 |
+
|
| 590 |
+
pbar.update(1)
|
| 591 |
+
|
| 592 |
+
# Print statistics
|
| 593 |
+
print_statistics(results, console)
|
| 594 |
+
|
| 595 |
+
# Check processing status
|
| 596 |
+
successful = sum(1 for r in results if r["status"] == "success")
|
| 597 |
+
errors = sum(1 for r in results if r["status"] == "error")
|
| 598 |
+
|
| 599 |
+
console.print()
|
| 600 |
+
if errors == 0 and successful > 0:
|
| 601 |
+
console.print("[bold green]Processing completed successfully![/bold green]")
|
| 602 |
+
console.print(f"[green]Processed files saved to: {output_dir}[/green]")
|
| 603 |
+
elif successful > 0:
|
| 604 |
+
console.print("[bold yellow]Processing completed with some errors.[/bold yellow]")
|
| 605 |
+
console.print(f"[yellow]Processed files saved to: {output_dir}[/yellow]")
|
| 606 |
+
else:
|
| 607 |
+
console.print("[bold red]Processing failed - no files were processed successfully.[/bold red]")
|
| 608 |
+
if not should_index:
|
| 609 |
+
return 1
|
| 610 |
+
|
| 611 |
+
# Vector indexing phase
|
| 612 |
+
if should_index:
|
| 613 |
+
try:
|
| 614 |
+
# Initialize embedding client with sentence-transformers
|
| 615 |
+
console.print("\n[bold cyan]Initializing Sentence Transformers Client...[/bold cyan]")
|
| 616 |
+
try:
|
| 617 |
+
embedding_client = SentenceTransformerClient(model_name=args.embedding_model)
|
| 618 |
+
model_info = embedding_client.get_model_info()
|
| 619 |
+
console.print(f"[green]✓[/green] Loaded model: {model_info['model_name']}")
|
| 620 |
+
console.print(f"[green]✓[/green] Device: {model_info['device']}")
|
| 621 |
+
console.print(f"[green]✓[/green] Embedding dimension: {model_info['embedding_dim']}")
|
| 622 |
+
except Exception as e:
|
| 623 |
+
console.print(f"[bold red]Failed to initialize Sentence Transformers: {e}[/bold red]")
|
| 624 |
+
console.print("[yellow]Install sentence-transformers: pip install sentence-transformers torch[/yellow]")
|
| 625 |
+
return 1
|
| 626 |
+
|
| 627 |
+
# Initialize Qdrant store
|
| 628 |
+
console.print("\n[bold cyan]Initializing Qdrant Store...[/bold cyan]")
|
| 629 |
+
try:
|
| 630 |
+
qdrant_manager = QdrantStoreManager()
|
| 631 |
+
qdrant_manager.initialize_collection(recreate=args.recreate_collection)
|
| 632 |
+
except Exception as e:
|
| 633 |
+
console.print(f"[bold red]Failed to initialize Qdrant: {e}[/bold red]")
|
| 634 |
+
return 1
|
| 635 |
+
|
| 636 |
+
# Load processed chunks
|
| 637 |
+
chunks = load_processed_chunks(output_dir, console)
|
| 638 |
+
|
| 639 |
+
if not chunks:
|
| 640 |
+
console.print("[yellow]No chunks to index. Please process documents first.[/yellow]")
|
| 641 |
+
return 0
|
| 642 |
+
|
| 643 |
+
# Build vector index
|
| 644 |
+
try:
|
| 645 |
+
index_stats = build_vector_index(
|
| 646 |
+
chunks=chunks,
|
| 647 |
+
embedding_client=embedding_client,
|
| 648 |
+
qdrant_manager=qdrant_manager,
|
| 649 |
+
batch_size=args.embedding_batch_size,
|
| 650 |
+
console=console,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
# Print index statistics
|
| 654 |
+
print_index_statistics(index_stats, console)
|
| 655 |
+
|
| 656 |
+
console.print("\n[bold green]Vector indexing completed successfully![/bold green]")
|
| 657 |
+
|
| 658 |
+
except Exception as e:
|
| 659 |
+
console.print(f"\n[bold red]Vector indexing failed: {e}[/bold red]")
|
| 660 |
+
if args.verbose:
|
| 661 |
+
traceback.print_exc()
|
| 662 |
+
return 1
|
| 663 |
+
|
| 664 |
+
except KeyboardInterrupt:
|
| 665 |
+
console.print("\n[yellow]Indexing interrupted by user (Ctrl+C)[/yellow]")
|
| 666 |
+
return 130
|
| 667 |
+
|
| 668 |
+
return 0
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
if __name__ == "__main__":
|
| 672 |
+
try:
|
| 673 |
+
exit_code = main()
|
| 674 |
+
sys.exit(exit_code)
|
| 675 |
+
except KeyboardInterrupt:
|
| 676 |
+
console = Console()
|
| 677 |
+
console.print("\n[yellow]Process interrupted by user (Ctrl+C)[/yellow]")
|
| 678 |
+
sys.exit(130)
|
| 679 |
+
except Exception as e:
|
| 680 |
+
console = Console()
|
| 681 |
+
console.print(f"\n[bold red]Fatal error: {e}[/bold red]")
|
| 682 |
+
traceback.print_exc()
|
| 683 |
+
sys.exit(1)
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Evaluation script for EyeWiki RAG system.
|
| 4 |
+
|
| 5 |
+
Evaluates the system on a set of test questions and measures:
|
| 6 |
+
- Retrieval recall (relevant sources retrieved)
|
| 7 |
+
- Answer relevance (expected topics covered)
|
| 8 |
+
- Source citation accuracy
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python scripts/evaluate.py
|
| 12 |
+
python scripts/evaluate.py --questions tests/custom_questions.json
|
| 13 |
+
python scripts/evaluate.py --output results/eval_results.json
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Any
|
| 22 |
+
|
| 23 |
+
from rich.console import Console
|
| 24 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
|
| 25 |
+
from rich.table import Table
|
| 26 |
+
from rich.panel import Panel
|
| 27 |
+
|
| 28 |
+
# Add project root to path
|
| 29 |
+
project_root = Path(__file__).parent.parent
|
| 30 |
+
sys.path.insert(0, str(project_root))
|
| 31 |
+
|
| 32 |
+
from config.settings import Settings
|
| 33 |
+
from src.llm.ollama_client import OllamaClient
|
| 34 |
+
from src.rag.query_engine import EyeWikiQueryEngine
|
| 35 |
+
from src.rag.reranker import CrossEncoderReranker
|
| 36 |
+
from src.rag.retriever import HybridRetriever
|
| 37 |
+
from src.vectorstore.qdrant_store import QdrantStoreManager
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
console = Console()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============================================================================
|
| 44 |
+
# Evaluation Metrics
|
| 45 |
+
# ============================================================================
|
| 46 |
+
|
| 47 |
+
def calculate_retrieval_recall(
|
| 48 |
+
retrieved_sources: List[str],
|
| 49 |
+
expected_sources: List[str],
|
| 50 |
+
) -> float:
|
| 51 |
+
"""
|
| 52 |
+
Calculate retrieval recall.
|
| 53 |
+
|
| 54 |
+
Recall = (# of expected sources retrieved) / (# of expected sources)
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
retrieved_sources: List of retrieved source titles
|
| 58 |
+
expected_sources: List of expected source titles
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Recall score (0-1)
|
| 62 |
+
"""
|
| 63 |
+
if not expected_sources:
|
| 64 |
+
return 1.0
|
| 65 |
+
|
| 66 |
+
# Normalize for case-insensitive matching
|
| 67 |
+
retrieved_lower = {s.lower() for s in retrieved_sources}
|
| 68 |
+
expected_lower = {s.lower() for s in expected_sources}
|
| 69 |
+
|
| 70 |
+
# Count matches (allow partial matching)
|
| 71 |
+
matches = 0
|
| 72 |
+
for expected in expected_lower:
|
| 73 |
+
for retrieved in retrieved_lower:
|
| 74 |
+
# Check if expected source name is in retrieved source or vice versa
|
| 75 |
+
if expected in retrieved or retrieved in expected:
|
| 76 |
+
matches += 1
|
| 77 |
+
break
|
| 78 |
+
|
| 79 |
+
recall = matches / len(expected_sources) if expected_sources else 0.0
|
| 80 |
+
return recall
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def calculate_answer_relevance(
|
| 84 |
+
answer: str,
|
| 85 |
+
expected_topics: List[str],
|
| 86 |
+
) -> float:
|
| 87 |
+
"""
|
| 88 |
+
Calculate answer relevance based on topic coverage.
|
| 89 |
+
|
| 90 |
+
Relevance = (# of expected topics found) / (# of expected topics)
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
answer: Generated answer text
|
| 94 |
+
expected_topics: List of expected topic keywords
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Relevance score (0-1)
|
| 98 |
+
"""
|
| 99 |
+
if not expected_topics:
|
| 100 |
+
return 1.0
|
| 101 |
+
|
| 102 |
+
answer_lower = answer.lower()
|
| 103 |
+
|
| 104 |
+
# Count how many expected topics appear in answer
|
| 105 |
+
topics_found = sum(1 for topic in expected_topics if topic.lower() in answer_lower)
|
| 106 |
+
|
| 107 |
+
relevance = topics_found / len(expected_topics) if expected_topics else 0.0
|
| 108 |
+
return relevance
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def calculate_citation_accuracy(
|
| 112 |
+
answer: str,
|
| 113 |
+
cited_sources: List[str],
|
| 114 |
+
expected_sources: List[str],
|
| 115 |
+
) -> Dict[str, float]:
|
| 116 |
+
"""
|
| 117 |
+
Calculate citation accuracy metrics.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
answer: Generated answer text
|
| 121 |
+
cited_sources: Sources returned by system
|
| 122 |
+
expected_sources: Expected sources
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Dictionary with citation metrics
|
| 126 |
+
"""
|
| 127 |
+
# Check if answer contains explicit citations
|
| 128 |
+
has_citations = "[Source:" in answer or "According to" in answer
|
| 129 |
+
|
| 130 |
+
# Calculate precision and recall
|
| 131 |
+
if cited_sources and expected_sources:
|
| 132 |
+
cited_set = {s.lower() for s in cited_sources}
|
| 133 |
+
expected_set = {s.lower() for s in expected_sources}
|
| 134 |
+
|
| 135 |
+
# Allow partial matching
|
| 136 |
+
true_positives = 0
|
| 137 |
+
for cited in cited_set:
|
| 138 |
+
for expected in expected_set:
|
| 139 |
+
if expected in cited or cited in expected:
|
| 140 |
+
true_positives += 1
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
precision = true_positives / len(cited_sources) if cited_sources else 0.0
|
| 144 |
+
recall = true_positives / len(expected_sources) if expected_sources else 0.0
|
| 145 |
+
|
| 146 |
+
# F1 score
|
| 147 |
+
f1 = (
|
| 148 |
+
2 * (precision * recall) / (precision + recall)
|
| 149 |
+
if (precision + recall) > 0
|
| 150 |
+
else 0.0
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
precision = 0.0
|
| 154 |
+
recall = 0.0
|
| 155 |
+
f1 = 0.0
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"has_explicit_citations": has_citations,
|
| 159 |
+
"precision": precision,
|
| 160 |
+
"recall": recall,
|
| 161 |
+
"f1": f1,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ============================================================================
|
| 166 |
+
# Question Evaluation
|
| 167 |
+
# ============================================================================
|
| 168 |
+
|
| 169 |
+
def evaluate_question(
|
| 170 |
+
question_data: Dict[str, Any],
|
| 171 |
+
query_engine: EyeWikiQueryEngine,
|
| 172 |
+
) -> Dict[str, Any]:
|
| 173 |
+
"""
|
| 174 |
+
Evaluate a single question.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
question_data: Question data with expected answers
|
| 178 |
+
query_engine: Query engine instance
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Evaluation results
|
| 182 |
+
"""
|
| 183 |
+
question_id = question_data["id"]
|
| 184 |
+
question = question_data["question"]
|
| 185 |
+
expected_topics = question_data["expected_topics"]
|
| 186 |
+
expected_sources = question_data["expected_sources"]
|
| 187 |
+
|
| 188 |
+
# Query the system
|
| 189 |
+
start_time = time.time()
|
| 190 |
+
try:
|
| 191 |
+
response = query_engine.query(
|
| 192 |
+
question=question,
|
| 193 |
+
include_sources=True,
|
| 194 |
+
)
|
| 195 |
+
query_time = time.time() - start_time
|
| 196 |
+
|
| 197 |
+
# Extract retrieved sources
|
| 198 |
+
retrieved_sources = [s.title for s in response.sources]
|
| 199 |
+
|
| 200 |
+
# Calculate metrics
|
| 201 |
+
retrieval_recall = calculate_retrieval_recall(
|
| 202 |
+
retrieved_sources, expected_sources
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
answer_relevance = calculate_answer_relevance(
|
| 206 |
+
response.answer, expected_topics
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
citation_metrics = calculate_citation_accuracy(
|
| 210 |
+
response.answer, retrieved_sources, expected_sources
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Detailed topic analysis
|
| 214 |
+
topics_found = [
|
| 215 |
+
topic for topic in expected_topics if topic.lower() in response.answer.lower()
|
| 216 |
+
]
|
| 217 |
+
topics_missing = [
|
| 218 |
+
topic
|
| 219 |
+
for topic in expected_topics
|
| 220 |
+
if topic.lower() not in response.answer.lower()
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
# Source analysis
|
| 224 |
+
sources_retrieved = []
|
| 225 |
+
sources_missing = []
|
| 226 |
+
|
| 227 |
+
for expected in expected_sources:
|
| 228 |
+
found = False
|
| 229 |
+
for retrieved in retrieved_sources:
|
| 230 |
+
if expected.lower() in retrieved.lower() or retrieved.lower() in expected.lower():
|
| 231 |
+
sources_retrieved.append(expected)
|
| 232 |
+
found = True
|
| 233 |
+
break
|
| 234 |
+
if not found:
|
| 235 |
+
sources_missing.append(expected)
|
| 236 |
+
|
| 237 |
+
result = {
|
| 238 |
+
"id": question_id,
|
| 239 |
+
"question": question,
|
| 240 |
+
"category": question_data.get("category", "unknown"),
|
| 241 |
+
"answer": response.answer,
|
| 242 |
+
"confidence": response.confidence,
|
| 243 |
+
"query_time": query_time,
|
| 244 |
+
"metrics": {
|
| 245 |
+
"retrieval_recall": retrieval_recall,
|
| 246 |
+
"answer_relevance": answer_relevance,
|
| 247 |
+
"citation_precision": citation_metrics["precision"],
|
| 248 |
+
"citation_recall": citation_metrics["recall"],
|
| 249 |
+
"citation_f1": citation_metrics["f1"],
|
| 250 |
+
},
|
| 251 |
+
"details": {
|
| 252 |
+
"retrieved_sources": retrieved_sources,
|
| 253 |
+
"expected_sources": expected_sources,
|
| 254 |
+
"sources_retrieved": sources_retrieved,
|
| 255 |
+
"sources_missing": sources_missing,
|
| 256 |
+
"topics_found": topics_found,
|
| 257 |
+
"topics_missing": topics_missing,
|
| 258 |
+
"has_explicit_citations": citation_metrics["has_explicit_citations"],
|
| 259 |
+
},
|
| 260 |
+
"success": True,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
except Exception as e:
|
| 264 |
+
result = {
|
| 265 |
+
"id": question_id,
|
| 266 |
+
"question": question,
|
| 267 |
+
"category": question_data.get("category", "unknown"),
|
| 268 |
+
"error": str(e),
|
| 269 |
+
"query_time": time.time() - start_time,
|
| 270 |
+
"success": False,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
return result
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ============================================================================
|
| 277 |
+
# Aggregate Analysis
|
| 278 |
+
# ============================================================================
|
| 279 |
+
|
| 280 |
+
def calculate_aggregate_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 281 |
+
"""
|
| 282 |
+
Calculate aggregate metrics across all questions.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
results: List of evaluation results
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Aggregate metrics
|
| 289 |
+
"""
|
| 290 |
+
successful_results = [r for r in results if r["success"]]
|
| 291 |
+
|
| 292 |
+
if not successful_results:
|
| 293 |
+
return {"error": "No successful evaluations"}
|
| 294 |
+
|
| 295 |
+
# Average metrics
|
| 296 |
+
avg_retrieval_recall = sum(
|
| 297 |
+
r["metrics"]["retrieval_recall"] for r in successful_results
|
| 298 |
+
) / len(successful_results)
|
| 299 |
+
|
| 300 |
+
avg_answer_relevance = sum(
|
| 301 |
+
r["metrics"]["answer_relevance"] for r in successful_results
|
| 302 |
+
) / len(successful_results)
|
| 303 |
+
|
| 304 |
+
avg_citation_precision = sum(
|
| 305 |
+
r["metrics"]["citation_precision"] for r in successful_results
|
| 306 |
+
) / len(successful_results)
|
| 307 |
+
|
| 308 |
+
avg_citation_recall = sum(
|
| 309 |
+
r["metrics"]["citation_recall"] for r in successful_results
|
| 310 |
+
) / len(successful_results)
|
| 311 |
+
|
| 312 |
+
avg_citation_f1 = sum(
|
| 313 |
+
r["metrics"]["citation_f1"] for r in successful_results
|
| 314 |
+
) / len(successful_results)
|
| 315 |
+
|
| 316 |
+
avg_confidence = sum(r["confidence"] for r in successful_results) / len(
|
| 317 |
+
successful_results
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
avg_query_time = sum(r["query_time"] for r in successful_results) / len(
|
| 321 |
+
successful_results
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Citation statistics
|
| 325 |
+
citations_present = sum(
|
| 326 |
+
1 for r in successful_results if r["details"]["has_explicit_citations"]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Category breakdown
|
| 330 |
+
categories = {}
|
| 331 |
+
for result in successful_results:
|
| 332 |
+
category = result["category"]
|
| 333 |
+
if category not in categories:
|
| 334 |
+
categories[category] = {
|
| 335 |
+
"count": 0,
|
| 336 |
+
"retrieval_recall": 0,
|
| 337 |
+
"answer_relevance": 0,
|
| 338 |
+
}
|
| 339 |
+
categories[category]["count"] += 1
|
| 340 |
+
categories[category]["retrieval_recall"] += result["metrics"]["retrieval_recall"]
|
| 341 |
+
categories[category]["answer_relevance"] += result["metrics"]["answer_relevance"]
|
| 342 |
+
|
| 343 |
+
# Average by category
|
| 344 |
+
for category, data in categories.items():
|
| 345 |
+
count = data["count"]
|
| 346 |
+
data["retrieval_recall"] /= count
|
| 347 |
+
data["answer_relevance"] /= count
|
| 348 |
+
|
| 349 |
+
return {
|
| 350 |
+
"total_questions": len(results),
|
| 351 |
+
"successful": len(successful_results),
|
| 352 |
+
"failed": len(results) - len(successful_results),
|
| 353 |
+
"metrics": {
|
| 354 |
+
"retrieval_recall": avg_retrieval_recall,
|
| 355 |
+
"answer_relevance": avg_answer_relevance,
|
| 356 |
+
"citation_precision": avg_citation_precision,
|
| 357 |
+
"citation_recall": avg_citation_recall,
|
| 358 |
+
"citation_f1": avg_citation_f1,
|
| 359 |
+
"avg_confidence": avg_confidence,
|
| 360 |
+
"avg_query_time": avg_query_time,
|
| 361 |
+
"citation_rate": citations_present / len(successful_results),
|
| 362 |
+
},
|
| 363 |
+
"by_category": categories,
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# ============================================================================
|
| 368 |
+
# Output Functions
|
| 369 |
+
# ============================================================================
|
| 370 |
+
|
| 371 |
+
def print_question_result(result: Dict[str, Any]):
|
| 372 |
+
"""Print result for a single question."""
|
| 373 |
+
if not result["success"]:
|
| 374 |
+
console.print(
|
| 375 |
+
f"\n[red]✗ {result['id']}: {result['question']}[/red]",
|
| 376 |
+
f"[red]Error: {result['error']}[/red]",
|
| 377 |
+
)
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
# Create metrics table
|
| 381 |
+
table = Table(show_header=False, box=None, padding=(0, 1))
|
| 382 |
+
table.add_column(style="cyan")
|
| 383 |
+
table.add_column(style="yellow")
|
| 384 |
+
|
| 385 |
+
metrics = result["metrics"]
|
| 386 |
+
table.add_row("Retrieval Recall", f"{metrics['retrieval_recall']:.2%}")
|
| 387 |
+
table.add_row("Answer Relevance", f"{metrics['answer_relevance']:.2%}")
|
| 388 |
+
table.add_row("Citation F1", f"{metrics['citation_f1']:.2%}")
|
| 389 |
+
table.add_row("Confidence", f"{result['confidence']:.2%}")
|
| 390 |
+
table.add_row("Query Time", f"{result['query_time']:.2f}s")
|
| 391 |
+
|
| 392 |
+
# Determine overall status
|
| 393 |
+
avg_score = (metrics["retrieval_recall"] + metrics["answer_relevance"]) / 2
|
| 394 |
+
if avg_score >= 0.8:
|
| 395 |
+
status = "[green]✓ PASS[/green]"
|
| 396 |
+
elif avg_score >= 0.6:
|
| 397 |
+
status = "[yellow]~ PARTIAL[/yellow]"
|
| 398 |
+
else:
|
| 399 |
+
status = "[red]✗ FAIL[/red]"
|
| 400 |
+
|
| 401 |
+
console.print(f"\n{status} [bold]{result['id']}:[/bold] {result['question']}")
|
| 402 |
+
console.print(table)
|
| 403 |
+
|
| 404 |
+
# Print missing items
|
| 405 |
+
details = result["details"]
|
| 406 |
+
if details["topics_missing"]:
|
| 407 |
+
console.print(
|
| 408 |
+
f" [dim]Missing topics: {', '.join(details['topics_missing'])}[/dim]"
|
| 409 |
+
)
|
| 410 |
+
if details["sources_missing"]:
|
| 411 |
+
console.print(
|
| 412 |
+
f" [dim]Missing sources: {', '.join(details['sources_missing'])}[/dim]"
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def print_aggregate_results(aggregate: Dict[str, Any]):
|
| 417 |
+
"""Print aggregate results."""
|
| 418 |
+
console.print("\n")
|
| 419 |
+
console.print(
|
| 420 |
+
Panel.fit(
|
| 421 |
+
"[bold cyan]Evaluation Summary[/bold cyan]",
|
| 422 |
+
border_style="cyan",
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Overall metrics table
|
| 427 |
+
table = Table(show_header=True, header_style="bold magenta")
|
| 428 |
+
table.add_column("Metric", style="cyan")
|
| 429 |
+
table.add_column("Score", style="yellow", justify="right")
|
| 430 |
+
table.add_column("Grade", style="green", justify="center")
|
| 431 |
+
|
| 432 |
+
metrics = aggregate["metrics"]
|
| 433 |
+
|
| 434 |
+
def get_grade(score: float) -> str:
|
| 435 |
+
if score >= 0.9:
|
| 436 |
+
return "[green]A[/green]"
|
| 437 |
+
elif score >= 0.8:
|
| 438 |
+
return "[green]B[/green]"
|
| 439 |
+
elif score >= 0.7:
|
| 440 |
+
return "[yellow]C[/yellow]"
|
| 441 |
+
elif score >= 0.6:
|
| 442 |
+
return "[yellow]D[/yellow]"
|
| 443 |
+
else:
|
| 444 |
+
return "[red]F[/red]"
|
| 445 |
+
|
| 446 |
+
table.add_row(
|
| 447 |
+
"Retrieval Recall",
|
| 448 |
+
f"{metrics['retrieval_recall']:.2%}",
|
| 449 |
+
get_grade(metrics["retrieval_recall"]),
|
| 450 |
+
)
|
| 451 |
+
table.add_row(
|
| 452 |
+
"Answer Relevance",
|
| 453 |
+
f"{metrics['answer_relevance']:.2%}",
|
| 454 |
+
get_grade(metrics["answer_relevance"]),
|
| 455 |
+
)
|
| 456 |
+
table.add_row(
|
| 457 |
+
"Citation Precision",
|
| 458 |
+
f"{metrics['citation_precision']:.2%}",
|
| 459 |
+
get_grade(metrics["citation_precision"]),
|
| 460 |
+
)
|
| 461 |
+
table.add_row(
|
| 462 |
+
"Citation Recall",
|
| 463 |
+
f"{metrics['citation_recall']:.2%}",
|
| 464 |
+
get_grade(metrics["citation_recall"]),
|
| 465 |
+
)
|
| 466 |
+
table.add_row(
|
| 467 |
+
"Citation F1",
|
| 468 |
+
f"{metrics['citation_f1']:.2%}",
|
| 469 |
+
get_grade(metrics["citation_f1"]),
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
console.print(table)
|
| 473 |
+
|
| 474 |
+
# Statistics
|
| 475 |
+
console.print(f"\n[bold]Statistics:[/bold]")
|
| 476 |
+
console.print(
|
| 477 |
+
f" Total Questions: {aggregate['total_questions']}",
|
| 478 |
+
f" Successful: [green]{aggregate['successful']}[/green]",
|
| 479 |
+
f" Failed: [red]{aggregate['failed']}[/red]",
|
| 480 |
+
f" Avg Confidence: {metrics['avg_confidence']:.2%}",
|
| 481 |
+
f" Avg Query Time: {metrics['avg_query_time']:.2f}s",
|
| 482 |
+
f" Citation Rate: {metrics['citation_rate']:.2%}",
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Category breakdown
|
| 486 |
+
if aggregate["by_category"]:
|
| 487 |
+
console.print(f"\n[bold]Performance by Category:[/bold]")
|
| 488 |
+
cat_table = Table(show_header=True, header_style="bold magenta")
|
| 489 |
+
cat_table.add_column("Category", style="cyan")
|
| 490 |
+
cat_table.add_column("Count", justify="right")
|
| 491 |
+
cat_table.add_column("Retrieval", justify="right")
|
| 492 |
+
cat_table.add_column("Relevance", justify="right")
|
| 493 |
+
|
| 494 |
+
for category, data in sorted(aggregate["by_category"].items()):
|
| 495 |
+
cat_table.add_row(
|
| 496 |
+
category,
|
| 497 |
+
str(data["count"]),
|
| 498 |
+
f"{data['retrieval_recall']:.2%}",
|
| 499 |
+
f"{data['answer_relevance']:.2%}",
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
console.print(cat_table)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
# ============================================================================
|
| 506 |
+
# Main Evaluation
|
| 507 |
+
# ============================================================================
|
| 508 |
+
|
| 509 |
+
def load_test_questions(questions_file: Path) -> List[Dict[str, Any]]:
|
| 510 |
+
"""Load test questions from JSON file."""
|
| 511 |
+
if not questions_file.exists():
|
| 512 |
+
console.print(f"[red]Error: Questions file not found: {questions_file}[/red]")
|
| 513 |
+
sys.exit(1)
|
| 514 |
+
|
| 515 |
+
with open(questions_file, "r") as f:
|
| 516 |
+
questions = json.load(f)
|
| 517 |
+
|
| 518 |
+
console.print(f"[green]✓[/green] Loaded {len(questions)} test questions")
|
| 519 |
+
return questions
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def initialize_system() -> EyeWikiQueryEngine:
|
| 523 |
+
"""Initialize the RAG system."""
|
| 524 |
+
console.print("[bold]Initializing RAG system...[/bold]")
|
| 525 |
+
|
| 526 |
+
# Load settings
|
| 527 |
+
settings = Settings()
|
| 528 |
+
|
| 529 |
+
# Initialize components
|
| 530 |
+
ollama_client = OllamaClient(
|
| 531 |
+
base_url=settings.ollama_base_url,
|
| 532 |
+
llm_model=settings.llm_model,
|
| 533 |
+
embedding_model=settings.embedding_model,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
qdrant_manager = QdrantStoreManager(
|
| 537 |
+
collection_name=settings.qdrant_collection_name,
|
| 538 |
+
qdrant_path=settings.qdrant_path,
|
| 539 |
+
vector_size=settings.embedding_dim,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
retriever = HybridRetriever(
|
| 543 |
+
qdrant_manager=qdrant_manager,
|
| 544 |
+
ollama_client=ollama_client,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
reranker = CrossEncoderReranker(
|
| 548 |
+
model_name=settings.reranker_model,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Load prompts
|
| 552 |
+
prompts_dir = project_root / "prompts"
|
| 553 |
+
system_prompt_path = prompts_dir / "system_prompt.txt"
|
| 554 |
+
query_prompt_path = prompts_dir / "query_prompt.txt"
|
| 555 |
+
disclaimer_path = prompts_dir / "medical_disclaimer.txt"
|
| 556 |
+
|
| 557 |
+
query_engine = EyeWikiQueryEngine(
|
| 558 |
+
retriever=retriever,
|
| 559 |
+
reranker=reranker,
|
| 560 |
+
llm_client=ollama_client,
|
| 561 |
+
system_prompt_path=system_prompt_path if system_prompt_path.exists() else None,
|
| 562 |
+
query_prompt_path=query_prompt_path if query_prompt_path.exists() else None,
|
| 563 |
+
disclaimer_path=disclaimer_path if disclaimer_path.exists() else None,
|
| 564 |
+
max_context_tokens=settings.max_context_tokens,
|
| 565 |
+
retrieval_k=20,
|
| 566 |
+
rerank_k=5,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
console.print("[green]✓[/green] System initialized\n")
|
| 570 |
+
return query_engine
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def run_evaluation(
|
| 574 |
+
questions_file: Path,
|
| 575 |
+
output_file: Path = None,
|
| 576 |
+
verbose: bool = False,
|
| 577 |
+
):
|
| 578 |
+
"""
|
| 579 |
+
Run evaluation on test questions.
|
| 580 |
+
|
| 581 |
+
Args:
|
| 582 |
+
questions_file: Path to test questions JSON
|
| 583 |
+
output_file: Optional path to save results
|
| 584 |
+
verbose: Print detailed results
|
| 585 |
+
"""
|
| 586 |
+
console.print(
|
| 587 |
+
Panel.fit(
|
| 588 |
+
"[bold blue]EyeWiki RAG Evaluation[/bold blue]",
|
| 589 |
+
border_style="blue",
|
| 590 |
+
)
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# Load questions
|
| 594 |
+
questions = load_test_questions(questions_file)
|
| 595 |
+
|
| 596 |
+
# Initialize system
|
| 597 |
+
query_engine = initialize_system()
|
| 598 |
+
|
| 599 |
+
# Evaluate questions
|
| 600 |
+
results = []
|
| 601 |
+
console.print("[bold]Evaluating questions...[/bold]\n")
|
| 602 |
+
|
| 603 |
+
with Progress(
|
| 604 |
+
SpinnerColumn(),
|
| 605 |
+
TextColumn("[progress.description]{task.description}"),
|
| 606 |
+
BarColumn(),
|
| 607 |
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
| 608 |
+
TimeElapsedColumn(),
|
| 609 |
+
console=console,
|
| 610 |
+
) as progress:
|
| 611 |
+
|
| 612 |
+
task = progress.add_task("Processing...", total=len(questions))
|
| 613 |
+
|
| 614 |
+
for question_data in questions:
|
| 615 |
+
result = evaluate_question(question_data, query_engine)
|
| 616 |
+
results.append(result)
|
| 617 |
+
|
| 618 |
+
if verbose:
|
| 619 |
+
print_question_result(result)
|
| 620 |
+
|
| 621 |
+
progress.update(task, advance=1)
|
| 622 |
+
|
| 623 |
+
# Calculate aggregate metrics
|
| 624 |
+
aggregate = calculate_aggregate_metrics(results)
|
| 625 |
+
|
| 626 |
+
# Print results
|
| 627 |
+
if not verbose:
|
| 628 |
+
console.print("\n[bold]Per-Question Results:[/bold]")
|
| 629 |
+
for result in results:
|
| 630 |
+
print_question_result(result)
|
| 631 |
+
|
| 632 |
+
print_aggregate_results(aggregate)
|
| 633 |
+
|
| 634 |
+
# Save results
|
| 635 |
+
if output_file:
|
| 636 |
+
output_data = {
|
| 637 |
+
"results": results,
|
| 638 |
+
"aggregate": aggregate,
|
| 639 |
+
"timestamp": time.time(),
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 643 |
+
with open(output_file, "w") as f:
|
| 644 |
+
json.dump(output_data, f, indent=2)
|
| 645 |
+
|
| 646 |
+
console.print(f"\n[green]✓[/green] Results saved to {output_file}")
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def main():
|
| 650 |
+
"""Main entry point."""
|
| 651 |
+
parser = argparse.ArgumentParser(
|
| 652 |
+
description="Evaluate EyeWiki RAG system on test questions"
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
parser.add_argument(
|
| 656 |
+
"--questions",
|
| 657 |
+
type=Path,
|
| 658 |
+
default=project_root / "tests" / "test_questions.json",
|
| 659 |
+
help="Path to test questions JSON file",
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
parser.add_argument(
|
| 663 |
+
"--output",
|
| 664 |
+
type=Path,
|
| 665 |
+
default=None,
|
| 666 |
+
help="Path to save evaluation results (JSON)",
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
parser.add_argument(
|
| 670 |
+
"-v",
|
| 671 |
+
"--verbose",
|
| 672 |
+
action="store_true",
|
| 673 |
+
help="Print detailed results for each question",
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
args = parser.parse_args()
|
| 677 |
+
|
| 678 |
+
try:
|
| 679 |
+
run_evaluation(
|
| 680 |
+
questions_file=args.questions,
|
| 681 |
+
output_file=args.output,
|
| 682 |
+
verbose=args.verbose,
|
| 683 |
+
)
|
| 684 |
+
except KeyboardInterrupt:
|
| 685 |
+
console.print("\n[yellow]Evaluation interrupted by user[/yellow]")
|
| 686 |
+
sys.exit(1)
|
| 687 |
+
except Exception as e:
|
| 688 |
+
console.print(f"\n[red]Error: {e}[/red]")
|
| 689 |
+
import traceback
|
| 690 |
+
|
| 691 |
+
traceback.print_exc()
|
| 692 |
+
sys.exit(1)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
if __name__ == "__main__":
|
| 696 |
+
main()
|
scripts/run_server.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Server startup script with pre-flight checks.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/run_server.py
|
| 7 |
+
python scripts/run_server.py --port 8080 --reload
|
| 8 |
+
python scripts/run_server.py --host 0.0.0.0 --port 8000
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import requests
|
| 17 |
+
from rich.console import Console
|
| 18 |
+
from rich.panel import Panel
|
| 19 |
+
from rich.table import Table
|
| 20 |
+
|
| 21 |
+
# Add project root to path
|
| 22 |
+
project_root = Path(__file__).parent.parent
|
| 23 |
+
sys.path.insert(0, str(project_root))
|
| 24 |
+
|
| 25 |
+
from config.settings import LLMProvider, Settings
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
console = Console()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_args():
|
| 32 |
+
"""Parse command line arguments."""
|
| 33 |
+
parser = argparse.ArgumentParser(
|
| 34 |
+
description="Start EyeWiki RAG API server with pre-flight checks"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--host",
|
| 39 |
+
type=str,
|
| 40 |
+
default="0.0.0.0",
|
| 41 |
+
help="Host to bind (default: 0.0.0.0)",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--port",
|
| 46 |
+
type=int,
|
| 47 |
+
default=8000,
|
| 48 |
+
help="Port number (default: 8000)",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--reload",
|
| 53 |
+
action="store_true",
|
| 54 |
+
help="Enable hot reload for development",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--skip-checks",
|
| 59 |
+
action="store_true",
|
| 60 |
+
help="Skip pre-flight checks (not recommended)",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return parser.parse_args()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def print_header():
|
| 67 |
+
"""Print welcome header."""
|
| 68 |
+
console.print()
|
| 69 |
+
console.print(
|
| 70 |
+
Panel.fit(
|
| 71 |
+
"[bold blue]EyeWiki RAG API Server[/bold blue]\n"
|
| 72 |
+
"[dim]Retrieval-Augmented Generation for Medical Knowledge[/dim]",
|
| 73 |
+
border_style="blue",
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
console.print()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def check_ollama(settings: Settings) -> bool:
|
| 80 |
+
"""
|
| 81 |
+
Check if Ollama is running and has required models.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
settings: Application settings
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
True if check passed, False otherwise
|
| 88 |
+
"""
|
| 89 |
+
console.print("[bold cyan]1. Checking Ollama service...[/bold cyan]")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# Check if Ollama is running
|
| 93 |
+
response = requests.get(f"{settings.ollama_base_url}/api/tags", timeout=5)
|
| 94 |
+
response.raise_for_status()
|
| 95 |
+
|
| 96 |
+
models_data = response.json()
|
| 97 |
+
available_models = [model["name"] for model in models_data.get("models", [])]
|
| 98 |
+
|
| 99 |
+
# Check for required LLM model (embedding model is sentence-transformers, not Ollama)
|
| 100 |
+
required_models = {
|
| 101 |
+
"LLM": settings.llm_model,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
table = Table(show_header=True, header_style="bold magenta")
|
| 105 |
+
table.add_column("Model Type", style="cyan")
|
| 106 |
+
table.add_column("Required Model", style="yellow")
|
| 107 |
+
table.add_column("Status", style="green")
|
| 108 |
+
|
| 109 |
+
all_found = True
|
| 110 |
+
for model_type, model_name in required_models.items():
|
| 111 |
+
# Check if model name (with or without tag) is in available models
|
| 112 |
+
found = any(
|
| 113 |
+
model_name in model or model in model_name for model in available_models
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
status = "[green]✓ Found[/green]" if found else "[red]✗ Missing[/red]"
|
| 117 |
+
table.add_row(model_type, model_name, status)
|
| 118 |
+
|
| 119 |
+
if not found:
|
| 120 |
+
all_found = False
|
| 121 |
+
|
| 122 |
+
console.print(table)
|
| 123 |
+
|
| 124 |
+
if not all_found:
|
| 125 |
+
console.print(
|
| 126 |
+
"\n[red]Error:[/red] Some required models are missing. "
|
| 127 |
+
"Pull them with:"
|
| 128 |
+
)
|
| 129 |
+
for model_type, model_name in required_models.items():
|
| 130 |
+
if not any(
|
| 131 |
+
model_name in model or model in model_name
|
| 132 |
+
for model in available_models
|
| 133 |
+
):
|
| 134 |
+
console.print(f" [yellow]ollama pull {model_name}[/yellow]")
|
| 135 |
+
console.print()
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
console.print("[green]✓ Ollama is running with all required models[/green]\n")
|
| 139 |
+
return True
|
| 140 |
+
|
| 141 |
+
except requests.RequestException as e:
|
| 142 |
+
console.print(f"[red]✗ Failed to connect to Ollama:[/red] {e}")
|
| 143 |
+
console.print(
|
| 144 |
+
f"\nMake sure Ollama is running at [yellow]{settings.ollama_base_url}[/yellow]"
|
| 145 |
+
)
|
| 146 |
+
console.print("Start it with: [yellow]ollama serve[/yellow]\n")
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def check_openai_config(settings: Settings) -> bool:
|
| 151 |
+
"""
|
| 152 |
+
Check if OpenAI-compatible API is configured with required API key.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
settings: Application settings
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
True if check passed, False otherwise
|
| 159 |
+
"""
|
| 160 |
+
console.print("[bold cyan]1. Checking OpenAI-compatible API configuration...[/bold cyan]")
|
| 161 |
+
|
| 162 |
+
table = Table(show_header=True, header_style="bold magenta")
|
| 163 |
+
table.add_column("Property", style="cyan")
|
| 164 |
+
table.add_column("Value", style="yellow")
|
| 165 |
+
table.add_column("Status", style="green")
|
| 166 |
+
|
| 167 |
+
# Check API key
|
| 168 |
+
has_key = bool(settings.openai_api_key)
|
| 169 |
+
key_display = f"{settings.openai_api_key[:8]}..." if has_key else "(not set)"
|
| 170 |
+
key_status = "[green]✓ Set[/green]" if has_key else "[red]✗ Missing[/red]"
|
| 171 |
+
table.add_row("API Key", key_display, key_status)
|
| 172 |
+
|
| 173 |
+
# Show base URL
|
| 174 |
+
base_url = settings.openai_base_url or "(OpenAI default)"
|
| 175 |
+
table.add_row("Base URL", base_url, "[green]✓[/green]")
|
| 176 |
+
|
| 177 |
+
# Show model
|
| 178 |
+
table.add_row("Model", settings.openai_model, "[green]✓[/green]")
|
| 179 |
+
|
| 180 |
+
console.print(table)
|
| 181 |
+
|
| 182 |
+
if not has_key:
|
| 183 |
+
console.print(
|
| 184 |
+
"\n[red]Error:[/red] API key is required for OpenAI-compatible provider."
|
| 185 |
+
)
|
| 186 |
+
console.print(
|
| 187 |
+
"Set the [yellow]OPENAI_API_KEY[/yellow] environment variable or add it to your [yellow].env[/yellow] file.\n"
|
| 188 |
+
)
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
console.print("[green]✓ OpenAI-compatible API configuration looks good[/green]\n")
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def check_vector_store(settings: Settings) -> bool:
|
| 196 |
+
"""
|
| 197 |
+
Check if vector store exists and has documents.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
settings: Application settings
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
True if check passed, False otherwise
|
| 204 |
+
"""
|
| 205 |
+
console.print("[bold cyan]2. Checking vector store...[/bold cyan]")
|
| 206 |
+
|
| 207 |
+
qdrant_path = Path(settings.qdrant_path)
|
| 208 |
+
collection_name = settings.qdrant_collection_name
|
| 209 |
+
|
| 210 |
+
# Check if Qdrant directory exists
|
| 211 |
+
if not qdrant_path.exists():
|
| 212 |
+
console.print(f"[red]✗ Qdrant directory not found:[/red] {qdrant_path}")
|
| 213 |
+
console.print(
|
| 214 |
+
"\nRun the indexing pipeline first:\n"
|
| 215 |
+
" [yellow]python scripts/build_index.py --index-vectors[/yellow]\n"
|
| 216 |
+
)
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
# Try to connect to Qdrant and check collection
|
| 220 |
+
try:
|
| 221 |
+
from qdrant_client import QdrantClient
|
| 222 |
+
|
| 223 |
+
client = QdrantClient(path=str(qdrant_path))
|
| 224 |
+
|
| 225 |
+
# Check if collection exists
|
| 226 |
+
collections = client.get_collections().collections
|
| 227 |
+
collection_names = [col.name for col in collections]
|
| 228 |
+
|
| 229 |
+
if collection_name not in collection_names:
|
| 230 |
+
console.print(
|
| 231 |
+
f"[red]✗ Collection '{collection_name}' not found[/red]\n"
|
| 232 |
+
f"Available collections: {collection_names}"
|
| 233 |
+
)
|
| 234 |
+
console.print(
|
| 235 |
+
"\nRun the indexing pipeline first:\n"
|
| 236 |
+
" [yellow]python scripts/build_index.py --index-vectors[/yellow]\n"
|
| 237 |
+
)
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
# Get collection info
|
| 241 |
+
collection_info = client.get_collection(collection_name)
|
| 242 |
+
points_count = collection_info.points_count
|
| 243 |
+
|
| 244 |
+
if points_count == 0:
|
| 245 |
+
console.print(
|
| 246 |
+
f"[yellow]⚠ Collection '{collection_name}' exists but is empty[/yellow]"
|
| 247 |
+
)
|
| 248 |
+
console.print(
|
| 249 |
+
"\nRun the indexing pipeline:\n"
|
| 250 |
+
" [yellow]python scripts/build_index.py --index-vectors[/yellow]\n"
|
| 251 |
+
)
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
# Print stats
|
| 255 |
+
table = Table(show_header=True, header_style="bold magenta")
|
| 256 |
+
table.add_column("Property", style="cyan")
|
| 257 |
+
table.add_column("Value", style="yellow")
|
| 258 |
+
|
| 259 |
+
table.add_row("Collection", collection_name)
|
| 260 |
+
table.add_row("Location", str(qdrant_path))
|
| 261 |
+
table.add_row("Documents", f"{points_count:,}")
|
| 262 |
+
|
| 263 |
+
console.print(table)
|
| 264 |
+
console.print("[green]✓ Vector store is ready[/green]\n")
|
| 265 |
+
return True
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
console.print(f"[red]✗ Failed to access vector store:[/red] {e}\n")
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def check_required_files() -> bool:
|
| 273 |
+
"""
|
| 274 |
+
Check if all required files exist.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
True if all files exist, False otherwise
|
| 278 |
+
"""
|
| 279 |
+
console.print("[bold cyan]3. Checking required files...[/bold cyan]")
|
| 280 |
+
|
| 281 |
+
required_files = {
|
| 282 |
+
"System Prompt": project_root / "prompts" / "system_prompt.txt",
|
| 283 |
+
"Query Prompt": project_root / "prompts" / "query_prompt.txt",
|
| 284 |
+
"Medical Disclaimer": project_root / "prompts" / "medical_disclaimer.txt",
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
table = Table(show_header=True, header_style="bold magenta")
|
| 288 |
+
table.add_column("File", style="cyan")
|
| 289 |
+
table.add_column("Path", style="yellow")
|
| 290 |
+
table.add_column("Status", style="green")
|
| 291 |
+
|
| 292 |
+
all_exist = True
|
| 293 |
+
for name, path in required_files.items():
|
| 294 |
+
exists = path.exists()
|
| 295 |
+
status = "[green]✓ Found[/green]" if exists else "[red]✗ Missing[/red]"
|
| 296 |
+
table.add_row(name, str(path.relative_to(project_root)), status)
|
| 297 |
+
|
| 298 |
+
if not exists:
|
| 299 |
+
all_exist = False
|
| 300 |
+
|
| 301 |
+
console.print(table)
|
| 302 |
+
|
| 303 |
+
if not all_exist:
|
| 304 |
+
console.print(
|
| 305 |
+
"\n[red]Error:[/red] Some required files are missing.\n"
|
| 306 |
+
"Make sure all prompt files are in the [yellow]prompts/[/yellow] directory.\n"
|
| 307 |
+
)
|
| 308 |
+
return False
|
| 309 |
+
|
| 310 |
+
console.print("[green]✓ All required files found[/green]\n")
|
| 311 |
+
return True
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def run_preflight_checks(skip_checks: bool = False) -> bool:
|
| 315 |
+
"""
|
| 316 |
+
Run all pre-flight checks.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
skip_checks: Skip all checks if True
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
True if all checks passed, False otherwise
|
| 323 |
+
"""
|
| 324 |
+
if skip_checks:
|
| 325 |
+
console.print("[yellow]⚠ Skipping pre-flight checks[/yellow]\n")
|
| 326 |
+
return True
|
| 327 |
+
|
| 328 |
+
console.print("[bold yellow]Running Pre-flight Checks...[/bold yellow]\n")
|
| 329 |
+
|
| 330 |
+
# Load settings
|
| 331 |
+
try:
|
| 332 |
+
settings = Settings()
|
| 333 |
+
except Exception as e:
|
| 334 |
+
console.print(f"[red]✗ Failed to load settings:[/red] {e}\n")
|
| 335 |
+
return False
|
| 336 |
+
|
| 337 |
+
console.print(f"[dim]LLM Provider: {settings.llm_provider.value}[/dim]\n")
|
| 338 |
+
|
| 339 |
+
# Check LLM provider (Ollama or OpenAI-compatible)
|
| 340 |
+
if settings.llm_provider == LLMProvider.OLLAMA:
|
| 341 |
+
llm_check = check_ollama(settings)
|
| 342 |
+
else:
|
| 343 |
+
llm_check = check_openai_config(settings)
|
| 344 |
+
|
| 345 |
+
# Run checks
|
| 346 |
+
checks = [
|
| 347 |
+
llm_check,
|
| 348 |
+
check_vector_store(settings),
|
| 349 |
+
check_required_files(),
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
if not all(checks):
|
| 353 |
+
console.print("[bold red]✗ Pre-flight checks failed[/bold red]")
|
| 354 |
+
console.print("Fix the issues above and try again.\n")
|
| 355 |
+
return False
|
| 356 |
+
|
| 357 |
+
console.print("[bold green]✓ All pre-flight checks passed![/bold green]\n")
|
| 358 |
+
return True
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def print_access_urls(host: str, port: int):
|
| 362 |
+
"""
|
| 363 |
+
Print access URLs for the server.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
host: Server host
|
| 367 |
+
port: Server port
|
| 368 |
+
"""
|
| 369 |
+
# Determine display host
|
| 370 |
+
display_host = "localhost" if host in ["0.0.0.0", "127.0.0.1"] else host
|
| 371 |
+
|
| 372 |
+
table = Table(
|
| 373 |
+
show_header=True,
|
| 374 |
+
header_style="bold magenta",
|
| 375 |
+
title="[bold green]Server Access URLs[/bold green]",
|
| 376 |
+
title_style="bold green",
|
| 377 |
+
)
|
| 378 |
+
table.add_column("Service", style="cyan", width=20)
|
| 379 |
+
table.add_column("URL", style="yellow")
|
| 380 |
+
table.add_column("Description", style="dim")
|
| 381 |
+
|
| 382 |
+
urls = [
|
| 383 |
+
("API Root", f"http://{display_host}:{port}", "API information"),
|
| 384 |
+
("Health Check", f"http://{display_host}:{port}/health", "Service health status"),
|
| 385 |
+
(
|
| 386 |
+
"Interactive Docs",
|
| 387 |
+
f"http://{display_host}:{port}/docs",
|
| 388 |
+
"Swagger UI documentation",
|
| 389 |
+
),
|
| 390 |
+
("ReDoc", f"http://{display_host}:{port}/redoc", "Alternative API docs"),
|
| 391 |
+
(
|
| 392 |
+
"Gradio UI",
|
| 393 |
+
f"http://{display_host}:{port}/ui",
|
| 394 |
+
"Web chat interface",
|
| 395 |
+
),
|
| 396 |
+
]
|
| 397 |
+
|
| 398 |
+
for service, url, description in urls:
|
| 399 |
+
table.add_row(service, url, description)
|
| 400 |
+
|
| 401 |
+
console.print()
|
| 402 |
+
console.print(table)
|
| 403 |
+
console.print()
|
| 404 |
+
|
| 405 |
+
# Print quick start commands
|
| 406 |
+
console.print("[bold cyan]Quick Test Commands:[/bold cyan]")
|
| 407 |
+
console.print(
|
| 408 |
+
f" [dim]# Test health endpoint[/dim]\n"
|
| 409 |
+
f" [yellow]curl http://{display_host}:{port}/health[/yellow]\n"
|
| 410 |
+
)
|
| 411 |
+
console.print(
|
| 412 |
+
f" [dim]# Query the API[/dim]\n"
|
| 413 |
+
f" [yellow]curl -X POST http://{display_host}:{port}/query \\[/yellow]\n"
|
| 414 |
+
f' [yellow] -H "Content-Type: application/json" \\[/yellow]\n'
|
| 415 |
+
f' [yellow] -d \'{{"question": "What is glaucoma?"}}\' [/yellow]\n'
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def start_server(host: str, port: int, reload: bool):
|
| 420 |
+
"""
|
| 421 |
+
Start the uvicorn server.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
host: Server host
|
| 425 |
+
port: Server port
|
| 426 |
+
reload: Enable hot reload
|
| 427 |
+
"""
|
| 428 |
+
console.print("[bold green]Starting server...[/bold green]\n")
|
| 429 |
+
|
| 430 |
+
# Print URLs before starting
|
| 431 |
+
print_access_urls(host, port)
|
| 432 |
+
|
| 433 |
+
# Import uvicorn here to avoid import errors if not installed
|
| 434 |
+
try:
|
| 435 |
+
import uvicorn
|
| 436 |
+
except ImportError:
|
| 437 |
+
console.print("[red]Error:[/red] uvicorn is not installed")
|
| 438 |
+
console.print("Install it with: [yellow]pip install uvicorn[/yellow]\n")
|
| 439 |
+
sys.exit(1)
|
| 440 |
+
|
| 441 |
+
# Start server
|
| 442 |
+
try:
|
| 443 |
+
console.print(
|
| 444 |
+
f"[dim]Server listening on {host}:{port}[/dim]",
|
| 445 |
+
f"[dim](Press CTRL+C to stop)[/dim]\n",
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
uvicorn.run(
|
| 449 |
+
"src.api.main:app",
|
| 450 |
+
host=host,
|
| 451 |
+
port=port,
|
| 452 |
+
reload=reload,
|
| 453 |
+
log_level="info",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
except KeyboardInterrupt:
|
| 457 |
+
console.print("\n\n[yellow]Server stopped by user[/yellow]")
|
| 458 |
+
except Exception as e:
|
| 459 |
+
console.print(f"\n[red]Error starting server:[/red] {e}")
|
| 460 |
+
sys.exit(1)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def main():
|
| 464 |
+
"""Main entry point."""
|
| 465 |
+
args = parse_args()
|
| 466 |
+
|
| 467 |
+
print_header()
|
| 468 |
+
|
| 469 |
+
# Run pre-flight checks
|
| 470 |
+
if not run_preflight_checks(skip_checks=args.skip_checks):
|
| 471 |
+
console.print("[red]Startup aborted due to failed checks[/red]\n")
|
| 472 |
+
sys.exit(1)
|
| 473 |
+
|
| 474 |
+
# Start server
|
| 475 |
+
start_server(host=args.host, port=args.port, reload=args.reload)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
if __name__ == "__main__":
|
| 479 |
+
main()
|
scripts/scrape_eyewiki.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""CLI script to run the EyeWiki crawler."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import asyncio
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add parent directory to path
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
from src.scraper.eyewiki_crawler import EyeWikiCrawler
|
| 14 |
+
from config.settings import settings
|
| 15 |
+
from rich.console import Console
|
| 16 |
+
from rich.panel import Panel
|
| 17 |
+
from rich.table import Table
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_args():
|
| 21 |
+
"""Parse command line arguments."""
|
| 22 |
+
parser = argparse.ArgumentParser(
|
| 23 |
+
description="Crawl EyeWiki medical articles",
|
| 24 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 25 |
+
epilog="""
|
| 26 |
+
Examples:
|
| 27 |
+
# Crawl up to 100 pages with default settings
|
| 28 |
+
python scripts/scrape_eyewiki.py --max-pages 100
|
| 29 |
+
|
| 30 |
+
# Resume previous crawl
|
| 31 |
+
python scripts/scrape_eyewiki.py --resume
|
| 32 |
+
|
| 33 |
+
# Crawl with depth 3 to custom directory
|
| 34 |
+
python scripts/scrape_eyewiki.py --depth 3 --output-dir ./my_data
|
| 35 |
+
|
| 36 |
+
# Full crawl (no page limit)
|
| 37 |
+
python scripts/scrape_eyewiki.py
|
| 38 |
+
""",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--max-pages",
|
| 43 |
+
type=int,
|
| 44 |
+
default=None,
|
| 45 |
+
help="Maximum number of pages to crawl (default: unlimited)",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--depth",
|
| 50 |
+
type=int,
|
| 51 |
+
default=2,
|
| 52 |
+
help="Maximum crawl depth (default: 2)",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--output-dir",
|
| 57 |
+
type=str,
|
| 58 |
+
default=None,
|
| 59 |
+
help=f"Output directory for scraped articles (default: {settings.data_raw_path})",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--resume",
|
| 64 |
+
action="store_true",
|
| 65 |
+
help="Resume from previous checkpoint if available",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--delay",
|
| 70 |
+
type=float,
|
| 71 |
+
default=None,
|
| 72 |
+
help=f"Delay between requests in seconds (default: {settings.scraper_delay})",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--timeout",
|
| 77 |
+
type=int,
|
| 78 |
+
default=None,
|
| 79 |
+
help=f"Request timeout in seconds (default: {settings.scraper_timeout})",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--start-urls",
|
| 84 |
+
type=str,
|
| 85 |
+
nargs="+",
|
| 86 |
+
default=None,
|
| 87 |
+
help="Starting URLs for crawl (default: EyeWiki main page and disease category)",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--checkpoint-file",
|
| 92 |
+
type=str,
|
| 93 |
+
default=None,
|
| 94 |
+
help="Custom checkpoint file path (default: output_dir/crawler_checkpoint.json)",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return parser.parse_args()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def print_banner(console: Console):
|
| 101 |
+
"""Print welcome banner."""
|
| 102 |
+
banner = """
|
| 103 |
+
[bold cyan]EyeWiki Medical Article Crawler[/bold cyan]
|
| 104 |
+
[dim]Powered by crawl4ai[/dim]
|
| 105 |
+
"""
|
| 106 |
+
console.print(Panel(banner, border_style="cyan"))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def print_configuration(console: Console, args, crawler: EyeWikiCrawler):
|
| 110 |
+
"""Print crawler configuration."""
|
| 111 |
+
table = Table(title="Crawler Configuration", show_header=False, border_style="blue")
|
| 112 |
+
table.add_column("Setting", style="cyan")
|
| 113 |
+
table.add_column("Value", style="white")
|
| 114 |
+
|
| 115 |
+
table.add_row("Output Directory", str(crawler.output_dir))
|
| 116 |
+
table.add_row("Max Pages", str(args.max_pages) if args.max_pages else "Unlimited")
|
| 117 |
+
table.add_row("Depth", str(args.depth))
|
| 118 |
+
table.add_row("Delay", f"{crawler.delay}s")
|
| 119 |
+
table.add_row("Timeout", f"{crawler.timeout}s")
|
| 120 |
+
table.add_row("Checkpoint File", str(crawler.checkpoint_file))
|
| 121 |
+
table.add_row("Resume Mode", "Yes" if args.resume else "No")
|
| 122 |
+
|
| 123 |
+
console.print(table)
|
| 124 |
+
console.print()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def print_summary(console: Console, crawler: EyeWikiCrawler, elapsed_time: float):
|
| 128 |
+
"""Print crawl summary statistics."""
|
| 129 |
+
console.print("\n")
|
| 130 |
+
|
| 131 |
+
# Create summary table
|
| 132 |
+
table = Table(title="Crawl Summary", border_style="green", show_header=True)
|
| 133 |
+
table.add_column("Metric", style="cyan", justify="left")
|
| 134 |
+
table.add_column("Value", style="white", justify="right")
|
| 135 |
+
|
| 136 |
+
# Calculate stats
|
| 137 |
+
pages_per_minute = (crawler.articles_saved / elapsed_time * 60) if elapsed_time > 0 else 0
|
| 138 |
+
success_rate = (
|
| 139 |
+
crawler.articles_saved / len(crawler.visited_urls) * 100
|
| 140 |
+
if crawler.visited_urls
|
| 141 |
+
else 0
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Add rows
|
| 145 |
+
table.add_row("Articles Saved", f"{crawler.articles_saved:,}")
|
| 146 |
+
table.add_row("URLs Visited", f"{len(crawler.visited_urls):,}")
|
| 147 |
+
table.add_row("URLs Failed", f"{len(crawler.failed_urls):,}")
|
| 148 |
+
table.add_row("URLs Remaining", f"{len(crawler.to_crawl):,}")
|
| 149 |
+
table.add_row("Success Rate", f"{success_rate:.1f}%")
|
| 150 |
+
table.add_row("Time Elapsed", f"{elapsed_time:.1f}s")
|
| 151 |
+
table.add_row("Pages/Minute", f"{pages_per_minute:.1f}")
|
| 152 |
+
|
| 153 |
+
console.print(table)
|
| 154 |
+
|
| 155 |
+
# Show failed URLs if any
|
| 156 |
+
if crawler.failed_urls:
|
| 157 |
+
console.print("\n[yellow]Failed URLs:[/yellow]")
|
| 158 |
+
for i, (url, error) in enumerate(list(crawler.failed_urls.items())[:5], 1):
|
| 159 |
+
console.print(f" {i}. [red]{url}[/red]")
|
| 160 |
+
console.print(f" [dim]{error}[/dim]")
|
| 161 |
+
|
| 162 |
+
if len(crawler.failed_urls) > 5:
|
| 163 |
+
console.print(f" [dim]... and {len(crawler.failed_urls) - 5} more[/dim]")
|
| 164 |
+
|
| 165 |
+
# Final status
|
| 166 |
+
console.print()
|
| 167 |
+
if crawler.articles_saved > 0:
|
| 168 |
+
console.print("[bold green]Crawl completed successfully![/bold green]")
|
| 169 |
+
console.print(f"[green]Articles saved to: {crawler.output_dir}[/green]")
|
| 170 |
+
else:
|
| 171 |
+
console.print("[bold yellow]No articles were saved.[/bold yellow]")
|
| 172 |
+
console.print("[yellow]Check the logs above for errors.[/yellow]")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
async def main():
|
| 176 |
+
"""Main entry point for the crawler script."""
|
| 177 |
+
# Parse arguments
|
| 178 |
+
args = parse_args()
|
| 179 |
+
|
| 180 |
+
# Initialize console
|
| 181 |
+
console = Console()
|
| 182 |
+
|
| 183 |
+
# Print banner
|
| 184 |
+
print_banner(console)
|
| 185 |
+
|
| 186 |
+
# Prepare output directory
|
| 187 |
+
output_dir = Path(args.output_dir) if args.output_dir else Path(settings.data_raw_path)
|
| 188 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 189 |
+
|
| 190 |
+
# Prepare checkpoint file
|
| 191 |
+
checkpoint_file = None
|
| 192 |
+
if args.checkpoint_file:
|
| 193 |
+
checkpoint_file = Path(args.checkpoint_file)
|
| 194 |
+
|
| 195 |
+
# If not resuming and checkpoint exists, ask user
|
| 196 |
+
if not args.resume and checkpoint_file and checkpoint_file.exists():
|
| 197 |
+
console.print("[yellow]Warning: Checkpoint file exists![/yellow]")
|
| 198 |
+
console.print(f"[yellow]File: {checkpoint_file}[/yellow]")
|
| 199 |
+
console.print("[yellow]Use --resume to continue from checkpoint, or it will be overwritten.[/yellow]")
|
| 200 |
+
console.print()
|
| 201 |
+
|
| 202 |
+
# Initialize crawler
|
| 203 |
+
try:
|
| 204 |
+
crawler = EyeWikiCrawler(
|
| 205 |
+
base_url="https://eyewiki.org",
|
| 206 |
+
output_dir=output_dir,
|
| 207 |
+
checkpoint_file=checkpoint_file,
|
| 208 |
+
delay=args.delay if args.delay is not None else settings.scraper_delay,
|
| 209 |
+
timeout=args.timeout if args.timeout is not None else settings.scraper_timeout,
|
| 210 |
+
)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
console.print(f"[bold red]Error initializing crawler: {e}[/bold red]")
|
| 213 |
+
return 1
|
| 214 |
+
|
| 215 |
+
# Print configuration
|
| 216 |
+
print_configuration(console, args, crawler)
|
| 217 |
+
|
| 218 |
+
# Prepare start URLs
|
| 219 |
+
start_urls = args.start_urls
|
| 220 |
+
if not start_urls and not args.resume:
|
| 221 |
+
# Start with popular medical articles that link to many other articles
|
| 222 |
+
start_urls = [
|
| 223 |
+
"https://eyewiki.org/Category:Articles"
|
| 224 |
+
]
|
| 225 |
+
console.print("[blue]Using default start URLs (seed articles):[/blue]")
|
| 226 |
+
for url in start_urls:
|
| 227 |
+
console.print(f" - {url}")
|
| 228 |
+
console.print()
|
| 229 |
+
|
| 230 |
+
# Start crawling
|
| 231 |
+
start_time = time.time()
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
await crawler.crawl(
|
| 235 |
+
max_pages=args.max_pages,
|
| 236 |
+
depth=args.depth,
|
| 237 |
+
start_urls=start_urls,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
elapsed_time = time.time() - start_time
|
| 241 |
+
|
| 242 |
+
# Print summary
|
| 243 |
+
print_summary(console, crawler, elapsed_time)
|
| 244 |
+
|
| 245 |
+
return 0
|
| 246 |
+
|
| 247 |
+
except KeyboardInterrupt:
|
| 248 |
+
elapsed_time = time.time() - start_time
|
| 249 |
+
console.print("\n[yellow]Crawl interrupted by user (Ctrl+C)[/yellow]")
|
| 250 |
+
console.print("[yellow]Saving checkpoint...[/yellow]")
|
| 251 |
+
|
| 252 |
+
# Crawler already saves checkpoint in its exception handler
|
| 253 |
+
# Just print summary
|
| 254 |
+
print_summary(console, crawler, elapsed_time)
|
| 255 |
+
|
| 256 |
+
console.print("\n[blue]You can resume with:[/blue]")
|
| 257 |
+
console.print(f"[blue] python scripts/scrape_eyewiki.py --resume[/blue]")
|
| 258 |
+
|
| 259 |
+
return 130 # Standard exit code for SIGINT
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
elapsed_time = time.time() - start_time
|
| 263 |
+
console.print(f"\n[bold red]Unexpected error: {e}[/bold red]")
|
| 264 |
+
|
| 265 |
+
# Print summary of what was accomplished
|
| 266 |
+
print_summary(console, crawler, elapsed_time)
|
| 267 |
+
|
| 268 |
+
return 1
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
try:
|
| 273 |
+
exit_code = asyncio.run(main())
|
| 274 |
+
sys.exit(exit_code)
|
| 275 |
+
except Exception as e:
|
| 276 |
+
console = Console()
|
| 277 |
+
console.print(f"[bold red]Fatal error: {e}[/bold red]")
|
| 278 |
+
sys.exit(1)
|
src/__init__.py
ADDED
|
File without changes
|
src/api/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API module for EyeWiki RAG system."""
|
| 2 |
+
|
| 3 |
+
from src.api.main import app
|
| 4 |
+
|
| 5 |
+
__all__ = ["app"]
|
src/api/gradio_ui.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio UI for EyeWiki RAG system."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from src.rag.query_engine import QueryResponse
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# Example Questions
|
| 16 |
+
# ============================================================================
|
| 17 |
+
|
| 18 |
+
EXAMPLE_QUESTIONS = [
|
| 19 |
+
"What are the symptoms of glaucoma?",
|
| 20 |
+
"How is diabetic retinopathy treated?",
|
| 21 |
+
"What causes macular degeneration?",
|
| 22 |
+
"What is the difference between open-angle and angle-closure glaucoma?",
|
| 23 |
+
"What are the risk factors for cataracts?",
|
| 24 |
+
"How is retinal detachment diagnosed?",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ============================================================================
|
| 29 |
+
# Styling
|
| 30 |
+
# ============================================================================
|
| 31 |
+
|
| 32 |
+
CUSTOM_CSS = """
|
| 33 |
+
/* Main container */
|
| 34 |
+
.gradio-container {
|
| 35 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
| 36 |
+
max-width: 1400px;
|
| 37 |
+
margin: 0 auto;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/* Header */
|
| 41 |
+
.header {
|
| 42 |
+
background: linear-gradient(135deg, #1e3a8a 0%, #3b82f6 100%);
|
| 43 |
+
color: white;
|
| 44 |
+
padding: 2rem;
|
| 45 |
+
border-radius: 12px;
|
| 46 |
+
margin-bottom: 2rem;
|
| 47 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.header h1 {
|
| 51 |
+
margin: 0 0 0.5rem 0;
|
| 52 |
+
font-size: 2rem;
|
| 53 |
+
font-weight: 700;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.header p {
|
| 57 |
+
margin: 0;
|
| 58 |
+
font-size: 1rem;
|
| 59 |
+
opacity: 0.95;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/* Chat interface */
|
| 63 |
+
.chatbot {
|
| 64 |
+
border: 1px solid #e5e7eb;
|
| 65 |
+
border-radius: 8px;
|
| 66 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
/* Text input */
|
| 70 |
+
.input-text textarea {
|
| 71 |
+
border: 2px solid #e5e7eb;
|
| 72 |
+
border-radius: 8px;
|
| 73 |
+
font-size: 1rem;
|
| 74 |
+
padding: 0.75rem;
|
| 75 |
+
transition: border-color 0.2s;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
.input-text textarea:focus {
|
| 79 |
+
border-color: #3b82f6;
|
| 80 |
+
outline: none;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/* Buttons */
|
| 84 |
+
.primary-button {
|
| 85 |
+
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
|
| 86 |
+
color: white;
|
| 87 |
+
border: none;
|
| 88 |
+
border-radius: 8px;
|
| 89 |
+
padding: 0.75rem 1.5rem;
|
| 90 |
+
font-weight: 600;
|
| 91 |
+
cursor: pointer;
|
| 92 |
+
transition: transform 0.1s, box-shadow 0.2s;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.primary-button:hover {
|
| 96 |
+
transform: translateY(-1px);
|
| 97 |
+
box-shadow: 0 4px 8px rgba(59, 130, 246, 0.3);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.secondary-button {
|
| 101 |
+
background: white;
|
| 102 |
+
color: #374151;
|
| 103 |
+
border: 1px solid #d1d5db;
|
| 104 |
+
border-radius: 8px;
|
| 105 |
+
padding: 0.5rem 1rem;
|
| 106 |
+
font-weight: 500;
|
| 107 |
+
cursor: pointer;
|
| 108 |
+
transition: background 0.2s;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.secondary-button:hover {
|
| 112 |
+
background: #f9fafb;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/* Sources accordion */
|
| 116 |
+
.accordion {
|
| 117 |
+
border: 1px solid #e5e7eb;
|
| 118 |
+
border-radius: 8px;
|
| 119 |
+
margin-top: 1rem;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/* Disclaimer */
|
| 123 |
+
.disclaimer {
|
| 124 |
+
background: #fef3c7;
|
| 125 |
+
border-left: 4px solid #f59e0b;
|
| 126 |
+
padding: 1rem;
|
| 127 |
+
border-radius: 8px;
|
| 128 |
+
margin-top: 2rem;
|
| 129 |
+
font-size: 0.875rem;
|
| 130 |
+
line-height: 1.5;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
.disclaimer strong {
|
| 134 |
+
color: #92400e;
|
| 135 |
+
font-weight: 700;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
/* Settings sidebar */
|
| 139 |
+
.settings {
|
| 140 |
+
background: #f9fafb;
|
| 141 |
+
border: 1px solid #e5e7eb;
|
| 142 |
+
border-radius: 8px;
|
| 143 |
+
padding: 1rem;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/* Example questions */
|
| 147 |
+
.examples {
|
| 148 |
+
background: white;
|
| 149 |
+
border: 1px solid #e5e7eb;
|
| 150 |
+
border-radius: 8px;
|
| 151 |
+
padding: 1rem;
|
| 152 |
+
margin-bottom: 1rem;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
.example-btn {
|
| 156 |
+
display: block;
|
| 157 |
+
width: 100%;
|
| 158 |
+
text-align: left;
|
| 159 |
+
padding: 0.75rem;
|
| 160 |
+
margin-bottom: 0.5rem;
|
| 161 |
+
background: white;
|
| 162 |
+
border: 1px solid #e5e7eb;
|
| 163 |
+
border-radius: 6px;
|
| 164 |
+
cursor: pointer;
|
| 165 |
+
transition: all 0.2s;
|
| 166 |
+
font-size: 0.875rem;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
.example-btn:hover {
|
| 170 |
+
background: #f0f9ff;
|
| 171 |
+
border-color: #3b82f6;
|
| 172 |
+
transform: translateX(4px);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
/* Confidence indicator */
|
| 176 |
+
.confidence-high {
|
| 177 |
+
color: #059669;
|
| 178 |
+
font-weight: 600;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
.confidence-medium {
|
| 182 |
+
color: #d97706;
|
| 183 |
+
font-weight: 600;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
.confidence-low {
|
| 187 |
+
color: #dc2626;
|
| 188 |
+
font-weight: 600;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
/* Source cards */
|
| 192 |
+
.source-card {
|
| 193 |
+
background: white;
|
| 194 |
+
border: 1px solid #e5e7eb;
|
| 195 |
+
border-radius: 6px;
|
| 196 |
+
padding: 0.75rem;
|
| 197 |
+
margin-bottom: 0.5rem;
|
| 198 |
+
transition: box-shadow 0.2s;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
.source-card:hover {
|
| 202 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
.source-title {
|
| 206 |
+
font-weight: 600;
|
| 207 |
+
color: #1e40af;
|
| 208 |
+
margin-bottom: 0.25rem;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
.source-score {
|
| 212 |
+
font-size: 0.75rem;
|
| 213 |
+
color: #6b7280;
|
| 214 |
+
}
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ============================================================================
|
| 219 |
+
# Formatting Functions
|
| 220 |
+
# ============================================================================
|
| 221 |
+
|
| 222 |
+
def format_sources_html(response: QueryResponse, max_sources: int = 5) -> str:
|
| 223 |
+
"""
|
| 224 |
+
Format sources as HTML.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
response: Query response with sources
|
| 228 |
+
max_sources: Maximum number of sources to display
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
HTML string with formatted sources
|
| 232 |
+
"""
|
| 233 |
+
if not response.sources:
|
| 234 |
+
return "<p style='color: #6b7280; font-style: italic;'>No sources available.</p>"
|
| 235 |
+
|
| 236 |
+
html_parts = []
|
| 237 |
+
|
| 238 |
+
# Limit sources
|
| 239 |
+
sources = response.sources[:max_sources]
|
| 240 |
+
|
| 241 |
+
for i, source in enumerate(sources, 1):
|
| 242 |
+
# Confidence indicator
|
| 243 |
+
score_pct = int(source.relevance_score * 100)
|
| 244 |
+
if source.relevance_score >= 0.7:
|
| 245 |
+
score_class = "confidence-high"
|
| 246 |
+
elif source.relevance_score >= 0.5:
|
| 247 |
+
score_class = "confidence-medium"
|
| 248 |
+
else:
|
| 249 |
+
score_class = "confidence-low"
|
| 250 |
+
|
| 251 |
+
html = f"""
|
| 252 |
+
<div class="source-card">
|
| 253 |
+
<div class="source-title">
|
| 254 |
+
{i}. <a href="{source.url}" target="_blank" style="text-decoration: none;">
|
| 255 |
+
{source.title}
|
| 256 |
+
</a>
|
| 257 |
+
</div>
|
| 258 |
+
{f'<div style="font-size: 0.875rem; color: #6b7280; margin-bottom: 0.25rem;">Section: {source.section}</div>' if source.section else ''}
|
| 259 |
+
<div class="source-score">
|
| 260 |
+
Relevance: <span class="{score_class}">{score_pct}%</span>
|
| 261 |
+
</div>
|
| 262 |
+
</div>
|
| 263 |
+
"""
|
| 264 |
+
html_parts.append(html)
|
| 265 |
+
|
| 266 |
+
return "\n".join(html_parts)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def format_confidence_text(confidence: float) -> str:
|
| 270 |
+
"""
|
| 271 |
+
Format confidence as text.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
confidence: Confidence score (0-1)
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
Formatted confidence string
|
| 278 |
+
"""
|
| 279 |
+
pct = int(confidence * 100)
|
| 280 |
+
|
| 281 |
+
if confidence >= 0.7:
|
| 282 |
+
emoji = "✅"
|
| 283 |
+
label = "High Confidence"
|
| 284 |
+
elif confidence >= 0.5:
|
| 285 |
+
emoji = "⚠️"
|
| 286 |
+
label = "Medium Confidence"
|
| 287 |
+
else:
|
| 288 |
+
emoji = "⚡"
|
| 289 |
+
label = "Low Confidence"
|
| 290 |
+
|
| 291 |
+
return f"{emoji} {label} ({pct}%)"
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ============================================================================
|
| 295 |
+
# Chat Interface Functions
|
| 296 |
+
# ============================================================================
|
| 297 |
+
|
| 298 |
+
def process_question(
|
| 299 |
+
question: str,
|
| 300 |
+
history: List[Dict[str, str]],
|
| 301 |
+
include_sources: bool,
|
| 302 |
+
max_sources: int,
|
| 303 |
+
query_engine_getter,
|
| 304 |
+
) -> tuple[List[Dict[str, str]], str]:
|
| 305 |
+
"""
|
| 306 |
+
Process a user question and update chat history.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
question: User's question
|
| 310 |
+
history: Chat history (list of message dicts with 'role' and 'content')
|
| 311 |
+
include_sources: Whether to include sources
|
| 312 |
+
max_sources: Maximum number of sources to show
|
| 313 |
+
query_engine_getter: Callable that returns query engine instance
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Updated history and sources HTML
|
| 317 |
+
"""
|
| 318 |
+
if not question or not question.strip():
|
| 319 |
+
return history, ""
|
| 320 |
+
|
| 321 |
+
# Get query engine
|
| 322 |
+
query_engine = query_engine_getter()
|
| 323 |
+
print(query_engine)
|
| 324 |
+
if not query_engine:
|
| 325 |
+
error_msg = "System is still initializing. Please wait a moment and try again."
|
| 326 |
+
history.append({"role": "user", "content": question})
|
| 327 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 328 |
+
return history, ""
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
# Query the engine
|
| 332 |
+
response = query_engine.query(
|
| 333 |
+
question=question,
|
| 334 |
+
include_sources=include_sources,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Format answer with confidence
|
| 338 |
+
confidence_text = format_confidence_text(response.confidence)
|
| 339 |
+
answer = f"**{confidence_text}**\n\n{response.answer}"
|
| 340 |
+
|
| 341 |
+
# Add disclaimer if present (without "educational purposes" text)
|
| 342 |
+
if response.disclaimer and not any(word in response.disclaimer.lower() for word in ['educational', 'education']):
|
| 343 |
+
answer += f"\n\n---\n\n{response.disclaimer}"
|
| 344 |
+
|
| 345 |
+
# Update history with message dicts
|
| 346 |
+
history.append({"role": "user", "content": question})
|
| 347 |
+
history.append({"role": "assistant", "content": answer})
|
| 348 |
+
|
| 349 |
+
# Format sources
|
| 350 |
+
sources_html = format_sources_html(response, max_sources) if include_sources else ""
|
| 351 |
+
|
| 352 |
+
return history, sources_html
|
| 353 |
+
|
| 354 |
+
except Exception as e:
|
| 355 |
+
logger.error(f"Error processing question: {e}", exc_info=True)
|
| 356 |
+
error_msg = f"Sorry, I encountered an error processing your question: {str(e)}"
|
| 357 |
+
history.append({"role": "user", "content": question})
|
| 358 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 359 |
+
return history, ""
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def clear_chat() -> tuple[List, str]:
|
| 363 |
+
"""
|
| 364 |
+
Clear chat history.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Empty history and sources
|
| 368 |
+
"""
|
| 369 |
+
return [], ""
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def load_example(example: str) -> str:
|
| 373 |
+
"""
|
| 374 |
+
Load an example question.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
example: Example question text
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
The example question
|
| 381 |
+
"""
|
| 382 |
+
return example
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# ============================================================================
|
| 386 |
+
# Gradio Interface
|
| 387 |
+
# ============================================================================
|
| 388 |
+
|
| 389 |
+
def create_gradio_interface(query_engine_getter) -> gr.Blocks:
|
| 390 |
+
"""
|
| 391 |
+
Create Gradio interface for EyeWiki RAG.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
query_engine_getter: Callable that returns the query engine instance
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
Gradio Blocks interface
|
| 398 |
+
"""
|
| 399 |
+
with gr.Blocks(
|
| 400 |
+
css=CUSTOM_CSS,
|
| 401 |
+
title="EyeWiki Medical Assistant",
|
| 402 |
+
theme=gr.themes.Soft(
|
| 403 |
+
primary_hue="blue",
|
| 404 |
+
secondary_hue="gray",
|
| 405 |
+
neutral_hue="slate",
|
| 406 |
+
),
|
| 407 |
+
) as interface:
|
| 408 |
+
|
| 409 |
+
# Header
|
| 410 |
+
gr.HTML("""
|
| 411 |
+
<div class="header">
|
| 412 |
+
<h1>🏥 EyeWiki Medical Assistant</h1>
|
| 413 |
+
<p>Ask questions about ophthalmology conditions, treatments, and procedures</p>
|
| 414 |
+
</div>
|
| 415 |
+
""")
|
| 416 |
+
|
| 417 |
+
with gr.Row():
|
| 418 |
+
# Main content (left side)
|
| 419 |
+
with gr.Column(scale=3):
|
| 420 |
+
|
| 421 |
+
# Chat interface
|
| 422 |
+
chatbot = gr.Chatbot(
|
| 423 |
+
label="Conversation",
|
| 424 |
+
height=500,
|
| 425 |
+
elem_classes=["chatbot"],
|
| 426 |
+
show_label=False,
|
| 427 |
+
avatar_images=(None, "🏥"),
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Input
|
| 431 |
+
with gr.Row():
|
| 432 |
+
question_input = gr.Textbox(
|
| 433 |
+
placeholder="Ask a question about eye health...",
|
| 434 |
+
label="Your Question",
|
| 435 |
+
lines=2,
|
| 436 |
+
elem_classes=["input-text"],
|
| 437 |
+
scale=4,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
with gr.Row():
|
| 441 |
+
submit_btn = gr.Button(
|
| 442 |
+
"Send",
|
| 443 |
+
variant="primary",
|
| 444 |
+
elem_classes=["primary-button"],
|
| 445 |
+
scale=1,
|
| 446 |
+
)
|
| 447 |
+
clear_btn = gr.Button(
|
| 448 |
+
"Clear",
|
| 449 |
+
elem_classes=["secondary-button"],
|
| 450 |
+
scale=1,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Sources accordion
|
| 454 |
+
with gr.Accordion("📚 Sources", open=False, elem_classes=["accordion"]):
|
| 455 |
+
sources_display = gr.HTML(
|
| 456 |
+
value="<p style='color: #6b7280; font-style: italic;'>Sources will appear here after asking a question.</p>"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Medical disclaimer
|
| 460 |
+
gr.HTML("""
|
| 461 |
+
<div class="disclaimer">
|
| 462 |
+
<strong>⚠️ Medical Disclaimer:</strong> This information is sourced from EyeWiki,
|
| 463 |
+
a resource of the American Academy of Ophthalmology (AAO). It is not a substitute
|
| 464 |
+
for professional medical advice, diagnosis, or treatment. AI systems can make errors.
|
| 465 |
+
Always consult with a qualified ophthalmologist or eye care professional for medical
|
| 466 |
+
concerns and verify any critical information with authoritative sources.
|
| 467 |
+
</div>
|
| 468 |
+
""")
|
| 469 |
+
|
| 470 |
+
# Sidebar (right side)
|
| 471 |
+
with gr.Column(scale=1, elem_classes=["settings"]):
|
| 472 |
+
|
| 473 |
+
gr.Markdown("### ⚙️ Settings")
|
| 474 |
+
|
| 475 |
+
include_sources = gr.Checkbox(
|
| 476 |
+
label="Show sources",
|
| 477 |
+
value=True,
|
| 478 |
+
info="Include source citations in responses"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
max_sources = gr.Slider(
|
| 482 |
+
minimum=1,
|
| 483 |
+
maximum=10,
|
| 484 |
+
value=5,
|
| 485 |
+
step=1,
|
| 486 |
+
label="Max sources",
|
| 487 |
+
info="Maximum number of sources to display"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
gr.Markdown("---")
|
| 491 |
+
gr.Markdown("### 💡 Example Questions")
|
| 492 |
+
|
| 493 |
+
# Example buttons
|
| 494 |
+
example_buttons = []
|
| 495 |
+
for example in EXAMPLE_QUESTIONS:
|
| 496 |
+
btn = gr.Button(
|
| 497 |
+
example,
|
| 498 |
+
elem_classes=["example-btn"],
|
| 499 |
+
size="sm",
|
| 500 |
+
)
|
| 501 |
+
example_buttons.append(btn)
|
| 502 |
+
|
| 503 |
+
gr.Markdown("---")
|
| 504 |
+
gr.Markdown("""
|
| 505 |
+
### 📖 About
|
| 506 |
+
|
| 507 |
+
**EyeWiki RAG System** - Powered by:
|
| 508 |
+
- Hybrid retrieval (semantic + keyword search)
|
| 509 |
+
- Cross-encoder reranking for precision
|
| 510 |
+
- Local LLM inference (GPU-accelerated)
|
| 511 |
+
- EyeWiki knowledge base (AAO)
|
| 512 |
+
|
| 513 |
+
All processing happens locally on your machine.
|
| 514 |
+
""")
|
| 515 |
+
|
| 516 |
+
# Event handlers
|
| 517 |
+
submit_event = submit_btn.click(
|
| 518 |
+
fn=lambda q, h, inc, max_s: process_question(q, h, inc, max_s, query_engine_getter),
|
| 519 |
+
inputs=[question_input, chatbot, include_sources, max_sources],
|
| 520 |
+
outputs=[chatbot, sources_display],
|
| 521 |
+
).then(
|
| 522 |
+
fn=lambda: "",
|
| 523 |
+
outputs=[question_input],
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
question_input.submit(
|
| 527 |
+
fn=lambda q, h, inc, max_s: process_question(q, h, inc, max_s, query_engine_getter),
|
| 528 |
+
inputs=[question_input, chatbot, include_sources, max_sources],
|
| 529 |
+
outputs=[chatbot, sources_display],
|
| 530 |
+
).then(
|
| 531 |
+
fn=lambda: "",
|
| 532 |
+
outputs=[question_input],
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
clear_btn.click(
|
| 536 |
+
fn=clear_chat,
|
| 537 |
+
outputs=[chatbot, sources_display],
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# Example button handlers
|
| 541 |
+
for btn in example_buttons:
|
| 542 |
+
btn.click(
|
| 543 |
+
fn=load_example,
|
| 544 |
+
inputs=[btn],
|
| 545 |
+
outputs=[question_input],
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
return interface
|
src/api/main.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for EyeWiki RAG system."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from fastapi import FastAPI, HTTPException, Request, status
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.responses import StreamingResponse
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from src.api.gradio_ui import create_gradio_interface
|
| 16 |
+
from config.settings import LLMProvider, Settings
|
| 17 |
+
from src.llm.llm_client import LLMClient
|
| 18 |
+
from src.llm.ollama_client import OllamaClient
|
| 19 |
+
from src.llm.openai_client import OpenAIClient
|
| 20 |
+
from src.llm.sentence_transformer_client import SentenceTransformerClient
|
| 21 |
+
from src.rag.query_engine import EyeWikiQueryEngine, QueryResponse
|
| 22 |
+
from src.rag.reranker import CrossEncoderReranker
|
| 23 |
+
from src.rag.retriever import HybridRetriever
|
| 24 |
+
from src.vectorstore.qdrant_store import QdrantStoreManager
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Configure logging
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
level=logging.INFO,
|
| 30 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 31 |
+
)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ============================================================================
|
| 36 |
+
# Request/Response Models
|
| 37 |
+
# ============================================================================
|
| 38 |
+
|
| 39 |
+
class QueryRequest(BaseModel):
|
| 40 |
+
"""
|
| 41 |
+
Request model for query endpoint.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
question: User's question
|
| 45 |
+
include_sources: Whether to include source information
|
| 46 |
+
filters: Optional metadata filters (disease_name, icd_codes, etc.)
|
| 47 |
+
"""
|
| 48 |
+
question: str = Field(..., min_length=3, description="User's question")
|
| 49 |
+
include_sources: bool = Field(default=True, description="Include source documents")
|
| 50 |
+
filters: Optional[dict] = Field(default=None, description="Metadata filters")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class StreamQueryRequest(BaseModel):
|
| 54 |
+
"""
|
| 55 |
+
Request model for streaming query endpoint.
|
| 56 |
+
|
| 57 |
+
Attributes:
|
| 58 |
+
question: User's question
|
| 59 |
+
filters: Optional metadata filters
|
| 60 |
+
"""
|
| 61 |
+
question: str = Field(..., min_length=3, description="User's question")
|
| 62 |
+
filters: Optional[dict] = Field(default=None, description="Metadata filters")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class HealthResponse(BaseModel):
|
| 66 |
+
"""
|
| 67 |
+
Response model for health check.
|
| 68 |
+
|
| 69 |
+
Attributes:
|
| 70 |
+
status: Overall status (healthy/unhealthy)
|
| 71 |
+
llm: LLM service status
|
| 72 |
+
qdrant: Qdrant service status
|
| 73 |
+
query_engine: Query engine initialization status
|
| 74 |
+
timestamp: Check timestamp
|
| 75 |
+
"""
|
| 76 |
+
status: str = Field(..., description="Overall status")
|
| 77 |
+
llm: dict = Field(..., description="LLM service status")
|
| 78 |
+
qdrant: dict = Field(..., description="Qdrant service status")
|
| 79 |
+
query_engine: dict = Field(..., description="Query engine status")
|
| 80 |
+
timestamp: float = Field(..., description="Unix timestamp")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class StatsResponse(BaseModel):
|
| 84 |
+
"""
|
| 85 |
+
Response model for statistics endpoint.
|
| 86 |
+
|
| 87 |
+
Attributes:
|
| 88 |
+
collection_info: Qdrant collection information
|
| 89 |
+
pipeline_config: Query engine pipeline configuration
|
| 90 |
+
documents_indexed: Number of indexed documents
|
| 91 |
+
timestamp: Stats timestamp
|
| 92 |
+
"""
|
| 93 |
+
collection_info: dict = Field(..., description="Collection information")
|
| 94 |
+
pipeline_config: dict = Field(..., description="Pipeline configuration")
|
| 95 |
+
documents_indexed: int = Field(..., description="Number of indexed documents")
|
| 96 |
+
timestamp: float = Field(..., description="Unix timestamp")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ErrorResponse(BaseModel):
|
| 100 |
+
"""
|
| 101 |
+
Error response model.
|
| 102 |
+
|
| 103 |
+
Attributes:
|
| 104 |
+
error: Error message
|
| 105 |
+
detail: Optional detailed error information
|
| 106 |
+
timestamp: Error timestamp
|
| 107 |
+
"""
|
| 108 |
+
error: str = Field(..., description="Error message")
|
| 109 |
+
detail: Optional[str] = Field(default=None, description="Error details")
|
| 110 |
+
timestamp: float = Field(..., description="Unix timestamp")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ============================================================================
|
| 114 |
+
# Global State
|
| 115 |
+
# ============================================================================
|
| 116 |
+
|
| 117 |
+
class AppState:
|
| 118 |
+
"""Application state container."""
|
| 119 |
+
|
| 120 |
+
def __init__(self):
|
| 121 |
+
self.settings: Optional[Settings] = None
|
| 122 |
+
self.llm_client: Optional[LLMClient] = None
|
| 123 |
+
self.embedding_client: Optional[SentenceTransformerClient] = None
|
| 124 |
+
self.qdrant_manager: Optional[QdrantStoreManager] = None
|
| 125 |
+
self.retriever: Optional[HybridRetriever] = None
|
| 126 |
+
self.reranker: Optional[CrossEncoderReranker] = None
|
| 127 |
+
self.query_engine: Optional[EyeWikiQueryEngine] = None
|
| 128 |
+
self.initialized: bool = False
|
| 129 |
+
self.initialization_error: Optional[str] = None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
app_state = AppState()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ============================================================================
|
| 136 |
+
# Lifecycle Management
|
| 137 |
+
# ============================================================================
|
| 138 |
+
|
| 139 |
+
@asynccontextmanager
|
| 140 |
+
async def lifespan(app: FastAPI):
|
| 141 |
+
"""
|
| 142 |
+
Application lifespan manager.
|
| 143 |
+
|
| 144 |
+
Handles startup and shutdown events.
|
| 145 |
+
"""
|
| 146 |
+
# Startup
|
| 147 |
+
logger.info("Starting EyeWiki RAG API...")
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
# Load settings
|
| 151 |
+
logger.info("Loading settings...")
|
| 152 |
+
app_state.settings = Settings()
|
| 153 |
+
|
| 154 |
+
# Initialize LLM client based on provider
|
| 155 |
+
logger.info(f"Initializing LLM client (provider: {app_state.settings.llm_provider.value})...")
|
| 156 |
+
if app_state.settings.llm_provider == LLMProvider.OPENAI:
|
| 157 |
+
app_state.llm_client = OpenAIClient(
|
| 158 |
+
api_key=app_state.settings.openai_api_key,
|
| 159 |
+
base_url=app_state.settings.openai_base_url,
|
| 160 |
+
model=app_state.settings.openai_model,
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
app_state.llm_client = OllamaClient(
|
| 164 |
+
base_url=app_state.settings.ollama_base_url,
|
| 165 |
+
embedding_model=None, # We use SentenceTransformerClient for embeddings
|
| 166 |
+
llm_model=app_state.settings.llm_model,
|
| 167 |
+
timeout=app_state.settings.ollama_timeout,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Initialize embedding client (sentence-transformers for stable embeddings)
|
| 171 |
+
logger.info("Initializing embedding client...")
|
| 172 |
+
app_state.embedding_client = SentenceTransformerClient(
|
| 173 |
+
model_name=app_state.settings.embedding_model,
|
| 174 |
+
)
|
| 175 |
+
logger.info(f"Embedding model loaded: {app_state.settings.embedding_model}")
|
| 176 |
+
|
| 177 |
+
# Initialize Qdrant manager
|
| 178 |
+
logger.info("Initializing Qdrant manager...")
|
| 179 |
+
app_state.qdrant_manager = QdrantStoreManager(
|
| 180 |
+
collection_name=app_state.settings.qdrant_collection_name,
|
| 181 |
+
path=app_state.settings.qdrant_path,
|
| 182 |
+
embedding_dim=app_state.embedding_client.embedding_dim,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Verify collection exists
|
| 186 |
+
collection_info = app_state.qdrant_manager.get_collection_info()
|
| 187 |
+
if not collection_info:
|
| 188 |
+
raise RuntimeError(
|
| 189 |
+
f"Qdrant collection '{app_state.settings.qdrant_collection_name}' not found. "
|
| 190 |
+
"Please run 'python scripts/build_index.py --index-vectors' first."
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
logger.info(
|
| 194 |
+
f"Qdrant collection loaded: {collection_info['vectors_count']} vectors"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Initialize retriever
|
| 198 |
+
logger.info("Initializing retriever...")
|
| 199 |
+
app_state.retriever = HybridRetriever(
|
| 200 |
+
qdrant_manager=app_state.qdrant_manager,
|
| 201 |
+
embedding_client=app_state.embedding_client,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Initialize reranker
|
| 205 |
+
logger.info("Initializing reranker...")
|
| 206 |
+
app_state.reranker = CrossEncoderReranker(
|
| 207 |
+
model_name=app_state.settings.reranker_model,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Load prompt files
|
| 211 |
+
project_root = Path(__file__).parent.parent.parent
|
| 212 |
+
prompts_dir = project_root / "prompts"
|
| 213 |
+
|
| 214 |
+
system_prompt_path = prompts_dir / "system_prompt.txt"
|
| 215 |
+
query_prompt_path = prompts_dir / "query_prompt.txt"
|
| 216 |
+
disclaimer_path = prompts_dir / "medical_disclaimer.txt"
|
| 217 |
+
|
| 218 |
+
# Verify prompts exist
|
| 219 |
+
if not system_prompt_path.exists():
|
| 220 |
+
logger.warning(f"System prompt not found: {system_prompt_path}")
|
| 221 |
+
system_prompt_path = None
|
| 222 |
+
|
| 223 |
+
if not query_prompt_path.exists():
|
| 224 |
+
logger.warning(f"Query prompt not found: {query_prompt_path}")
|
| 225 |
+
query_prompt_path = None
|
| 226 |
+
|
| 227 |
+
if not disclaimer_path.exists():
|
| 228 |
+
logger.warning(f"Disclaimer not found: {disclaimer_path}")
|
| 229 |
+
disclaimer_path = None
|
| 230 |
+
|
| 231 |
+
# Initialize query engine
|
| 232 |
+
logger.info("Initializing query engine...")
|
| 233 |
+
app_state.query_engine = EyeWikiQueryEngine(
|
| 234 |
+
retriever=app_state.retriever,
|
| 235 |
+
reranker=app_state.reranker,
|
| 236 |
+
llm_client=app_state.llm_client,
|
| 237 |
+
system_prompt_path=system_prompt_path,
|
| 238 |
+
query_prompt_path=query_prompt_path,
|
| 239 |
+
disclaimer_path=disclaimer_path,
|
| 240 |
+
max_context_tokens=app_state.settings.max_context_tokens,
|
| 241 |
+
retrieval_k=20,
|
| 242 |
+
rerank_k=5,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
app_state.initialized = True
|
| 246 |
+
logger.info("EyeWiki RAG API started successfully")
|
| 247 |
+
logger.info("Gradio UI available at /ui")
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
error_msg = f"Failed to initialize application: {e}"
|
| 251 |
+
logger.error(error_msg, exc_info=True)
|
| 252 |
+
app_state.initialization_error = error_msg
|
| 253 |
+
# Don't raise - allow app to start but endpoints will return errors
|
| 254 |
+
|
| 255 |
+
yield
|
| 256 |
+
|
| 257 |
+
# Shutdown
|
| 258 |
+
logger.info("Shutting down EyeWiki RAG API...")
|
| 259 |
+
|
| 260 |
+
# Cleanup Qdrant client
|
| 261 |
+
if app_state.qdrant_manager:
|
| 262 |
+
try:
|
| 263 |
+
app_state.qdrant_manager.close()
|
| 264 |
+
logger.info("Qdrant client closed")
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"Error closing Qdrant client: {e}")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ============================================================================
|
| 270 |
+
# FastAPI App
|
| 271 |
+
# ============================================================================
|
| 272 |
+
|
| 273 |
+
app = FastAPI(
|
| 274 |
+
title="EyeWiki RAG API",
|
| 275 |
+
description="Retrieval-Augmented Generation API for EyeWiki medical knowledge base",
|
| 276 |
+
version="1.0.0",
|
| 277 |
+
lifespan=lifespan,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ============================================================================
|
| 282 |
+
# Middleware
|
| 283 |
+
# ============================================================================
|
| 284 |
+
|
| 285 |
+
# CORS middleware for local development
|
| 286 |
+
app.add_middleware(
|
| 287 |
+
CORSMiddleware,
|
| 288 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 289 |
+
allow_credentials=True,
|
| 290 |
+
allow_methods=["*"],
|
| 291 |
+
allow_headers=["*"],
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@app.middleware("http")
|
| 296 |
+
async def log_requests(request: Request, call_next):
|
| 297 |
+
"""
|
| 298 |
+
Request logging middleware.
|
| 299 |
+
|
| 300 |
+
Logs all incoming requests with timing information.
|
| 301 |
+
"""
|
| 302 |
+
start_time = time.time()
|
| 303 |
+
|
| 304 |
+
# Log request
|
| 305 |
+
logger.info(
|
| 306 |
+
f"Request: {request.method} {request.url.path} "
|
| 307 |
+
f"from {request.client.host if request.client else 'unknown'}"
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Process request
|
| 311 |
+
response = await call_next(request)
|
| 312 |
+
|
| 313 |
+
# Log response
|
| 314 |
+
duration = time.time() - start_time
|
| 315 |
+
logger.info(
|
| 316 |
+
f"Response: {response.status_code} "
|
| 317 |
+
f"in {duration:.3f}s"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return response
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ============================================================================
|
| 324 |
+
# Helper Functions
|
| 325 |
+
# ============================================================================
|
| 326 |
+
|
| 327 |
+
def check_initialization():
|
| 328 |
+
"""
|
| 329 |
+
Check if application is initialized.
|
| 330 |
+
|
| 331 |
+
Raises:
|
| 332 |
+
HTTPException: If app not initialized
|
| 333 |
+
"""
|
| 334 |
+
if not app_state.initialized:
|
| 335 |
+
error_detail = app_state.initialization_error or "Application not initialized"
|
| 336 |
+
raise HTTPException(
|
| 337 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 338 |
+
detail=error_detail
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ============================================================================
|
| 343 |
+
# Endpoints
|
| 344 |
+
# ============================================================================
|
| 345 |
+
|
| 346 |
+
@app.get("/")
|
| 347 |
+
async def root():
|
| 348 |
+
"""
|
| 349 |
+
Root endpoint.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
Welcome message with API information
|
| 353 |
+
"""
|
| 354 |
+
return {
|
| 355 |
+
"name": "EyeWiki RAG API",
|
| 356 |
+
"version": "1.0.0",
|
| 357 |
+
"description": "Retrieval-Augmented Generation API for EyeWiki medical knowledge base",
|
| 358 |
+
"endpoints": {
|
| 359 |
+
"health": "GET /health",
|
| 360 |
+
"query": "POST /query",
|
| 361 |
+
"stream": "POST /query/stream",
|
| 362 |
+
"stats": "GET /stats",
|
| 363 |
+
"docs": "GET /docs",
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@app.get("/health", response_model=HealthResponse)
|
| 369 |
+
async def health_check():
|
| 370 |
+
"""
|
| 371 |
+
Health check endpoint.
|
| 372 |
+
|
| 373 |
+
Checks status of:
|
| 374 |
+
- Ollama service
|
| 375 |
+
- Qdrant service
|
| 376 |
+
- Query engine initialization
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
HealthResponse with service statuses
|
| 380 |
+
"""
|
| 381 |
+
timestamp = time.time()
|
| 382 |
+
|
| 383 |
+
# Check LLM provider
|
| 384 |
+
llm_status = {"status": "unknown", "detail": None}
|
| 385 |
+
if app_state.llm_client:
|
| 386 |
+
provider = app_state.settings.llm_provider.value if app_state.settings else "unknown"
|
| 387 |
+
llm_status["provider"] = provider
|
| 388 |
+
try:
|
| 389 |
+
if isinstance(app_state.llm_client, OllamaClient):
|
| 390 |
+
health_ok = app_state.llm_client.check_health()
|
| 391 |
+
llm_status["status"] = "healthy" if health_ok else "unhealthy"
|
| 392 |
+
llm_status["model"] = app_state.llm_client.llm_model
|
| 393 |
+
else:
|
| 394 |
+
# For OpenAI-compatible clients, assume healthy if initialized
|
| 395 |
+
llm_status["status"] = "healthy"
|
| 396 |
+
llm_status["model"] = app_state.llm_client.llm_model
|
| 397 |
+
except Exception as e:
|
| 398 |
+
llm_status = {"status": "unhealthy", "detail": str(e), "provider": provider}
|
| 399 |
+
else:
|
| 400 |
+
llm_status = {"status": "not_initialized", "detail": "Client not created"}
|
| 401 |
+
|
| 402 |
+
# Check Qdrant
|
| 403 |
+
qdrant_status = {"status": "unknown", "detail": None}
|
| 404 |
+
if app_state.qdrant_manager:
|
| 405 |
+
try:
|
| 406 |
+
info = app_state.qdrant_manager.get_collection_info()
|
| 407 |
+
if info:
|
| 408 |
+
qdrant_status = {
|
| 409 |
+
"status": "healthy",
|
| 410 |
+
"collection": info["name"],
|
| 411 |
+
"vectors_count": info["vectors_count"],
|
| 412 |
+
}
|
| 413 |
+
else:
|
| 414 |
+
qdrant_status = {
|
| 415 |
+
"status": "unhealthy",
|
| 416 |
+
"detail": "Collection not found"
|
| 417 |
+
}
|
| 418 |
+
except Exception as e:
|
| 419 |
+
qdrant_status = {"status": "unhealthy", "detail": str(e)}
|
| 420 |
+
else:
|
| 421 |
+
qdrant_status = {"status": "not_initialized", "detail": "Manager not created"}
|
| 422 |
+
|
| 423 |
+
# Check query engine
|
| 424 |
+
query_engine_status = {
|
| 425 |
+
"status": "initialized" if app_state.initialized else "not_initialized",
|
| 426 |
+
"error": app_state.initialization_error,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
# Overall status
|
| 430 |
+
overall_status = "healthy"
|
| 431 |
+
if not app_state.initialized:
|
| 432 |
+
overall_status = "unhealthy"
|
| 433 |
+
elif llm_status["status"] != "healthy" or qdrant_status["status"] != "healthy":
|
| 434 |
+
overall_status = "degraded"
|
| 435 |
+
|
| 436 |
+
return HealthResponse(
|
| 437 |
+
status=overall_status,
|
| 438 |
+
llm=llm_status,
|
| 439 |
+
qdrant=qdrant_status,
|
| 440 |
+
query_engine=query_engine_status,
|
| 441 |
+
timestamp=timestamp,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
@app.post("/query", response_model=QueryResponse)
|
| 446 |
+
async def query(request: QueryRequest):
|
| 447 |
+
"""
|
| 448 |
+
Main query endpoint.
|
| 449 |
+
|
| 450 |
+
Processes a question using the full RAG pipeline:
|
| 451 |
+
1. Retrieval (hybrid search)
|
| 452 |
+
2. Reranking (cross-encoder)
|
| 453 |
+
3. Context assembly
|
| 454 |
+
4. LLM generation
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
request: QueryRequest with question and options
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
QueryResponse with answer, sources, and disclaimer
|
| 461 |
+
|
| 462 |
+
Raises:
|
| 463 |
+
HTTPException: If service unavailable or query fails
|
| 464 |
+
"""
|
| 465 |
+
check_initialization()
|
| 466 |
+
|
| 467 |
+
try:
|
| 468 |
+
logger.info(f"Processing query: '{request.question}'")
|
| 469 |
+
|
| 470 |
+
response = app_state.query_engine.query(
|
| 471 |
+
question=request.question,
|
| 472 |
+
include_sources=request.include_sources,
|
| 473 |
+
filters=request.filters,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
logger.info(
|
| 477 |
+
f"Query complete: {len(response.sources)} sources, "
|
| 478 |
+
f"confidence: {response.confidence:.2f}"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return response
|
| 482 |
+
|
| 483 |
+
except Exception as e:
|
| 484 |
+
logger.error(f"Error processing query: {e}", exc_info=True)
|
| 485 |
+
raise HTTPException(
|
| 486 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 487 |
+
detail=f"Error processing query: {str(e)}"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
@app.post("/query/stream")
|
| 492 |
+
async def stream_query(request: StreamQueryRequest):
|
| 493 |
+
"""
|
| 494 |
+
Streaming query endpoint.
|
| 495 |
+
|
| 496 |
+
Returns answer as Server-Sent Events (SSE) for real-time streaming.
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
request: StreamQueryRequest with question and options
|
| 500 |
+
|
| 501 |
+
Returns:
|
| 502 |
+
StreamingResponse with SSE
|
| 503 |
+
|
| 504 |
+
Raises:
|
| 505 |
+
HTTPException: If service unavailable or query fails
|
| 506 |
+
"""
|
| 507 |
+
check_initialization()
|
| 508 |
+
|
| 509 |
+
async def generate():
|
| 510 |
+
"""Generate SSE stream."""
|
| 511 |
+
try:
|
| 512 |
+
logger.info(f"Processing streaming query: '{request.question}'")
|
| 513 |
+
|
| 514 |
+
# Stream answer chunks
|
| 515 |
+
for chunk in app_state.query_engine.stream_query(
|
| 516 |
+
question=request.question,
|
| 517 |
+
filters=request.filters,
|
| 518 |
+
):
|
| 519 |
+
# SSE format: data: <content>\n\n
|
| 520 |
+
yield f"data: {chunk}\n\n"
|
| 521 |
+
|
| 522 |
+
logger.info("Streaming query complete")
|
| 523 |
+
|
| 524 |
+
except Exception as e:
|
| 525 |
+
logger.error(f"Error in streaming query: {e}", exc_info=True)
|
| 526 |
+
yield f"data: [ERROR] {str(e)}\n\n"
|
| 527 |
+
|
| 528 |
+
return StreamingResponse(
|
| 529 |
+
generate(),
|
| 530 |
+
media_type="text/event-stream",
|
| 531 |
+
headers={
|
| 532 |
+
"Cache-Control": "no-cache",
|
| 533 |
+
"Connection": "keep-alive",
|
| 534 |
+
}
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
@app.get("/stats", response_model=StatsResponse)
|
| 539 |
+
async def get_stats():
|
| 540 |
+
"""
|
| 541 |
+
Get index and pipeline statistics.
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
StatsResponse with collection info and pipeline config
|
| 545 |
+
|
| 546 |
+
Raises:
|
| 547 |
+
HTTPException: If service unavailable or stats retrieval fails
|
| 548 |
+
"""
|
| 549 |
+
check_initialization()
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
# Get collection info
|
| 553 |
+
collection_info = app_state.qdrant_manager.get_collection_info()
|
| 554 |
+
if not collection_info:
|
| 555 |
+
raise HTTPException(
|
| 556 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 557 |
+
detail="Collection not found"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# Get pipeline config
|
| 561 |
+
pipeline_config = app_state.query_engine.get_pipeline_info()
|
| 562 |
+
|
| 563 |
+
return StatsResponse(
|
| 564 |
+
collection_info=collection_info,
|
| 565 |
+
pipeline_config=pipeline_config,
|
| 566 |
+
documents_indexed=collection_info.get("vectors_count", 0),
|
| 567 |
+
timestamp=time.time(),
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
except HTTPException:
|
| 571 |
+
raise
|
| 572 |
+
except Exception as e:
|
| 573 |
+
logger.error(f"Error retrieving stats: {e}", exc_info=True)
|
| 574 |
+
raise HTTPException(
|
| 575 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 576 |
+
detail=f"Error retrieving stats: {str(e)}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# ============================================================================
|
| 581 |
+
# Error Handlers
|
| 582 |
+
# ============================================================================
|
| 583 |
+
|
| 584 |
+
@app.exception_handler(HTTPException)
|
| 585 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
| 586 |
+
"""
|
| 587 |
+
Handle HTTP exceptions.
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
JSON error response with proper status code
|
| 591 |
+
"""
|
| 592 |
+
return {
|
| 593 |
+
"error": exc.detail,
|
| 594 |
+
"status_code": exc.status_code,
|
| 595 |
+
"timestamp": time.time(),
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
@app.exception_handler(Exception)
|
| 600 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
| 601 |
+
"""
|
| 602 |
+
Handle general exceptions.
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
JSON error response with 500 status
|
| 606 |
+
"""
|
| 607 |
+
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
| 608 |
+
|
| 609 |
+
return {
|
| 610 |
+
"error": "Internal server error",
|
| 611 |
+
"detail": str(exc),
|
| 612 |
+
"status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 613 |
+
"timestamp": time.time(),
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
# ============================================================================
|
| 618 |
+
# Mount Gradio UI
|
| 619 |
+
# ============================================================================
|
| 620 |
+
|
| 621 |
+
# Create and mount Gradio interface
|
| 622 |
+
# Gradio will access query_engine through app_state once initialized
|
| 623 |
+
gradio_interface = create_gradio_interface(
|
| 624 |
+
query_engine_getter=lambda: app_state.query_engine
|
| 625 |
+
)
|
| 626 |
+
app = gr.mount_gradio_app(app, gradio_interface, path="/ui")
|
| 627 |
+
logger.info("Gradio UI mounted at /ui")
|
src/llm/__init__.py
ADDED
|
File without changes
|
src/llm/llm_client.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for LLM clients."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Generator, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LLMClient(ABC):
|
| 8 |
+
"""
|
| 9 |
+
Abstract base class for LLM clients.
|
| 10 |
+
|
| 11 |
+
All LLM providers (Ollama, OpenAI-compatible, etc.) must implement
|
| 12 |
+
this interface to be used interchangeably in the RAG pipeline.
|
| 13 |
+
|
| 14 |
+
Implementations must also expose a ``llm_model`` attribute (str)
|
| 15 |
+
identifying the model in use.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
llm_model: str
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def generate(
|
| 22 |
+
self,
|
| 23 |
+
prompt: str,
|
| 24 |
+
system_prompt: Optional[str] = None,
|
| 25 |
+
temperature: Optional[float] = None,
|
| 26 |
+
max_tokens: Optional[int] = None,
|
| 27 |
+
stop: Optional[List[str]] = None,
|
| 28 |
+
) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Generate text using the LLM (non-streaming).
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
prompt: User prompt
|
| 34 |
+
system_prompt: Optional system prompt
|
| 35 |
+
temperature: Sampling temperature
|
| 36 |
+
max_tokens: Maximum tokens to generate
|
| 37 |
+
stop: Stop sequences
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Generated text
|
| 41 |
+
"""
|
| 42 |
+
...
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def stream_generate(
|
| 46 |
+
self,
|
| 47 |
+
prompt: str,
|
| 48 |
+
system_prompt: Optional[str] = None,
|
| 49 |
+
temperature: Optional[float] = None,
|
| 50 |
+
max_tokens: Optional[int] = None,
|
| 51 |
+
stop: Optional[List[str]] = None,
|
| 52 |
+
) -> Generator[str, None, None]:
|
| 53 |
+
"""
|
| 54 |
+
Generate text using the LLM with streaming.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
prompt: User prompt
|
| 58 |
+
system_prompt: Optional system prompt
|
| 59 |
+
temperature: Sampling temperature
|
| 60 |
+
max_tokens: Maximum tokens to generate
|
| 61 |
+
stop: Stop sequences
|
| 62 |
+
|
| 63 |
+
Yields:
|
| 64 |
+
Generated text chunks
|
| 65 |
+
"""
|
| 66 |
+
...
|
src/llm/ollama_client.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ollama client for embeddings and LLM inference."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from typing import Generator, List, Optional
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
from rich.console import Console
|
| 9 |
+
|
| 10 |
+
from config.settings import settings
|
| 11 |
+
from src.llm.llm_client import LLMClient
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class OllamaConnectionError(Exception):
|
| 20 |
+
"""Raised when cannot connect to Ollama."""
|
| 21 |
+
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OllamaModelNotFoundError(Exception):
|
| 26 |
+
"""Raised when requested model is not available."""
|
| 27 |
+
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class OllamaClient(LLMClient):
|
| 32 |
+
"""
|
| 33 |
+
Client for interacting with Ollama for embeddings and LLM inference.
|
| 34 |
+
|
| 35 |
+
Features:
|
| 36 |
+
- Embedding generation (single and batch)
|
| 37 |
+
- LLM text generation (streaming and non-streaming)
|
| 38 |
+
- Health checks
|
| 39 |
+
- Automatic retry with exponential backoff
|
| 40 |
+
- Model verification
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
base_url: Optional[str] = None,
|
| 46 |
+
embedding_model: Optional[str] = None,
|
| 47 |
+
llm_model: Optional[str] = None,
|
| 48 |
+
timeout: int = 30,
|
| 49 |
+
max_retries: int = 3,
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Initialize Ollama client.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
base_url: Ollama API base URL (default: from settings)
|
| 56 |
+
embedding_model: Embedding model name (None to skip, or from settings if not provided)
|
| 57 |
+
llm_model: LLM model name (default: from settings)
|
| 58 |
+
timeout: Request timeout in seconds
|
| 59 |
+
max_retries: Maximum number of retries for failed requests
|
| 60 |
+
"""
|
| 61 |
+
self.base_url = (base_url or settings.ollama_base_url).rstrip("/")
|
| 62 |
+
self.embedding_model = embedding_model
|
| 63 |
+
self.llm_model = llm_model or settings.llm_model
|
| 64 |
+
self.timeout = timeout
|
| 65 |
+
self.max_retries = max_retries
|
| 66 |
+
|
| 67 |
+
self.console = Console()
|
| 68 |
+
|
| 69 |
+
# Test connection and verify models
|
| 70 |
+
self._initialize()
|
| 71 |
+
|
| 72 |
+
def _initialize(self):
|
| 73 |
+
"""Initialize connection and verify models."""
|
| 74 |
+
# Check if Ollama is running
|
| 75 |
+
if not self.check_health():
|
| 76 |
+
error_msg = (
|
| 77 |
+
f"Cannot connect to Ollama at {self.base_url}. "
|
| 78 |
+
"Please ensure Ollama is running."
|
| 79 |
+
)
|
| 80 |
+
logger.error(error_msg)
|
| 81 |
+
raise OllamaConnectionError(error_msg)
|
| 82 |
+
|
| 83 |
+
self.console.print(f"[green][/green] Connected to Ollama at {self.base_url}")
|
| 84 |
+
|
| 85 |
+
# Verify embedding model (only if specified)
|
| 86 |
+
if self.embedding_model and not self._check_model_exists(self.embedding_model):
|
| 87 |
+
error_msg = (
|
| 88 |
+
f"Embedding model '{self.embedding_model}' not found. "
|
| 89 |
+
f"Please pull it with: ollama pull {self.embedding_model}"
|
| 90 |
+
)
|
| 91 |
+
logger.error(error_msg)
|
| 92 |
+
raise OllamaModelNotFoundError(error_msg)
|
| 93 |
+
|
| 94 |
+
# Get and log embedding model info
|
| 95 |
+
if self.embedding_model:
|
| 96 |
+
embed_info = self._get_model_info(self.embedding_model)
|
| 97 |
+
if embed_info:
|
| 98 |
+
self.console.print(
|
| 99 |
+
f"[green][/green] Embedding model: {self.embedding_model}"
|
| 100 |
+
)
|
| 101 |
+
logger.info(f"Embedding model info: {embed_info}")
|
| 102 |
+
|
| 103 |
+
# Verify LLM model
|
| 104 |
+
if not self._check_model_exists(self.llm_model):
|
| 105 |
+
error_msg = (
|
| 106 |
+
f"LLM model '{self.llm_model}' not found. "
|
| 107 |
+
f"Please pull it with: ollama pull {self.llm_model}"
|
| 108 |
+
)
|
| 109 |
+
logger.error(error_msg)
|
| 110 |
+
raise OllamaModelNotFoundError(error_msg)
|
| 111 |
+
|
| 112 |
+
# Get and log LLM model info
|
| 113 |
+
llm_info = self._get_model_info(self.llm_model)
|
| 114 |
+
if llm_info:
|
| 115 |
+
self.console.print(f"[green][/green] LLM model: {self.llm_model}")
|
| 116 |
+
logger.info(f"LLM model info: {llm_info}")
|
| 117 |
+
|
| 118 |
+
def check_health(self) -> bool:
|
| 119 |
+
"""
|
| 120 |
+
Check if Ollama server is running and reachable.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
True if server is healthy, False otherwise
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
response = requests.get(
|
| 127 |
+
f"{self.base_url}/api/tags", timeout=self.timeout
|
| 128 |
+
)
|
| 129 |
+
return response.status_code == 200
|
| 130 |
+
except requests.exceptions.RequestException as e:
|
| 131 |
+
logger.warning(f"Health check failed: {e}")
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
def _check_model_exists(self, model_name: str) -> bool:
|
| 135 |
+
"""
|
| 136 |
+
Check if a model exists in Ollama.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model_name: Name of the model to check
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
True if model exists, False otherwise
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
response = requests.get(
|
| 146 |
+
f"{self.base_url}/api/tags", timeout=self.timeout
|
| 147 |
+
)
|
| 148 |
+
if response.status_code == 200:
|
| 149 |
+
data = response.json()
|
| 150 |
+
models = [m["name"] for m in data.get("models", [])]
|
| 151 |
+
# Check both exact match and with :latest tag
|
| 152 |
+
return (
|
| 153 |
+
model_name in models
|
| 154 |
+
or f"{model_name}:latest" in models
|
| 155 |
+
or any(m.startswith(f"{model_name}:") for m in models)
|
| 156 |
+
)
|
| 157 |
+
except requests.exceptions.RequestException as e:
|
| 158 |
+
logger.error(f"Error checking model existence: {e}")
|
| 159 |
+
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
def _get_model_info(self, model_name: str) -> Optional[dict]:
|
| 163 |
+
"""
|
| 164 |
+
Get information about a model.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
model_name: Name of the model
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Dictionary with model information or None
|
| 171 |
+
"""
|
| 172 |
+
try:
|
| 173 |
+
response = requests.post(
|
| 174 |
+
f"{self.base_url}/api/show",
|
| 175 |
+
json={"name": model_name},
|
| 176 |
+
timeout=self.timeout,
|
| 177 |
+
)
|
| 178 |
+
if response.status_code == 200:
|
| 179 |
+
return response.json()
|
| 180 |
+
except requests.exceptions.RequestException as e:
|
| 181 |
+
logger.warning(f"Could not get model info: {e}")
|
| 182 |
+
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
def _retry_with_backoff(self, func, *args, **kwargs):
|
| 186 |
+
"""
|
| 187 |
+
Retry a function with exponential backoff.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
func: Function to retry
|
| 191 |
+
*args: Positional arguments for func
|
| 192 |
+
**kwargs: Keyword arguments for func
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Function result
|
| 196 |
+
|
| 197 |
+
Raises:
|
| 198 |
+
Last exception if all retries fail
|
| 199 |
+
"""
|
| 200 |
+
last_exception = None
|
| 201 |
+
|
| 202 |
+
for attempt in range(self.max_retries):
|
| 203 |
+
try:
|
| 204 |
+
return func(*args, **kwargs)
|
| 205 |
+
except requests.exceptions.RequestException as e:
|
| 206 |
+
last_exception = e
|
| 207 |
+
if attempt < self.max_retries - 1:
|
| 208 |
+
# Exponential backoff: 1s, 2s, 4s, ...
|
| 209 |
+
wait_time = 2**attempt
|
| 210 |
+
logger.warning(
|
| 211 |
+
f"Request failed (attempt {attempt + 1}/{self.max_retries}), "
|
| 212 |
+
f"retrying in {wait_time}s: {e}"
|
| 213 |
+
)
|
| 214 |
+
time.sleep(wait_time)
|
| 215 |
+
else:
|
| 216 |
+
logger.error(f"All {self.max_retries} attempts failed")
|
| 217 |
+
|
| 218 |
+
raise last_exception
|
| 219 |
+
|
| 220 |
+
def embed_text(self, text: str, return_zero_on_failure: bool = False, max_chars: int = 2000) -> List[float]:
|
| 221 |
+
"""
|
| 222 |
+
Generate embedding for a single text.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
text: Input text to embed
|
| 226 |
+
return_zero_on_failure: If True, return zero vector instead of raising exception
|
| 227 |
+
max_chars: Maximum characters to send to Ollama (default: 2000, safe limit for WSL2)
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Embedding vector as list of floats
|
| 231 |
+
|
| 232 |
+
Raises:
|
| 233 |
+
OllamaConnectionError: If request fails after retries (unless return_zero_on_failure=True)
|
| 234 |
+
"""
|
| 235 |
+
# Handle empty text
|
| 236 |
+
if not text or not text.strip():
|
| 237 |
+
logger.warning("Empty text provided for embedding, returning zero vector")
|
| 238 |
+
return [0.0] * 768 # Standard embedding dimension
|
| 239 |
+
|
| 240 |
+
# Truncate if too long to prevent context overflow
|
| 241 |
+
original_length = len(text)
|
| 242 |
+
if len(text) > max_chars:
|
| 243 |
+
text = text[:max_chars]
|
| 244 |
+
logger.debug(f"Truncated text from {original_length} to {max_chars} chars for embedding")
|
| 245 |
+
|
| 246 |
+
def _embed():
|
| 247 |
+
response = requests.post(
|
| 248 |
+
f"{self.base_url}/api/embed", # Correct endpoint for Ollama 0.13.2+
|
| 249 |
+
json={"model": self.embedding_model, "input": text}, # Use 'input' not 'prompt'
|
| 250 |
+
timeout=self.timeout,
|
| 251 |
+
)
|
| 252 |
+
response.raise_for_status()
|
| 253 |
+
data = response.json()
|
| 254 |
+
# API returns embeddings array, we want the first one
|
| 255 |
+
return data["embeddings"][0] if "embeddings" in data else data["embedding"]
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
return self._retry_with_backoff(_embed)
|
| 259 |
+
except requests.exceptions.RequestException as e:
|
| 260 |
+
if return_zero_on_failure:
|
| 261 |
+
logger.warning(f"Failed to generate embedding (text length: {len(text)}), returning zero vector: {e}")
|
| 262 |
+
return [0.0] * 768
|
| 263 |
+
else:
|
| 264 |
+
logger.error(f"Failed to generate embedding: {e}")
|
| 265 |
+
raise OllamaConnectionError(f"Embedding generation failed: {e}")
|
| 266 |
+
|
| 267 |
+
def embed_batch(
|
| 268 |
+
self, texts: List[str], batch_size: int = 1, show_progress: bool = True
|
| 269 |
+
) -> List[List[float]]:
|
| 270 |
+
"""
|
| 271 |
+
Generate embeddings for multiple texts sequentially.
|
| 272 |
+
|
| 273 |
+
Note: batch_size parameter is kept for API compatibility but is ignored.
|
| 274 |
+
Processing is always sequential to avoid overwhelming local Ollama instance.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
texts: List of input texts
|
| 278 |
+
batch_size: Ignored (kept for compatibility)
|
| 279 |
+
show_progress: Show progress bar
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
List of embedding vectors
|
| 283 |
+
"""
|
| 284 |
+
import time
|
| 285 |
+
|
| 286 |
+
embeddings = []
|
| 287 |
+
failed_count = 0
|
| 288 |
+
|
| 289 |
+
if show_progress:
|
| 290 |
+
from tqdm import tqdm
|
| 291 |
+
pbar = tqdm(total=len(texts), desc="Generating embeddings", unit="chunk")
|
| 292 |
+
|
| 293 |
+
for i, text in enumerate(texts):
|
| 294 |
+
try:
|
| 295 |
+
if i > 0:
|
| 296 |
+
time.sleep(0.5)
|
| 297 |
+
|
| 298 |
+
# Use return_zero_on_failure to prevent single failures from stopping the entire process
|
| 299 |
+
embedding = self.embed_text(text, return_zero_on_failure=True)
|
| 300 |
+
embeddings.append(embedding)
|
| 301 |
+
|
| 302 |
+
# Check if we got a zero vector (indicates failure)
|
| 303 |
+
if embedding == [0.0] * 768:
|
| 304 |
+
failed_count += 1
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
logger.error(f"Unexpected error embedding text {i}: {e}")
|
| 308 |
+
# Fallback to zero vector
|
| 309 |
+
embeddings.append([0.0] * 768)
|
| 310 |
+
failed_count += 1
|
| 311 |
+
|
| 312 |
+
if show_progress:
|
| 313 |
+
pbar.update(1)
|
| 314 |
+
|
| 315 |
+
if show_progress:
|
| 316 |
+
pbar.close()
|
| 317 |
+
|
| 318 |
+
success_count = len(texts) - failed_count
|
| 319 |
+
logger.info(f"Generated {success_count}/{len(texts)} embeddings successfully")
|
| 320 |
+
|
| 321 |
+
if failed_count > 0:
|
| 322 |
+
logger.warning(f"{failed_count} chunks failed and were assigned zero vectors")
|
| 323 |
+
|
| 324 |
+
return embeddings
|
| 325 |
+
|
| 326 |
+
def generate(
|
| 327 |
+
self,
|
| 328 |
+
prompt: str,
|
| 329 |
+
system_prompt: Optional[str] = None,
|
| 330 |
+
temperature: Optional[float] = None,
|
| 331 |
+
max_tokens: Optional[int] = None,
|
| 332 |
+
stop: Optional[List[str]] = None,
|
| 333 |
+
) -> str:
|
| 334 |
+
"""
|
| 335 |
+
Generate text using LLM (non-streaming).
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
prompt: User prompt
|
| 339 |
+
system_prompt: Optional system prompt
|
| 340 |
+
temperature: Sampling temperature (default: from settings)
|
| 341 |
+
max_tokens: Maximum tokens to generate (default: from settings)
|
| 342 |
+
stop: Stop sequences
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
Generated text
|
| 346 |
+
|
| 347 |
+
Raises:
|
| 348 |
+
OllamaConnectionError: If generation fails
|
| 349 |
+
"""
|
| 350 |
+
temperature = temperature if temperature is not None else settings.llm_temperature
|
| 351 |
+
max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
|
| 352 |
+
|
| 353 |
+
def _generate():
|
| 354 |
+
payload = {
|
| 355 |
+
"model": self.llm_model,
|
| 356 |
+
"prompt": prompt,
|
| 357 |
+
"stream": False,
|
| 358 |
+
"options": {
|
| 359 |
+
"temperature": temperature,
|
| 360 |
+
"num_predict": max_tokens,
|
| 361 |
+
},
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
if system_prompt:
|
| 365 |
+
payload["system"] = system_prompt
|
| 366 |
+
|
| 367 |
+
if stop:
|
| 368 |
+
payload["options"]["stop"] = stop
|
| 369 |
+
|
| 370 |
+
response = requests.post(
|
| 371 |
+
f"{self.base_url}/api/generate",
|
| 372 |
+
json=payload,
|
| 373 |
+
timeout=self.timeout * 2, # Longer timeout for generation
|
| 374 |
+
)
|
| 375 |
+
response.raise_for_status()
|
| 376 |
+
data = response.json()
|
| 377 |
+
return data["response"]
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
return self._retry_with_backoff(_generate)
|
| 381 |
+
except requests.exceptions.RequestException as e:
|
| 382 |
+
logger.error(f"Failed to generate text: {e}")
|
| 383 |
+
raise OllamaConnectionError(f"Text generation failed: {e}")
|
| 384 |
+
|
| 385 |
+
def stream_generate(
|
| 386 |
+
self,
|
| 387 |
+
prompt: str,
|
| 388 |
+
system_prompt: Optional[str] = None,
|
| 389 |
+
temperature: Optional[float] = None,
|
| 390 |
+
max_tokens: Optional[int] = None,
|
| 391 |
+
stop: Optional[List[str]] = None,
|
| 392 |
+
) -> Generator[str, None, None]:
|
| 393 |
+
"""
|
| 394 |
+
Generate text using LLM with streaming.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
prompt: User prompt
|
| 398 |
+
system_prompt: Optional system prompt
|
| 399 |
+
temperature: Sampling temperature (default: from settings)
|
| 400 |
+
max_tokens: Maximum tokens to generate (default: from settings)
|
| 401 |
+
stop: Stop sequences
|
| 402 |
+
|
| 403 |
+
Yields:
|
| 404 |
+
Generated text chunks
|
| 405 |
+
|
| 406 |
+
Raises:
|
| 407 |
+
OllamaConnectionError: If generation fails
|
| 408 |
+
"""
|
| 409 |
+
temperature = temperature if temperature is not None else settings.llm_temperature
|
| 410 |
+
max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
|
| 411 |
+
|
| 412 |
+
payload = {
|
| 413 |
+
"model": self.llm_model,
|
| 414 |
+
"prompt": prompt,
|
| 415 |
+
"stream": True,
|
| 416 |
+
"options": {
|
| 417 |
+
"temperature": temperature,
|
| 418 |
+
"num_predict": max_tokens,
|
| 419 |
+
},
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
if system_prompt:
|
| 423 |
+
payload["system"] = system_prompt
|
| 424 |
+
|
| 425 |
+
if stop:
|
| 426 |
+
payload["options"]["stop"] = stop
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
response = requests.post(
|
| 430 |
+
f"{self.base_url}/api/generate",
|
| 431 |
+
json=payload,
|
| 432 |
+
stream=True,
|
| 433 |
+
timeout=self.timeout * 2,
|
| 434 |
+
)
|
| 435 |
+
response.raise_for_status()
|
| 436 |
+
|
| 437 |
+
# Stream responses
|
| 438 |
+
for line in response.iter_lines():
|
| 439 |
+
if line:
|
| 440 |
+
import json
|
| 441 |
+
|
| 442 |
+
data = json.loads(line)
|
| 443 |
+
if "response" in data:
|
| 444 |
+
yield data["response"]
|
| 445 |
+
|
| 446 |
+
# Check if done
|
| 447 |
+
if data.get("done", False):
|
| 448 |
+
break
|
| 449 |
+
|
| 450 |
+
except requests.exceptions.RequestException as e:
|
| 451 |
+
logger.error(f"Failed to stream generate text: {e}")
|
| 452 |
+
raise OllamaConnectionError(f"Streaming generation failed: {e}")
|
| 453 |
+
|
| 454 |
+
def get_available_models(self) -> List[str]:
|
| 455 |
+
"""
|
| 456 |
+
Get list of available models in Ollama.
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
List of model names
|
| 460 |
+
"""
|
| 461 |
+
try:
|
| 462 |
+
response = requests.get(
|
| 463 |
+
f"{self.base_url}/api/tags", timeout=self.timeout
|
| 464 |
+
)
|
| 465 |
+
if response.status_code == 200:
|
| 466 |
+
data = response.json()
|
| 467 |
+
return [m["name"] for m in data.get("models", [])]
|
| 468 |
+
except requests.exceptions.RequestException as e:
|
| 469 |
+
logger.error(f"Failed to get available models: {e}")
|
| 470 |
+
|
| 471 |
+
return []
|
| 472 |
+
|
| 473 |
+
def pull_model(self, model_name: str) -> bool:
|
| 474 |
+
"""
|
| 475 |
+
Pull a model from Ollama registry.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
model_name: Name of model to pull
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
True if successful
|
| 482 |
+
|
| 483 |
+
Note:
|
| 484 |
+
This is a blocking operation that may take a while
|
| 485 |
+
"""
|
| 486 |
+
try:
|
| 487 |
+
self.console.print(f"[cyan]Pulling model: {model_name}...[/cyan]")
|
| 488 |
+
|
| 489 |
+
response = requests.post(
|
| 490 |
+
f"{self.base_url}/api/pull",
|
| 491 |
+
json={"name": model_name},
|
| 492 |
+
stream=True,
|
| 493 |
+
timeout=None, # No timeout for pulling
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Stream progress
|
| 497 |
+
for line in response.iter_lines():
|
| 498 |
+
if line:
|
| 499 |
+
import json
|
| 500 |
+
|
| 501 |
+
data = json.loads(line)
|
| 502 |
+
status = data.get("status", "")
|
| 503 |
+
if status:
|
| 504 |
+
self.console.print(f" {status}")
|
| 505 |
+
|
| 506 |
+
self.console.print(f"[green][/green] Model pulled: {model_name}")
|
| 507 |
+
return True
|
| 508 |
+
|
| 509 |
+
except requests.exceptions.RequestException as e:
|
| 510 |
+
logger.error(f"Failed to pull model: {e}")
|
| 511 |
+
self.console.print(f"[red][/red] Failed to pull model: {e}")
|
| 512 |
+
return False
|
src/llm/openai_client.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenAI-compatible client for LLM inference (supports Groq, DeepSeek, OpenAI, etc.)."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Generator, List, Optional
|
| 5 |
+
|
| 6 |
+
from rich.console import Console
|
| 7 |
+
|
| 8 |
+
from config.settings import settings
|
| 9 |
+
from src.llm.llm_client import LLMClient
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class OpenAIClientError(Exception):
|
| 18 |
+
"""Raised when an OpenAI-compatible API call fails."""
|
| 19 |
+
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OpenAIClient(LLMClient):
|
| 24 |
+
"""
|
| 25 |
+
Client for interacting with OpenAI-compatible APIs for LLM inference.
|
| 26 |
+
|
| 27 |
+
Supports:
|
| 28 |
+
- OpenAI (https://api.openai.com/v1)
|
| 29 |
+
- Groq (https://api.groq.com/openai/v1)
|
| 30 |
+
- DeepSeek (https://api.deepseek.com/v1)
|
| 31 |
+
- Any OpenAI-compatible endpoint
|
| 32 |
+
|
| 33 |
+
Features:
|
| 34 |
+
- Non-streaming and streaming text generation
|
| 35 |
+
- Configurable model, temperature, and max tokens
|
| 36 |
+
- Automatic retry via the openai SDK
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
api_key: Optional[str] = None,
|
| 42 |
+
base_url: Optional[str] = None,
|
| 43 |
+
model: Optional[str] = None,
|
| 44 |
+
temperature: Optional[float] = None,
|
| 45 |
+
max_tokens: Optional[int] = None,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Initialize OpenAI-compatible client.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
api_key: API key (default: from settings)
|
| 52 |
+
base_url: Base URL for the API (default: from settings, or OpenAI default)
|
| 53 |
+
model: Model name (default: from settings)
|
| 54 |
+
temperature: Default temperature (default: from settings)
|
| 55 |
+
max_tokens: Default max tokens (default: from settings)
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
from openai import OpenAI
|
| 59 |
+
except ImportError:
|
| 60 |
+
raise ImportError(
|
| 61 |
+
"The 'openai' package is required for OpenAI-compatible providers. "
|
| 62 |
+
"Install it with: pip install openai>=1.0.0"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self._api_key = api_key or settings.openai_api_key
|
| 66 |
+
if not self._api_key:
|
| 67 |
+
raise OpenAIClientError(
|
| 68 |
+
"API key is required for OpenAI-compatible provider. "
|
| 69 |
+
"Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self._base_url = base_url or settings.openai_base_url
|
| 73 |
+
self.llm_model = model or settings.openai_model
|
| 74 |
+
self._temperature = temperature if temperature is not None else settings.llm_temperature
|
| 75 |
+
self._max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
|
| 76 |
+
|
| 77 |
+
self.console = Console()
|
| 78 |
+
|
| 79 |
+
# Initialize OpenAI client
|
| 80 |
+
client_kwargs = {"api_key": self._api_key}
|
| 81 |
+
if self._base_url:
|
| 82 |
+
client_kwargs["base_url"] = self._base_url
|
| 83 |
+
|
| 84 |
+
self._client = OpenAI(**client_kwargs)
|
| 85 |
+
|
| 86 |
+
# Log initialization
|
| 87 |
+
provider_name = self._base_url or "OpenAI (default)"
|
| 88 |
+
self.console.print(f"[green][/green] OpenAI-compatible client initialized")
|
| 89 |
+
self.console.print(f" Provider: {provider_name}")
|
| 90 |
+
self.console.print(f" Model: {self.llm_model}")
|
| 91 |
+
logger.info(f"OpenAI client initialized: provider={provider_name}, model={self.llm_model}")
|
| 92 |
+
|
| 93 |
+
def generate(
|
| 94 |
+
self,
|
| 95 |
+
prompt: str,
|
| 96 |
+
system_prompt: Optional[str] = None,
|
| 97 |
+
temperature: Optional[float] = None,
|
| 98 |
+
max_tokens: Optional[int] = None,
|
| 99 |
+
stop: Optional[List[str]] = None,
|
| 100 |
+
) -> str:
|
| 101 |
+
"""
|
| 102 |
+
Generate text using the OpenAI-compatible API (non-streaming).
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
prompt: User prompt
|
| 106 |
+
system_prompt: Optional system prompt
|
| 107 |
+
temperature: Sampling temperature (default: from init/settings)
|
| 108 |
+
max_tokens: Maximum tokens to generate (default: from init/settings)
|
| 109 |
+
stop: Stop sequences
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Generated text
|
| 113 |
+
|
| 114 |
+
Raises:
|
| 115 |
+
OpenAIClientError: If generation fails
|
| 116 |
+
"""
|
| 117 |
+
temperature = temperature if temperature is not None else self._temperature
|
| 118 |
+
max_tokens = max_tokens if max_tokens is not None else self._max_tokens
|
| 119 |
+
|
| 120 |
+
messages = []
|
| 121 |
+
if system_prompt:
|
| 122 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 123 |
+
messages.append({"role": "user", "content": prompt})
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
response = self._client.chat.completions.create(
|
| 127 |
+
model=self.llm_model,
|
| 128 |
+
messages=messages,
|
| 129 |
+
temperature=temperature,
|
| 130 |
+
max_tokens=max_tokens,
|
| 131 |
+
stop=stop,
|
| 132 |
+
)
|
| 133 |
+
return response.choices[0].message.content or ""
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Failed to generate text via OpenAI-compatible API: {e}")
|
| 137 |
+
raise OpenAIClientError(f"Text generation failed: {e}")
|
| 138 |
+
|
| 139 |
+
def stream_generate(
|
| 140 |
+
self,
|
| 141 |
+
prompt: str,
|
| 142 |
+
system_prompt: Optional[str] = None,
|
| 143 |
+
temperature: Optional[float] = None,
|
| 144 |
+
max_tokens: Optional[int] = None,
|
| 145 |
+
stop: Optional[List[str]] = None,
|
| 146 |
+
) -> Generator[str, None, None]:
|
| 147 |
+
"""
|
| 148 |
+
Generate text using the OpenAI-compatible API with streaming.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
prompt: User prompt
|
| 152 |
+
system_prompt: Optional system prompt
|
| 153 |
+
temperature: Sampling temperature (default: from init/settings)
|
| 154 |
+
max_tokens: Maximum tokens to generate (default: from init/settings)
|
| 155 |
+
stop: Stop sequences
|
| 156 |
+
|
| 157 |
+
Yields:
|
| 158 |
+
Generated text chunks
|
| 159 |
+
|
| 160 |
+
Raises:
|
| 161 |
+
OpenAIClientError: If generation fails
|
| 162 |
+
"""
|
| 163 |
+
temperature = temperature if temperature is not None else self._temperature
|
| 164 |
+
max_tokens = max_tokens if max_tokens is not None else self._max_tokens
|
| 165 |
+
|
| 166 |
+
messages = []
|
| 167 |
+
if system_prompt:
|
| 168 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 169 |
+
messages.append({"role": "user", "content": prompt})
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
stream = self._client.chat.completions.create(
|
| 173 |
+
model=self.llm_model,
|
| 174 |
+
messages=messages,
|
| 175 |
+
temperature=temperature,
|
| 176 |
+
max_tokens=max_tokens,
|
| 177 |
+
stop=stop,
|
| 178 |
+
stream=True,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
for chunk in stream:
|
| 182 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
| 183 |
+
yield chunk.choices[0].delta.content
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Failed to stream generate text via OpenAI-compatible API: {e}")
|
| 187 |
+
raise OpenAIClientError(f"Streaming generation failed: {e}")
|
src/llm/sentence_transformer_client.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sentence Transformers client for reliable embeddings."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Configure logging
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SentenceTransformerClient:
|
| 17 |
+
"""
|
| 18 |
+
Client for generating embeddings using sentence-transformers.
|
| 19 |
+
|
| 20 |
+
This is a drop-in replacement for OllamaClient embeddings with much better
|
| 21 |
+
stability and performance. Uses HuggingFace models directly without any server.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 27 |
+
device: str = None,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the sentence transformer client.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_name: HuggingFace model name
|
| 34 |
+
Options:
|
| 35 |
+
- "sentence-transformers/all-MiniLM-L6-v2" (384 dim, fast, general)
|
| 36 |
+
- "sentence-transformers/all-mpnet-base-v2" (768 dim, better quality)
|
| 37 |
+
- "BAAI/bge-small-en-v1.5" (384 dim, good for retrieval)
|
| 38 |
+
- "BAAI/bge-base-en-v1.5" (768 dim, better quality)
|
| 39 |
+
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
| 40 |
+
"""
|
| 41 |
+
self.model_name = model_name
|
| 42 |
+
|
| 43 |
+
# Auto-detect device if not specified
|
| 44 |
+
if device is None:
|
| 45 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
+
else:
|
| 47 |
+
self.device = device
|
| 48 |
+
|
| 49 |
+
logger.info(f"Loading embedding model: {model_name}")
|
| 50 |
+
logger.info(f"Using device: {self.device}")
|
| 51 |
+
|
| 52 |
+
# Load model
|
| 53 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
| 54 |
+
|
| 55 |
+
# Get embedding dimension
|
| 56 |
+
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 57 |
+
logger.info(f"Embedding dimension: {self.embedding_dim}")
|
| 58 |
+
|
| 59 |
+
def embed_text(self, text: str, return_zero_on_failure: bool = False) -> List[float]:
|
| 60 |
+
"""
|
| 61 |
+
Generate embedding for a single text.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
text: Input text to embed
|
| 65 |
+
return_zero_on_failure: If True, return zero vector on error (for compatibility)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Embedding vector as list of floats
|
| 69 |
+
"""
|
| 70 |
+
if not text or not text.strip():
|
| 71 |
+
logger.warning("Empty text provided, returning zero vector")
|
| 72 |
+
return [0.0] * self.embedding_dim
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
embedding = self.model.encode(
|
| 76 |
+
text,
|
| 77 |
+
convert_to_numpy=True,
|
| 78 |
+
show_progress_bar=False,
|
| 79 |
+
)
|
| 80 |
+
return embedding.tolist()
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Failed to generate embedding: {e}")
|
| 84 |
+
if return_zero_on_failure:
|
| 85 |
+
return [0.0] * self.embedding_dim
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
def embed_batch(
|
| 89 |
+
self,
|
| 90 |
+
texts: List[str],
|
| 91 |
+
batch_size: int = 32,
|
| 92 |
+
show_progress: bool = True,
|
| 93 |
+
) -> List[List[float]]:
|
| 94 |
+
"""
|
| 95 |
+
Generate embeddings for multiple texts efficiently.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
texts: List of input texts
|
| 99 |
+
batch_size: Number of texts to process in parallel
|
| 100 |
+
show_progress: Show progress bar
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
List of embedding vectors
|
| 104 |
+
"""
|
| 105 |
+
if not texts:
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
logger.info(f"Generating embeddings for {len(texts)} texts (batch_size={batch_size})")
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
embeddings = self.model.encode(
|
| 112 |
+
texts,
|
| 113 |
+
batch_size=batch_size,
|
| 114 |
+
show_progress_bar=show_progress,
|
| 115 |
+
convert_to_numpy=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Convert to list of lists
|
| 119 |
+
embeddings_list = embeddings.tolist()
|
| 120 |
+
|
| 121 |
+
logger.info(f"Successfully generated {len(embeddings_list)} embeddings")
|
| 122 |
+
return embeddings_list
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Batch embedding failed: {e}")
|
| 126 |
+
# Fallback to sequential processing
|
| 127 |
+
logger.warning("Falling back to sequential processing")
|
| 128 |
+
embeddings = []
|
| 129 |
+
|
| 130 |
+
iterator = tqdm(texts, desc="Generating embeddings") if show_progress else texts
|
| 131 |
+
for text in iterator:
|
| 132 |
+
embedding = self.embed_text(text, return_zero_on_failure=True)
|
| 133 |
+
embeddings.append(embedding)
|
| 134 |
+
|
| 135 |
+
failed_count = sum(1 for emb in embeddings if emb == [0.0] * self.embedding_dim)
|
| 136 |
+
if failed_count > 0:
|
| 137 |
+
logger.warning(f"{failed_count} embeddings failed and were assigned zero vectors")
|
| 138 |
+
|
| 139 |
+
return embeddings
|
| 140 |
+
|
| 141 |
+
def get_model_info(self) -> dict:
|
| 142 |
+
"""Get information about the loaded model."""
|
| 143 |
+
return {
|
| 144 |
+
"model_name": self.model_name,
|
| 145 |
+
"device": self.device,
|
| 146 |
+
"embedding_dim": self.embedding_dim,
|
| 147 |
+
"max_seq_length": self.model.max_seq_length,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Convenience function to create client with settings
|
| 152 |
+
def create_embedding_client(
|
| 153 |
+
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
| 154 |
+
) -> SentenceTransformerClient:
|
| 155 |
+
"""
|
| 156 |
+
Create embedding client with default settings.
|
| 157 |
+
|
| 158 |
+
Using all-mpnet-base-v2 by default as it provides 768-dim embeddings
|
| 159 |
+
(same as nomic-embed-text) with better quality.
|
| 160 |
+
"""
|
| 161 |
+
return SentenceTransformerClient(model_name=model_name)
|
src/processing/__init__.py
ADDED
|
File without changes
|
src/processing/chunker.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Semantic chunker for processing markdown documents with hierarchical structure."""
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
from rich.console import Console
|
| 12 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
| 13 |
+
|
| 14 |
+
from config.settings import settings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChunkNode(BaseModel):
|
| 18 |
+
"""
|
| 19 |
+
Pydantic model representing a semantic chunk of text.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
chunk_id: Unique identifier for the chunk
|
| 23 |
+
content: The actual text content
|
| 24 |
+
parent_section: The section header this chunk belongs to
|
| 25 |
+
document_title: Original article title
|
| 26 |
+
source_url: EyeWiki URL of the source document
|
| 27 |
+
chunk_index: Position of chunk in the document (0-indexed)
|
| 28 |
+
token_count: Approximate number of tokens in the chunk
|
| 29 |
+
metadata: Additional metadata from the source document
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
chunk_id: str = Field(..., description="Unique identifier (hash-based)")
|
| 33 |
+
content: str = Field(..., description="Text content of the chunk")
|
| 34 |
+
parent_section: str = Field(default="", description="Parent section header")
|
| 35 |
+
document_title: str = Field(default="", description="Original document title")
|
| 36 |
+
source_url: str = Field(default="", description="Source URL")
|
| 37 |
+
chunk_index: int = Field(..., ge=0, description="Position in document")
|
| 38 |
+
token_count: int = Field(..., ge=0, description="Approximate token count")
|
| 39 |
+
metadata: Dict = Field(default_factory=dict, description="Additional metadata")
|
| 40 |
+
|
| 41 |
+
def to_dict(self) -> Dict:
|
| 42 |
+
"""Convert to dictionary representation."""
|
| 43 |
+
return self.model_dump()
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_dict(cls, data: Dict) -> "ChunkNode":
|
| 47 |
+
"""Create ChunkNode from dictionary."""
|
| 48 |
+
return cls(**data)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SemanticChunker:
|
| 52 |
+
"""
|
| 53 |
+
Hierarchical semantic chunker that respects markdown structure.
|
| 54 |
+
|
| 55 |
+
Features:
|
| 56 |
+
- Splits on ## headers first (sections)
|
| 57 |
+
- Then splits large sections into semantic chunks
|
| 58 |
+
- Preserves parent section context
|
| 59 |
+
- Uses LlamaIndex SentenceSplitter for semantic splitting
|
| 60 |
+
- Configurable chunk sizes and overlap
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
chunk_size: Optional[int] = None,
|
| 66 |
+
chunk_overlap: Optional[int] = None,
|
| 67 |
+
min_chunk_size: int = 100,
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Initialize the SemanticChunker.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
chunk_size: Target chunk size in tokens (default: from settings)
|
| 74 |
+
chunk_overlap: Overlap between chunks in tokens (default: from settings)
|
| 75 |
+
min_chunk_size: Minimum chunk size to keep (default: 100 tokens)
|
| 76 |
+
"""
|
| 77 |
+
self.chunk_size = chunk_size or settings.chunk_size
|
| 78 |
+
self.chunk_overlap = chunk_overlap or settings.chunk_overlap
|
| 79 |
+
self.min_chunk_size = min_chunk_size
|
| 80 |
+
|
| 81 |
+
# Initialize LlamaIndex sentence splitter
|
| 82 |
+
self.sentence_splitter = SentenceSplitter(
|
| 83 |
+
chunk_size=self.chunk_size,
|
| 84 |
+
chunk_overlap=self.chunk_overlap,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.console = Console()
|
| 88 |
+
|
| 89 |
+
def _estimate_tokens(self, text: str) -> int:
|
| 90 |
+
"""
|
| 91 |
+
Estimate token count for text.
|
| 92 |
+
|
| 93 |
+
Uses a simple heuristic: ~4 characters per token.
|
| 94 |
+
More accurate than word count for medical/technical text.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
text: Input text
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Estimated token count
|
| 101 |
+
"""
|
| 102 |
+
return len(text) // 4
|
| 103 |
+
|
| 104 |
+
def _generate_chunk_id(self, content: str, chunk_index: int, source_url: str) -> str:
|
| 105 |
+
"""
|
| 106 |
+
Generate unique chunk ID using hash.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
content: Chunk content
|
| 110 |
+
chunk_index: Index of chunk
|
| 111 |
+
source_url: Source URL
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Unique chunk identifier
|
| 115 |
+
"""
|
| 116 |
+
# Create a unique string combining content snippet, index, and source
|
| 117 |
+
unique_string = f"{source_url}:{chunk_index}:{content[:100]}"
|
| 118 |
+
return hashlib.sha256(unique_string.encode()).hexdigest()[:16]
|
| 119 |
+
|
| 120 |
+
def _parse_markdown_sections(self, markdown: str) -> List[Tuple[str, str]]:
|
| 121 |
+
"""
|
| 122 |
+
Parse markdown into sections based on ## headers.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
markdown: Markdown content
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List of (header, content) tuples
|
| 129 |
+
"""
|
| 130 |
+
sections = []
|
| 131 |
+
|
| 132 |
+
# Split by ## headers (h2)
|
| 133 |
+
# Pattern matches: ## Header or ##Header
|
| 134 |
+
pattern = r"^##\s+(.+?)$"
|
| 135 |
+
lines = markdown.split("\n")
|
| 136 |
+
|
| 137 |
+
current_header = ""
|
| 138 |
+
current_content = []
|
| 139 |
+
|
| 140 |
+
for line in lines:
|
| 141 |
+
match = re.match(pattern, line)
|
| 142 |
+
if match:
|
| 143 |
+
# Save previous section if it has content
|
| 144 |
+
if current_content:
|
| 145 |
+
sections.append((current_header, "\n".join(current_content)))
|
| 146 |
+
|
| 147 |
+
# Start new section
|
| 148 |
+
current_header = match.group(1).strip()
|
| 149 |
+
current_content = [line] # Include the header in content
|
| 150 |
+
else:
|
| 151 |
+
current_content.append(line)
|
| 152 |
+
|
| 153 |
+
# Add final section
|
| 154 |
+
if current_content:
|
| 155 |
+
sections.append((current_header, "\n".join(current_content)))
|
| 156 |
+
|
| 157 |
+
return sections
|
| 158 |
+
|
| 159 |
+
def _split_large_section(self, text: str) -> List[str]:
|
| 160 |
+
"""
|
| 161 |
+
Split large section into semantic chunks using LlamaIndex.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
text: Section text to split
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
List of text chunks
|
| 168 |
+
"""
|
| 169 |
+
# Use LlamaIndex SentenceSplitter
|
| 170 |
+
chunks = self.sentence_splitter.split_text(text)
|
| 171 |
+
return chunks
|
| 172 |
+
|
| 173 |
+
def _clean_content(self, content: str) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Clean chunk content by removing excessive whitespace.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
content: Raw content
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Cleaned content
|
| 182 |
+
"""
|
| 183 |
+
# Remove excessive blank lines (more than 2 consecutive)
|
| 184 |
+
content = re.sub(r"\n{3,}", "\n\n", content)
|
| 185 |
+
|
| 186 |
+
# Remove leading/trailing whitespace
|
| 187 |
+
content = content.strip()
|
| 188 |
+
|
| 189 |
+
return content
|
| 190 |
+
|
| 191 |
+
def chunk_document(
|
| 192 |
+
self,
|
| 193 |
+
markdown_content: str,
|
| 194 |
+
metadata: Dict,
|
| 195 |
+
) -> List[ChunkNode]:
|
| 196 |
+
"""
|
| 197 |
+
Chunk a markdown document with hierarchical structure.
|
| 198 |
+
|
| 199 |
+
Process:
|
| 200 |
+
1. Parse document into sections by ## headers
|
| 201 |
+
2. For each section, check if it needs splitting
|
| 202 |
+
3. If section is small enough, keep as single chunk
|
| 203 |
+
4. If section is large, split into semantic chunks
|
| 204 |
+
5. Preserve parent section context in each chunk
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
markdown_content: Markdown text content
|
| 208 |
+
metadata: Document metadata (must include 'url' and 'title')
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
List of ChunkNode objects
|
| 212 |
+
"""
|
| 213 |
+
chunks = []
|
| 214 |
+
chunk_index = 0
|
| 215 |
+
|
| 216 |
+
# Extract metadata
|
| 217 |
+
source_url = metadata.get("url", "")
|
| 218 |
+
document_title = metadata.get("title", "Untitled")
|
| 219 |
+
|
| 220 |
+
# Parse into sections
|
| 221 |
+
sections = self._parse_markdown_sections(markdown_content)
|
| 222 |
+
|
| 223 |
+
# If no sections found, treat entire document as one section
|
| 224 |
+
if not sections or (len(sections) == 1 and not sections[0][0]):
|
| 225 |
+
sections = [("", markdown_content)]
|
| 226 |
+
|
| 227 |
+
for section_header, section_content in sections:
|
| 228 |
+
# Clean section content
|
| 229 |
+
section_content = self._clean_content(section_content)
|
| 230 |
+
|
| 231 |
+
# Skip empty sections
|
| 232 |
+
if not section_content:
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
# Estimate tokens in section
|
| 236 |
+
section_tokens = self._estimate_tokens(section_content)
|
| 237 |
+
|
| 238 |
+
# If section is smaller than chunk size, keep as single chunk
|
| 239 |
+
if section_tokens <= self.chunk_size:
|
| 240 |
+
# Only create chunk if it meets minimum size
|
| 241 |
+
if section_tokens >= self.min_chunk_size:
|
| 242 |
+
chunk_id = self._generate_chunk_id(
|
| 243 |
+
section_content, chunk_index, source_url
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
chunk = ChunkNode(
|
| 247 |
+
chunk_id=chunk_id,
|
| 248 |
+
content=section_content,
|
| 249 |
+
parent_section=section_header,
|
| 250 |
+
document_title=document_title,
|
| 251 |
+
source_url=source_url,
|
| 252 |
+
chunk_index=chunk_index,
|
| 253 |
+
token_count=section_tokens,
|
| 254 |
+
metadata=metadata,
|
| 255 |
+
)
|
| 256 |
+
chunks.append(chunk)
|
| 257 |
+
chunk_index += 1
|
| 258 |
+
else:
|
| 259 |
+
# Section is large, split into semantic chunks
|
| 260 |
+
sub_chunks = self._split_large_section(section_content)
|
| 261 |
+
|
| 262 |
+
for sub_chunk_content in sub_chunks:
|
| 263 |
+
sub_chunk_content = self._clean_content(sub_chunk_content)
|
| 264 |
+
|
| 265 |
+
# Skip if empty or too small
|
| 266 |
+
sub_chunk_tokens = self._estimate_tokens(sub_chunk_content)
|
| 267 |
+
if sub_chunk_tokens < self.min_chunk_size:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
chunk_id = self._generate_chunk_id(
|
| 271 |
+
sub_chunk_content, chunk_index, source_url
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
chunk = ChunkNode(
|
| 275 |
+
chunk_id=chunk_id,
|
| 276 |
+
content=sub_chunk_content,
|
| 277 |
+
parent_section=section_header,
|
| 278 |
+
document_title=document_title,
|
| 279 |
+
source_url=source_url,
|
| 280 |
+
chunk_index=chunk_index,
|
| 281 |
+
token_count=sub_chunk_tokens,
|
| 282 |
+
metadata=metadata,
|
| 283 |
+
)
|
| 284 |
+
chunks.append(chunk)
|
| 285 |
+
chunk_index += 1
|
| 286 |
+
|
| 287 |
+
return chunks
|
| 288 |
+
|
| 289 |
+
def chunk_directory(
|
| 290 |
+
self,
|
| 291 |
+
input_dir: Path,
|
| 292 |
+
output_dir: Path,
|
| 293 |
+
pattern: str = "*.md",
|
| 294 |
+
) -> Dict[str, int]:
|
| 295 |
+
"""
|
| 296 |
+
Process all markdown files in a directory.
|
| 297 |
+
|
| 298 |
+
For each .md file, looks for corresponding .json metadata file,
|
| 299 |
+
chunks the document, and saves chunks to output directory.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
input_dir: Directory containing markdown files
|
| 303 |
+
output_dir: Directory to save chunked outputs
|
| 304 |
+
pattern: Glob pattern for files to process (default: "*.md")
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
Dictionary with processing statistics
|
| 308 |
+
"""
|
| 309 |
+
input_dir = Path(input_dir)
|
| 310 |
+
output_dir = Path(output_dir)
|
| 311 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 312 |
+
|
| 313 |
+
# Find all markdown files
|
| 314 |
+
md_files = list(input_dir.glob(pattern))
|
| 315 |
+
|
| 316 |
+
if not md_files:
|
| 317 |
+
self.console.print(f"[yellow]No files matching '{pattern}' found in {input_dir}[/yellow]")
|
| 318 |
+
return {"processed": 0, "failed": 0, "total_chunks": 0}
|
| 319 |
+
|
| 320 |
+
stats = {
|
| 321 |
+
"processed": 0,
|
| 322 |
+
"failed": 0,
|
| 323 |
+
"skipped": 0,
|
| 324 |
+
"total_chunks": 0,
|
| 325 |
+
"total_tokens": 0,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
self.console.print(f"\n[bold cyan]Chunking Documents[/bold cyan]")
|
| 329 |
+
self.console.print(f"Input: {input_dir}")
|
| 330 |
+
self.console.print(f"Output: {output_dir}")
|
| 331 |
+
self.console.print(f"Files found: {len(md_files)}\n")
|
| 332 |
+
|
| 333 |
+
with Progress(
|
| 334 |
+
SpinnerColumn(),
|
| 335 |
+
TextColumn("[progress.description]{task.description}"),
|
| 336 |
+
BarColumn(),
|
| 337 |
+
TaskProgressColumn(),
|
| 338 |
+
console=self.console,
|
| 339 |
+
) as progress:
|
| 340 |
+
|
| 341 |
+
task = progress.add_task(
|
| 342 |
+
"[cyan]Processing...",
|
| 343 |
+
total=len(md_files),
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
for md_file in md_files:
|
| 347 |
+
try:
|
| 348 |
+
# Look for corresponding JSON metadata file
|
| 349 |
+
json_file = md_file.with_suffix(".json")
|
| 350 |
+
|
| 351 |
+
if not json_file.exists():
|
| 352 |
+
self.console.print(
|
| 353 |
+
f"[yellow]Skipping {md_file.name}: No metadata file found[/yellow]"
|
| 354 |
+
)
|
| 355 |
+
stats["skipped"] += 1
|
| 356 |
+
progress.advance(task)
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
# Read markdown content
|
| 360 |
+
with open(md_file, "r", encoding="utf-8") as f:
|
| 361 |
+
markdown_content = f.read()
|
| 362 |
+
|
| 363 |
+
# Read metadata
|
| 364 |
+
with open(json_file, "r", encoding="utf-8") as f:
|
| 365 |
+
metadata = json.load(f)
|
| 366 |
+
|
| 367 |
+
# Skip if markdown is too small
|
| 368 |
+
if self._estimate_tokens(markdown_content) < self.min_chunk_size:
|
| 369 |
+
self.console.print(
|
| 370 |
+
f"[yellow]Skipping {md_file.name}: Content too small[/yellow]"
|
| 371 |
+
)
|
| 372 |
+
stats["skipped"] += 1
|
| 373 |
+
progress.advance(task)
|
| 374 |
+
continue
|
| 375 |
+
|
| 376 |
+
# Chunk the document
|
| 377 |
+
chunks = self.chunk_document(markdown_content, metadata)
|
| 378 |
+
|
| 379 |
+
if not chunks:
|
| 380 |
+
self.console.print(
|
| 381 |
+
f"[yellow]Skipping {md_file.name}: No chunks created[/yellow]"
|
| 382 |
+
)
|
| 383 |
+
stats["skipped"] += 1
|
| 384 |
+
progress.advance(task)
|
| 385 |
+
continue
|
| 386 |
+
|
| 387 |
+
# Save chunks to output file
|
| 388 |
+
output_file = output_dir / f"{md_file.stem}_chunks.json"
|
| 389 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 390 |
+
chunk_dicts = [chunk.to_dict() for chunk in chunks]
|
| 391 |
+
json.dump(chunk_dicts, f, indent=2, ensure_ascii=False)
|
| 392 |
+
|
| 393 |
+
# Update stats
|
| 394 |
+
stats["processed"] += 1
|
| 395 |
+
stats["total_chunks"] += len(chunks)
|
| 396 |
+
stats["total_tokens"] += sum(chunk.token_count for chunk in chunks)
|
| 397 |
+
|
| 398 |
+
progress.update(
|
| 399 |
+
task,
|
| 400 |
+
description=f"[cyan]Processing ({stats['processed']} done, {stats['total_chunks']} chunks): {md_file.name[:40]}...",
|
| 401 |
+
)
|
| 402 |
+
progress.advance(task)
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
self.console.print(f"[red]Error processing {md_file.name}: {e}[/red]")
|
| 406 |
+
stats["failed"] += 1
|
| 407 |
+
progress.advance(task)
|
| 408 |
+
|
| 409 |
+
# Print summary
|
| 410 |
+
self.console.print("\n[bold cyan]Chunking Summary[/bold cyan]")
|
| 411 |
+
self.console.print(f"Files processed: {stats['processed']}")
|
| 412 |
+
self.console.print(f"Files skipped: {stats['skipped']}")
|
| 413 |
+
self.console.print(f"Files failed: {stats['failed']}")
|
| 414 |
+
self.console.print(f"Total chunks created: {stats['total_chunks']}")
|
| 415 |
+
self.console.print(f"Total tokens: {stats['total_tokens']:,}")
|
| 416 |
+
|
| 417 |
+
if stats["processed"] > 0:
|
| 418 |
+
avg_chunks = stats["total_chunks"] / stats["processed"]
|
| 419 |
+
avg_tokens = stats["total_tokens"] / stats["total_chunks"] if stats["total_chunks"] > 0 else 0
|
| 420 |
+
self.console.print(f"Average chunks per document: {avg_chunks:.1f}")
|
| 421 |
+
self.console.print(f"Average tokens per chunk: {avg_tokens:.1f}")
|
| 422 |
+
|
| 423 |
+
return stats
|
src/processing/metadata_extractor.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Medical metadata extractor for EyeWiki articles."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Dict, List, Set
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MetadataExtractor:
|
| 8 |
+
"""
|
| 9 |
+
Extract medical metadata from EyeWiki articles.
|
| 10 |
+
|
| 11 |
+
Extracts:
|
| 12 |
+
- Disease names
|
| 13 |
+
- ICD-10 codes
|
| 14 |
+
- Anatomical structures
|
| 15 |
+
- Symptoms
|
| 16 |
+
- Treatments (medications and procedures)
|
| 17 |
+
- Categories
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Comprehensive list of eye anatomical structures
|
| 21 |
+
ANATOMICAL_STRUCTURES = {
|
| 22 |
+
# Major structures
|
| 23 |
+
"cornea", "corneal", "sclera", "scleral", "retina", "retinal",
|
| 24 |
+
"lens", "crystalline lens", "iris", "iridial", "pupil", "pupillary",
|
| 25 |
+
"choroid", "choroidal", "vitreous", "vitreous humor",
|
| 26 |
+
"optic nerve", "optic disc", "optic cup",
|
| 27 |
+
|
| 28 |
+
# Anterior segment
|
| 29 |
+
"anterior chamber", "posterior chamber", "anterior segment",
|
| 30 |
+
"trabecular meshwork", "schlemm's canal", "ciliary body", "ciliary muscle",
|
| 31 |
+
"zonules", "zonular", "aqueous humor", "aqueous",
|
| 32 |
+
|
| 33 |
+
# Posterior segment
|
| 34 |
+
"posterior segment", "macula", "macular", "fovea", "foveal",
|
| 35 |
+
"retinal pigment epithelium", "rpe", "photoreceptors",
|
| 36 |
+
"rods", "cones", "ganglion cells",
|
| 37 |
+
|
| 38 |
+
# Retinal layers
|
| 39 |
+
"inner limiting membrane", "nerve fiber layer", "ganglion cell layer",
|
| 40 |
+
"inner plexiform layer", "inner nuclear layer", "outer plexiform layer",
|
| 41 |
+
"outer nuclear layer", "external limiting membrane",
|
| 42 |
+
"photoreceptor layer", "bruch's membrane",
|
| 43 |
+
|
| 44 |
+
# Extraocular
|
| 45 |
+
"eyelid", "eyelids", "conjunctiva", "conjunctival",
|
| 46 |
+
"lacrimal gland", "tear film", "meibomian glands",
|
| 47 |
+
"extraocular muscles", "rectus muscle", "oblique muscle",
|
| 48 |
+
"orbit", "orbital", "optic chiasm",
|
| 49 |
+
|
| 50 |
+
# Blood vessels
|
| 51 |
+
"central retinal artery", "central retinal vein",
|
| 52 |
+
"retinal vessels", "vascular", "vasculature",
|
| 53 |
+
"choriocapillaris",
|
| 54 |
+
|
| 55 |
+
# Angles and spaces
|
| 56 |
+
"angle", "iridocorneal angle", "suprachoroidal space",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Common ophthalmic medications
|
| 60 |
+
MEDICATIONS = {
|
| 61 |
+
# Glaucoma medications
|
| 62 |
+
"latanoprost", "timolol", "dorzolamide", "brinzolamide",
|
| 63 |
+
"brimonidine", "apraclonidine", "bimatoprost", "travoprost",
|
| 64 |
+
"tafluprost", "pilocarpine", "carbachol",
|
| 65 |
+
"acetazolamide", "methazolamide",
|
| 66 |
+
|
| 67 |
+
# Anti-VEGF agents
|
| 68 |
+
"bevacizumab", "ranibizumab", "aflibercept", "brolucizumab",
|
| 69 |
+
"pegaptanib", "faricimab",
|
| 70 |
+
|
| 71 |
+
# Steroids
|
| 72 |
+
"prednisolone", "dexamethasone", "triamcinolone", "fluocinolone",
|
| 73 |
+
"difluprednate", "fluorometholone", "loteprednol",
|
| 74 |
+
"betamethasone", "hydrocortisone",
|
| 75 |
+
|
| 76 |
+
# Antibiotics
|
| 77 |
+
"moxifloxacin", "gatifloxacin", "ciprofloxacin", "ofloxacin",
|
| 78 |
+
"levofloxacin", "tobramycin", "gentamicin", "erythromycin",
|
| 79 |
+
"azithromycin", "bacitracin", "polymyxin", "neomycin",
|
| 80 |
+
"vancomycin", "ceftazidime", "cefazolin",
|
| 81 |
+
|
| 82 |
+
# Antivirals
|
| 83 |
+
"acyclovir", "ganciclovir", "valganciclovir", "valacyclovir",
|
| 84 |
+
"trifluridine", "foscarnet",
|
| 85 |
+
|
| 86 |
+
# Anti-inflammatory
|
| 87 |
+
"ketorolac", "diclofenac", "nepafenac", "bromfenac",
|
| 88 |
+
"cyclosporine", "tacrolimus", "lifitegrast",
|
| 89 |
+
|
| 90 |
+
# Mydriatics/Cycloplegics
|
| 91 |
+
"tropicamide", "cyclopentolate", "atropine", "homatropine",
|
| 92 |
+
"phenylephrine",
|
| 93 |
+
|
| 94 |
+
# Other
|
| 95 |
+
"mitomycin", "5-fluorouracil", "interferon",
|
| 96 |
+
"methotrexate", "chlorambucil",
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# Common ophthalmic procedures
|
| 100 |
+
PROCEDURES = {
|
| 101 |
+
# Cataract surgery
|
| 102 |
+
"phacoemulsification", "phaco", "cataract extraction",
|
| 103 |
+
"extracapsular cataract extraction", "ecce",
|
| 104 |
+
"intracapsular cataract extraction", "icce",
|
| 105 |
+
"iol implantation", "intraocular lens",
|
| 106 |
+
|
| 107 |
+
# Glaucoma procedures
|
| 108 |
+
"trabeculectomy", "tube shunt", "glaucoma drainage device",
|
| 109 |
+
"ahmed valve", "baerveldt implant", "molteno implant",
|
| 110 |
+
"selective laser trabeculoplasty", "slt", "argon laser trabeculoplasty", "alt",
|
| 111 |
+
"laser peripheral iridotomy", "lpi", "iridotomy",
|
| 112 |
+
"cyclophotocoagulation", "cyclocryotherapy",
|
| 113 |
+
"minimally invasive glaucoma surgery", "migs",
|
| 114 |
+
"trabectome", "istent", "kahook dual blade", "goniotomy",
|
| 115 |
+
|
| 116 |
+
# Retinal procedures
|
| 117 |
+
"vitrectomy", "pars plana vitrectomy", "ppv",
|
| 118 |
+
"membrane peeling", "epiretinal membrane peeling",
|
| 119 |
+
"endolaser", "photocoagulation", "panretinal photocoagulation", "prp",
|
| 120 |
+
"focal laser", "grid laser",
|
| 121 |
+
"pneumatic retinopexy", "scleral buckle",
|
| 122 |
+
"silicone oil", "gas tamponade", "c3f8", "sf6",
|
| 123 |
+
|
| 124 |
+
# Corneal procedures
|
| 125 |
+
"penetrating keratoplasty", "pkp", "corneal transplant",
|
| 126 |
+
"descemet stripping endothelial keratoplasty", "dsek", "dsaek",
|
| 127 |
+
"descemet membrane endothelial keratoplasty", "dmek",
|
| 128 |
+
"deep anterior lamellar keratoplasty", "dalk",
|
| 129 |
+
"phototherapeutic keratectomy", "ptk",
|
| 130 |
+
"corneal crosslinking", "cxl",
|
| 131 |
+
|
| 132 |
+
# Refractive surgery
|
| 133 |
+
"lasik", "prk", "photorefractive keratectomy",
|
| 134 |
+
"smile", "lasek", "refractive lens exchange",
|
| 135 |
+
"phakic iol", "icl",
|
| 136 |
+
|
| 137 |
+
# Injections
|
| 138 |
+
"intravitreal injection", "intravitreal",
|
| 139 |
+
"subtenon injection", "retrobulbar block", "peribulbar block",
|
| 140 |
+
|
| 141 |
+
# Laser procedures
|
| 142 |
+
"yag laser capsulotomy", "laser capsulotomy",
|
| 143 |
+
"laser iridotomy", "laser trabeculoplasty",
|
| 144 |
+
|
| 145 |
+
# Other
|
| 146 |
+
"enucleation", "evisceration", "exenteration",
|
| 147 |
+
"orbital decompression", "ptosis repair", "blepharoplasty",
|
| 148 |
+
"dacryocystorhinostomy", "dcr",
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# Common ophthalmic symptoms
|
| 152 |
+
SYMPTOMS = {
|
| 153 |
+
# Visual symptoms
|
| 154 |
+
"blurred vision", "blurring", "vision loss", "visual loss",
|
| 155 |
+
"decreased vision", "blindness", "blind spot",
|
| 156 |
+
"photophobia", "light sensitivity", "glare", "halos",
|
| 157 |
+
"diplopia", "double vision", "metamorphopsia", "distortion",
|
| 158 |
+
"scotoma", "floaters", "flashes", "photopsia",
|
| 159 |
+
"night blindness", "nyctalopia", "color vision defect",
|
| 160 |
+
"visual field defect", "peripheral vision loss",
|
| 161 |
+
|
| 162 |
+
# Pain and discomfort
|
| 163 |
+
"eye pain", "ocular pain", "pain", "foreign body sensation",
|
| 164 |
+
"irritation", "burning", "stinging", "grittiness",
|
| 165 |
+
"discomfort", "ache", "headache",
|
| 166 |
+
|
| 167 |
+
# Discharge and tearing
|
| 168 |
+
"discharge", "tearing", "epiphora", "watery eyes",
|
| 169 |
+
"mucus", "crusting", "mattering",
|
| 170 |
+
|
| 171 |
+
# Redness and inflammation
|
| 172 |
+
"redness", "red eye", "injection", "hyperemia",
|
| 173 |
+
"swelling", "edema", "chemosis", "inflammation",
|
| 174 |
+
|
| 175 |
+
# Other
|
| 176 |
+
"itching", "pruritus", "dryness", "dry eye",
|
| 177 |
+
"eye strain", "asthenopia", "fatigue",
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def __init__(self):
|
| 181 |
+
"""Initialize the metadata extractor."""
|
| 182 |
+
# Compile regex patterns for efficiency
|
| 183 |
+
self.icd_pattern = re.compile(
|
| 184 |
+
r'\b[A-Z]\d{2}(?:\.\d{1,2})?\b|' # ICD-10: H40.1, H35.32, etc.
|
| 185 |
+
r'\b[H][0-5]\d(?:\.\d{1,3})?\b' # Ophthalmic ICD-10 (H00-H59)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def extract_icd_codes(self, text: str) -> List[str]:
|
| 189 |
+
"""
|
| 190 |
+
Extract ICD-10 codes from text using regex.
|
| 191 |
+
|
| 192 |
+
Patterns matched:
|
| 193 |
+
- Standard ICD-10: H40.1, H35.32, etc.
|
| 194 |
+
- Ophthalmic codes: H00-H59 range
|
| 195 |
+
- Generic codes: A00, B99.9, etc.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
text: Input text to search
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
List of unique ICD-10 codes found
|
| 202 |
+
"""
|
| 203 |
+
codes = self.icd_pattern.findall(text)
|
| 204 |
+
|
| 205 |
+
# Filter to valid ophthalmic codes (H00-H59) and deduplicate
|
| 206 |
+
valid_codes = set()
|
| 207 |
+
for code in codes:
|
| 208 |
+
# Prioritize H codes (ophthalmic)
|
| 209 |
+
if code.startswith('H'):
|
| 210 |
+
# Validate H00-H59 range
|
| 211 |
+
try:
|
| 212 |
+
main_code = int(code[1:3])
|
| 213 |
+
if 0 <= main_code <= 59:
|
| 214 |
+
valid_codes.add(code)
|
| 215 |
+
except (ValueError, IndexError):
|
| 216 |
+
continue
|
| 217 |
+
else:
|
| 218 |
+
# Keep other valid ICD-10 codes
|
| 219 |
+
valid_codes.add(code)
|
| 220 |
+
|
| 221 |
+
return sorted(list(valid_codes))
|
| 222 |
+
|
| 223 |
+
def extract_anatomical_terms(self, text: str) -> List[str]:
|
| 224 |
+
"""
|
| 225 |
+
Extract anatomical structure mentions from text.
|
| 226 |
+
|
| 227 |
+
Uses case-insensitive pattern matching against predefined
|
| 228 |
+
anatomical structure vocabulary.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
text: Input text to search
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
List of unique anatomical structures found
|
| 235 |
+
"""
|
| 236 |
+
text_lower = text.lower()
|
| 237 |
+
found_structures = set()
|
| 238 |
+
|
| 239 |
+
for structure in self.ANATOMICAL_STRUCTURES:
|
| 240 |
+
# Use word boundaries to avoid partial matches
|
| 241 |
+
pattern = r'\b' + re.escape(structure) + r's?\b' # Allow plural
|
| 242 |
+
if re.search(pattern, text_lower):
|
| 243 |
+
found_structures.add(structure)
|
| 244 |
+
|
| 245 |
+
return sorted(list(found_structures))
|
| 246 |
+
|
| 247 |
+
def extract_medications(self, text: str) -> List[str]:
|
| 248 |
+
"""
|
| 249 |
+
Extract medication mentions from text.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
text: Input text to search
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
List of unique medications found
|
| 256 |
+
"""
|
| 257 |
+
text_lower = text.lower()
|
| 258 |
+
found_medications = set()
|
| 259 |
+
|
| 260 |
+
for medication in self.MEDICATIONS:
|
| 261 |
+
# Use word boundaries to avoid partial matches
|
| 262 |
+
pattern = r'\b' + re.escape(medication) + r'\b'
|
| 263 |
+
if re.search(pattern, text_lower):
|
| 264 |
+
found_medications.add(medication)
|
| 265 |
+
|
| 266 |
+
return sorted(list(found_medications))
|
| 267 |
+
|
| 268 |
+
def extract_procedures(self, text: str) -> List[str]:
|
| 269 |
+
"""
|
| 270 |
+
Extract procedure mentions from text.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
text: Input text to search
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
List of unique procedures found
|
| 277 |
+
"""
|
| 278 |
+
text_lower = text.lower()
|
| 279 |
+
found_procedures = set()
|
| 280 |
+
|
| 281 |
+
for procedure in self.PROCEDURES:
|
| 282 |
+
# Use word boundaries to avoid partial matches
|
| 283 |
+
pattern = r'\b' + re.escape(procedure) + r'\b'
|
| 284 |
+
if re.search(pattern, text_lower):
|
| 285 |
+
found_procedures.add(procedure)
|
| 286 |
+
|
| 287 |
+
return sorted(list(found_procedures))
|
| 288 |
+
|
| 289 |
+
def extract_symptoms(self, text: str) -> List[str]:
|
| 290 |
+
"""
|
| 291 |
+
Extract symptom mentions from text.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
text: Input text to search
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
List of unique symptoms found
|
| 298 |
+
"""
|
| 299 |
+
text_lower = text.lower()
|
| 300 |
+
found_symptoms = set()
|
| 301 |
+
|
| 302 |
+
for symptom in self.SYMPTOMS:
|
| 303 |
+
# Use word boundaries for multi-word symptoms
|
| 304 |
+
pattern = r'\b' + re.escape(symptom) + r'\b'
|
| 305 |
+
if re.search(pattern, text_lower):
|
| 306 |
+
found_symptoms.add(symptom)
|
| 307 |
+
|
| 308 |
+
return sorted(list(found_symptoms))
|
| 309 |
+
|
| 310 |
+
def extract_disease_name(self, existing_metadata: Dict) -> str:
|
| 311 |
+
"""
|
| 312 |
+
Extract primary disease name from metadata.
|
| 313 |
+
|
| 314 |
+
Tries multiple sources:
|
| 315 |
+
1. Article title
|
| 316 |
+
2. First category
|
| 317 |
+
3. URL path
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
existing_metadata: Metadata dict with 'title', 'url', 'categories'
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Primary disease/condition name
|
| 324 |
+
"""
|
| 325 |
+
# Try title first
|
| 326 |
+
title = existing_metadata.get("title", "")
|
| 327 |
+
if title:
|
| 328 |
+
# Clean title - remove common prefixes
|
| 329 |
+
cleaned = re.sub(r'^(Disease|Condition|Syndrome):\s*', '', title, flags=re.IGNORECASE)
|
| 330 |
+
return cleaned.strip()
|
| 331 |
+
|
| 332 |
+
# Try first category
|
| 333 |
+
categories = existing_metadata.get("categories", [])
|
| 334 |
+
if categories and len(categories) > 0:
|
| 335 |
+
return categories[0].strip()
|
| 336 |
+
|
| 337 |
+
# Try URL path as fallback
|
| 338 |
+
url = existing_metadata.get("url", "")
|
| 339 |
+
if url:
|
| 340 |
+
# Extract last part of URL path
|
| 341 |
+
match = re.search(r'/([^/]+)$', url)
|
| 342 |
+
if match:
|
| 343 |
+
# Replace underscores with spaces
|
| 344 |
+
name = match.group(1).replace('_', ' ')
|
| 345 |
+
return name.strip()
|
| 346 |
+
|
| 347 |
+
return "Unknown"
|
| 348 |
+
|
| 349 |
+
def extract(self, content: str, existing_metadata: Dict) -> Dict:
|
| 350 |
+
"""
|
| 351 |
+
Extract comprehensive medical metadata from article content.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
content: Article text content (markdown)
|
| 355 |
+
existing_metadata: Existing metadata dict with basic info
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
Enhanced metadata dictionary with medical information
|
| 359 |
+
"""
|
| 360 |
+
# Start with existing metadata
|
| 361 |
+
enhanced_metadata = existing_metadata.copy()
|
| 362 |
+
|
| 363 |
+
# Extract disease name
|
| 364 |
+
enhanced_metadata["disease_name"] = self.extract_disease_name(existing_metadata)
|
| 365 |
+
|
| 366 |
+
# Extract ICD codes
|
| 367 |
+
enhanced_metadata["icd_codes"] = self.extract_icd_codes(content)
|
| 368 |
+
|
| 369 |
+
# Extract anatomical structures
|
| 370 |
+
enhanced_metadata["anatomical_structures"] = self.extract_anatomical_terms(content)
|
| 371 |
+
|
| 372 |
+
# Extract symptoms
|
| 373 |
+
enhanced_metadata["symptoms"] = self.extract_symptoms(content)
|
| 374 |
+
|
| 375 |
+
# Extract treatments
|
| 376 |
+
medications = self.extract_medications(content)
|
| 377 |
+
procedures = self.extract_procedures(content)
|
| 378 |
+
enhanced_metadata["treatments"] = {
|
| 379 |
+
"medications": medications,
|
| 380 |
+
"procedures": procedures,
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
# Preserve existing categories
|
| 384 |
+
if "categories" not in enhanced_metadata:
|
| 385 |
+
enhanced_metadata["categories"] = []
|
| 386 |
+
|
| 387 |
+
# Add extraction statistics
|
| 388 |
+
enhanced_metadata["extraction_stats"] = {
|
| 389 |
+
"icd_codes_found": len(enhanced_metadata["icd_codes"]),
|
| 390 |
+
"anatomical_terms_found": len(enhanced_metadata["anatomical_structures"]),
|
| 391 |
+
"symptoms_found": len(enhanced_metadata["symptoms"]),
|
| 392 |
+
"medications_found": len(medications),
|
| 393 |
+
"procedures_found": len(procedures),
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
return enhanced_metadata
|
| 397 |
+
|
| 398 |
+
def extract_batch(self, documents: List[Dict]) -> List[Dict]:
|
| 399 |
+
"""
|
| 400 |
+
Extract metadata from multiple documents.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
documents: List of dicts with 'content' and 'metadata' keys
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
List of enhanced metadata dictionaries
|
| 407 |
+
"""
|
| 408 |
+
results = []
|
| 409 |
+
|
| 410 |
+
for doc in documents:
|
| 411 |
+
content = doc.get("content", "")
|
| 412 |
+
metadata = doc.get("metadata", {})
|
| 413 |
+
|
| 414 |
+
enhanced = self.extract(content, metadata)
|
| 415 |
+
results.append(enhanced)
|
| 416 |
+
|
| 417 |
+
return results
|
| 418 |
+
|
| 419 |
+
def get_anatomical_vocabulary(self) -> Set[str]:
|
| 420 |
+
"""Get the full anatomical vocabulary set."""
|
| 421 |
+
return self.ANATOMICAL_STRUCTURES.copy()
|
| 422 |
+
|
| 423 |
+
def get_medication_vocabulary(self) -> Set[str]:
|
| 424 |
+
"""Get the full medication vocabulary set."""
|
| 425 |
+
return self.MEDICATIONS.copy()
|
| 426 |
+
|
| 427 |
+
def get_procedure_vocabulary(self) -> Set[str]:
|
| 428 |
+
"""Get the full procedure vocabulary set."""
|
| 429 |
+
return self.PROCEDURES.copy()
|
| 430 |
+
|
| 431 |
+
def get_symptom_vocabulary(self) -> Set[str]:
|
| 432 |
+
"""Get the full symptom vocabulary set."""
|
| 433 |
+
return self.SYMPTOMS.copy()
|
src/rag/__init__.py
ADDED
|
File without changes
|
src/rag/query_engine.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Query engine orchestrating the full RAG pipeline."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Generator, List, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
from rich.console import Console
|
| 9 |
+
|
| 10 |
+
from src.rag.retriever import HybridRetriever, RetrievalResult
|
| 11 |
+
from src.rag.reranker import CrossEncoderReranker
|
| 12 |
+
from src.llm.llm_client import LLMClient
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Medical disclaimer (default)
|
| 21 |
+
MEDICAL_DISCLAIMER = (
|
| 22 |
+
"**Medical Disclaimer:** This information is sourced from EyeWiki, a resource of the "
|
| 23 |
+
"American Academy of Ophthalmology (AAO). It is not a substitute for professional "
|
| 24 |
+
"medical advice, diagnosis, or treatment. AI systems can make errors. Always consult "
|
| 25 |
+
"with a qualified ophthalmologist or eye care professional for medical concerns and "
|
| 26 |
+
"verify any critical information with authoritative sources."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Default system prompt
|
| 30 |
+
DEFAULT_SYSTEM_PROMPT = """You are an expert ophthalmology assistant with comprehensive knowledge of eye diseases, treatments, and procedures.
|
| 31 |
+
|
| 32 |
+
Your role is to provide accurate, evidence-based information from the EyeWiki medical knowledge base.
|
| 33 |
+
|
| 34 |
+
Guidelines:
|
| 35 |
+
- Base your answers strictly on the provided context
|
| 36 |
+
- Cite sources using [Source: Title] format when referencing information
|
| 37 |
+
- If the context doesn't contain enough information, say so explicitly
|
| 38 |
+
- Use clear, precise medical terminology while remaining accessible
|
| 39 |
+
- Structure your responses logically with appropriate sections
|
| 40 |
+
- For treatment information, emphasize the importance of professional consultation
|
| 41 |
+
- Always maintain professional medical standards"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SourceInfo(BaseModel):
|
| 45 |
+
"""
|
| 46 |
+
Information about a source document.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
title: Document title
|
| 50 |
+
url: Source URL
|
| 51 |
+
section: Section within document
|
| 52 |
+
relevance_score: Relevance score (cross-encoder scores, unbounded)
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
title: str = Field(..., description="Document title")
|
| 56 |
+
url: str = Field(..., description="Source URL")
|
| 57 |
+
section: str = Field(default="", description="Section within document")
|
| 58 |
+
relevance_score: float = Field(..., description="Relevance score (cross-encoder, unbounded)")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class QueryResponse(BaseModel):
|
| 62 |
+
"""
|
| 63 |
+
Response from query engine.
|
| 64 |
+
|
| 65 |
+
Attributes:
|
| 66 |
+
answer: Generated answer text
|
| 67 |
+
sources: List of source documents used
|
| 68 |
+
confidence: Confidence score based on retrieval
|
| 69 |
+
disclaimer: Medical disclaimer text
|
| 70 |
+
query: Original query
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
answer: str = Field(..., description="Generated answer")
|
| 74 |
+
sources: List[SourceInfo] = Field(default_factory=list, description="Source documents")
|
| 75 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
| 76 |
+
disclaimer: str = Field(default=MEDICAL_DISCLAIMER, description="Medical disclaimer")
|
| 77 |
+
query: str = Field(..., description="Original query")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class EyeWikiQueryEngine:
|
| 81 |
+
"""
|
| 82 |
+
Query engine orchestrating the full RAG pipeline.
|
| 83 |
+
|
| 84 |
+
Pipeline:
|
| 85 |
+
1. Query � Retriever (hybrid search)
|
| 86 |
+
2. Results � Reranker (cross-encoder)
|
| 87 |
+
3. Top results � Context assembly
|
| 88 |
+
4. Context + Query � LLM generation
|
| 89 |
+
5. Response + Sources + Disclaimer
|
| 90 |
+
|
| 91 |
+
Features:
|
| 92 |
+
- Two-stage retrieval (fast + precise)
|
| 93 |
+
- Context assembly with token limits
|
| 94 |
+
- Source diversity prioritization
|
| 95 |
+
- Medical disclaimer inclusion
|
| 96 |
+
- Streaming and non-streaming modes
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
retriever: HybridRetriever,
|
| 102 |
+
reranker: CrossEncoderReranker,
|
| 103 |
+
llm_client: LLMClient,
|
| 104 |
+
system_prompt_path: Optional[Path] = None,
|
| 105 |
+
query_prompt_path: Optional[Path] = None,
|
| 106 |
+
disclaimer_path: Optional[Path] = None,
|
| 107 |
+
max_context_tokens: int = 4000,
|
| 108 |
+
retrieval_k: int = 20,
|
| 109 |
+
rerank_k: int = 5,
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
Initialize query engine.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
retriever: HybridRetriever instance
|
| 116 |
+
reranker: CrossEncoderReranker instance
|
| 117 |
+
llm_client: LLMClient instance (OllamaClient or OpenAIClient)
|
| 118 |
+
system_prompt_path: Path to custom system prompt file
|
| 119 |
+
query_prompt_path: Path to custom query prompt template
|
| 120 |
+
disclaimer_path: Path to custom medical disclaimer file
|
| 121 |
+
max_context_tokens: Maximum tokens for context
|
| 122 |
+
retrieval_k: Number of documents to retrieve initially
|
| 123 |
+
rerank_k: Number of documents after reranking
|
| 124 |
+
"""
|
| 125 |
+
self.retriever = retriever
|
| 126 |
+
self.reranker = reranker
|
| 127 |
+
self.llm_client = llm_client
|
| 128 |
+
self.max_context_tokens = max_context_tokens
|
| 129 |
+
self.retrieval_k = retrieval_k
|
| 130 |
+
self.rerank_k = rerank_k
|
| 131 |
+
|
| 132 |
+
self.console = Console()
|
| 133 |
+
|
| 134 |
+
# Load system prompt
|
| 135 |
+
if system_prompt_path and system_prompt_path.exists():
|
| 136 |
+
with open(system_prompt_path, "r") as f:
|
| 137 |
+
self.system_prompt = f.read()
|
| 138 |
+
logger.info(f"Loaded system prompt from {system_prompt_path}")
|
| 139 |
+
else:
|
| 140 |
+
self.system_prompt = DEFAULT_SYSTEM_PROMPT
|
| 141 |
+
logger.info("Using default system prompt")
|
| 142 |
+
|
| 143 |
+
# Load query prompt template
|
| 144 |
+
if query_prompt_path and query_prompt_path.exists():
|
| 145 |
+
with open(query_prompt_path, "r") as f:
|
| 146 |
+
self.query_prompt_template = f.read()
|
| 147 |
+
logger.info(f"Loaded query prompt from {query_prompt_path}")
|
| 148 |
+
else:
|
| 149 |
+
self.query_prompt_template = None
|
| 150 |
+
logger.info("Using inline query prompt formatting")
|
| 151 |
+
|
| 152 |
+
# Load medical disclaimer
|
| 153 |
+
if disclaimer_path and disclaimer_path.exists():
|
| 154 |
+
with open(disclaimer_path, "r") as f:
|
| 155 |
+
self.medical_disclaimer = f.read().strip()
|
| 156 |
+
logger.info(f"Loaded medical disclaimer from {disclaimer_path}")
|
| 157 |
+
else:
|
| 158 |
+
self.medical_disclaimer = MEDICAL_DISCLAIMER
|
| 159 |
+
logger.info("Using default medical disclaimer")
|
| 160 |
+
|
| 161 |
+
def _estimate_tokens(self, text: str) -> int:
|
| 162 |
+
"""
|
| 163 |
+
Estimate token count for text.
|
| 164 |
+
|
| 165 |
+
Uses simple heuristic: ~4 characters per token.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
text: Input text
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Estimated token count
|
| 172 |
+
"""
|
| 173 |
+
return len(text) // 4
|
| 174 |
+
|
| 175 |
+
def _prioritize_diverse_sources(
|
| 176 |
+
self, results: List[RetrievalResult]
|
| 177 |
+
) -> List[RetrievalResult]:
|
| 178 |
+
"""
|
| 179 |
+
Prioritize results from diverse sources.
|
| 180 |
+
|
| 181 |
+
Ensures we don't just get multiple chunks from the same article.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
results: Sorted list of retrieval results
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Reordered list prioritizing diversity
|
| 188 |
+
"""
|
| 189 |
+
seen_documents = set()
|
| 190 |
+
diverse_results = []
|
| 191 |
+
remaining_results = []
|
| 192 |
+
|
| 193 |
+
# First pass: one chunk per document
|
| 194 |
+
for result in results:
|
| 195 |
+
doc_title = result.document_title
|
| 196 |
+
if doc_title not in seen_documents:
|
| 197 |
+
diverse_results.append(result)
|
| 198 |
+
seen_documents.add(doc_title)
|
| 199 |
+
else:
|
| 200 |
+
remaining_results.append(result)
|
| 201 |
+
|
| 202 |
+
# Second pass: add remaining high-scoring chunks
|
| 203 |
+
diverse_results.extend(remaining_results)
|
| 204 |
+
|
| 205 |
+
return diverse_results
|
| 206 |
+
|
| 207 |
+
def _assemble_context(self, results: List[RetrievalResult]) -> str:
|
| 208 |
+
"""
|
| 209 |
+
Assemble context from retrieval results.
|
| 210 |
+
|
| 211 |
+
Features:
|
| 212 |
+
- Formats with section headers
|
| 213 |
+
- Limits to max_context_tokens
|
| 214 |
+
- Prioritizes diverse sources
|
| 215 |
+
- Includes source citations
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
results: List of retrieval results
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Formatted context string
|
| 222 |
+
"""
|
| 223 |
+
if not results:
|
| 224 |
+
return ""
|
| 225 |
+
|
| 226 |
+
# Prioritize diversity
|
| 227 |
+
diverse_results = self._prioritize_diverse_sources(results)
|
| 228 |
+
|
| 229 |
+
context_parts = []
|
| 230 |
+
total_tokens = 0
|
| 231 |
+
|
| 232 |
+
for i, result in enumerate(diverse_results, 1):
|
| 233 |
+
# Format context chunk
|
| 234 |
+
chunk_text = f"[Source {i}: {result.document_title}"
|
| 235 |
+
if result.section:
|
| 236 |
+
chunk_text += f" - {result.section}"
|
| 237 |
+
chunk_text += f"]\n{result.content}\n"
|
| 238 |
+
|
| 239 |
+
# Check token limit
|
| 240 |
+
chunk_tokens = self._estimate_tokens(chunk_text)
|
| 241 |
+
|
| 242 |
+
if total_tokens + chunk_tokens > self.max_context_tokens:
|
| 243 |
+
logger.info(
|
| 244 |
+
f"Reached context token limit ({self.max_context_tokens}), "
|
| 245 |
+
f"using {i-1} of {len(diverse_results)} chunks"
|
| 246 |
+
)
|
| 247 |
+
break
|
| 248 |
+
|
| 249 |
+
context_parts.append(chunk_text)
|
| 250 |
+
total_tokens += chunk_tokens
|
| 251 |
+
|
| 252 |
+
context = "\n".join(context_parts)
|
| 253 |
+
|
| 254 |
+
logger.info(
|
| 255 |
+
f"Assembled context: {len(context_parts)} chunks, "
|
| 256 |
+
f"~{total_tokens} tokens"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return context
|
| 260 |
+
|
| 261 |
+
def _extract_sources(self, results: List[RetrievalResult]) -> List[SourceInfo]:
|
| 262 |
+
"""
|
| 263 |
+
Extract source information from results.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
results: List of retrieval results
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
List of SourceInfo objects
|
| 270 |
+
"""
|
| 271 |
+
sources = []
|
| 272 |
+
seen_titles = set()
|
| 273 |
+
|
| 274 |
+
for result in results:
|
| 275 |
+
# Deduplicate by title
|
| 276 |
+
if result.document_title not in seen_titles:
|
| 277 |
+
source = SourceInfo(
|
| 278 |
+
title=result.document_title,
|
| 279 |
+
url=result.source_url,
|
| 280 |
+
section=result.section,
|
| 281 |
+
relevance_score=result.score,
|
| 282 |
+
)
|
| 283 |
+
sources.append(source)
|
| 284 |
+
seen_titles.add(result.document_title)
|
| 285 |
+
|
| 286 |
+
return sources
|
| 287 |
+
|
| 288 |
+
def _calculate_confidence(self, results: List[RetrievalResult]) -> float:
|
| 289 |
+
"""
|
| 290 |
+
Calculate confidence score based on retrieval scores.
|
| 291 |
+
|
| 292 |
+
Uses average of top reranked scores.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
results: List of retrieval results
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Confidence score (0-1)
|
| 299 |
+
"""
|
| 300 |
+
if not results:
|
| 301 |
+
return 0.0
|
| 302 |
+
|
| 303 |
+
# Use average of top scores
|
| 304 |
+
top_scores = [r.score for r in results[:self.rerank_k]]
|
| 305 |
+
|
| 306 |
+
if not top_scores:
|
| 307 |
+
return 0.0
|
| 308 |
+
|
| 309 |
+
avg_score = sum(top_scores) / len(top_scores)
|
| 310 |
+
|
| 311 |
+
# Normalize to 0-1 range (assuming scores are roughly 0-1)
|
| 312 |
+
confidence = min(max(avg_score, 0.0), 1.0)
|
| 313 |
+
|
| 314 |
+
return confidence
|
| 315 |
+
|
| 316 |
+
def _format_prompt(self, query: str, context: str) -> str:
|
| 317 |
+
"""
|
| 318 |
+
Format the prompt for LLM.
|
| 319 |
+
|
| 320 |
+
Uses query_prompt_template if loaded, otherwise uses default format.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
query: User query
|
| 324 |
+
context: Assembled context
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
Formatted prompt
|
| 328 |
+
"""
|
| 329 |
+
if self.query_prompt_template:
|
| 330 |
+
# Use template with placeholders
|
| 331 |
+
prompt = self.query_prompt_template.format(
|
| 332 |
+
context=context,
|
| 333 |
+
question=query
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
# Default inline formatting
|
| 337 |
+
prompt = f"""Context from EyeWiki medical knowledge base:
|
| 338 |
+
|
| 339 |
+
{context}
|
| 340 |
+
|
| 341 |
+
---
|
| 342 |
+
|
| 343 |
+
Question: {query}
|
| 344 |
+
|
| 345 |
+
Please provide a comprehensive answer based on the context above. Structure your response clearly and cite sources where appropriate."""
|
| 346 |
+
|
| 347 |
+
return prompt
|
| 348 |
+
|
| 349 |
+
def query(
|
| 350 |
+
self,
|
| 351 |
+
question: str,
|
| 352 |
+
include_sources: bool = True,
|
| 353 |
+
filters: Optional[dict] = None,
|
| 354 |
+
) -> QueryResponse:
|
| 355 |
+
"""
|
| 356 |
+
Query the engine and get response.
|
| 357 |
+
|
| 358 |
+
Pipeline:
|
| 359 |
+
1. Retrieve documents (retrieval_k)
|
| 360 |
+
2. Rerank with cross-encoder (rerank_k)
|
| 361 |
+
3. Assemble context with token limits
|
| 362 |
+
4. Generate answer with LLM
|
| 363 |
+
5. Return response with sources and disclaimer
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
question: User question
|
| 367 |
+
include_sources: Include source information in response
|
| 368 |
+
filters: Optional metadata filters for retrieval
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
QueryResponse object
|
| 372 |
+
"""
|
| 373 |
+
logger.info(f"Processing query: '{question}'")
|
| 374 |
+
|
| 375 |
+
# Step 1: Retrieve documents
|
| 376 |
+
logger.info(f"Retrieving top {self.retrieval_k} candidates...")
|
| 377 |
+
retrieval_results = self.retriever.retrieve(
|
| 378 |
+
query=question,
|
| 379 |
+
top_k=self.retrieval_k,
|
| 380 |
+
filters=filters,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if not retrieval_results:
|
| 384 |
+
logger.warning("No results found for query")
|
| 385 |
+
return QueryResponse(
|
| 386 |
+
answer="I couldn't find relevant information to answer this question in the EyeWiki knowledge base.",
|
| 387 |
+
sources=[],
|
| 388 |
+
confidence=0.0,
|
| 389 |
+
query=question,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Step 2: Rerank for precision
|
| 393 |
+
logger.info(f"Reranking to top {self.rerank_k}...")
|
| 394 |
+
reranked_results = self.reranker.rerank(
|
| 395 |
+
query=question,
|
| 396 |
+
documents=retrieval_results,
|
| 397 |
+
top_k=self.rerank_k,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Step 3: Assemble context
|
| 401 |
+
context = self._assemble_context(reranked_results)
|
| 402 |
+
|
| 403 |
+
# Step 4: Generate answer
|
| 404 |
+
logger.info("Generating answer with LLM...")
|
| 405 |
+
prompt = self._format_prompt(question, context)
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
answer = self.llm_client.generate(
|
| 409 |
+
prompt=prompt,
|
| 410 |
+
system_prompt=self.system_prompt,
|
| 411 |
+
temperature=0.1, # Low temperature for factual responses
|
| 412 |
+
)
|
| 413 |
+
except Exception as e:
|
| 414 |
+
logger.error(f"Error generating answer: {e}")
|
| 415 |
+
answer = (
|
| 416 |
+
"I encountered an error while generating the answer. "
|
| 417 |
+
"Please try again or rephrase your question."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Step 5: Extract sources
|
| 421 |
+
sources = self._extract_sources(reranked_results) if include_sources else []
|
| 422 |
+
|
| 423 |
+
# Step 6: Calculate confidence
|
| 424 |
+
confidence = self._calculate_confidence(reranked_results)
|
| 425 |
+
|
| 426 |
+
# Create response
|
| 427 |
+
response = QueryResponse(
|
| 428 |
+
answer=answer,
|
| 429 |
+
sources=sources,
|
| 430 |
+
confidence=confidence,
|
| 431 |
+
query=question,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
logger.info(
|
| 435 |
+
f"Query complete: {len(sources)} sources, "
|
| 436 |
+
f"confidence: {confidence:.2f}"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
return response
|
| 440 |
+
|
| 441 |
+
def stream_query(
|
| 442 |
+
self,
|
| 443 |
+
question: str,
|
| 444 |
+
filters: Optional[dict] = None,
|
| 445 |
+
) -> Generator[str, None, None]:
|
| 446 |
+
"""
|
| 447 |
+
Query with streaming response.
|
| 448 |
+
|
| 449 |
+
Yields answer chunks in real-time.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
question: User question
|
| 453 |
+
filters: Optional metadata filters
|
| 454 |
+
|
| 455 |
+
Yields:
|
| 456 |
+
Answer chunks as they are generated
|
| 457 |
+
"""
|
| 458 |
+
logger.info(f"Processing streaming query: '{question}'")
|
| 459 |
+
|
| 460 |
+
# Retrieval and reranking (same as query())
|
| 461 |
+
retrieval_results = self.retriever.retrieve(
|
| 462 |
+
query=question,
|
| 463 |
+
top_k=self.retrieval_k,
|
| 464 |
+
filters=filters,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
if not retrieval_results:
|
| 468 |
+
yield "I couldn't find relevant information to answer this question."
|
| 469 |
+
return
|
| 470 |
+
|
| 471 |
+
reranked_results = self.reranker.rerank(
|
| 472 |
+
query=question,
|
| 473 |
+
documents=retrieval_results,
|
| 474 |
+
top_k=self.rerank_k,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Assemble context
|
| 478 |
+
context = self._assemble_context(reranked_results)
|
| 479 |
+
|
| 480 |
+
# Generate prompt
|
| 481 |
+
prompt = self._format_prompt(question, context)
|
| 482 |
+
|
| 483 |
+
# Stream generation
|
| 484 |
+
try:
|
| 485 |
+
for chunk in self.llm_client.stream_generate(
|
| 486 |
+
prompt=prompt,
|
| 487 |
+
system_prompt=self.system_prompt,
|
| 488 |
+
temperature=0.1,
|
| 489 |
+
):
|
| 490 |
+
yield chunk
|
| 491 |
+
|
| 492 |
+
except Exception as e:
|
| 493 |
+
logger.error(f"Error in streaming generation: {e}")
|
| 494 |
+
yield "\n\n[Error: Failed to generate response]"
|
| 495 |
+
|
| 496 |
+
def batch_query(
|
| 497 |
+
self,
|
| 498 |
+
questions: List[str],
|
| 499 |
+
include_sources: bool = True,
|
| 500 |
+
) -> List[QueryResponse]:
|
| 501 |
+
"""
|
| 502 |
+
Process multiple queries.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
questions: List of questions
|
| 506 |
+
include_sources: Include sources in responses
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
List of QueryResponse objects
|
| 510 |
+
"""
|
| 511 |
+
responses = []
|
| 512 |
+
|
| 513 |
+
for question in questions:
|
| 514 |
+
response = self.query(question, include_sources=include_sources)
|
| 515 |
+
responses.append(response)
|
| 516 |
+
|
| 517 |
+
return responses
|
| 518 |
+
|
| 519 |
+
def get_pipeline_info(self) -> dict:
|
| 520 |
+
"""
|
| 521 |
+
Get information about the pipeline configuration.
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
Dictionary with pipeline settings
|
| 525 |
+
"""
|
| 526 |
+
return {
|
| 527 |
+
"retrieval_k": self.retrieval_k,
|
| 528 |
+
"rerank_k": self.rerank_k,
|
| 529 |
+
"max_context_tokens": self.max_context_tokens,
|
| 530 |
+
"retriever_config": {
|
| 531 |
+
"dense_weight": self.retriever.dense_weight,
|
| 532 |
+
"sparse_weight": self.retriever.sparse_weight,
|
| 533 |
+
"term_expansion": self.retriever.enable_term_expansion,
|
| 534 |
+
},
|
| 535 |
+
"reranker_info": self.reranker.get_model_info(),
|
| 536 |
+
"llm_model": self.llm_client.llm_model,
|
| 537 |
+
}
|
src/rag/reranker.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cross-encoder reranker for improved retrieval relevance."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from sentence_transformers import CrossEncoder
|
| 8 |
+
from rich.console import Console
|
| 9 |
+
|
| 10 |
+
from src.rag.retriever import RetrievalResult
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CrossEncoderReranker:
|
| 19 |
+
"""
|
| 20 |
+
Reranker using cross-encoder models for improved relevance.
|
| 21 |
+
|
| 22 |
+
Features:
|
| 23 |
+
- Uses sentence-transformers cross-encoder
|
| 24 |
+
- Automatic GPU/CPU detection
|
| 25 |
+
- Model caching for efficiency
|
| 26 |
+
- Preserves original retrieval scores
|
| 27 |
+
- Batch processing for speed
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# Model cache to avoid reloading
|
| 31 |
+
_model_cache = {}
|
| 32 |
+
|
| 33 |
+
# Available models
|
| 34 |
+
AVAILABLE_MODELS = {
|
| 35 |
+
"ms-marco-mini": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 36 |
+
"ms-marco-base": "cross-encoder/ms-marco-MiniLM-L-12-v2",
|
| 37 |
+
"medicalai": "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb", # Medical domain
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
model_name: str = "ms-marco-mini",
|
| 43 |
+
device: Optional[str] = None,
|
| 44 |
+
max_length: int = 512,
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Initialize cross-encoder reranker.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
model_name: Model name (key from AVAILABLE_MODELS) or full path
|
| 51 |
+
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
| 52 |
+
max_length: Maximum sequence length
|
| 53 |
+
"""
|
| 54 |
+
# Resolve model name
|
| 55 |
+
if model_name in self.AVAILABLE_MODELS:
|
| 56 |
+
self.model_path = self.AVAILABLE_MODELS[model_name]
|
| 57 |
+
self.model_name = model_name
|
| 58 |
+
else:
|
| 59 |
+
self.model_path = model_name
|
| 60 |
+
self.model_name = "custom"
|
| 61 |
+
|
| 62 |
+
self.max_length = max_length
|
| 63 |
+
self.console = Console()
|
| 64 |
+
|
| 65 |
+
# Detect device
|
| 66 |
+
if device is None:
|
| 67 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
else:
|
| 69 |
+
self.device = device
|
| 70 |
+
|
| 71 |
+
# Load model
|
| 72 |
+
self._load_model()
|
| 73 |
+
|
| 74 |
+
def _load_model(self):
|
| 75 |
+
"""Load cross-encoder model with caching."""
|
| 76 |
+
cache_key = f"{self.model_path}_{self.device}"
|
| 77 |
+
|
| 78 |
+
# Check cache
|
| 79 |
+
if cache_key in self._model_cache:
|
| 80 |
+
self.model = self._model_cache[cache_key]
|
| 81 |
+
logger.info(f"Loaded reranker model from cache: {self.model_name}")
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Load model
|
| 85 |
+
try:
|
| 86 |
+
self.console.print(f"[cyan]Loading reranker model: {self.model_name}...[/cyan]")
|
| 87 |
+
|
| 88 |
+
self.model = CrossEncoder(
|
| 89 |
+
self.model_path,
|
| 90 |
+
max_length=self.max_length,
|
| 91 |
+
device=self.device,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Cache model
|
| 95 |
+
self._model_cache[cache_key] = self.model
|
| 96 |
+
|
| 97 |
+
device_info = f"GPU ({torch.cuda.get_device_name(0)})" if self.device == "cuda" else "CPU"
|
| 98 |
+
self.console.print(
|
| 99 |
+
f"[green][/green] Loaded reranker model: {self.model_name} on {device_info}"
|
| 100 |
+
)
|
| 101 |
+
logger.info(
|
| 102 |
+
f"Loaded cross-encoder model: {self.model_path} on {self.device}"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Failed to load reranker model: {e}")
|
| 107 |
+
self.console.print(f"[red][/red] Failed to load reranker model: {e}")
|
| 108 |
+
raise
|
| 109 |
+
|
| 110 |
+
def score_pairs(self, query: str, documents: List[str]) -> List[float]:
|
| 111 |
+
"""
|
| 112 |
+
Score query-document pairs.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
query: Search query
|
| 116 |
+
documents: List of document texts
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of relevance scores (higher is better)
|
| 120 |
+
"""
|
| 121 |
+
if not documents:
|
| 122 |
+
return []
|
| 123 |
+
|
| 124 |
+
# Create query-document pairs
|
| 125 |
+
pairs = [[query, doc] for doc in documents]
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
# Get scores from cross-encoder
|
| 129 |
+
scores = self.model.predict(pairs, convert_to_numpy=True)
|
| 130 |
+
|
| 131 |
+
# Convert to Python list
|
| 132 |
+
scores = scores.tolist()
|
| 133 |
+
|
| 134 |
+
logger.debug(f"Scored {len(documents)} documents")
|
| 135 |
+
|
| 136 |
+
return scores
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Error scoring pairs: {e}")
|
| 140 |
+
# Return zeros if scoring fails
|
| 141 |
+
return [0.0] * len(documents)
|
| 142 |
+
|
| 143 |
+
def rerank(
|
| 144 |
+
self,
|
| 145 |
+
query: str,
|
| 146 |
+
documents: List[RetrievalResult],
|
| 147 |
+
top_k: Optional[int] = None,
|
| 148 |
+
) -> List[RetrievalResult]:
|
| 149 |
+
"""
|
| 150 |
+
Rerank documents using cross-encoder.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
query: Search query
|
| 154 |
+
documents: List of RetrievalResult objects from retriever
|
| 155 |
+
top_k: Number of top results to return (None for all)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
List of RetrievalResult objects sorted by reranker score
|
| 159 |
+
"""
|
| 160 |
+
if not documents:
|
| 161 |
+
logger.warning("No documents to rerank")
|
| 162 |
+
return []
|
| 163 |
+
|
| 164 |
+
# Extract document texts
|
| 165 |
+
doc_texts = [doc.content for doc in documents]
|
| 166 |
+
|
| 167 |
+
# Score all documents
|
| 168 |
+
logger.info(f"Reranking {len(documents)} documents for query: '{query[:50]}...'")
|
| 169 |
+
rerank_scores = self.score_pairs(query, doc_texts)
|
| 170 |
+
|
| 171 |
+
# Create new results with updated scores
|
| 172 |
+
reranked_results = []
|
| 173 |
+
for doc, rerank_score in zip(documents, rerank_scores):
|
| 174 |
+
# Create a new RetrievalResult with updated score
|
| 175 |
+
# Store original retrieval score in metadata
|
| 176 |
+
updated_metadata = doc.metadata.copy()
|
| 177 |
+
updated_metadata["original_retrieval_score"] = doc.score
|
| 178 |
+
updated_metadata["reranker_score"] = float(rerank_score)
|
| 179 |
+
|
| 180 |
+
reranked_doc = RetrievalResult(
|
| 181 |
+
content=doc.content,
|
| 182 |
+
metadata=updated_metadata,
|
| 183 |
+
score=float(rerank_score), # Use reranker score as primary score
|
| 184 |
+
source_url=doc.source_url,
|
| 185 |
+
section=doc.section,
|
| 186 |
+
chunk_id=doc.chunk_id,
|
| 187 |
+
document_title=doc.document_title,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
reranked_results.append(reranked_doc)
|
| 191 |
+
|
| 192 |
+
# Sort by reranker score (descending)
|
| 193 |
+
reranked_results.sort(key=lambda x: x.score, reverse=True)
|
| 194 |
+
|
| 195 |
+
# Log score changes
|
| 196 |
+
if reranked_results:
|
| 197 |
+
logger.info(
|
| 198 |
+
f"Reranking complete. Top result score: {reranked_results[0].score:.4f} "
|
| 199 |
+
f"(original: {reranked_results[0].metadata.get('original_retrieval_score', 0):.4f})"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Return top_k if specified
|
| 203 |
+
if top_k is not None:
|
| 204 |
+
return reranked_results[:top_k]
|
| 205 |
+
|
| 206 |
+
return reranked_results
|
| 207 |
+
|
| 208 |
+
def rerank_with_comparison(
|
| 209 |
+
self,
|
| 210 |
+
query: str,
|
| 211 |
+
documents: List[RetrievalResult],
|
| 212 |
+
top_k: Optional[int] = None,
|
| 213 |
+
) -> List[Tuple[RetrievalResult, dict]]:
|
| 214 |
+
"""
|
| 215 |
+
Rerank with detailed comparison of scores.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
query: Search query
|
| 219 |
+
documents: List of RetrievalResult objects
|
| 220 |
+
top_k: Number of top results to return
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
List of (RetrievalResult, comparison_dict) tuples
|
| 224 |
+
where comparison_dict contains:
|
| 225 |
+
- original_score: Original retrieval score
|
| 226 |
+
- reranker_score: Cross-encoder score
|
| 227 |
+
- score_change: Difference (reranker - original)
|
| 228 |
+
- rank_change: Change in ranking position
|
| 229 |
+
"""
|
| 230 |
+
if not documents:
|
| 231 |
+
return []
|
| 232 |
+
|
| 233 |
+
# Store original rankings
|
| 234 |
+
original_rankings = {doc.chunk_id: idx for idx, doc in enumerate(documents)}
|
| 235 |
+
|
| 236 |
+
# Rerank documents
|
| 237 |
+
reranked_docs = self.rerank(query, documents, top_k=None)
|
| 238 |
+
|
| 239 |
+
# Create comparison results
|
| 240 |
+
results_with_comparison = []
|
| 241 |
+
|
| 242 |
+
for new_rank, doc in enumerate(reranked_docs):
|
| 243 |
+
original_rank = original_rankings[doc.chunk_id]
|
| 244 |
+
original_score = doc.metadata.get("original_retrieval_score", 0.0)
|
| 245 |
+
reranker_score = doc.score
|
| 246 |
+
|
| 247 |
+
comparison = {
|
| 248 |
+
"original_score": original_score,
|
| 249 |
+
"reranker_score": reranker_score,
|
| 250 |
+
"score_change": reranker_score - original_score,
|
| 251 |
+
"original_rank": original_rank,
|
| 252 |
+
"new_rank": new_rank,
|
| 253 |
+
"rank_change": original_rank - new_rank, # Positive = moved up
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
results_with_comparison.append((doc, comparison))
|
| 257 |
+
|
| 258 |
+
# Return top_k if specified
|
| 259 |
+
if top_k is not None:
|
| 260 |
+
return results_with_comparison[:top_k]
|
| 261 |
+
|
| 262 |
+
return results_with_comparison
|
| 263 |
+
|
| 264 |
+
def get_model_info(self) -> dict:
|
| 265 |
+
"""
|
| 266 |
+
Get information about the loaded model.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Dictionary with model information
|
| 270 |
+
"""
|
| 271 |
+
return {
|
| 272 |
+
"model_name": self.model_name,
|
| 273 |
+
"model_path": self.model_path,
|
| 274 |
+
"device": self.device,
|
| 275 |
+
"max_length": self.max_length,
|
| 276 |
+
"gpu_available": torch.cuda.is_available(),
|
| 277 |
+
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def clear_cache(self):
|
| 281 |
+
"""Clear the model cache."""
|
| 282 |
+
self._model_cache.clear()
|
| 283 |
+
logger.info("Cleared model cache")
|
| 284 |
+
|
| 285 |
+
@classmethod
|
| 286 |
+
def get_available_models(cls) -> dict:
|
| 287 |
+
"""
|
| 288 |
+
Get dictionary of available models.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Dictionary mapping model names to paths
|
| 292 |
+
"""
|
| 293 |
+
return cls.AVAILABLE_MODELS.copy()
|
src/rag/retriever.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hybrid retriever combining dense and sparse search for optimal retrieval."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import re
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
from rich.console import Console
|
| 9 |
+
|
| 10 |
+
from src.vectorstore.qdrant_store import QdrantStoreManager, SearchResult
|
| 11 |
+
from src.llm.sentence_transformer_client import SentenceTransformerClient
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RetrievalResult(BaseModel):
|
| 20 |
+
"""
|
| 21 |
+
Pydantic model for retrieval results.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
content: Retrieved text content
|
| 25 |
+
metadata: Document metadata (disease, ICD codes, etc.)
|
| 26 |
+
score: Relevance score
|
| 27 |
+
source_url: EyeWiki source URL
|
| 28 |
+
section: Parent section header
|
| 29 |
+
chunk_id: Unique chunk identifier
|
| 30 |
+
document_title: Article title
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
content: str = Field(..., description="Retrieved text content")
|
| 34 |
+
metadata: Dict = Field(default_factory=dict, description="Document metadata")
|
| 35 |
+
score: float = Field(..., description="Relevance score (can be negative for cross-encoder)")
|
| 36 |
+
source_url: str = Field(default="", description="EyeWiki source URL")
|
| 37 |
+
section: str = Field(default="", description="Parent section header")
|
| 38 |
+
chunk_id: str = Field(default="", description="Unique chunk identifier")
|
| 39 |
+
document_title: str = Field(default="", description="Article title")
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def from_search_result(cls, result: SearchResult) -> "RetrievalResult":
|
| 43 |
+
"""
|
| 44 |
+
Convert SearchResult to RetrievalResult.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
result: SearchResult from Qdrant
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
RetrievalResult instance
|
| 51 |
+
"""
|
| 52 |
+
return cls(
|
| 53 |
+
content=result.content,
|
| 54 |
+
metadata=result.metadata,
|
| 55 |
+
score=result.score,
|
| 56 |
+
source_url=result.source_url,
|
| 57 |
+
section=result.parent_section,
|
| 58 |
+
chunk_id=result.chunk_id,
|
| 59 |
+
document_title=result.document_title,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class HybridRetriever:
|
| 64 |
+
"""
|
| 65 |
+
Hybrid retriever combining dense (semantic) and sparse (BM25) search.
|
| 66 |
+
|
| 67 |
+
Features:
|
| 68 |
+
- Dense vector search via embeddings (default weight: 0.7)
|
| 69 |
+
- Sparse BM25 keyword search (default weight: 0.3)
|
| 70 |
+
- Configurable fusion weights
|
| 71 |
+
- Query preprocessing
|
| 72 |
+
- Medical term expansion
|
| 73 |
+
- Metadata filtering
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
# Medical term synonyms and abbreviations for query expansion
|
| 77 |
+
MEDICAL_TERM_EXPANSIONS = {
|
| 78 |
+
# Common abbreviations
|
| 79 |
+
"iop": ["intraocular pressure", "iop"],
|
| 80 |
+
"amd": ["age-related macular degeneration", "amd"],
|
| 81 |
+
"armd": ["age-related macular degeneration", "armd"],
|
| 82 |
+
"dme": ["diabetic macular edema", "dme"],
|
| 83 |
+
"dr": ["diabetic retinopathy", "dr"],
|
| 84 |
+
"poag": ["primary open-angle glaucoma", "poag"],
|
| 85 |
+
"pacg": ["primary angle-closure glaucoma", "pacg"],
|
| 86 |
+
"rvo": ["retinal vein occlusion", "rvo"],
|
| 87 |
+
"rao": ["retinal artery occlusion", "rao"],
|
| 88 |
+
"crvo": ["central retinal vein occlusion", "crvo"],
|
| 89 |
+
"brvo": ["branch retinal vein occlusion", "brvo"],
|
| 90 |
+
"crao": ["central retinal artery occlusion", "crao"],
|
| 91 |
+
"vegf": ["vascular endothelial growth factor", "vegf"],
|
| 92 |
+
"oct": ["optical coherence tomography", "oct"],
|
| 93 |
+
"fa": ["fluorescein angiography", "fa"],
|
| 94 |
+
"icg": ["indocyanine green angiography", "icg"],
|
| 95 |
+
"erg": ["electroretinography", "erg"],
|
| 96 |
+
"vf": ["visual field", "vf"],
|
| 97 |
+
"va": ["visual acuity", "va"],
|
| 98 |
+
|
| 99 |
+
# Common synonyms
|
| 100 |
+
"retina": ["retina", "retinal"],
|
| 101 |
+
"cornea": ["cornea", "corneal"],
|
| 102 |
+
"glaucoma": ["glaucoma", "glaucomatous"],
|
| 103 |
+
"cataract": ["cataract", "lens opacity"],
|
| 104 |
+
"macula": ["macula", "macular"],
|
| 105 |
+
"optic nerve": ["optic nerve", "optic disc", "optic cup"],
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
qdrant_manager: QdrantStoreManager,
|
| 111 |
+
embedding_client: SentenceTransformerClient,
|
| 112 |
+
dense_weight: float = 0.7,
|
| 113 |
+
sparse_weight: float = 0.3,
|
| 114 |
+
enable_term_expansion: bool = True,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Initialize hybrid retriever.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
qdrant_manager: QdrantStoreManager for vector search
|
| 121 |
+
embedding_client: SentenceTransformerClient for query embeddings
|
| 122 |
+
dense_weight: Weight for dense (semantic) search (0-1)
|
| 123 |
+
sparse_weight: Weight for sparse (BM25) search (0-1)
|
| 124 |
+
enable_term_expansion: Enable medical term expansion
|
| 125 |
+
"""
|
| 126 |
+
self.qdrant_manager = qdrant_manager
|
| 127 |
+
self.embedding_client = embedding_client
|
| 128 |
+
self.dense_weight = dense_weight
|
| 129 |
+
self.sparse_weight = sparse_weight
|
| 130 |
+
self.enable_term_expansion = enable_term_expansion
|
| 131 |
+
|
| 132 |
+
self.console = Console()
|
| 133 |
+
|
| 134 |
+
# Validate weights
|
| 135 |
+
total_weight = dense_weight + sparse_weight
|
| 136 |
+
if not (0.99 <= total_weight <= 1.01): # Allow small floating point error
|
| 137 |
+
logger.warning(
|
| 138 |
+
f"Weights sum to {total_weight:.2f}, not 1.0. "
|
| 139 |
+
"Normalizing weights."
|
| 140 |
+
)
|
| 141 |
+
self.dense_weight = dense_weight / total_weight
|
| 142 |
+
self.sparse_weight = sparse_weight / total_weight
|
| 143 |
+
|
| 144 |
+
logger.info(
|
| 145 |
+
f"Initialized HybridRetriever (dense: {self.dense_weight:.2f}, "
|
| 146 |
+
f"sparse: {self.sparse_weight:.2f})"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def _preprocess_query(self, query: str) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Preprocess query text.
|
| 152 |
+
|
| 153 |
+
- Convert to lowercase
|
| 154 |
+
- Remove excessive whitespace
|
| 155 |
+
- Normalize punctuation
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
query: Raw query string
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Preprocessed query
|
| 162 |
+
"""
|
| 163 |
+
# Convert to lowercase
|
| 164 |
+
query = query.lower()
|
| 165 |
+
|
| 166 |
+
# Remove excessive whitespace
|
| 167 |
+
query = re.sub(r'\s+', ' ', query)
|
| 168 |
+
|
| 169 |
+
# Strip leading/trailing whitespace
|
| 170 |
+
query = query.strip()
|
| 171 |
+
|
| 172 |
+
return query
|
| 173 |
+
|
| 174 |
+
def _expand_medical_terms(self, query: str) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Expand medical abbreviations and add synonyms.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
query: Preprocessed query
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Expanded query with synonyms
|
| 183 |
+
"""
|
| 184 |
+
if not self.enable_term_expansion:
|
| 185 |
+
return query
|
| 186 |
+
|
| 187 |
+
expanded_terms = []
|
| 188 |
+
words = query.split()
|
| 189 |
+
|
| 190 |
+
for word in words:
|
| 191 |
+
# Check if word matches any abbreviation or term
|
| 192 |
+
if word in self.MEDICAL_TERM_EXPANSIONS:
|
| 193 |
+
# Add all expansions
|
| 194 |
+
expansions = self.MEDICAL_TERM_EXPANSIONS[word]
|
| 195 |
+
expanded_terms.extend(expansions)
|
| 196 |
+
else:
|
| 197 |
+
# Keep original word
|
| 198 |
+
expanded_terms.append(word)
|
| 199 |
+
|
| 200 |
+
# Join and deduplicate
|
| 201 |
+
expanded_query = " ".join(expanded_terms)
|
| 202 |
+
|
| 203 |
+
logger.debug(f"Query expansion: '{query}' � '{expanded_query}'")
|
| 204 |
+
|
| 205 |
+
return expanded_query
|
| 206 |
+
|
| 207 |
+
def _generate_query_embedding(self, query: str) -> List[float]:
|
| 208 |
+
"""
|
| 209 |
+
Generate embedding for query.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
query: Query text
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Query embedding vector
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
embedding = self.embedding_client.embed_text(query)
|
| 219 |
+
return embedding
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(f"Failed to generate query embedding: {e}")
|
| 222 |
+
raise
|
| 223 |
+
|
| 224 |
+
def _merge_results(
|
| 225 |
+
self,
|
| 226 |
+
dense_results: List[SearchResult],
|
| 227 |
+
sparse_results: Optional[List[SearchResult]] = None,
|
| 228 |
+
) -> List[Tuple[RetrievalResult, float]]:
|
| 229 |
+
"""
|
| 230 |
+
Merge dense and sparse results using weighted fusion.
|
| 231 |
+
|
| 232 |
+
Uses Reciprocal Rank Fusion (RRF) for score combination.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
dense_results: Results from dense search
|
| 236 |
+
sparse_results: Results from sparse search (if available)
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
List of (RetrievalResult, combined_score) tuples
|
| 240 |
+
"""
|
| 241 |
+
# If no sparse results, just use dense results
|
| 242 |
+
if not sparse_results:
|
| 243 |
+
results = []
|
| 244 |
+
for result in dense_results:
|
| 245 |
+
retrieval_result = RetrievalResult.from_search_result(result)
|
| 246 |
+
# Apply dense weight to score
|
| 247 |
+
weighted_score = result.score * self.dense_weight
|
| 248 |
+
results.append((retrieval_result, weighted_score))
|
| 249 |
+
return results
|
| 250 |
+
|
| 251 |
+
# Create score dictionaries keyed by chunk_id
|
| 252 |
+
dense_scores = {r.chunk_id: r.score for r in dense_results}
|
| 253 |
+
sparse_scores = {r.chunk_id: r.score for r in sparse_results}
|
| 254 |
+
|
| 255 |
+
# Get all unique chunk_ids
|
| 256 |
+
all_chunk_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
|
| 257 |
+
|
| 258 |
+
# Create lookup for full result objects
|
| 259 |
+
result_lookup = {}
|
| 260 |
+
for result in dense_results:
|
| 261 |
+
result_lookup[result.chunk_id] = result
|
| 262 |
+
for result in sparse_results:
|
| 263 |
+
if result.chunk_id not in result_lookup:
|
| 264 |
+
result_lookup[result.chunk_id] = result
|
| 265 |
+
|
| 266 |
+
# Calculate weighted combined scores
|
| 267 |
+
combined_results = []
|
| 268 |
+
for chunk_id in all_chunk_ids:
|
| 269 |
+
dense_score = dense_scores.get(chunk_id, 0.0)
|
| 270 |
+
sparse_score = sparse_scores.get(chunk_id, 0.0)
|
| 271 |
+
|
| 272 |
+
# Weighted combination
|
| 273 |
+
combined_score = (
|
| 274 |
+
dense_score * self.dense_weight + sparse_score * self.sparse_weight
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
result = result_lookup[chunk_id]
|
| 278 |
+
retrieval_result = RetrievalResult.from_search_result(result)
|
| 279 |
+
combined_results.append((retrieval_result, combined_score))
|
| 280 |
+
|
| 281 |
+
# Sort by combined score (descending)
|
| 282 |
+
combined_results.sort(key=lambda x: x[1], reverse=True)
|
| 283 |
+
|
| 284 |
+
return combined_results
|
| 285 |
+
|
| 286 |
+
def retrieve_with_scores(
|
| 287 |
+
self,
|
| 288 |
+
query: str,
|
| 289 |
+
top_k: int = 10,
|
| 290 |
+
filters: Optional[Dict] = None,
|
| 291 |
+
) -> List[Tuple[RetrievalResult, float]]:
|
| 292 |
+
"""
|
| 293 |
+
Retrieve documents with scores.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
query: Search query
|
| 297 |
+
top_k: Number of results to return
|
| 298 |
+
filters: Optional metadata filters
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
List of (RetrievalResult, score) tuples
|
| 302 |
+
"""
|
| 303 |
+
# Preprocess query
|
| 304 |
+
processed_query = self._preprocess_query(query)
|
| 305 |
+
|
| 306 |
+
# Expand medical terms
|
| 307 |
+
expanded_query = self._expand_medical_terms(processed_query)
|
| 308 |
+
|
| 309 |
+
logger.info(f"Retrieving for query: '{query}'")
|
| 310 |
+
logger.debug(f"Processed query: '{expanded_query}'")
|
| 311 |
+
|
| 312 |
+
# Generate query embedding
|
| 313 |
+
query_embedding = self._generate_query_embedding(expanded_query)
|
| 314 |
+
|
| 315 |
+
# Perform dense search
|
| 316 |
+
dense_results = self.qdrant_manager.search(
|
| 317 |
+
query_embedding=query_embedding,
|
| 318 |
+
top_k=top_k * 2, # Get more for fusion
|
| 319 |
+
filters=filters,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
logger.info(f"Dense search returned {len(dense_results)} results")
|
| 323 |
+
|
| 324 |
+
# Note: For true hybrid search with sparse vectors, you would also:
|
| 325 |
+
# 1. Generate sparse vector for query (BM25)
|
| 326 |
+
# 2. Perform sparse search via qdrant_manager.hybrid_search()
|
| 327 |
+
# 3. Merge results using RRF
|
| 328 |
+
#
|
| 329 |
+
# For now, we'll use dense search only
|
| 330 |
+
# In production, implement proper BM25 sparse vector generation
|
| 331 |
+
|
| 332 |
+
sparse_results = None # Placeholder for sparse search
|
| 333 |
+
|
| 334 |
+
# Merge results
|
| 335 |
+
combined_results = self._merge_results(dense_results, sparse_results)
|
| 336 |
+
|
| 337 |
+
# Return top_k
|
| 338 |
+
return combined_results[:top_k]
|
| 339 |
+
|
| 340 |
+
def retrieve(
|
| 341 |
+
self,
|
| 342 |
+
query: str,
|
| 343 |
+
top_k: int = 10,
|
| 344 |
+
filters: Optional[Dict] = None,
|
| 345 |
+
) -> List[RetrievalResult]:
|
| 346 |
+
"""
|
| 347 |
+
Retrieve documents (without scores).
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
query: Search query
|
| 351 |
+
top_k: Number of results to return
|
| 352 |
+
filters: Optional metadata filters
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
List of RetrievalResult objects
|
| 356 |
+
"""
|
| 357 |
+
results_with_scores = self.retrieve_with_scores(query, top_k, filters)
|
| 358 |
+
|
| 359 |
+
# Extract just the results, drop scores
|
| 360 |
+
results = [result for result, score in results_with_scores]
|
| 361 |
+
|
| 362 |
+
return results
|
| 363 |
+
|
| 364 |
+
def retrieve_by_disease(
|
| 365 |
+
self,
|
| 366 |
+
query: str,
|
| 367 |
+
disease_name: str,
|
| 368 |
+
top_k: int = 10,
|
| 369 |
+
) -> List[RetrievalResult]:
|
| 370 |
+
"""
|
| 371 |
+
Retrieve documents filtered by disease name.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
query: Search query
|
| 375 |
+
disease_name: Disease name to filter by
|
| 376 |
+
top_k: Number of results to return
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
List of RetrievalResult objects
|
| 380 |
+
"""
|
| 381 |
+
filters = {"disease_name": disease_name}
|
| 382 |
+
return self.retrieve(query, top_k, filters)
|
| 383 |
+
|
| 384 |
+
def retrieve_by_icd_code(
|
| 385 |
+
self,
|
| 386 |
+
query: str,
|
| 387 |
+
icd_codes: List[str],
|
| 388 |
+
top_k: int = 10,
|
| 389 |
+
) -> List[RetrievalResult]:
|
| 390 |
+
"""
|
| 391 |
+
Retrieve documents filtered by ICD codes.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
query: Search query
|
| 395 |
+
icd_codes: List of ICD codes to filter by
|
| 396 |
+
top_k: Number of results to return
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
List of RetrievalResult objects
|
| 400 |
+
"""
|
| 401 |
+
filters = {"icd_codes": icd_codes}
|
| 402 |
+
return self.retrieve(query, top_k, filters)
|
| 403 |
+
|
| 404 |
+
def retrieve_by_anatomy(
|
| 405 |
+
self,
|
| 406 |
+
query: str,
|
| 407 |
+
anatomical_structures: List[str],
|
| 408 |
+
top_k: int = 10,
|
| 409 |
+
) -> List[RetrievalResult]:
|
| 410 |
+
"""
|
| 411 |
+
Retrieve documents filtered by anatomical structures.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
query: Search query
|
| 415 |
+
anatomical_structures: List of anatomical terms
|
| 416 |
+
top_k: Number of results to return
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
List of RetrievalResult objects
|
| 420 |
+
"""
|
| 421 |
+
filters = {"anatomical_structures": anatomical_structures}
|
| 422 |
+
return self.retrieve(query, top_k, filters)
|
| 423 |
+
|
| 424 |
+
def get_similar_sections(
|
| 425 |
+
self,
|
| 426 |
+
section_content: str,
|
| 427 |
+
top_k: int = 5,
|
| 428 |
+
filters: Optional[Dict] = None,
|
| 429 |
+
) -> List[RetrievalResult]:
|
| 430 |
+
"""
|
| 431 |
+
Find similar sections based on content.
|
| 432 |
+
|
| 433 |
+
Useful for "related sections" or "see also" features.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
section_content: Content to find similar sections for
|
| 437 |
+
top_k: Number of results to return
|
| 438 |
+
filters: Optional metadata filters
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
List of RetrievalResult objects
|
| 442 |
+
"""
|
| 443 |
+
# Use the section content itself as the query
|
| 444 |
+
return self.retrieve(section_content, top_k, filters)
|
| 445 |
+
|
| 446 |
+
def multi_query_retrieve(
|
| 447 |
+
self,
|
| 448 |
+
queries: List[str],
|
| 449 |
+
top_k: int = 10,
|
| 450 |
+
filters: Optional[Dict] = None,
|
| 451 |
+
deduplicate: bool = True,
|
| 452 |
+
) -> List[RetrievalResult]:
|
| 453 |
+
"""
|
| 454 |
+
Retrieve using multiple queries and combine results.
|
| 455 |
+
|
| 456 |
+
Useful for query decomposition or multi-faceted questions.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
queries: List of query strings
|
| 460 |
+
top_k: Total number of results to return
|
| 461 |
+
filters: Optional metadata filters
|
| 462 |
+
deduplicate: Remove duplicate results
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
List of RetrievalResult objects
|
| 466 |
+
"""
|
| 467 |
+
all_results = []
|
| 468 |
+
seen_chunk_ids = set()
|
| 469 |
+
|
| 470 |
+
# Retrieve for each query
|
| 471 |
+
for query in queries:
|
| 472 |
+
results = self.retrieve(query, top_k=top_k, filters=filters)
|
| 473 |
+
|
| 474 |
+
for result in results:
|
| 475 |
+
if deduplicate:
|
| 476 |
+
if result.chunk_id not in seen_chunk_ids:
|
| 477 |
+
all_results.append(result)
|
| 478 |
+
seen_chunk_ids.add(result.chunk_id)
|
| 479 |
+
else:
|
| 480 |
+
all_results.append(result)
|
| 481 |
+
|
| 482 |
+
# Return top_k overall
|
| 483 |
+
return all_results[:top_k]
|
src/scraper/__init__.py
ADDED
|
File without changes
|
src/scraper/eyewiki_crawler.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""EyeWiki crawler for medical article scraping using crawl4ai."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from collections import deque
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Optional, Set
|
| 10 |
+
from urllib.parse import urljoin, urlparse, parse_qs
|
| 11 |
+
from urllib.robotparser import RobotFileParser
|
| 12 |
+
|
| 13 |
+
import aiohttp
|
| 14 |
+
from bs4 import BeautifulSoup
|
| 15 |
+
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
| 16 |
+
from rich.console import Console
|
| 17 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
| 18 |
+
|
| 19 |
+
from config.settings import settings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class EyeWikiCrawler:
|
| 23 |
+
"""
|
| 24 |
+
Asynchronous crawler for EyeWiki medical articles.
|
| 25 |
+
|
| 26 |
+
Features:
|
| 27 |
+
- Asynchronous crawling with crawl4ai
|
| 28 |
+
- Respects robots.txt
|
| 29 |
+
- Polite crawling with configurable delays
|
| 30 |
+
- Markdown content extraction
|
| 31 |
+
- Checkpointing for resume capability
|
| 32 |
+
- Progress tracking with rich console
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
base_url: str = "https://eyewiki.org",
|
| 38 |
+
output_dir: Optional[Path] = None,
|
| 39 |
+
checkpoint_file: Optional[Path] = None,
|
| 40 |
+
delay: float = 1.5,
|
| 41 |
+
timeout: int = 30,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize the EyeWiki crawler.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
base_url: Base URL for EyeWiki
|
| 48 |
+
output_dir: Directory to save scraped articles
|
| 49 |
+
checkpoint_file: Path to checkpoint file
|
| 50 |
+
delay: Delay between requests in seconds
|
| 51 |
+
timeout: Request timeout in seconds
|
| 52 |
+
"""
|
| 53 |
+
self.base_url = base_url
|
| 54 |
+
self.domain = urlparse(base_url).netloc
|
| 55 |
+
self.output_dir = output_dir or Path(settings.data_raw_path)
|
| 56 |
+
self.checkpoint_file = checkpoint_file or (self.output_dir / "crawler_checkpoint.json")
|
| 57 |
+
self.delay = delay
|
| 58 |
+
self.timeout = timeout
|
| 59 |
+
|
| 60 |
+
# Ensure output directory exists
|
| 61 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
# Crawl state
|
| 64 |
+
self.visited_urls: Set[str] = set()
|
| 65 |
+
self.to_crawl: deque = deque()
|
| 66 |
+
self.failed_urls: Dict[str, str] = {}
|
| 67 |
+
self.articles_saved: int = 0
|
| 68 |
+
|
| 69 |
+
# Rich console for logging
|
| 70 |
+
self.console = Console()
|
| 71 |
+
|
| 72 |
+
# Robot parser
|
| 73 |
+
self.robot_parser = RobotFileParser()
|
| 74 |
+
self.robot_parser.set_url(urljoin(base_url, "/robots.txt"))
|
| 75 |
+
|
| 76 |
+
# Patterns to skip
|
| 77 |
+
self.skip_patterns = [
|
| 78 |
+
r"/index\.php\?title=.*&action=", # Edit, history, etc.
|
| 79 |
+
r"/index\.php\?title=.*&diff=", # Page diffs
|
| 80 |
+
r"/index\.php\?title=.*&oldid=", # Page history/revisions
|
| 81 |
+
r"/index\.php\?title=.*&direction=", # Page navigation
|
| 82 |
+
r"/index\.php\?title=Special:", # Special pages (login, create account, etc.)
|
| 83 |
+
r"/Special:", # Special pages
|
| 84 |
+
r"/User:", # User pages
|
| 85 |
+
r"/User_talk:", # User talk pages
|
| 86 |
+
r"/Talk:", # Talk pages
|
| 87 |
+
r"/File:", # File pages
|
| 88 |
+
r"/Template:", # Template pages
|
| 89 |
+
r"/Help:", # Help pages
|
| 90 |
+
r"/MediaWiki:", # MediaWiki pages
|
| 91 |
+
r"#", # Anchor links
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
def _is_valid_article_url(self, url: str) -> bool:
|
| 95 |
+
"""
|
| 96 |
+
Check if URL is a valid medical article.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
url: URL to check
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
True if valid article URL
|
| 103 |
+
"""
|
| 104 |
+
# Must be from eyewiki.org domain
|
| 105 |
+
if self.domain not in url:
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
# Skip patterns (these take precedence)
|
| 109 |
+
for pattern in self.skip_patterns:
|
| 110 |
+
if re.search(pattern, url):
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
# Parse URL to check path
|
| 114 |
+
parsed = urlparse(url)
|
| 115 |
+
path = parsed.path.strip("/")
|
| 116 |
+
|
| 117 |
+
# Must be article-like URL
|
| 118 |
+
# EyeWiki articles can be:
|
| 119 |
+
# 1. Direct: /Article_Name (e.g., /Cataract)
|
| 120 |
+
# 2. Wiki-style: /wiki/Article_Name
|
| 121 |
+
# 3. Query-based: /w/index.php?title=Article_Name
|
| 122 |
+
|
| 123 |
+
# For query-based URLs, check if title parameter exists and is not a special page
|
| 124 |
+
if parsed.query and "title=" in parsed.query:
|
| 125 |
+
return True
|
| 126 |
+
|
| 127 |
+
# For direct URLs, check if path is non-empty and looks like an article
|
| 128 |
+
# (starts with capital letter, no file extension)
|
| 129 |
+
if path and not path.startswith("w/") and not "." in path:
|
| 130 |
+
# Path should look like an article name (capitalized, underscores/spaces)
|
| 131 |
+
if path[0].isupper() or path.startswith("wiki/"):
|
| 132 |
+
return True
|
| 133 |
+
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
def _normalize_url(self, url: str) -> str:
|
| 137 |
+
"""
|
| 138 |
+
Normalize URL for consistent comparison.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
url: URL to normalize
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Normalized URL
|
| 145 |
+
"""
|
| 146 |
+
# Remove fragment
|
| 147 |
+
url = url.split("#")[0]
|
| 148 |
+
# Remove trailing slash
|
| 149 |
+
url = url.rstrip("/")
|
| 150 |
+
return url
|
| 151 |
+
|
| 152 |
+
def _can_fetch(self, url: str) -> bool:
|
| 153 |
+
"""
|
| 154 |
+
Check if URL can be fetched according to robots.txt.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
url: URL to check
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
True if allowed to fetch
|
| 161 |
+
"""
|
| 162 |
+
try:
|
| 163 |
+
return self.robot_parser.can_fetch("*", url)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
self.console.print(f"[yellow]Warning: Could not check robots.txt: {e}[/yellow]")
|
| 166 |
+
return True # Be permissive if robots.txt check fails
|
| 167 |
+
|
| 168 |
+
def _extract_links(self, html: str, current_url: str) -> Set[str]:
|
| 169 |
+
"""
|
| 170 |
+
Extract valid article links from HTML.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
html: HTML content
|
| 174 |
+
current_url: Current page URL for resolving relative links
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Set of valid article URLs
|
| 178 |
+
"""
|
| 179 |
+
soup = BeautifulSoup(html, "html.parser")
|
| 180 |
+
links = set()
|
| 181 |
+
|
| 182 |
+
for a_tag in soup.find_all("a", href=True):
|
| 183 |
+
href = a_tag["href"]
|
| 184 |
+
# Resolve relative URLs
|
| 185 |
+
absolute_url = urljoin(current_url, href)
|
| 186 |
+
normalized_url = self._normalize_url(absolute_url)
|
| 187 |
+
|
| 188 |
+
if self._is_valid_article_url(normalized_url):
|
| 189 |
+
links.add(normalized_url)
|
| 190 |
+
|
| 191 |
+
return links
|
| 192 |
+
|
| 193 |
+
def _extract_metadata(self, soup: BeautifulSoup, url: str) -> Dict:
|
| 194 |
+
"""
|
| 195 |
+
Extract metadata from article page.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
soup: BeautifulSoup object
|
| 199 |
+
url: Article URL
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dictionary of metadata
|
| 203 |
+
"""
|
| 204 |
+
metadata = {
|
| 205 |
+
"url": url,
|
| 206 |
+
"title": "",
|
| 207 |
+
"last_updated": None,
|
| 208 |
+
"categories": [],
|
| 209 |
+
"scraped_at": datetime.utcnow().isoformat(),
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
# Extract title
|
| 213 |
+
title_tag = soup.find("h1", {"id": "firstHeading"}) or soup.find("h1")
|
| 214 |
+
if title_tag:
|
| 215 |
+
metadata["title"] = title_tag.get_text(strip=True)
|
| 216 |
+
|
| 217 |
+
# Extract categories
|
| 218 |
+
category_links = soup.find_all("a", href=re.compile(r"/Category:"))
|
| 219 |
+
metadata["categories"] = [link.get_text(strip=True) for link in category_links]
|
| 220 |
+
|
| 221 |
+
# Extract last modified date (if available)
|
| 222 |
+
last_modified = soup.find("li", {"id": "footer-info-lastmod"})
|
| 223 |
+
if last_modified:
|
| 224 |
+
metadata["last_updated"] = last_modified.get_text(strip=True)
|
| 225 |
+
|
| 226 |
+
return metadata
|
| 227 |
+
|
| 228 |
+
def save_article(self, content: Dict, filepath: Path) -> None:
|
| 229 |
+
"""
|
| 230 |
+
Save article content and metadata to files.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
content: Dictionary with 'markdown' and 'metadata' keys
|
| 234 |
+
filepath: Base filepath (without extension)
|
| 235 |
+
"""
|
| 236 |
+
# Save markdown content
|
| 237 |
+
md_file = filepath.with_suffix(".md")
|
| 238 |
+
with open(md_file, "w", encoding="utf-8") as f:
|
| 239 |
+
f.write(content["markdown"])
|
| 240 |
+
|
| 241 |
+
# Save metadata as JSON sidecar
|
| 242 |
+
json_file = filepath.with_suffix(".json")
|
| 243 |
+
with open(json_file, "w", encoding="utf-8") as f:
|
| 244 |
+
json.dump(content["metadata"], f, indent=2, ensure_ascii=False)
|
| 245 |
+
|
| 246 |
+
self.articles_saved += 1
|
| 247 |
+
self.console.print(f"[green][/green] Saved: {content['metadata'].get('title', 'Untitled')}")
|
| 248 |
+
|
| 249 |
+
def load_checkpoint(self) -> bool:
|
| 250 |
+
"""
|
| 251 |
+
Load checkpoint data to resume crawling.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
True if checkpoint was loaded successfully
|
| 255 |
+
"""
|
| 256 |
+
if not self.checkpoint_file.exists():
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
with open(self.checkpoint_file, "r") as f:
|
| 261 |
+
data = json.load(f)
|
| 262 |
+
|
| 263 |
+
self.visited_urls = set(data.get("visited_urls", []))
|
| 264 |
+
self.to_crawl = deque(data.get("to_crawl", []))
|
| 265 |
+
self.failed_urls = data.get("failed_urls", {})
|
| 266 |
+
self.articles_saved = data.get("articles_saved", 0)
|
| 267 |
+
|
| 268 |
+
self.console.print(f"[blue]Loaded checkpoint:[/blue] {len(self.visited_urls)} visited, "
|
| 269 |
+
f"{len(self.to_crawl)} queued, {self.articles_saved} saved")
|
| 270 |
+
return True
|
| 271 |
+
except Exception as e:
|
| 272 |
+
self.console.print(f"[red]Error loading checkpoint: {e}[/red]")
|
| 273 |
+
return False
|
| 274 |
+
|
| 275 |
+
def save_checkpoint(self) -> None:
|
| 276 |
+
"""Save current crawl state to checkpoint file."""
|
| 277 |
+
data = {
|
| 278 |
+
"visited_urls": list(self.visited_urls),
|
| 279 |
+
"to_crawl": list(self.to_crawl),
|
| 280 |
+
"failed_urls": self.failed_urls,
|
| 281 |
+
"articles_saved": self.articles_saved,
|
| 282 |
+
"last_checkpoint": datetime.utcnow().isoformat(),
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
with open(self.checkpoint_file, "w") as f:
|
| 287 |
+
json.dump(data, f, indent=2)
|
| 288 |
+
except Exception as e:
|
| 289 |
+
self.console.print(f"[red]Error saving checkpoint: {e}[/red]")
|
| 290 |
+
|
| 291 |
+
async def crawl_single_page(self, url: str) -> Optional[Dict]:
|
| 292 |
+
"""
|
| 293 |
+
Crawl a single page and extract content.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
url: URL to crawl
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Dictionary with markdown content and metadata, or None if failed
|
| 300 |
+
"""
|
| 301 |
+
if not self._can_fetch(url):
|
| 302 |
+
self.console.print(f"[yellow]Blocked by robots.txt:[/yellow] {url}")
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
try:
|
| 306 |
+
# Configure browser settings
|
| 307 |
+
browser_config = BrowserConfig(
|
| 308 |
+
headless=True,
|
| 309 |
+
verbose=False,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Configure crawler settings
|
| 313 |
+
crawler_config = CrawlerRunConfig(
|
| 314 |
+
cache_mode=CacheMode.BYPASS,
|
| 315 |
+
page_timeout=self.timeout * 1000, # Convert to milliseconds
|
| 316 |
+
wait_for="body",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Create crawler and run
|
| 320 |
+
async with AsyncWebCrawler(config=browser_config) as crawler:
|
| 321 |
+
result = await crawler.arun(
|
| 322 |
+
url=url,
|
| 323 |
+
config=crawler_config,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if not result.success:
|
| 327 |
+
self.console.print(f"[red]Failed to crawl:[/red] {url}")
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
# Parse HTML for metadata
|
| 331 |
+
soup = BeautifulSoup(result.html, "html.parser")
|
| 332 |
+
metadata = self._extract_metadata(soup, url)
|
| 333 |
+
|
| 334 |
+
# Get markdown content
|
| 335 |
+
markdown = result.markdown
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
"markdown": markdown,
|
| 339 |
+
"metadata": metadata,
|
| 340 |
+
"html": result.html,
|
| 341 |
+
"links": self._extract_links(result.html, url),
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
except Exception as e:
|
| 345 |
+
self.console.print(f"[red]Error crawling {url}:[/red] {e}")
|
| 346 |
+
self.failed_urls[url] = str(e)
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
async def crawl(
|
| 350 |
+
self,
|
| 351 |
+
max_pages: Optional[int] = None,
|
| 352 |
+
depth: int = 2,
|
| 353 |
+
start_urls: Optional[list] = None,
|
| 354 |
+
) -> None:
|
| 355 |
+
"""
|
| 356 |
+
Crawl EyeWiki starting from the main page.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
max_pages: Maximum number of pages to crawl (None for unlimited)
|
| 360 |
+
depth: Maximum depth to crawl
|
| 361 |
+
start_urls: Optional list of starting URLs (defaults to base_url)
|
| 362 |
+
"""
|
| 363 |
+
# Try to load checkpoint
|
| 364 |
+
checkpoint_loaded = self.load_checkpoint()
|
| 365 |
+
|
| 366 |
+
# Initialize robot parser
|
| 367 |
+
try:
|
| 368 |
+
self.robot_parser.read()
|
| 369 |
+
self.console.print("[green][/green] Loaded robots.txt")
|
| 370 |
+
except Exception as e:
|
| 371 |
+
self.console.print(f"[yellow]Warning: Could not load robots.txt: {e}[/yellow]")
|
| 372 |
+
|
| 373 |
+
# Initialize queue if not loaded from checkpoint
|
| 374 |
+
if not checkpoint_loaded:
|
| 375 |
+
if start_urls:
|
| 376 |
+
self.to_crawl.extend([(url, 0) for url in start_urls])
|
| 377 |
+
else:
|
| 378 |
+
self.to_crawl.append((self.base_url, 0))
|
| 379 |
+
|
| 380 |
+
self.console.print(f"\n[bold cyan]Starting EyeWiki Crawl[/bold cyan]")
|
| 381 |
+
self.console.print(f"Max pages: {max_pages or 'unlimited'}")
|
| 382 |
+
self.console.print(f"Max depth: {depth}")
|
| 383 |
+
self.console.print(f"Delay: {self.delay}s\n")
|
| 384 |
+
|
| 385 |
+
with Progress(
|
| 386 |
+
SpinnerColumn(),
|
| 387 |
+
TextColumn("[progress.description]{task.description}"),
|
| 388 |
+
BarColumn(),
|
| 389 |
+
TaskProgressColumn(),
|
| 390 |
+
console=self.console,
|
| 391 |
+
) as progress:
|
| 392 |
+
|
| 393 |
+
task = progress.add_task(
|
| 394 |
+
"[cyan]Crawling...",
|
| 395 |
+
total=max_pages if max_pages else 100,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
while self.to_crawl:
|
| 400 |
+
# Check max_pages limit
|
| 401 |
+
if max_pages and self.articles_saved >= max_pages:
|
| 402 |
+
self.console.print(f"\n[yellow]Reached max_pages limit: {max_pages}[/yellow]")
|
| 403 |
+
break
|
| 404 |
+
|
| 405 |
+
# Get next URL
|
| 406 |
+
current_url, current_depth = self.to_crawl.popleft()
|
| 407 |
+
|
| 408 |
+
# Skip if already visited
|
| 409 |
+
if current_url in self.visited_urls:
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
# Check depth limit
|
| 413 |
+
if current_depth > depth:
|
| 414 |
+
continue
|
| 415 |
+
|
| 416 |
+
# Mark as visited
|
| 417 |
+
self.visited_urls.add(current_url)
|
| 418 |
+
|
| 419 |
+
# Update progress
|
| 420 |
+
progress.update(
|
| 421 |
+
task,
|
| 422 |
+
completed=self.articles_saved,
|
| 423 |
+
description=f"[cyan]Crawling ({self.articles_saved} saved, {len(self.to_crawl)} queued): {current_url[:60]}...",
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Crawl the page
|
| 427 |
+
result = await self.crawl_single_page(current_url)
|
| 428 |
+
|
| 429 |
+
if result:
|
| 430 |
+
# Create filename from URL
|
| 431 |
+
parsed = urlparse(current_url)
|
| 432 |
+
|
| 433 |
+
# For URLs with query parameters (like index.php?title=Article_Name),
|
| 434 |
+
# extract the title parameter
|
| 435 |
+
if parsed.query:
|
| 436 |
+
query_params = parse_qs(parsed.query)
|
| 437 |
+
if 'title' in query_params:
|
| 438 |
+
# Use the title parameter as filename
|
| 439 |
+
filename = query_params['title'][0]
|
| 440 |
+
else:
|
| 441 |
+
# Fallback: use the entire query string
|
| 442 |
+
filename = parsed.query
|
| 443 |
+
else:
|
| 444 |
+
# Use path-based filename for clean URLs like /wiki/Article_Name
|
| 445 |
+
path_parts = parsed.path.strip("/").split("/")
|
| 446 |
+
filename = "_".join(path_parts[-2:]) if len(path_parts) > 1 else path_parts[-1]
|
| 447 |
+
|
| 448 |
+
# Clean filename
|
| 449 |
+
filename = re.sub(r"[^\w\s-]", "_", filename)
|
| 450 |
+
filename = re.sub(r"[-\s]+", "_", filename)
|
| 451 |
+
filename = filename[:200] # Limit length
|
| 452 |
+
|
| 453 |
+
# Save article
|
| 454 |
+
filepath = self.output_dir / filename
|
| 455 |
+
self.save_article(result, filepath)
|
| 456 |
+
|
| 457 |
+
# Add discovered links to queue
|
| 458 |
+
for link in result["links"]:
|
| 459 |
+
if link not in self.visited_urls:
|
| 460 |
+
self.to_crawl.append((link, current_depth + 1))
|
| 461 |
+
|
| 462 |
+
# Polite delay
|
| 463 |
+
await asyncio.sleep(self.delay)
|
| 464 |
+
|
| 465 |
+
# Periodic checkpoint save (every 10 articles)
|
| 466 |
+
if self.articles_saved % 10 == 0:
|
| 467 |
+
self.save_checkpoint()
|
| 468 |
+
|
| 469 |
+
except KeyboardInterrupt:
|
| 470 |
+
self.console.print("\n[yellow]Crawl interrupted by user[/yellow]")
|
| 471 |
+
except Exception as e:
|
| 472 |
+
self.console.print(f"\n[red]Error during crawl: {e}[/red]")
|
| 473 |
+
finally:
|
| 474 |
+
# Final checkpoint save
|
| 475 |
+
self.save_checkpoint()
|
| 476 |
+
|
| 477 |
+
# Print summary
|
| 478 |
+
self.console.print("\n[bold cyan]Crawl Summary[/bold cyan]")
|
| 479 |
+
self.console.print(f"Articles saved: {self.articles_saved}")
|
| 480 |
+
self.console.print(f"URLs visited: {len(self.visited_urls)}")
|
| 481 |
+
self.console.print(f"URLs failed: {len(self.failed_urls)}")
|
| 482 |
+
self.console.print(f"URLs remaining: {len(self.to_crawl)}")
|
| 483 |
+
|
| 484 |
+
if self.failed_urls:
|
| 485 |
+
self.console.print("\n[yellow]Failed URLs:[/yellow]")
|
| 486 |
+
for url, error in list(self.failed_urls.items())[:10]:
|
| 487 |
+
self.console.print(f" - {url}: {error}")
|
| 488 |
+
if len(self.failed_urls) > 10:
|
| 489 |
+
self.console.print(f" ... and {len(self.failed_urls) - 10} more")
|
src/vectorstore/__init__.py
ADDED
|
File without changes
|
src/vectorstore/qdrant_store.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qdrant vector store manager for EyeWiki RAG system."""
|
| 2 |
+
|
| 3 |
+
import uuid
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
from qdrant_client import QdrantClient
|
| 10 |
+
from qdrant_client.models import (
|
| 11 |
+
Distance,
|
| 12 |
+
VectorParams,
|
| 13 |
+
SparseVectorParams,
|
| 14 |
+
SparseIndexParams,
|
| 15 |
+
PointStruct,
|
| 16 |
+
Filter,
|
| 17 |
+
FieldCondition,
|
| 18 |
+
MatchValue,
|
| 19 |
+
MatchAny,
|
| 20 |
+
Range,
|
| 21 |
+
ScoredPoint,
|
| 22 |
+
)
|
| 23 |
+
from rich.console import Console
|
| 24 |
+
|
| 25 |
+
from config.settings import settings
|
| 26 |
+
from src.processing.chunker import ChunkNode
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Configure logging
|
| 30 |
+
logging.basicConfig(level=logging.INFO)
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SearchResult(BaseModel):
|
| 35 |
+
"""
|
| 36 |
+
Pydantic model for search results.
|
| 37 |
+
|
| 38 |
+
Attributes:
|
| 39 |
+
id: Unique identifier of the result
|
| 40 |
+
score: Relevance score
|
| 41 |
+
chunk_id: Chunk identifier
|
| 42 |
+
content: Text content
|
| 43 |
+
parent_section: Section header
|
| 44 |
+
document_title: Article title
|
| 45 |
+
source_url: EyeWiki URL
|
| 46 |
+
metadata: Additional metadata
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
id: str = Field(..., description="Unique result identifier")
|
| 50 |
+
score: float = Field(..., ge=0.0, description="Relevance score")
|
| 51 |
+
chunk_id: str = Field(..., description="Chunk identifier")
|
| 52 |
+
content: str = Field(..., description="Text content")
|
| 53 |
+
parent_section: str = Field(default="", description="Parent section header")
|
| 54 |
+
document_title: str = Field(default="", description="Document title")
|
| 55 |
+
source_url: str = Field(default="", description="Source URL")
|
| 56 |
+
metadata: Dict = Field(default_factory=dict, description="Additional metadata")
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def from_scored_point(cls, point: ScoredPoint) -> "SearchResult":
|
| 60 |
+
"""
|
| 61 |
+
Create SearchResult from Qdrant ScoredPoint.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
point: Qdrant scored point
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
SearchResult instance
|
| 68 |
+
"""
|
| 69 |
+
payload = point.payload or {}
|
| 70 |
+
|
| 71 |
+
return cls(
|
| 72 |
+
id=str(point.id),
|
| 73 |
+
score=point.score,
|
| 74 |
+
chunk_id=payload.get("chunk_id", ""),
|
| 75 |
+
content=payload.get("content", ""),
|
| 76 |
+
parent_section=payload.get("parent_section", ""),
|
| 77 |
+
document_title=payload.get("document_title", ""),
|
| 78 |
+
source_url=payload.get("source_url", ""),
|
| 79 |
+
metadata=payload.get("metadata", {}),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class QdrantStoreManager:
|
| 84 |
+
"""
|
| 85 |
+
Qdrant vector store manager for EyeWiki documents.
|
| 86 |
+
|
| 87 |
+
Features:
|
| 88 |
+
- Local/persistent Qdrant storage
|
| 89 |
+
- Dense vector search (semantic)
|
| 90 |
+
- Sparse vector search (BM25)
|
| 91 |
+
- Hybrid search combining both
|
| 92 |
+
- Metadata filtering
|
| 93 |
+
- Batched operations for efficiency
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
collection_name: Optional[str] = None,
|
| 99 |
+
path: Optional[str] = None,
|
| 100 |
+
embedding_dim: int = 768, # Default for nomic-embed-text
|
| 101 |
+
batch_size: int = 100,
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Initialize Qdrant store manager.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
collection_name: Name of the collection (default: from settings)
|
| 108 |
+
path: Path to Qdrant storage (default: from settings)
|
| 109 |
+
embedding_dim: Dimension of dense embeddings
|
| 110 |
+
batch_size: Batch size for bulk operations
|
| 111 |
+
"""
|
| 112 |
+
self.collection_name = collection_name or settings.qdrant_collection_name
|
| 113 |
+
self.path = Path(path or settings.qdrant_path)
|
| 114 |
+
self.embedding_dim = embedding_dim
|
| 115 |
+
self.batch_size = batch_size
|
| 116 |
+
|
| 117 |
+
# Create storage directory
|
| 118 |
+
self.path.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
# Initialize Qdrant client (local/persistent mode)
|
| 121 |
+
try:
|
| 122 |
+
self.client = QdrantClient(path=str(self.path))
|
| 123 |
+
logger.info(f"Initialized Qdrant client at {self.path}")
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Failed to initialize Qdrant client: {e}")
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
self.console = Console()
|
| 129 |
+
|
| 130 |
+
def initialize_collection(self, recreate: bool = False) -> None:
|
| 131 |
+
"""
|
| 132 |
+
Initialize the Qdrant collection with vector configurations.
|
| 133 |
+
|
| 134 |
+
Creates collection with:
|
| 135 |
+
- Dense vectors for semantic search (cosine similarity)
|
| 136 |
+
- Sparse vectors for BM25/keyword search
|
| 137 |
+
- Payload indexing for metadata filtering
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
recreate: If True, delete existing collection and recreate
|
| 141 |
+
"""
|
| 142 |
+
try:
|
| 143 |
+
# Check if collection exists
|
| 144 |
+
collections = self.client.get_collections().collections
|
| 145 |
+
collection_exists = any(c.name == self.collection_name for c in collections)
|
| 146 |
+
|
| 147 |
+
if collection_exists:
|
| 148 |
+
if recreate:
|
| 149 |
+
self.console.print(
|
| 150 |
+
f"[yellow]Deleting existing collection: {self.collection_name}[/yellow]"
|
| 151 |
+
)
|
| 152 |
+
self.client.delete_collection(self.collection_name)
|
| 153 |
+
else:
|
| 154 |
+
self.console.print(
|
| 155 |
+
f"[blue]Collection already exists: {self.collection_name}[/blue]"
|
| 156 |
+
)
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
# Create collection with dense and sparse vector configurations
|
| 160 |
+
self.console.print(f"[cyan]Creating collection: {self.collection_name}[/cyan]")
|
| 161 |
+
|
| 162 |
+
self.client.create_collection(
|
| 163 |
+
collection_name=self.collection_name,
|
| 164 |
+
vectors_config={
|
| 165 |
+
# Dense vector for semantic search
|
| 166 |
+
"dense": VectorParams(
|
| 167 |
+
size=self.embedding_dim,
|
| 168 |
+
distance=Distance.COSINE,
|
| 169 |
+
),
|
| 170 |
+
},
|
| 171 |
+
sparse_vectors_config={
|
| 172 |
+
# Sparse vector for BM25/keyword search
|
| 173 |
+
"sparse": SparseVectorParams(
|
| 174 |
+
index=SparseIndexParams(
|
| 175 |
+
on_disk=False, # Keep in memory for speed
|
| 176 |
+
),
|
| 177 |
+
),
|
| 178 |
+
},
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Create payload indexes for efficient filtering
|
| 182 |
+
# Index on key metadata fields
|
| 183 |
+
self.client.create_payload_index(
|
| 184 |
+
collection_name=self.collection_name,
|
| 185 |
+
field_name="document_title",
|
| 186 |
+
field_schema="keyword",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self.client.create_payload_index(
|
| 190 |
+
collection_name=self.collection_name,
|
| 191 |
+
field_name="parent_section",
|
| 192 |
+
field_schema="keyword",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.client.create_payload_index(
|
| 196 |
+
collection_name=self.collection_name,
|
| 197 |
+
field_name="metadata.disease_name",
|
| 198 |
+
field_schema="keyword",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.client.create_payload_index(
|
| 202 |
+
collection_name=self.collection_name,
|
| 203 |
+
field_name="metadata.icd_codes",
|
| 204 |
+
field_schema="keyword",
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.console.print(
|
| 208 |
+
f"[green][/green] Collection created: {self.collection_name}"
|
| 209 |
+
)
|
| 210 |
+
logger.info(f"Created collection: {self.collection_name}")
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.error(f"Failed to initialize collection: {e}")
|
| 214 |
+
raise
|
| 215 |
+
|
| 216 |
+
def add_documents(
|
| 217 |
+
self,
|
| 218 |
+
chunks: List[ChunkNode],
|
| 219 |
+
dense_embeddings: List[List[float]],
|
| 220 |
+
sparse_embeddings: Optional[List[Dict]] = None,
|
| 221 |
+
) -> int:
|
| 222 |
+
"""
|
| 223 |
+
Add documents to the vector store with batched upserts.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
chunks: List of ChunkNode objects
|
| 227 |
+
dense_embeddings: List of dense embedding vectors
|
| 228 |
+
sparse_embeddings: Optional list of sparse vectors (for BM25)
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Number of documents successfully added
|
| 232 |
+
|
| 233 |
+
Raises:
|
| 234 |
+
ValueError: If chunks and embeddings length mismatch
|
| 235 |
+
"""
|
| 236 |
+
if len(chunks) != len(dense_embeddings):
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Chunks ({len(chunks)}) and embeddings ({len(dense_embeddings)}) "
|
| 239 |
+
"must have same length"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if sparse_embeddings and len(sparse_embeddings) != len(chunks):
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"Chunks ({len(chunks)}) and sparse embeddings ({len(sparse_embeddings)}) "
|
| 245 |
+
"must have same length"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
total_added = 0
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
# Process in batches
|
| 252 |
+
for i in range(0, len(chunks), self.batch_size):
|
| 253 |
+
batch_chunks = chunks[i : i + self.batch_size]
|
| 254 |
+
batch_dense = dense_embeddings[i : i + self.batch_size]
|
| 255 |
+
batch_sparse = (
|
| 256 |
+
sparse_embeddings[i : i + self.batch_size]
|
| 257 |
+
if sparse_embeddings
|
| 258 |
+
else None
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Create points for batch
|
| 262 |
+
points = []
|
| 263 |
+
for j, chunk in enumerate(batch_chunks):
|
| 264 |
+
# Prepare vector dict
|
| 265 |
+
vectors = {"dense": batch_dense[j]}
|
| 266 |
+
|
| 267 |
+
# Add sparse vector if available
|
| 268 |
+
if batch_sparse:
|
| 269 |
+
vectors["sparse"] = batch_sparse[j]
|
| 270 |
+
|
| 271 |
+
# Create point
|
| 272 |
+
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, chunk.chunk_id))
|
| 273 |
+
point = PointStruct(
|
| 274 |
+
id=point_id,
|
| 275 |
+
vector=vectors,
|
| 276 |
+
payload={
|
| 277 |
+
"chunk_id": chunk.chunk_id,
|
| 278 |
+
"content": chunk.content,
|
| 279 |
+
"parent_section": chunk.parent_section,
|
| 280 |
+
"document_title": chunk.document_title,
|
| 281 |
+
"source_url": chunk.source_url,
|
| 282 |
+
"chunk_index": chunk.chunk_index,
|
| 283 |
+
"token_count": chunk.token_count,
|
| 284 |
+
"metadata": chunk.metadata,
|
| 285 |
+
},
|
| 286 |
+
)
|
| 287 |
+
points.append(point)
|
| 288 |
+
|
| 289 |
+
# Upsert batch
|
| 290 |
+
self.client.upsert(
|
| 291 |
+
collection_name=self.collection_name,
|
| 292 |
+
points=points,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
total_added += len(points)
|
| 296 |
+
|
| 297 |
+
logger.info(
|
| 298 |
+
f"Uploaded batch {i // self.batch_size + 1}: "
|
| 299 |
+
f"{len(points)} points (total: {total_added})"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.console.print(
|
| 303 |
+
f"[green][/green] Added {total_added} documents to {self.collection_name}"
|
| 304 |
+
)
|
| 305 |
+
return total_added
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"Failed to add documents: {e}")
|
| 309 |
+
raise
|
| 310 |
+
|
| 311 |
+
def search(
|
| 312 |
+
self,
|
| 313 |
+
query_embedding: List[float],
|
| 314 |
+
top_k: int = 10,
|
| 315 |
+
filters: Optional[Dict] = None,
|
| 316 |
+
score_threshold: Optional[float] = None,
|
| 317 |
+
) -> List[SearchResult]:
|
| 318 |
+
"""
|
| 319 |
+
Search using dense vector (semantic search).
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
query_embedding: Dense query vector
|
| 323 |
+
top_k: Number of results to return
|
| 324 |
+
filters: Optional metadata filters (e.g., {"disease_name": "Glaucoma"})
|
| 325 |
+
score_threshold: Minimum score threshold
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
List of SearchResult objects
|
| 329 |
+
"""
|
| 330 |
+
try:
|
| 331 |
+
# Build filter conditions
|
| 332 |
+
query_filter = self._build_filter(filters) if filters else None
|
| 333 |
+
|
| 334 |
+
# Perform search
|
| 335 |
+
results = self.client.query_points(
|
| 336 |
+
collection_name=self.collection_name,
|
| 337 |
+
query=query_embedding,
|
| 338 |
+
using="dense", # Specify which named vector to use
|
| 339 |
+
limit=top_k,
|
| 340 |
+
query_filter=query_filter,
|
| 341 |
+
score_threshold=score_threshold,
|
| 342 |
+
).points
|
| 343 |
+
|
| 344 |
+
# Convert to SearchResult objects
|
| 345 |
+
search_results = [SearchResult.from_scored_point(r) for r in results]
|
| 346 |
+
|
| 347 |
+
logger.info(f"Dense search returned {len(search_results)} results")
|
| 348 |
+
return search_results
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(f"Search failed: {e}")
|
| 352 |
+
raise
|
| 353 |
+
|
| 354 |
+
def hybrid_search(
|
| 355 |
+
self,
|
| 356 |
+
query_embedding: List[float],
|
| 357 |
+
query_sparse: Optional[Dict] = None,
|
| 358 |
+
top_k: int = 10,
|
| 359 |
+
filters: Optional[Dict] = None,
|
| 360 |
+
) -> List[SearchResult]:
|
| 361 |
+
"""
|
| 362 |
+
Hybrid search combining dense (semantic) and sparse (BM25) vectors.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
query_embedding: Dense query vector
|
| 366 |
+
query_sparse: Sparse query vector for BM25
|
| 367 |
+
top_k: Number of results to return
|
| 368 |
+
filters: Optional metadata filters
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
List of SearchResult objects with combined scores
|
| 372 |
+
"""
|
| 373 |
+
try:
|
| 374 |
+
# If no sparse vector provided, fall back to dense search
|
| 375 |
+
if query_sparse is None:
|
| 376 |
+
logger.warning("No sparse vector provided, using dense search only")
|
| 377 |
+
return self.search(query_embedding, top_k, filters)
|
| 378 |
+
|
| 379 |
+
# Build filter conditions
|
| 380 |
+
query_filter = self._build_filter(filters) if filters else None
|
| 381 |
+
|
| 382 |
+
# Perform hybrid search
|
| 383 |
+
# Note: Qdrant supports multiple vectors in search, but for true hybrid
|
| 384 |
+
# we'd need to do two separate searches and merge results
|
| 385 |
+
# For simplicity, we'll use the query API with dense vector
|
| 386 |
+
# In production, you'd want to implement proper RRF (Reciprocal Rank Fusion)
|
| 387 |
+
|
| 388 |
+
results = self.client.query_points(
|
| 389 |
+
collection_name=self.collection_name,
|
| 390 |
+
query=query_embedding,
|
| 391 |
+
using="dense", # Specify which named vector to use
|
| 392 |
+
limit=top_k * 2, # Get more results for reranking
|
| 393 |
+
query_filter=query_filter,
|
| 394 |
+
).points
|
| 395 |
+
|
| 396 |
+
# Convert to SearchResult objects
|
| 397 |
+
search_results = [SearchResult.from_scored_point(r) for r in results]
|
| 398 |
+
|
| 399 |
+
# For now, return top_k results
|
| 400 |
+
# In production, implement RRF combining dense and sparse results
|
| 401 |
+
logger.info(f"Hybrid search returned {len(search_results[:top_k])} results")
|
| 402 |
+
return search_results[:top_k]
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.error(f"Hybrid search failed: {e}")
|
| 406 |
+
raise
|
| 407 |
+
|
| 408 |
+
def _build_filter(self, filters: Dict) -> Filter:
|
| 409 |
+
"""
|
| 410 |
+
Build Qdrant filter from dictionary.
|
| 411 |
+
|
| 412 |
+
Supports:
|
| 413 |
+
- disease_name: str
|
| 414 |
+
- icd_codes: List[str]
|
| 415 |
+
- anatomical_structures: List[str]
|
| 416 |
+
- document_title: str
|
| 417 |
+
- parent_section: str
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
filters: Dictionary of filter conditions
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
Qdrant Filter object
|
| 424 |
+
"""
|
| 425 |
+
conditions = []
|
| 426 |
+
|
| 427 |
+
# Disease name filter
|
| 428 |
+
if "disease_name" in filters:
|
| 429 |
+
conditions.append(
|
| 430 |
+
FieldCondition(
|
| 431 |
+
key="metadata.disease_name",
|
| 432 |
+
match=MatchValue(value=filters["disease_name"]),
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# ICD codes filter (match any)
|
| 437 |
+
if "icd_codes" in filters:
|
| 438 |
+
icd_list = filters["icd_codes"]
|
| 439 |
+
if isinstance(icd_list, str):
|
| 440 |
+
icd_list = [icd_list]
|
| 441 |
+
conditions.append(
|
| 442 |
+
FieldCondition(
|
| 443 |
+
key="metadata.icd_codes",
|
| 444 |
+
match=MatchAny(any=icd_list),
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Anatomical structures filter
|
| 449 |
+
if "anatomical_structures" in filters:
|
| 450 |
+
structures = filters["anatomical_structures"]
|
| 451 |
+
if isinstance(structures, str):
|
| 452 |
+
structures = [structures]
|
| 453 |
+
conditions.append(
|
| 454 |
+
FieldCondition(
|
| 455 |
+
key="metadata.anatomical_structures",
|
| 456 |
+
match=MatchAny(any=structures),
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Document title filter
|
| 461 |
+
if "document_title" in filters:
|
| 462 |
+
conditions.append(
|
| 463 |
+
FieldCondition(
|
| 464 |
+
key="document_title",
|
| 465 |
+
match=MatchValue(value=filters["document_title"]),
|
| 466 |
+
)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# Parent section filter
|
| 470 |
+
if "parent_section" in filters:
|
| 471 |
+
conditions.append(
|
| 472 |
+
FieldCondition(
|
| 473 |
+
key="parent_section",
|
| 474 |
+
match=MatchValue(value=filters["parent_section"]),
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Token count range filter
|
| 479 |
+
if "min_tokens" in filters or "max_tokens" in filters:
|
| 480 |
+
range_filter = {}
|
| 481 |
+
if "min_tokens" in filters:
|
| 482 |
+
range_filter["gte"] = filters["min_tokens"]
|
| 483 |
+
if "max_tokens" in filters:
|
| 484 |
+
range_filter["lte"] = filters["max_tokens"]
|
| 485 |
+
|
| 486 |
+
conditions.append(
|
| 487 |
+
FieldCondition(
|
| 488 |
+
key="token_count",
|
| 489 |
+
range=Range(**range_filter),
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
return Filter(must=conditions) if conditions else None
|
| 494 |
+
|
| 495 |
+
def get_collection_info(self) -> Dict:
|
| 496 |
+
"""
|
| 497 |
+
Get information about the collection.
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
Dictionary with collection statistics
|
| 501 |
+
"""
|
| 502 |
+
try:
|
| 503 |
+
info = self.client.get_collection(self.collection_name)
|
| 504 |
+
|
| 505 |
+
return {
|
| 506 |
+
"name": self.collection_name,
|
| 507 |
+
"vectors_count": getattr(info, "vectors_count", 0),
|
| 508 |
+
"points_count": info.points_count,
|
| 509 |
+
"status": info.status,
|
| 510 |
+
"optimizer_status": info.optimizer_status,
|
| 511 |
+
"indexed_vectors_count": getattr(info, "indexed_vectors_count", 0),
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
except Exception as e:
|
| 515 |
+
logger.error(f"Failed to get collection info: {e}")
|
| 516 |
+
raise
|
| 517 |
+
|
| 518 |
+
def delete_collection(self) -> bool:
|
| 519 |
+
"""
|
| 520 |
+
Delete the collection.
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
True if successful
|
| 524 |
+
"""
|
| 525 |
+
try:
|
| 526 |
+
result = self.client.delete_collection(self.collection_name)
|
| 527 |
+
self.console.print(
|
| 528 |
+
f"[yellow]Deleted collection: {self.collection_name}[/yellow]"
|
| 529 |
+
)
|
| 530 |
+
logger.info(f"Deleted collection: {self.collection_name}")
|
| 531 |
+
return result
|
| 532 |
+
|
| 533 |
+
except Exception as e:
|
| 534 |
+
logger.error(f"Failed to delete collection: {e}")
|
| 535 |
+
raise
|
| 536 |
+
|
| 537 |
+
def count_documents(self) -> int:
|
| 538 |
+
"""
|
| 539 |
+
Count total documents in collection.
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
Number of documents
|
| 543 |
+
"""
|
| 544 |
+
try:
|
| 545 |
+
info = self.client.get_collection(self.collection_name)
|
| 546 |
+
return info.points_count or 0
|
| 547 |
+
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logger.error(f"Failed to count documents: {e}")
|
| 550 |
+
return 0
|
| 551 |
+
|
| 552 |
+
def get_document_by_id(self, doc_id: str) -> Optional[SearchResult]:
|
| 553 |
+
"""
|
| 554 |
+
Retrieve a specific document by ID.
|
| 555 |
+
|
| 556 |
+
Args:
|
| 557 |
+
doc_id: Document ID (chunk_id)
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
SearchResult if found, None otherwise
|
| 561 |
+
"""
|
| 562 |
+
try:
|
| 563 |
+
points = self.client.retrieve(
|
| 564 |
+
collection_name=self.collection_name,
|
| 565 |
+
ids=[doc_id],
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if not points:
|
| 569 |
+
return None
|
| 570 |
+
|
| 571 |
+
point = points[0]
|
| 572 |
+
payload = point.payload or {}
|
| 573 |
+
|
| 574 |
+
return SearchResult(
|
| 575 |
+
id=str(point.id),
|
| 576 |
+
score=1.0, # No score for direct retrieval
|
| 577 |
+
chunk_id=payload.get("chunk_id", ""),
|
| 578 |
+
content=payload.get("content", ""),
|
| 579 |
+
parent_section=payload.get("parent_section", ""),
|
| 580 |
+
document_title=payload.get("document_title", ""),
|
| 581 |
+
source_url=payload.get("source_url", ""),
|
| 582 |
+
metadata=payload.get("metadata", {}),
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
except Exception as e:
|
| 586 |
+
logger.error(f"Failed to get document by ID: {e}")
|
| 587 |
+
return None
|
tests/README.md
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tests
|
| 2 |
+
|
| 3 |
+
Comprehensive test suite for the EyeWiki RAG system.
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
Install test dependencies:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
pip install pytest pytest-cov pytest-mock requests
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## Running Tests
|
| 14 |
+
|
| 15 |
+
### Run all tests:
|
| 16 |
+
```bash
|
| 17 |
+
pytest
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### Run with verbose output:
|
| 21 |
+
```bash
|
| 22 |
+
pytest -v
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Run specific test file:
|
| 26 |
+
```bash
|
| 27 |
+
pytest tests/test_components.py -v
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Run specific test:
|
| 31 |
+
```bash
|
| 32 |
+
pytest tests/test_components.py::test_chunk_respects_headers -v
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Run tests by marker:
|
| 36 |
+
```bash
|
| 37 |
+
# Run only unit tests
|
| 38 |
+
pytest -m unit
|
| 39 |
+
|
| 40 |
+
# Run only integration tests
|
| 41 |
+
pytest -m integration
|
| 42 |
+
|
| 43 |
+
# Run only API tests
|
| 44 |
+
pytest -m api
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Run with coverage:
|
| 48 |
+
```bash
|
| 49 |
+
pytest --cov=src --cov-report=html
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
This will generate a coverage report in `htmlcov/index.html`.
|
| 53 |
+
|
| 54 |
+
## Test Categories
|
| 55 |
+
|
| 56 |
+
### Unit Tests (`@pytest.mark.unit`)
|
| 57 |
+
- Fast, isolated tests
|
| 58 |
+
- Mock external dependencies
|
| 59 |
+
- Test individual components
|
| 60 |
+
|
| 61 |
+
### Integration Tests (`@pytest.mark.integration`)
|
| 62 |
+
- Test multiple components together
|
| 63 |
+
- May be slower
|
| 64 |
+
- May require real dependencies
|
| 65 |
+
|
| 66 |
+
### API Tests (`@pytest.mark.api`)
|
| 67 |
+
- Test FastAPI endpoints
|
| 68 |
+
- Require server components
|
| 69 |
+
- Use TestClient
|
| 70 |
+
|
| 71 |
+
## Test Structure
|
| 72 |
+
|
| 73 |
+
### Chunker Tests
|
| 74 |
+
- `test_chunk_respects_headers()` - Verifies markdown header handling
|
| 75 |
+
- `test_chunk_size_limits()` - Checks chunk size constraints
|
| 76 |
+
- `test_metadata_preserved()` - Ensures metadata propagation
|
| 77 |
+
|
| 78 |
+
### Retriever Tests
|
| 79 |
+
- `test_retrieval_returns_results()` - Basic retrieval functionality
|
| 80 |
+
- `test_hybrid_search_combines_scores()` - Score combination logic
|
| 81 |
+
- `test_filters_work()` - Metadata filtering
|
| 82 |
+
|
| 83 |
+
### Reranker Tests
|
| 84 |
+
- `test_reranking_changes_order()` - Verifies reranking effect
|
| 85 |
+
- `test_top_k_respected()` - Checks top_k parameter
|
| 86 |
+
|
| 87 |
+
### Query Engine Tests
|
| 88 |
+
- `test_full_query_pipeline()` - End-to-end query flow
|
| 89 |
+
- `test_sources_included()` - Source citation functionality
|
| 90 |
+
- `test_disclaimer_present()` - Medical disclaimer inclusion
|
| 91 |
+
- `test_streaming_query()` - Streaming response
|
| 92 |
+
|
| 93 |
+
### API Tests
|
| 94 |
+
- `test_health_endpoint()` - Health check endpoint
|
| 95 |
+
- `test_query_endpoint()` - Main query endpoint
|
| 96 |
+
- `test_query_endpoint_validation()` - Input validation
|
| 97 |
+
|
| 98 |
+
### Metadata Tests
|
| 99 |
+
- `test_icd_code_extraction()` - ICD-10 code extraction
|
| 100 |
+
- `test_anatomical_term_extraction()` - Anatomical term detection
|
| 101 |
+
- `test_medication_extraction()` - Medication identification
|
| 102 |
+
|
| 103 |
+
## Fixtures
|
| 104 |
+
|
| 105 |
+
Reusable test fixtures are defined in `test_components.py`:
|
| 106 |
+
|
| 107 |
+
- `semantic_chunker` - ChunkerSemanticChunker instance
|
| 108 |
+
- `metadata_extractor` - MetadataExtractor instance
|
| 109 |
+
- `sample_chunks` - Sample ChunkNode objects
|
| 110 |
+
- `mock_retriever` - Mocked HybridRetriever
|
| 111 |
+
- `mock_reranker` - Mocked CrossEncoderReranker
|
| 112 |
+
- `mock_ollama_client` - Mocked OllamaClient
|
| 113 |
+
- `query_engine` - Fully configured QueryEngine with mocks
|
| 114 |
+
- `test_client` - FastAPI TestClient
|
| 115 |
+
|
| 116 |
+
## Writing New Tests
|
| 117 |
+
|
| 118 |
+
### Example unit test:
|
| 119 |
+
```python
|
| 120 |
+
@pytest.mark.unit
|
| 121 |
+
def test_my_component(my_fixture):
|
| 122 |
+
"""Test description."""
|
| 123 |
+
result = my_fixture.some_method()
|
| 124 |
+
assert result == expected_value
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Example integration test:
|
| 128 |
+
```python
|
| 129 |
+
@pytest.mark.integration
|
| 130 |
+
def test_component_interaction():
|
| 131 |
+
"""Test multiple components together."""
|
| 132 |
+
# Setup
|
| 133 |
+
component_a = ComponentA()
|
| 134 |
+
component_b = ComponentB(component_a)
|
| 135 |
+
|
| 136 |
+
# Test
|
| 137 |
+
result = component_b.process()
|
| 138 |
+
|
| 139 |
+
# Assert
|
| 140 |
+
assert result.is_valid()
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### Example API test:
|
| 144 |
+
```python
|
| 145 |
+
@pytest.mark.api
|
| 146 |
+
def test_my_endpoint(test_client):
|
| 147 |
+
"""Test API endpoint."""
|
| 148 |
+
response = test_client.get("/my-endpoint")
|
| 149 |
+
assert response.status_code == 200
|
| 150 |
+
assert "expected_field" in response.json()
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Continuous Integration
|
| 154 |
+
|
| 155 |
+
These tests are designed to run in CI/CD pipelines. Mock external dependencies (Ollama, Qdrant) to ensure tests run in any environment.
|
| 156 |
+
|
| 157 |
+
## Troubleshooting
|
| 158 |
+
|
| 159 |
+
### Import Errors
|
| 160 |
+
Make sure the project root is in PYTHONPATH:
|
| 161 |
+
```bash
|
| 162 |
+
export PYTHONPATH=/path/to/eyewiki-rag:$PYTHONPATH
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
### Mock Issues
|
| 166 |
+
If mocks aren't working properly, check that you're using the correct spec:
|
| 167 |
+
```python
|
| 168 |
+
mock = Mock(spec=RealClass)
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### API Tests Failing
|
| 172 |
+
API tests may fail if the application isn't properly initialized. Use mocking to isolate components.
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pytest configuration and shared fixtures.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Add project root to Python path
|
| 9 |
+
project_root = Path(__file__).parent.parent
|
| 10 |
+
sys.path.insert(0, str(project_root))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def pytest_configure(config):
|
| 14 |
+
"""Configure pytest."""
|
| 15 |
+
# Add custom markers
|
| 16 |
+
config.addinivalue_line(
|
| 17 |
+
"markers", "integration: mark test as integration test (may be slow)"
|
| 18 |
+
)
|
| 19 |
+
config.addinivalue_line(
|
| 20 |
+
"markers", "api: mark test as API test (requires server components)"
|
| 21 |
+
)
|
| 22 |
+
config.addinivalue_line(
|
| 23 |
+
"markers", "unit: mark test as unit test (fast, isolated)"
|
| 24 |
+
)
|
tests/test_components.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive tests for EyeWiki RAG components.
|
| 3 |
+
|
| 4 |
+
Run with:
|
| 5 |
+
pytest tests/test_components.py -v
|
| 6 |
+
pytest tests/test_components.py::test_chunk_respects_headers -v
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from unittest.mock import Mock, patch, MagicMock
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
from src.processing.chunker import ChunkNode, SemanticChunker
|
| 15 |
+
from src.processing.metadata_extractor import MetadataExtractor
|
| 16 |
+
from src.rag.retriever import HybridRetriever, RetrievalResult
|
| 17 |
+
from src.rag.reranker import CrossEncoderReranker
|
| 18 |
+
from src.rag.query_engine import EyeWikiQueryEngine, QueryResponse, SourceInfo
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ============================================================================
|
| 22 |
+
# Test Data
|
| 23 |
+
# ============================================================================
|
| 24 |
+
|
| 25 |
+
SAMPLE_MARKDOWN = """# Glaucoma
|
| 26 |
+
|
| 27 |
+
## Overview
|
| 28 |
+
|
| 29 |
+
Glaucoma is a group of eye conditions that damage the optic nerve.
|
| 30 |
+
|
| 31 |
+
## Symptoms
|
| 32 |
+
|
| 33 |
+
Common symptoms include:
|
| 34 |
+
- Vision loss
|
| 35 |
+
- Eye pain
|
| 36 |
+
- Halos around lights
|
| 37 |
+
|
| 38 |
+
## Treatment
|
| 39 |
+
|
| 40 |
+
Treatment options include:
|
| 41 |
+
- Medications (IOP-lowering drops)
|
| 42 |
+
- Laser procedures
|
| 43 |
+
- Surgery
|
| 44 |
+
|
| 45 |
+
### Medications
|
| 46 |
+
|
| 47 |
+
Beta-blockers and prostaglandin analogs are commonly used.
|
| 48 |
+
|
| 49 |
+
### Surgery
|
| 50 |
+
|
| 51 |
+
Trabeculectomy is a common surgical procedure.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
SAMPLE_METADATA = {
|
| 55 |
+
"title": "Glaucoma",
|
| 56 |
+
"url": "https://eyewiki.aao.org/Glaucoma",
|
| 57 |
+
"source": "eyewiki",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ============================================================================
|
| 62 |
+
# Fixtures
|
| 63 |
+
# ============================================================================
|
| 64 |
+
|
| 65 |
+
@pytest.fixture
|
| 66 |
+
def semantic_chunker():
|
| 67 |
+
"""Create a SemanticChunker instance."""
|
| 68 |
+
return SemanticChunker(
|
| 69 |
+
chunk_size=200,
|
| 70 |
+
chunk_overlap=20,
|
| 71 |
+
min_chunk_size=50,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@pytest.fixture
|
| 76 |
+
def metadata_extractor():
|
| 77 |
+
"""Create a MetadataExtractor instance."""
|
| 78 |
+
return MetadataExtractor()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@pytest.fixture
|
| 82 |
+
def sample_chunks():
|
| 83 |
+
"""Create sample retrieval results for testing."""
|
| 84 |
+
return [
|
| 85 |
+
ChunkNode(
|
| 86 |
+
id="chunk_1",
|
| 87 |
+
content="Glaucoma is characterized by elevated intraocular pressure (IOP).",
|
| 88 |
+
document_title="Glaucoma",
|
| 89 |
+
source_url="https://eyewiki.aao.org/Glaucoma",
|
| 90 |
+
parent_section="Overview",
|
| 91 |
+
metadata={"icd_codes": ["H40.1"], "anatomical_terms": ["optic nerve"]},
|
| 92 |
+
chunk_index=0,
|
| 93 |
+
total_chunks=5,
|
| 94 |
+
),
|
| 95 |
+
ChunkNode(
|
| 96 |
+
id="chunk_2",
|
| 97 |
+
content="Treatment includes beta-blockers and prostaglandin analogs.",
|
| 98 |
+
document_title="Glaucoma",
|
| 99 |
+
source_url="https://eyewiki.aao.org/Glaucoma",
|
| 100 |
+
parent_section="Treatment",
|
| 101 |
+
metadata={"medications": ["beta-blockers", "prostaglandin analogs"]},
|
| 102 |
+
chunk_index=1,
|
| 103 |
+
total_chunks=5,
|
| 104 |
+
),
|
| 105 |
+
ChunkNode(
|
| 106 |
+
id="chunk_3",
|
| 107 |
+
content="Diabetic retinopathy affects the retinal blood vessels.",
|
| 108 |
+
document_title="Diabetic Retinopathy",
|
| 109 |
+
source_url="https://eyewiki.aao.org/Diabetic_Retinopathy",
|
| 110 |
+
parent_section="Overview",
|
| 111 |
+
metadata={"icd_codes": ["E11.3"], "anatomical_terms": ["retina"]},
|
| 112 |
+
chunk_index=0,
|
| 113 |
+
total_chunks=3,
|
| 114 |
+
),
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@pytest.fixture
|
| 119 |
+
def mock_retriever(sample_chunks):
|
| 120 |
+
"""Create a mock HybridRetriever."""
|
| 121 |
+
retriever = Mock(spec=HybridRetriever)
|
| 122 |
+
|
| 123 |
+
# Convert ChunkNodes to RetrievalResults
|
| 124 |
+
retrieval_results = [
|
| 125 |
+
RetrievalResult(
|
| 126 |
+
id=chunk.id,
|
| 127 |
+
content=chunk.content,
|
| 128 |
+
document_title=chunk.document_title,
|
| 129 |
+
source_url=chunk.source_url,
|
| 130 |
+
section=chunk.parent_section,
|
| 131 |
+
metadata=chunk.metadata,
|
| 132 |
+
score=0.9 - (i * 0.1), # Decreasing scores
|
| 133 |
+
)
|
| 134 |
+
for i, chunk in enumerate(sample_chunks)
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
retriever.retrieve.return_value = retrieval_results
|
| 138 |
+
return retriever
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@pytest.fixture
|
| 142 |
+
def mock_reranker():
|
| 143 |
+
"""Create a mock CrossEncoderReranker."""
|
| 144 |
+
reranker = Mock(spec=CrossEncoderReranker)
|
| 145 |
+
|
| 146 |
+
def rerank_func(query: str, documents: List[RetrievalResult], top_k: int):
|
| 147 |
+
# Reverse order to simulate reranking
|
| 148 |
+
reranked = list(reversed(documents[:top_k]))
|
| 149 |
+
# Update scores
|
| 150 |
+
for i, doc in enumerate(reranked):
|
| 151 |
+
doc.score = 0.95 - (i * 0.05)
|
| 152 |
+
return reranked
|
| 153 |
+
|
| 154 |
+
reranker.rerank.side_effect = rerank_func
|
| 155 |
+
return reranker
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@pytest.fixture
|
| 159 |
+
def mock_ollama_client():
|
| 160 |
+
"""Create a mock OllamaClient."""
|
| 161 |
+
client = Mock()
|
| 162 |
+
client.generate.return_value = (
|
| 163 |
+
"Glaucoma is a group of eye diseases that damage the optic nerve. "
|
| 164 |
+
"It is often associated with elevated intraocular pressure (IOP). "
|
| 165 |
+
"[Source: Glaucoma]"
|
| 166 |
+
)
|
| 167 |
+
client.stream_generate.return_value = iter(["Glaucoma ", "is ", "a disease."])
|
| 168 |
+
client.embed_text.return_value = [0.1] * 768
|
| 169 |
+
return client
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@pytest.fixture
|
| 173 |
+
def query_engine(mock_retriever, mock_reranker, mock_ollama_client, tmp_path):
|
| 174 |
+
"""Create a QueryEngine instance with mocked dependencies."""
|
| 175 |
+
# Create temporary prompt files
|
| 176 |
+
system_prompt = tmp_path / "system_prompt.txt"
|
| 177 |
+
system_prompt.write_text("You are an expert ophthalmology assistant.")
|
| 178 |
+
|
| 179 |
+
query_prompt = tmp_path / "query_prompt.txt"
|
| 180 |
+
query_prompt.write_text("Context: {context}\n\nQuestion: {question}\n\nAnswer:")
|
| 181 |
+
|
| 182 |
+
disclaimer = tmp_path / "disclaimer.txt"
|
| 183 |
+
disclaimer.write_text("Medical disclaimer text.")
|
| 184 |
+
|
| 185 |
+
return EyeWikiQueryEngine(
|
| 186 |
+
retriever=mock_retriever,
|
| 187 |
+
reranker=mock_reranker,
|
| 188 |
+
llm_client=mock_ollama_client,
|
| 189 |
+
system_prompt_path=system_prompt,
|
| 190 |
+
query_prompt_path=query_prompt,
|
| 191 |
+
disclaimer_path=disclaimer,
|
| 192 |
+
max_context_tokens=4000,
|
| 193 |
+
retrieval_k=20,
|
| 194 |
+
rerank_k=5,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ============================================================================
|
| 199 |
+
# Chunker Tests
|
| 200 |
+
# ============================================================================
|
| 201 |
+
|
| 202 |
+
def test_chunk_respects_headers(semantic_chunker):
|
| 203 |
+
"""Test that chunker respects markdown headers."""
|
| 204 |
+
chunks = semantic_chunker.chunk_document(
|
| 205 |
+
markdown_content=SAMPLE_MARKDOWN,
|
| 206 |
+
metadata=SAMPLE_METADATA,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Should have multiple chunks based on headers
|
| 210 |
+
assert len(chunks) > 0
|
| 211 |
+
|
| 212 |
+
# Check that parent sections are correctly identified
|
| 213 |
+
sections = {chunk.parent_section for chunk in chunks}
|
| 214 |
+
assert "Overview" in sections or "Symptoms" in sections or "Treatment" in sections
|
| 215 |
+
|
| 216 |
+
# Verify each chunk has required fields
|
| 217 |
+
for chunk in chunks:
|
| 218 |
+
assert chunk.content
|
| 219 |
+
assert chunk.document_title == "Glaucoma"
|
| 220 |
+
assert chunk.source_url == SAMPLE_METADATA["url"]
|
| 221 |
+
assert chunk.id
|
| 222 |
+
assert isinstance(chunk.chunk_index, int)
|
| 223 |
+
assert isinstance(chunk.total_chunks, int)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def test_chunk_size_limits(semantic_chunker):
|
| 227 |
+
"""Test that chunks respect size limits."""
|
| 228 |
+
# Create a very long section
|
| 229 |
+
long_text = "This is a test sentence. " * 200 # Very long text
|
| 230 |
+
long_markdown = f"# Test\n\n## Section\n\n{long_text}"
|
| 231 |
+
|
| 232 |
+
chunks = semantic_chunker.chunk_document(
|
| 233 |
+
markdown_content=long_markdown,
|
| 234 |
+
metadata=SAMPLE_METADATA,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# All chunks should respect min size
|
| 238 |
+
for chunk in chunks:
|
| 239 |
+
# Token estimation: len(text) // 4
|
| 240 |
+
estimated_tokens = len(chunk.content) // 4
|
| 241 |
+
# Should not be too small (unless it's the last chunk)
|
| 242 |
+
if chunk.chunk_index < chunk.total_chunks - 1:
|
| 243 |
+
assert estimated_tokens >= semantic_chunker.min_chunk_size
|
| 244 |
+
|
| 245 |
+
# Should have created multiple chunks for long text
|
| 246 |
+
assert len(chunks) > 1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def test_metadata_preserved(semantic_chunker):
|
| 250 |
+
"""Test that metadata is preserved in chunks."""
|
| 251 |
+
custom_metadata = {
|
| 252 |
+
"title": "Test Document",
|
| 253 |
+
"url": "https://example.com/test",
|
| 254 |
+
"custom_field": "custom_value",
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
chunks = semantic_chunker.chunk_document(
|
| 258 |
+
markdown_content=SAMPLE_MARKDOWN,
|
| 259 |
+
metadata=custom_metadata,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# All chunks should have the same base metadata
|
| 263 |
+
for chunk in chunks:
|
| 264 |
+
assert chunk.document_title == custom_metadata["title"]
|
| 265 |
+
assert chunk.source_url == custom_metadata["url"]
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ============================================================================
|
| 269 |
+
# Retriever Tests
|
| 270 |
+
# ============================================================================
|
| 271 |
+
|
| 272 |
+
def test_retrieval_returns_results(mock_retriever):
|
| 273 |
+
"""Test that retriever returns results."""
|
| 274 |
+
query = "What is glaucoma?"
|
| 275 |
+
results = mock_retriever.retrieve(query=query, top_k=10)
|
| 276 |
+
|
| 277 |
+
assert len(results) > 0
|
| 278 |
+
assert all(isinstance(r, RetrievalResult) for r in results)
|
| 279 |
+
|
| 280 |
+
# Verify result structure
|
| 281 |
+
for result in results:
|
| 282 |
+
assert result.id
|
| 283 |
+
assert result.content
|
| 284 |
+
assert result.document_title
|
| 285 |
+
assert result.source_url
|
| 286 |
+
assert 0 <= result.score <= 1
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def test_hybrid_search_combines_scores(mock_retriever):
|
| 290 |
+
"""Test that hybrid search returns combined scores."""
|
| 291 |
+
query = "glaucoma treatment"
|
| 292 |
+
results = mock_retriever.retrieve(query=query, top_k=5)
|
| 293 |
+
|
| 294 |
+
# Scores should be in descending order
|
| 295 |
+
scores = [r.score for r in results]
|
| 296 |
+
assert scores == sorted(scores, reverse=True)
|
| 297 |
+
|
| 298 |
+
# All scores should be valid
|
| 299 |
+
assert all(0 <= score <= 1 for score in scores)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def test_filters_work(mock_retriever):
|
| 303 |
+
"""Test that metadata filters work."""
|
| 304 |
+
# Add filter functionality to mock
|
| 305 |
+
def retrieve_with_filter(query: str, top_k: int, filters: dict = None):
|
| 306 |
+
results = mock_retriever.retrieve(query=query, top_k=top_k)
|
| 307 |
+
|
| 308 |
+
if filters:
|
| 309 |
+
# Simple filter implementation for testing
|
| 310 |
+
filtered = []
|
| 311 |
+
for r in results:
|
| 312 |
+
if "disease_name" in filters:
|
| 313 |
+
if filters["disease_name"] in r.document_title:
|
| 314 |
+
filtered.append(r)
|
| 315 |
+
else:
|
| 316 |
+
filtered.append(r)
|
| 317 |
+
return filtered
|
| 318 |
+
return results
|
| 319 |
+
|
| 320 |
+
mock_retriever.retrieve.side_effect = retrieve_with_filter
|
| 321 |
+
|
| 322 |
+
# Test with filter
|
| 323 |
+
results = mock_retriever.retrieve(
|
| 324 |
+
query="treatment",
|
| 325 |
+
top_k=10,
|
| 326 |
+
filters={"disease_name": "Glaucoma"}
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# All results should match filter
|
| 330 |
+
assert all("Glaucoma" in r.document_title for r in results)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# ============================================================================
|
| 334 |
+
# Reranker Tests
|
| 335 |
+
# ============================================================================
|
| 336 |
+
|
| 337 |
+
def test_reranking_changes_order(mock_reranker, sample_chunks):
|
| 338 |
+
"""Test that reranking changes result order."""
|
| 339 |
+
# Convert to RetrievalResults
|
| 340 |
+
results = [
|
| 341 |
+
RetrievalResult(
|
| 342 |
+
id=chunk.id,
|
| 343 |
+
content=chunk.content,
|
| 344 |
+
document_title=chunk.document_title,
|
| 345 |
+
source_url=chunk.source_url,
|
| 346 |
+
section=chunk.parent_section,
|
| 347 |
+
metadata=chunk.metadata,
|
| 348 |
+
score=0.5, # All same initial score
|
| 349 |
+
)
|
| 350 |
+
for chunk in sample_chunks
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
original_order = [r.id for r in results]
|
| 354 |
+
|
| 355 |
+
reranked = mock_reranker.rerank(
|
| 356 |
+
query="What is glaucoma?",
|
| 357 |
+
documents=results,
|
| 358 |
+
top_k=3,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
reranked_order = [r.id for r in reranked]
|
| 362 |
+
|
| 363 |
+
# Order should change (due to our mock reversing the order)
|
| 364 |
+
assert reranked_order != original_order
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def test_top_k_respected(mock_reranker, sample_chunks):
|
| 368 |
+
"""Test that reranker respects top_k parameter."""
|
| 369 |
+
results = [
|
| 370 |
+
RetrievalResult(
|
| 371 |
+
id=chunk.id,
|
| 372 |
+
content=chunk.content,
|
| 373 |
+
document_title=chunk.document_title,
|
| 374 |
+
source_url=chunk.source_url,
|
| 375 |
+
section=chunk.parent_section,
|
| 376 |
+
metadata=chunk.metadata,
|
| 377 |
+
score=0.5,
|
| 378 |
+
)
|
| 379 |
+
for chunk in sample_chunks
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
top_k = 2
|
| 383 |
+
reranked = mock_reranker.rerank(
|
| 384 |
+
query="treatment options",
|
| 385 |
+
documents=results,
|
| 386 |
+
top_k=top_k,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Should return exactly top_k results
|
| 390 |
+
assert len(reranked) == top_k
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# ============================================================================
|
| 394 |
+
# Query Engine Tests
|
| 395 |
+
# ============================================================================
|
| 396 |
+
|
| 397 |
+
def test_full_query_pipeline(query_engine):
|
| 398 |
+
"""Test the full query pipeline."""
|
| 399 |
+
query = "What is glaucoma?"
|
| 400 |
+
|
| 401 |
+
response = query_engine.query(
|
| 402 |
+
question=query,
|
| 403 |
+
include_sources=True,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Verify response structure
|
| 407 |
+
assert isinstance(response, QueryResponse)
|
| 408 |
+
assert response.answer
|
| 409 |
+
assert response.query == query
|
| 410 |
+
assert 0 <= response.confidence <= 1
|
| 411 |
+
assert response.disclaimer
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def test_sources_included(query_engine):
|
| 415 |
+
"""Test that sources are included in response."""
|
| 416 |
+
response = query_engine.query(
|
| 417 |
+
question="What is glaucoma?",
|
| 418 |
+
include_sources=True,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Should have sources
|
| 422 |
+
assert len(response.sources) > 0
|
| 423 |
+
|
| 424 |
+
# Verify source structure
|
| 425 |
+
for source in response.sources:
|
| 426 |
+
assert isinstance(source, SourceInfo)
|
| 427 |
+
assert source.title
|
| 428 |
+
assert source.url
|
| 429 |
+
assert 0 <= source.relevance_score <= 1
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def test_disclaimer_present(query_engine):
|
| 433 |
+
"""Test that medical disclaimer is present."""
|
| 434 |
+
response = query_engine.query(
|
| 435 |
+
question="How is glaucoma treated?",
|
| 436 |
+
include_sources=True,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Disclaimer should be present
|
| 440 |
+
assert response.disclaimer
|
| 441 |
+
assert len(response.disclaimer) > 0
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def test_query_without_sources(query_engine):
|
| 445 |
+
"""Test query with sources disabled."""
|
| 446 |
+
response = query_engine.query(
|
| 447 |
+
question="What is glaucoma?",
|
| 448 |
+
include_sources=False,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Should still have answer
|
| 452 |
+
assert response.answer
|
| 453 |
+
|
| 454 |
+
# Sources should be empty
|
| 455 |
+
assert len(response.sources) == 0
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def test_streaming_query(query_engine):
|
| 459 |
+
"""Test streaming query functionality."""
|
| 460 |
+
chunks = list(query_engine.stream_query(
|
| 461 |
+
question="What is glaucoma?",
|
| 462 |
+
))
|
| 463 |
+
|
| 464 |
+
# Should have received chunks
|
| 465 |
+
assert len(chunks) > 0
|
| 466 |
+
|
| 467 |
+
# All chunks should be strings
|
| 468 |
+
assert all(isinstance(chunk, str) for chunk in chunks)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def test_confidence_calculation(query_engine):
|
| 472 |
+
"""Test confidence score calculation."""
|
| 473 |
+
response = query_engine.query(
|
| 474 |
+
question="What is glaucoma?",
|
| 475 |
+
include_sources=True,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Confidence should be calculated
|
| 479 |
+
assert response.confidence is not None
|
| 480 |
+
assert 0 <= response.confidence <= 1
|
| 481 |
+
|
| 482 |
+
# With high-scoring retrieval results, confidence should be high
|
| 483 |
+
# (Our mock returns scores like 0.9, 0.8, 0.7)
|
| 484 |
+
assert response.confidence > 0.5
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def test_empty_retrieval_results(query_engine, mock_retriever):
|
| 488 |
+
"""Test handling of empty retrieval results."""
|
| 489 |
+
# Mock retriever to return empty list
|
| 490 |
+
mock_retriever.retrieve.return_value = []
|
| 491 |
+
|
| 492 |
+
response = query_engine.query(
|
| 493 |
+
question="What is xyzabc?", # Non-existent topic
|
| 494 |
+
include_sources=True,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Should still return a response
|
| 498 |
+
assert response.answer
|
| 499 |
+
assert "couldn't find" in response.answer.lower() or "no results" in response.answer.lower()
|
| 500 |
+
assert len(response.sources) == 0
|
| 501 |
+
assert response.confidence == 0.0
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
# ============================================================================
|
| 505 |
+
# API Tests
|
| 506 |
+
# ============================================================================
|
| 507 |
+
|
| 508 |
+
@pytest.fixture
|
| 509 |
+
def test_client():
|
| 510 |
+
"""Create a test client for FastAPI."""
|
| 511 |
+
from fastapi.testclient import TestClient
|
| 512 |
+
from src.api.main import app
|
| 513 |
+
|
| 514 |
+
return TestClient(app)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def test_health_endpoint(test_client):
|
| 518 |
+
"""Test the health check endpoint."""
|
| 519 |
+
response = test_client.get("/health")
|
| 520 |
+
|
| 521 |
+
# Should return 200 or 503 depending on initialization
|
| 522 |
+
assert response.status_code in [200, 503]
|
| 523 |
+
|
| 524 |
+
# Should have JSON response
|
| 525 |
+
data = response.json()
|
| 526 |
+
assert "status" in data
|
| 527 |
+
assert "timestamp" in data
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def test_root_endpoint(test_client):
|
| 531 |
+
"""Test the root endpoint."""
|
| 532 |
+
response = test_client.get("/")
|
| 533 |
+
|
| 534 |
+
assert response.status_code == 200
|
| 535 |
+
data = response.json()
|
| 536 |
+
|
| 537 |
+
assert "name" in data
|
| 538 |
+
assert "version" in data
|
| 539 |
+
assert "endpoints" in data
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def test_query_endpoint(test_client):
|
| 543 |
+
"""Test the query endpoint."""
|
| 544 |
+
# Note: This will likely fail if system is not fully initialized
|
| 545 |
+
# In real testing, you'd mock the app_state
|
| 546 |
+
|
| 547 |
+
response = test_client.post(
|
| 548 |
+
"/query",
|
| 549 |
+
json={
|
| 550 |
+
"question": "What is glaucoma?",
|
| 551 |
+
"include_sources": True,
|
| 552 |
+
}
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Should return 200 if initialized, 503 if not
|
| 556 |
+
assert response.status_code in [200, 503]
|
| 557 |
+
|
| 558 |
+
if response.status_code == 200:
|
| 559 |
+
data = response.json()
|
| 560 |
+
assert "answer" in data
|
| 561 |
+
assert "query" in data
|
| 562 |
+
assert "confidence" in data
|
| 563 |
+
assert "disclaimer" in data
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def test_query_endpoint_validation(test_client):
|
| 567 |
+
"""Test query endpoint input validation."""
|
| 568 |
+
# Test with invalid input
|
| 569 |
+
response = test_client.post(
|
| 570 |
+
"/query",
|
| 571 |
+
json={
|
| 572 |
+
"question": "", # Empty question
|
| 573 |
+
}
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Should return validation error
|
| 577 |
+
assert response.status_code == 422 # Unprocessable Entity
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def test_stats_endpoint(test_client):
|
| 581 |
+
"""Test the stats endpoint."""
|
| 582 |
+
response = test_client.get("/stats")
|
| 583 |
+
|
| 584 |
+
# Should return 200 if initialized, 503 if not
|
| 585 |
+
assert response.status_code in [200, 503, 404]
|
| 586 |
+
|
| 587 |
+
if response.status_code == 200:
|
| 588 |
+
data = response.json()
|
| 589 |
+
assert "collection_info" in data
|
| 590 |
+
assert "pipeline_config" in data
|
| 591 |
+
assert "documents_indexed" in data
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# ============================================================================
|
| 595 |
+
# Metadata Extractor Tests
|
| 596 |
+
# ============================================================================
|
| 597 |
+
|
| 598 |
+
def test_icd_code_extraction(metadata_extractor):
|
| 599 |
+
"""Test ICD-10 code extraction."""
|
| 600 |
+
text = "Patient diagnosed with H40.1 (Primary open-angle glaucoma) and E11.3 (Type 2 diabetes with ophthalmic complications)."
|
| 601 |
+
|
| 602 |
+
icd_codes = metadata_extractor.extract_icd_codes(text)
|
| 603 |
+
|
| 604 |
+
assert "H40.1" in icd_codes
|
| 605 |
+
assert "E11.3" in icd_codes
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def test_anatomical_term_extraction(metadata_extractor):
|
| 609 |
+
"""Test anatomical term extraction."""
|
| 610 |
+
text = "The optic nerve and retina are affected. The cornea appears normal."
|
| 611 |
+
|
| 612 |
+
terms = metadata_extractor.extract_anatomical_terms(text)
|
| 613 |
+
|
| 614 |
+
assert "optic nerve" in terms
|
| 615 |
+
assert "retina" in terms
|
| 616 |
+
assert "cornea" in terms
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def test_medication_extraction(metadata_extractor):
|
| 620 |
+
"""Test medication extraction."""
|
| 621 |
+
text = "Prescribed latanoprost and timolol for IOP reduction."
|
| 622 |
+
|
| 623 |
+
medications = metadata_extractor.extract_medications(text)
|
| 624 |
+
|
| 625 |
+
assert "latanoprost" in medications or "timolol" in medications
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def test_full_metadata_extraction(metadata_extractor):
|
| 629 |
+
"""Test full metadata extraction."""
|
| 630 |
+
text = """
|
| 631 |
+
Patient with H40.1 primary open-angle glaucoma affecting the optic nerve.
|
| 632 |
+
Prescribed latanoprost drops. Vision loss and eye pain reported.
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
metadata = metadata_extractor.extract(text, existing_metadata={})
|
| 636 |
+
|
| 637 |
+
# Should extract various metadata
|
| 638 |
+
assert "icd_codes" in metadata
|
| 639 |
+
assert "anatomical_terms" in metadata
|
| 640 |
+
assert "medications" in metadata
|
| 641 |
+
assert "symptoms" in metadata
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
# ============================================================================
|
| 645 |
+
# Integration Tests
|
| 646 |
+
# ============================================================================
|
| 647 |
+
|
| 648 |
+
def test_end_to_end_chunk_to_query():
|
| 649 |
+
"""Test end-to-end flow from chunking to query (with mocks)."""
|
| 650 |
+
# 1. Chunk document
|
| 651 |
+
chunker = SemanticChunker(chunk_size=200, chunk_overlap=20)
|
| 652 |
+
chunks = chunker.chunk_document(
|
| 653 |
+
markdown_content=SAMPLE_MARKDOWN,
|
| 654 |
+
metadata=SAMPLE_METADATA,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
assert len(chunks) > 0
|
| 658 |
+
|
| 659 |
+
# 2. Convert to retrieval results
|
| 660 |
+
results = [
|
| 661 |
+
RetrievalResult(
|
| 662 |
+
id=chunk.id,
|
| 663 |
+
content=chunk.content,
|
| 664 |
+
document_title=chunk.document_title,
|
| 665 |
+
source_url=chunk.source_url,
|
| 666 |
+
section=chunk.parent_section,
|
| 667 |
+
metadata=chunk.metadata,
|
| 668 |
+
score=0.8,
|
| 669 |
+
)
|
| 670 |
+
for chunk in chunks[:3]
|
| 671 |
+
]
|
| 672 |
+
|
| 673 |
+
# 3. Mock reranker
|
| 674 |
+
reranker = Mock(spec=CrossEncoderReranker)
|
| 675 |
+
reranker.rerank.return_value = results[:2]
|
| 676 |
+
|
| 677 |
+
# 4. Mock LLM
|
| 678 |
+
llm = Mock()
|
| 679 |
+
llm.generate.return_value = "Glaucoma is an eye disease."
|
| 680 |
+
|
| 681 |
+
# 5. Mock retriever
|
| 682 |
+
retriever = Mock(spec=HybridRetriever)
|
| 683 |
+
retriever.retrieve.return_value = results
|
| 684 |
+
|
| 685 |
+
# 6. Create query engine
|
| 686 |
+
engine = EyeWikiQueryEngine(
|
| 687 |
+
retriever=retriever,
|
| 688 |
+
reranker=reranker,
|
| 689 |
+
llm_client=llm,
|
| 690 |
+
max_context_tokens=4000,
|
| 691 |
+
retrieval_k=20,
|
| 692 |
+
rerank_k=5,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# 7. Query
|
| 696 |
+
response = engine.query("What is glaucoma?")
|
| 697 |
+
|
| 698 |
+
assert response.answer
|
| 699 |
+
assert response.confidence > 0
|
tests/test_questions.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": "q1",
|
| 4 |
+
"question": "What are the main symptoms of glaucoma?",
|
| 5 |
+
"expected_topics": [
|
| 6 |
+
"vision loss",
|
| 7 |
+
"peripheral vision",
|
| 8 |
+
"eye pressure",
|
| 9 |
+
"optic nerve damage",
|
| 10 |
+
"blind spots"
|
| 11 |
+
],
|
| 12 |
+
"expected_sources": [
|
| 13 |
+
"Glaucoma",
|
| 14 |
+
"Primary Open-Angle Glaucoma"
|
| 15 |
+
],
|
| 16 |
+
"category": "symptoms"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"id": "q2",
|
| 20 |
+
"question": "How is diabetic retinopathy treated?",
|
| 21 |
+
"expected_topics": [
|
| 22 |
+
"laser treatment",
|
| 23 |
+
"anti-VEGF",
|
| 24 |
+
"photocoagulation",
|
| 25 |
+
"vitrectomy",
|
| 26 |
+
"blood sugar control"
|
| 27 |
+
],
|
| 28 |
+
"expected_sources": [
|
| 29 |
+
"Diabetic Retinopathy",
|
| 30 |
+
"Proliferative Diabetic Retinopathy"
|
| 31 |
+
],
|
| 32 |
+
"category": "treatment"
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"id": "q3",
|
| 36 |
+
"question": "What causes age-related macular degeneration?",
|
| 37 |
+
"expected_topics": [
|
| 38 |
+
"aging",
|
| 39 |
+
"macula",
|
| 40 |
+
"drusen",
|
| 41 |
+
"photoreceptor",
|
| 42 |
+
"central vision"
|
| 43 |
+
],
|
| 44 |
+
"expected_sources": [
|
| 45 |
+
"Age-Related Macular Degeneration",
|
| 46 |
+
"AMD",
|
| 47 |
+
"Macular Degeneration"
|
| 48 |
+
],
|
| 49 |
+
"category": "etiology"
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": "q4",
|
| 53 |
+
"question": "What is the difference between open-angle and angle-closure glaucoma?",
|
| 54 |
+
"expected_topics": [
|
| 55 |
+
"drainage angle",
|
| 56 |
+
"trabecular meshwork",
|
| 57 |
+
"acute",
|
| 58 |
+
"chronic",
|
| 59 |
+
"iridotomy"
|
| 60 |
+
],
|
| 61 |
+
"expected_sources": [
|
| 62 |
+
"Glaucoma",
|
| 63 |
+
"Primary Open-Angle Glaucoma",
|
| 64 |
+
"Angle-Closure Glaucoma"
|
| 65 |
+
],
|
| 66 |
+
"category": "classification"
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"id": "q5",
|
| 70 |
+
"question": "What are the risk factors for cataracts?",
|
| 71 |
+
"expected_topics": [
|
| 72 |
+
"age",
|
| 73 |
+
"diabetes",
|
| 74 |
+
"UV exposure",
|
| 75 |
+
"smoking",
|
| 76 |
+
"steroid"
|
| 77 |
+
],
|
| 78 |
+
"expected_sources": [
|
| 79 |
+
"Cataract",
|
| 80 |
+
"Age-Related Cataract"
|
| 81 |
+
],
|
| 82 |
+
"category": "risk_factors"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"id": "q6",
|
| 86 |
+
"question": "How is retinal detachment diagnosed?",
|
| 87 |
+
"expected_topics": [
|
| 88 |
+
"dilated eye exam",
|
| 89 |
+
"ophthalmoscopy",
|
| 90 |
+
"ultrasound",
|
| 91 |
+
"floaters",
|
| 92 |
+
"flashes"
|
| 93 |
+
],
|
| 94 |
+
"expected_sources": [
|
| 95 |
+
"Retinal Detachment",
|
| 96 |
+
"Rhegmatogenous Retinal Detachment"
|
| 97 |
+
],
|
| 98 |
+
"category": "diagnosis"
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"id": "q7",
|
| 102 |
+
"question": "What medications are used to lower intraocular pressure?",
|
| 103 |
+
"expected_topics": [
|
| 104 |
+
"prostaglandin analogs",
|
| 105 |
+
"beta-blockers",
|
| 106 |
+
"alpha agonists",
|
| 107 |
+
"carbonic anhydrase inhibitors",
|
| 108 |
+
"latanoprost",
|
| 109 |
+
"timolol"
|
| 110 |
+
],
|
| 111 |
+
"expected_sources": [
|
| 112 |
+
"Glaucoma",
|
| 113 |
+
"Medical Therapy for Glaucoma"
|
| 114 |
+
],
|
| 115 |
+
"category": "pharmacology"
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"id": "q8",
|
| 119 |
+
"question": "What is keratoconus and how is it managed?",
|
| 120 |
+
"expected_topics": [
|
| 121 |
+
"cornea",
|
| 122 |
+
"thinning",
|
| 123 |
+
"cone-shaped",
|
| 124 |
+
"corneal crosslinking",
|
| 125 |
+
"contact lenses"
|
| 126 |
+
],
|
| 127 |
+
"expected_sources": [
|
| 128 |
+
"Keratoconus"
|
| 129 |
+
],
|
| 130 |
+
"category": "corneal_disease"
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"id": "q9",
|
| 134 |
+
"question": "What are the complications of cataract surgery?",
|
| 135 |
+
"expected_topics": [
|
| 136 |
+
"posterior capsule opacification",
|
| 137 |
+
"infection",
|
| 138 |
+
"endophthalmitis",
|
| 139 |
+
"cystoid macular edema",
|
| 140 |
+
"retinal detachment"
|
| 141 |
+
],
|
| 142 |
+
"expected_sources": [
|
| 143 |
+
"Cataract Surgery",
|
| 144 |
+
"Phacoemulsification"
|
| 145 |
+
],
|
| 146 |
+
"category": "complications"
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"id": "q10",
|
| 150 |
+
"question": "How does dry eye syndrome present?",
|
| 151 |
+
"expected_topics": [
|
| 152 |
+
"burning",
|
| 153 |
+
"irritation",
|
| 154 |
+
"tear film",
|
| 155 |
+
"meibomian gland",
|
| 156 |
+
"artificial tears"
|
| 157 |
+
],
|
| 158 |
+
"expected_sources": [
|
| 159 |
+
"Dry Eye",
|
| 160 |
+
"Dry Eye Syndrome"
|
| 161 |
+
],
|
| 162 |
+
"category": "symptoms"
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"id": "q11",
|
| 166 |
+
"question": "What is the pathophysiology of uveitis?",
|
| 167 |
+
"expected_topics": [
|
| 168 |
+
"inflammation",
|
| 169 |
+
"uvea",
|
| 170 |
+
"anterior",
|
| 171 |
+
"posterior",
|
| 172 |
+
"immune-mediated"
|
| 173 |
+
],
|
| 174 |
+
"expected_sources": [
|
| 175 |
+
"Uveitis",
|
| 176 |
+
"Anterior Uveitis"
|
| 177 |
+
],
|
| 178 |
+
"category": "pathophysiology"
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"id": "q12",
|
| 182 |
+
"question": "What imaging modalities are used for macular disease?",
|
| 183 |
+
"expected_topics": [
|
| 184 |
+
"OCT",
|
| 185 |
+
"optical coherence tomography",
|
| 186 |
+
"fluorescein angiography",
|
| 187 |
+
"fundus photography",
|
| 188 |
+
"angiography"
|
| 189 |
+
],
|
| 190 |
+
"expected_sources": [
|
| 191 |
+
"Macular Degeneration",
|
| 192 |
+
"OCT",
|
| 193 |
+
"Optical Coherence Tomography"
|
| 194 |
+
],
|
| 195 |
+
"category": "imaging"
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"id": "q13",
|
| 199 |
+
"question": "What is optic neuritis and what are its causes?",
|
| 200 |
+
"expected_topics": [
|
| 201 |
+
"optic nerve inflammation",
|
| 202 |
+
"vision loss",
|
| 203 |
+
"pain with eye movement",
|
| 204 |
+
"multiple sclerosis",
|
| 205 |
+
"demyelination"
|
| 206 |
+
],
|
| 207 |
+
"expected_sources": [
|
| 208 |
+
"Optic Neuritis"
|
| 209 |
+
],
|
| 210 |
+
"category": "neuro_ophthalmology"
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"id": "q14",
|
| 214 |
+
"question": "How is proliferative diabetic retinopathy different from non-proliferative?",
|
| 215 |
+
"expected_topics": [
|
| 216 |
+
"neovascularization",
|
| 217 |
+
"microaneurysms",
|
| 218 |
+
"hemorrhages",
|
| 219 |
+
"vitreous hemorrhage",
|
| 220 |
+
"retinal ischemia"
|
| 221 |
+
],
|
| 222 |
+
"expected_sources": [
|
| 223 |
+
"Diabetic Retinopathy",
|
| 224 |
+
"Proliferative Diabetic Retinopathy",
|
| 225 |
+
"Non-Proliferative Diabetic Retinopathy"
|
| 226 |
+
],
|
| 227 |
+
"category": "classification"
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"id": "q15",
|
| 231 |
+
"question": "What are the signs of papilledema?",
|
| 232 |
+
"expected_topics": [
|
| 233 |
+
"optic disc swelling",
|
| 234 |
+
"increased intracranial pressure",
|
| 235 |
+
"headache",
|
| 236 |
+
"blurred vision",
|
| 237 |
+
"nausea"
|
| 238 |
+
],
|
| 239 |
+
"expected_sources": [
|
| 240 |
+
"Papilledema",
|
| 241 |
+
"Optic Disc Edema"
|
| 242 |
+
],
|
| 243 |
+
"category": "diagnosis"
|
| 244 |
+
}
|
| 245 |
+
]
|